Fix beam NER resizing (#6834)

* move label check to sub methods

* add tests
This commit is contained in:
Sofie Van Landeghem 2021-01-27 13:39:14 +01:00 committed by GitHub
parent 5ed51c9dd2
commit 6b68ad027b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 26 additions and 3 deletions

View File

@ -205,7 +205,6 @@ cdef class Parser(TrainablePipe):
def predict(self, docs):
if isinstance(docs, Doc):
docs = [docs]
self._ensure_labels_are_added(docs)
if not any(len(doc) for doc in docs):
result = self.moves.init_batch(docs)
return result
@ -222,6 +221,7 @@ cdef class Parser(TrainablePipe):
def greedy_parse(self, docs, drop=0.):
cdef vector[StateC*] states
cdef StateClass state
self._ensure_labels_are_added(docs)
set_dropout_rate(self.model, drop)
batch = self.moves.init_batch(docs)
model = self.model.predict(docs)
@ -240,6 +240,7 @@ cdef class Parser(TrainablePipe):
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
cdef Beam beam
cdef Doc doc
self._ensure_labels_are_added(docs)
batch = _beam_utils.BeamBatch(
self.moves,
self.moves.init_batch(docs),

View File

@ -138,6 +138,28 @@ def test_ner_labels_added_implicitly_on_predict():
assert "D" in ner.labels
def test_ner_labels_added_implicitly_on_beam_parse():
nlp = Language()
ner = nlp.add_pipe("beam_ner")
for label in ["A", "B", "C"]:
ner.add_label(label)
nlp.initialize()
doc = Doc(nlp.vocab, words=["hello", "world"], ents=["B-D", "O"])
ner.beam_parse([doc], beam_width=32)
assert "D" in ner.labels
def test_ner_labels_added_implicitly_on_greedy_parse():
nlp = Language()
ner = nlp.add_pipe("beam_ner")
for label in ["A", "B", "C"]:
ner.add_label(label)
nlp.initialize()
doc = Doc(nlp.vocab, words=["hello", "world"], ents=["B-D", "O"])
ner.greedy_parse([doc])
assert "D" in ner.labels
def test_ner_labels_added_implicitly_on_update():
nlp = Language()
ner = nlp.add_pipe("ner")

View File

@ -303,14 +303,14 @@ def test_issue4313():
doc = nlp("What do you think about Apple ?")
assert len(ner.labels) == 1
assert "SOME_LABEL" in ner.labels
ner.add_label("MY_ORG") # TODO: not sure if we want this to be necessary...
apple_ent = Span(doc, 5, 6, label="MY_ORG")
doc.ents = list(doc.ents) + [apple_ent]
# ensure the beam_parse still works with the new label
docs = [doc]
ner = nlp.get_pipe("beam_ner")
ner.beam_parse(docs, drop=0.0, beam_width=beam_width, beam_density=beam_density)
assert len(ner.labels) == 2
assert "MY_ORG" in ner.labels
def test_issue4348():