mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix bug when too many entity types. Fixes #2800
This commit is contained in:
parent
8809dc4514
commit
96fe314d8d
|
@ -747,7 +747,8 @@ cdef class Parser:
|
|||
|
||||
def transition_batch(self, states, float[:, ::1] scores):
|
||||
cdef StateClass state
|
||||
cdef int[500] is_valid # TODO: Unhack
|
||||
cdef Pool mem = Pool()
|
||||
is_valid = <int*>mem.alloc(self.moves.n_moves, sizeof(int))
|
||||
cdef float* c_scores = &scores[0, 0]
|
||||
for state in states:
|
||||
self.moves.set_valid(is_valid, state.c)
|
||||
|
|
34
spacy/tests/regression/test_issue2800.py
Normal file
34
spacy/tests/regression/test_issue2800.py
Normal file
|
@ -0,0 +1,34 @@
|
|||
'''Test issue that arises when too many labels are added to NER model.'''
|
||||
import random
|
||||
from ...lang.en import English
|
||||
|
||||
def train_model(train_data, entity_types):
|
||||
nlp = English(pipeline=[])
|
||||
|
||||
ner = nlp.create_pipe("ner")
|
||||
nlp.add_pipe(ner)
|
||||
|
||||
for entity_type in list(entity_types):
|
||||
ner.add_label(entity_type)
|
||||
|
||||
optimizer = nlp.begin_training()
|
||||
|
||||
# Start training
|
||||
for i in range(20):
|
||||
losses = {}
|
||||
index = 0
|
||||
random.shuffle(train_data)
|
||||
|
||||
for statement, entities in train_data:
|
||||
nlp.update([statement], [entities], sgd=optimizer, losses=losses, drop=0.5)
|
||||
return nlp
|
||||
|
||||
|
||||
def test_train_with_many_entity_types():
|
||||
train_data = []
|
||||
train_data.extend([("One sentence", {"entities": []})])
|
||||
entity_types = [str(i) for i in range(1000)]
|
||||
|
||||
model = train_model(train_data, entity_types)
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user