mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Fix beam NER resizing (#6834)
* move label check to sub methods * add tests
This commit is contained in:
parent
5ed51c9dd2
commit
6b68ad027b
|
@ -205,7 +205,6 @@ cdef class Parser(TrainablePipe):
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
self._ensure_labels_are_added(docs)
|
|
||||||
if not any(len(doc) for doc in docs):
|
if not any(len(doc) for doc in docs):
|
||||||
result = self.moves.init_batch(docs)
|
result = self.moves.init_batch(docs)
|
||||||
return result
|
return result
|
||||||
|
@ -222,6 +221,7 @@ cdef class Parser(TrainablePipe):
|
||||||
def greedy_parse(self, docs, drop=0.):
|
def greedy_parse(self, docs, drop=0.):
|
||||||
cdef vector[StateC*] states
|
cdef vector[StateC*] states
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
|
self._ensure_labels_are_added(docs)
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
batch = self.moves.init_batch(docs)
|
batch = self.moves.init_batch(docs)
|
||||||
model = self.model.predict(docs)
|
model = self.model.predict(docs)
|
||||||
|
@ -240,6 +240,7 @@ cdef class Parser(TrainablePipe):
|
||||||
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
def beam_parse(self, docs, int beam_width, float drop=0., beam_density=0.):
|
||||||
cdef Beam beam
|
cdef Beam beam
|
||||||
cdef Doc doc
|
cdef Doc doc
|
||||||
|
self._ensure_labels_are_added(docs)
|
||||||
batch = _beam_utils.BeamBatch(
|
batch = _beam_utils.BeamBatch(
|
||||||
self.moves,
|
self.moves,
|
||||||
self.moves.init_batch(docs),
|
self.moves.init_batch(docs),
|
||||||
|
|
|
@ -138,6 +138,28 @@ def test_ner_labels_added_implicitly_on_predict():
|
||||||
assert "D" in ner.labels
|
assert "D" in ner.labels
|
||||||
|
|
||||||
|
|
||||||
|
def test_ner_labels_added_implicitly_on_beam_parse():
|
||||||
|
nlp = Language()
|
||||||
|
ner = nlp.add_pipe("beam_ner")
|
||||||
|
for label in ["A", "B", "C"]:
|
||||||
|
ner.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
doc = Doc(nlp.vocab, words=["hello", "world"], ents=["B-D", "O"])
|
||||||
|
ner.beam_parse([doc], beam_width=32)
|
||||||
|
assert "D" in ner.labels
|
||||||
|
|
||||||
|
|
||||||
|
def test_ner_labels_added_implicitly_on_greedy_parse():
|
||||||
|
nlp = Language()
|
||||||
|
ner = nlp.add_pipe("beam_ner")
|
||||||
|
for label in ["A", "B", "C"]:
|
||||||
|
ner.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
doc = Doc(nlp.vocab, words=["hello", "world"], ents=["B-D", "O"])
|
||||||
|
ner.greedy_parse([doc])
|
||||||
|
assert "D" in ner.labels
|
||||||
|
|
||||||
|
|
||||||
def test_ner_labels_added_implicitly_on_update():
|
def test_ner_labels_added_implicitly_on_update():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
ner = nlp.add_pipe("ner")
|
ner = nlp.add_pipe("ner")
|
||||||
|
|
|
@ -303,14 +303,14 @@ def test_issue4313():
|
||||||
doc = nlp("What do you think about Apple ?")
|
doc = nlp("What do you think about Apple ?")
|
||||||
assert len(ner.labels) == 1
|
assert len(ner.labels) == 1
|
||||||
assert "SOME_LABEL" in ner.labels
|
assert "SOME_LABEL" in ner.labels
|
||||||
ner.add_label("MY_ORG") # TODO: not sure if we want this to be necessary...
|
|
||||||
apple_ent = Span(doc, 5, 6, label="MY_ORG")
|
apple_ent = Span(doc, 5, 6, label="MY_ORG")
|
||||||
doc.ents = list(doc.ents) + [apple_ent]
|
doc.ents = list(doc.ents) + [apple_ent]
|
||||||
|
|
||||||
# ensure the beam_parse still works with the new label
|
# ensure the beam_parse still works with the new label
|
||||||
docs = [doc]
|
docs = [doc]
|
||||||
ner = nlp.get_pipe("beam_ner")
|
|
||||||
ner.beam_parse(docs, drop=0.0, beam_width=beam_width, beam_density=beam_density)
|
ner.beam_parse(docs, drop=0.0, beam_width=beam_width, beam_density=beam_density)
|
||||||
|
assert len(ner.labels) == 2
|
||||||
|
assert "MY_ORG" in ner.labels
|
||||||
|
|
||||||
|
|
||||||
def test_issue4348():
|
def test_issue4348():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user