diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index ad92e41b2..dc96ae9e4 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -205,7 +205,6 @@ cdef class Parser(TrainablePipe): def predict(self, docs): if isinstance(docs, Doc): docs = [docs] - self._ensure_labels_are_added(docs) if not any(len(doc) for doc in docs): result = self.moves.init_batch(docs) return result @@ -222,6 +221,7 @@ cdef class Parser(TrainablePipe): def greedy_parse(self, docs, drop=0.): cdef vector[StateC*] states cdef StateClass state + self._ensure_labels_are_added(docs) set_dropout_rate(self.model, drop) batch = self.moves.init_batch(docs) model = self.model.predict(docs) @@ -240,6 +240,7 @@ cdef class Parser(TrainablePipe): def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.): cdef Beam beam cdef Doc doc + self._ensure_labels_are_added(docs) batch = _beam_utils.BeamBatch( self.moves, self.moves.init_batch(docs), diff --git a/spacy/tests/parser/test_add_label.py b/spacy/tests/parser/test_add_label.py index 7c96f654b..e955a12a8 100644 --- a/spacy/tests/parser/test_add_label.py +++ b/spacy/tests/parser/test_add_label.py @@ -138,6 +138,28 @@ def test_ner_labels_added_implicitly_on_predict(): assert "D" in ner.labels +def test_ner_labels_added_implicitly_on_beam_parse(): + nlp = Language() + ner = nlp.add_pipe("beam_ner") + for label in ["A", "B", "C"]: + ner.add_label(label) + nlp.initialize() + doc = Doc(nlp.vocab, words=["hello", "world"], ents=["B-D", "O"]) + ner.beam_parse([doc], beam_width=32) + assert "D" in ner.labels + + +def test_ner_labels_added_implicitly_on_greedy_parse(): + nlp = Language() + ner = nlp.add_pipe("beam_ner") + for label in ["A", "B", "C"]: + ner.add_label(label) + nlp.initialize() + doc = Doc(nlp.vocab, words=["hello", "world"], ents=["B-D", "O"]) + ner.greedy_parse([doc]) + assert "D" in ner.labels + + def test_ner_labels_added_implicitly_on_update(): nlp = Language() ner = nlp.add_pipe("ner") diff --git a/spacy/tests/regression/test_issue4001-4500.py b/spacy/tests/regression/test_issue4001-4500.py index 25982623f..a4c15dac2 100644 --- a/spacy/tests/regression/test_issue4001-4500.py +++ b/spacy/tests/regression/test_issue4001-4500.py @@ -303,14 +303,14 @@ def test_issue4313(): doc = nlp("What do you think about Apple ?") assert len(ner.labels) == 1 assert "SOME_LABEL" in ner.labels - ner.add_label("MY_ORG") # TODO: not sure if we want this to be necessary... apple_ent = Span(doc, 5, 6, label="MY_ORG") doc.ents = list(doc.ents) + [apple_ent] # ensure the beam_parse still works with the new label docs = [doc] - ner = nlp.get_pipe("beam_ner") ner.beam_parse(docs, drop=0.0, beam_width=beam_width, beam_density=beam_density) + assert len(ner.labels) == 2 + assert "MY_ORG" in ner.labels def test_issue4348():