mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Handle Docs with no entities in EntityLinker (#11640)
* Handle docs with no entities If a whole batch contains no entities it won't make it to the model, but it's possible for individual Docs to have no entities. Before this commit, those Docs would cause an error when attempting to concatenate arrays because the dimensions didn't match. It turns out the process of preparing the Ragged at the end of the span maker forward was a little different from list2ragged, which just uses the flatten function directly. Letting list2ragged do the conversion avoids the dimension issue. This did not come up before because in NEL demo projects it's typical for data with no entities to be discarded before it reaches the NEL component. This includes a simple direct test that shows the issue and checks it's resolved. It doesn't check if there are any downstream changes, so a more complete test could be added. A full run was tested by adding an example with no entities to the Emerson sample project. * Add a blank instance to default training data in tests Rather than adding a specific test, since not failing on instances with no entities is basic functionality, it makes sense to add it to the default set. * Fix without modifying architecture If the architecture is modified this would have to be a new version, but this change isn't big enough to merit that.
This commit is contained in:
parent
6b78135b9e
commit
d61e742960
|
@ -71,11 +71,10 @@ def span_maker_forward(model, docs: List[Doc], is_train) -> Tuple[Ragged, Callab
|
||||||
cands.append((start_token, end_token))
|
cands.append((start_token, end_token))
|
||||||
|
|
||||||
candidates.append(ops.asarray2i(cands))
|
candidates.append(ops.asarray2i(cands))
|
||||||
candlens = ops.asarray1i([len(cands) for cands in candidates])
|
lengths = model.ops.asarray1i([len(cands) for cands in candidates])
|
||||||
candidates = ops.xp.concatenate(candidates)
|
out = Ragged(model.ops.flatten(candidates), lengths)
|
||||||
outputs = Ragged(candidates, candlens)
|
|
||||||
# because this is just rearranging docs, the backprop does nothing
|
# because this is just rearranging docs, the backprop does nothing
|
||||||
return outputs, lambda x: []
|
return out, lambda x: []
|
||||||
|
|
||||||
|
|
||||||
@registry.misc("spacy.KBFromFile.v1")
|
@registry.misc("spacy.KBFromFile.v1")
|
||||||
|
|
|
@ -9,6 +9,7 @@ from spacy.compat import pickle
|
||||||
from spacy.kb import Candidate, InMemoryLookupKB, get_candidates, KnowledgeBase
|
from spacy.kb import Candidate, InMemoryLookupKB, get_candidates, KnowledgeBase
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.ml import load_kb
|
from spacy.ml import load_kb
|
||||||
|
from spacy.ml.models.entity_linker import build_span_maker
|
||||||
from spacy.pipeline import EntityLinker
|
from spacy.pipeline import EntityLinker
|
||||||
from spacy.pipeline.legacy import EntityLinker_v1
|
from spacy.pipeline.legacy import EntityLinker_v1
|
||||||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||||
|
@ -715,7 +716,11 @@ TRAIN_DATA = [
|
||||||
("Russ Cochran was a member of University of Kentucky's golf team.",
|
("Russ Cochran was a member of University of Kentucky's golf team.",
|
||||||
{"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}},
|
{"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}},
|
||||||
"entities": [(0, 12, "PERSON"), (43, 51, "LOC")],
|
"entities": [(0, 12, "PERSON"), (43, 51, "LOC")],
|
||||||
"sent_starts": [1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]})
|
"sent_starts": [1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]}),
|
||||||
|
# having a blank instance shouldn't break things
|
||||||
|
("The weather is nice today.",
|
||||||
|
{"links": {}, "entities": [],
|
||||||
|
"sent_starts": [1, -1, 0, 0, 0, 0]})
|
||||||
]
|
]
|
||||||
GOLD_entities = ["Q2146908", "Q7381115", "Q7381115", "Q2146908"]
|
GOLD_entities = ["Q2146908", "Q7381115", "Q7381115", "Q2146908"]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
@ -1196,3 +1201,18 @@ def test_threshold(meet_threshold: bool, config: Dict[str, Any]):
|
||||||
|
|
||||||
assert len(doc.ents) == 1
|
assert len(doc.ents) == 1
|
||||||
assert doc.ents[0].kb_id_ == entity_id if meet_threshold else EntityLinker.NIL
|
assert doc.ents[0].kb_id_ == entity_id if meet_threshold else EntityLinker.NIL
|
||||||
|
|
||||||
|
|
||||||
|
def test_span_maker_forward_with_empty():
|
||||||
|
"""The forward pass of the span maker may have a doc with no entities."""
|
||||||
|
nlp = English()
|
||||||
|
doc1 = nlp("a b c")
|
||||||
|
ent = doc1[0:1]
|
||||||
|
ent.label_ = "X"
|
||||||
|
doc1.ents = [ent]
|
||||||
|
# no entities
|
||||||
|
doc2 = nlp("x y z")
|
||||||
|
|
||||||
|
# just to get a model
|
||||||
|
span_maker = build_span_maker()
|
||||||
|
span_maker([doc1, doc2], False)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user