mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Bugfix/nel crossing sentence (#7630)
* ensure each entity gets a KB ID, even when it's not within a sentence * cleanup
This commit is contained in:
parent
673e2bc4c0
commit
27dbbb9903
|
@ -300,77 +300,77 @@ class EntityLinker(TrainablePipe):
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
sentences = [s for s in doc.sents]
|
sentences = [s for s in doc.sents]
|
||||||
if len(doc) > 0:
|
if len(doc) > 0:
|
||||||
# Looping through each sentence and each entity
|
# Looping through each entity (TODO: rewrite)
|
||||||
# This may go wrong if there are entities across sentences - which shouldn't happen normally.
|
for ent in doc.ents:
|
||||||
for sent_index, sent in enumerate(sentences):
|
sent = ent.sent
|
||||||
if sent.ents:
|
sent_index = sentences.index(sent)
|
||||||
# get n_neightbour sentences, clipped to the length of the document
|
assert sent_index >= 0
|
||||||
start_sentence = max(0, sent_index - self.n_sents)
|
# get n_neightbour sentences, clipped to the length of the document
|
||||||
end_sentence = min(
|
start_sentence = max(0, sent_index - self.n_sents)
|
||||||
len(sentences) - 1, sent_index + self.n_sents
|
end_sentence = min(
|
||||||
)
|
len(sentences) - 1, sent_index + self.n_sents
|
||||||
start_token = sentences[start_sentence].start
|
)
|
||||||
end_token = sentences[end_sentence].end
|
start_token = sentences[start_sentence].start
|
||||||
sent_doc = doc[start_token:end_token].as_doc()
|
end_token = sentences[end_sentence].end
|
||||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
sent_doc = doc[start_token:end_token].as_doc()
|
||||||
xp = self.model.ops.xp
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||||
if self.incl_context:
|
xp = self.model.ops.xp
|
||||||
sentence_encoding = self.model.predict([sent_doc])[0]
|
if self.incl_context:
|
||||||
sentence_encoding_t = sentence_encoding.T
|
sentence_encoding = self.model.predict([sent_doc])[0]
|
||||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
sentence_encoding_t = sentence_encoding.T
|
||||||
for ent in sent.ents:
|
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||||
entity_count += 1
|
entity_count += 1
|
||||||
if ent.label_ in self.labels_discard:
|
if ent.label_ in self.labels_discard:
|
||||||
# ignoring this entity - setting to NIL
|
# ignoring this entity - setting to NIL
|
||||||
final_kb_ids.append(self.NIL)
|
final_kb_ids.append(self.NIL)
|
||||||
else:
|
else:
|
||||||
candidates = self.get_candidates(self.kb, ent)
|
candidates = self.get_candidates(self.kb, ent)
|
||||||
if not candidates:
|
if not candidates:
|
||||||
# no prediction possible for this entity - setting to NIL
|
# no prediction possible for this entity - setting to NIL
|
||||||
final_kb_ids.append(self.NIL)
|
final_kb_ids.append(self.NIL)
|
||||||
elif len(candidates) == 1:
|
elif len(candidates) == 1:
|
||||||
# shortcut for efficiency reasons: take the 1 candidate
|
# shortcut for efficiency reasons: take the 1 candidate
|
||||||
# TODO: thresholding
|
# TODO: thresholding
|
||||||
final_kb_ids.append(candidates[0].entity_)
|
final_kb_ids.append(candidates[0].entity_)
|
||||||
else:
|
else:
|
||||||
random.shuffle(candidates)
|
random.shuffle(candidates)
|
||||||
# set all prior probabilities to 0 if incl_prior=False
|
# set all prior probabilities to 0 if incl_prior=False
|
||||||
prior_probs = xp.asarray(
|
prior_probs = xp.asarray(
|
||||||
[c.prior_prob for c in candidates]
|
[c.prior_prob for c in candidates]
|
||||||
|
)
|
||||||
|
if not self.incl_prior:
|
||||||
|
prior_probs = xp.asarray(
|
||||||
|
[0.0 for _ in candidates]
|
||||||
|
)
|
||||||
|
scores = prior_probs
|
||||||
|
# add in similarity from the context
|
||||||
|
if self.incl_context:
|
||||||
|
entity_encodings = xp.asarray(
|
||||||
|
[c.entity_vector for c in candidates]
|
||||||
|
)
|
||||||
|
entity_norm = xp.linalg.norm(
|
||||||
|
entity_encodings, axis=1
|
||||||
|
)
|
||||||
|
if len(entity_encodings) != len(prior_probs):
|
||||||
|
raise RuntimeError(
|
||||||
|
Errors.E147.format(
|
||||||
|
method="predict",
|
||||||
|
msg="vectors not of equal length",
|
||||||
|
)
|
||||||
)
|
)
|
||||||
if not self.incl_prior:
|
# cosine similarity
|
||||||
prior_probs = xp.asarray(
|
sims = xp.dot(
|
||||||
[0.0 for _ in candidates]
|
entity_encodings, sentence_encoding_t
|
||||||
)
|
) / (sentence_norm * entity_norm)
|
||||||
scores = prior_probs
|
if sims.shape != prior_probs.shape:
|
||||||
# add in similarity from the context
|
raise ValueError(Errors.E161)
|
||||||
if self.incl_context:
|
scores = (
|
||||||
entity_encodings = xp.asarray(
|
prior_probs + sims - (prior_probs * sims)
|
||||||
[c.entity_vector for c in candidates]
|
)
|
||||||
)
|
# TODO: thresholding
|
||||||
entity_norm = xp.linalg.norm(
|
best_index = scores.argmax().item()
|
||||||
entity_encodings, axis=1
|
best_candidate = candidates[best_index]
|
||||||
)
|
final_kb_ids.append(best_candidate.entity_)
|
||||||
if len(entity_encodings) != len(prior_probs):
|
|
||||||
raise RuntimeError(
|
|
||||||
Errors.E147.format(
|
|
||||||
method="predict",
|
|
||||||
msg="vectors not of equal length",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
# cosine similarity
|
|
||||||
sims = xp.dot(
|
|
||||||
entity_encodings, sentence_encoding_t
|
|
||||||
) / (sentence_norm * entity_norm)
|
|
||||||
if sims.shape != prior_probs.shape:
|
|
||||||
raise ValueError(Errors.E161)
|
|
||||||
scores = (
|
|
||||||
prior_probs + sims - (prior_probs * sims)
|
|
||||||
)
|
|
||||||
# TODO: thresholding
|
|
||||||
best_index = scores.argmax().item()
|
|
||||||
best_candidate = candidates[best_index]
|
|
||||||
final_kb_ids.append(best_candidate.entity_)
|
|
||||||
if not (len(final_kb_ids) == entity_count):
|
if not (len(final_kb_ids) == entity_count):
|
||||||
err = Errors.E147.format(
|
err = Errors.E147.format(
|
||||||
method="predict", msg="result variables not of equal length"
|
method="predict", msg="result variables not of equal length"
|
||||||
|
|
|
@ -1,4 +1,6 @@
|
||||||
|
from spacy.kb import KnowledgeBase
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
|
from spacy.training import Example
|
||||||
|
|
||||||
|
|
||||||
def test_issue7065():
|
def test_issue7065():
|
||||||
|
@ -16,3 +18,58 @@ def test_issue7065():
|
||||||
ent = doc.ents[0]
|
ent = doc.ents[0]
|
||||||
assert ent.start < sent0.end < ent.end
|
assert ent.start < sent0.end < ent.end
|
||||||
assert sentences.index(ent.sent) == 0
|
assert sentences.index(ent.sent) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue7065_b():
|
||||||
|
# Test that the NEL doesn't crash when an entity crosses a sentence boundary
|
||||||
|
nlp = English()
|
||||||
|
vector_length = 3
|
||||||
|
nlp.add_pipe("sentencizer")
|
||||||
|
|
||||||
|
text = "Mahler 's Symphony No. 8 was beautiful."
|
||||||
|
entities = [(0, 6, "PERSON"), (10, 24, "WORK")]
|
||||||
|
links = {(0, 6): {"Q7304": 1.0, "Q270853": 0.0},
|
||||||
|
(10, 24): {"Q7304": 0.0, "Q270853": 1.0}}
|
||||||
|
sent_starts = [1, -1, 0, 0, 0, 0, 0, 0, 0]
|
||||||
|
doc = nlp(text)
|
||||||
|
example = Example.from_dict(doc, {"entities": entities, "links": links, "sent_starts": sent_starts})
|
||||||
|
train_examples = [example]
|
||||||
|
|
||||||
|
def create_kb(vocab):
|
||||||
|
# create artificial KB
|
||||||
|
mykb = KnowledgeBase(vocab, entity_vector_length=vector_length)
|
||||||
|
mykb.add_entity(entity="Q270853", freq=12, entity_vector=[9, 1, -7])
|
||||||
|
mykb.add_alias(
|
||||||
|
alias="No. 8",
|
||||||
|
entities=["Q270853"],
|
||||||
|
probabilities=[1.0],
|
||||||
|
)
|
||||||
|
mykb.add_entity(entity="Q7304", freq=12, entity_vector=[6, -4, 3])
|
||||||
|
mykb.add_alias(
|
||||||
|
alias="Mahler",
|
||||||
|
entities=["Q7304"],
|
||||||
|
probabilities=[1.0],
|
||||||
|
)
|
||||||
|
return mykb
|
||||||
|
|
||||||
|
# Create the Entity Linker component and add it to the pipeline
|
||||||
|
entity_linker = nlp.add_pipe("entity_linker", last=True)
|
||||||
|
entity_linker.set_kb(create_kb)
|
||||||
|
|
||||||
|
# train the NEL pipe
|
||||||
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
for i in range(2):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
|
||||||
|
# Add a custom rule-based component to mimick NER
|
||||||
|
patterns = [
|
||||||
|
{"label": "PERSON", "pattern": [{"LOWER": "mahler"}]},
|
||||||
|
{"label": "WORK", "pattern": [{"LOWER": "symphony"}, {"LOWER": "no"}, {"LOWER": "."}, {"LOWER": "8"}]}
|
||||||
|
]
|
||||||
|
ruler = nlp.add_pipe("entity_ruler", before="entity_linker")
|
||||||
|
ruler.add_patterns(patterns)
|
||||||
|
|
||||||
|
# test the trained model - this should not throw E148
|
||||||
|
doc = nlp(text)
|
||||||
|
assert doc
|
||||||
|
|
Loading…
Reference in New Issue
Block a user