fix entity linker

This commit is contained in:
svlandeg 2020-06-17 21:12:25 +02:00
parent be5934b827
commit 6d73e139b0
5 changed files with 42 additions and 59 deletions

View File

@ -113,6 +113,8 @@ class Warnings(object):
"ignored during training.")
# TODO: fix numbering after merging develop into master
W093 = ("Could not find any data to train the {name} on. Is your "
"input data correctly formatted ?")
W094 = ("Model '{model}' ({model_version}) specifies an under-constrained "
"spaCy version requirement: {version}. This can lead to compatibility "
"problems with older versions, or as new spaCy versions are "
@ -556,9 +558,6 @@ class Errors(object):
"{obj}.{attr}\nAttribute '{attr}' does not exist on {obj}.")
E186 = ("'{tok_a}' and '{tok_b}' are different texts.")
E187 = ("Only unicode strings are supported as labels.")
E188 = ("Could not match the gold entity links to entities in the doc - "
"make sure the gold EL data refers to valid results of the "
"named entity recognizer in the `nlp` pipeline.")
E189 = ("Each argument to `get_doc` should be of equal length.")
E190 = ("Token head out of range in `Doc.from_array()` for token index "
"'{index}' with value '{value}' (equivalent to relative head "

View File

@ -109,7 +109,7 @@ cdef class Example:
prev_j = -1
prev_value = value
if field in ["ENT_IOB", "ENT_TYPE"]:
if field in ["ENT_IOB", "ENT_TYPE", "ENT_KB_ID"]:
# Assign one-to-many NER tags
for j, cand_j in enumerate(gold_to_cand):
if cand_j is None:
@ -196,7 +196,7 @@ def _annot2array(vocab, tok_annot, doc_annot):
entities = doc_annot.get("entities", {})
if value and not entities:
raise ValueError(Errors.E981)
ent_kb_ids = _parse_links(vocab, words, value, entities)
ent_kb_ids = _parse_links(vocab, tok_annot["ORTH"], value, entities)
tok_annot["ENT_KB_ID"] = ent_kb_ids
elif key == "cats":
pass
@ -308,7 +308,6 @@ def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
def _parse_links(vocab, words, links, entities):
reference = Doc(vocab, words=words)
starts = {token.idx: token.i for token in reference}
ends = {token.idx + len(token): token.i for token in reference}
ent_kb_ids = ["" for _ in reference]

View File

@ -2,7 +2,6 @@
import numpy
import srsly
import random
from ast import literal_eval
from thinc.api import CosineDistance, to_categorical, get_array_module
from thinc.api import set_dropout_rate, SequenceCategoricalCrossentropy
@ -92,15 +91,6 @@ class Pipe(object):
"""Modify a batch of documents, using pre-computed scores."""
raise NotImplementedError
def update(self, docs, set_annotations=False, drop=0.0, sgd=None, losses=None):
"""Learn from a batch of documents and gold-standard information,
updating the pipe's model.
Delegates to predict() and get_loss().
"""
if set_annotations:
docs = list(self.pipe(docs))
def rehearse(self, examples, sgd=None, losses=None, **config):
pass
@ -1094,31 +1084,19 @@ class EntityLinker(Pipe):
predictions = self.model.predict(docs)
for eg in examples:
doc = eg.predicted
ents_by_offset = dict()
for ent in doc.ents:
ents_by_offset[(ent.start_char, ent.end_char)] = ent
links = self._get_links_from_doc(eg.reference)
for entity, kb_dict in links.items():
if isinstance(entity, str):
entity = literal_eval(entity)
start, end = entity
mention = doc.text[start:end]
# the gold annotations should link to proper entities - if this fails, the dataset is likely corrupt
if not (start, end) in ents_by_offset:
raise RuntimeError(Errors.E188)
ent = ents_by_offset[(start, end)]
for kb_id, value in kb_dict.items():
# Currently only training on the positive instances - we assume there is at least 1 per doc/gold
if value:
try:
sentence_docs.append(ent.sent.as_doc())
except AttributeError:
# Catch the exception when ent.sent is None and provide a user-friendly warning
raise RuntimeError(Errors.E030)
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
for ent in eg.doc.ents:
kb_id = kb_ids[ent.start] # KB ID of the first token is the same as the whole span
if kb_id:
try:
sentence_docs.append(ent.sent.as_doc())
except AttributeError:
# Catch the exception when ent.sent is None and provide a user-friendly warning
raise RuntimeError(Errors.E030)
set_dropout_rate(self.model, drop)
if not sentence_docs:
warnings.warn(Warnings.W093.format(name="Entity Linker"))
return 0.0
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
loss, d_scores = self.get_similarity_loss(
scores=sentence_encodings,
@ -1137,13 +1115,12 @@ class EntityLinker(Pipe):
def get_similarity_loss(self, examples, scores):
entity_encodings = []
for eg in examples:
links = self._get_links_from_doc(eg.reference)
for entity, kb_dict in links.items():
for kb_id, value in kb_dict.items():
# this loss function assumes we're only using positive examples
if value:
entity_encoding = self.kb.get_vector(kb_id)
entity_encodings.append(entity_encoding)
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
for ent in eg.doc.ents:
kb_id = kb_ids[ent.start]
if kb_id:
entity_encoding = self.kb.get_vector(kb_id)
entity_encodings.append(entity_encoding)
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
@ -1158,10 +1135,11 @@ class EntityLinker(Pipe):
def get_loss(self, examples, scores):
cats = []
for eg in examples:
links = self._get_links_from_doc(eg.reference)
for entity, kb_dict in links.items():
for kb_id, value in kb_dict.items():
cats.append([value])
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
for ent in eg.doc.ents:
kb_id = kb_ids[ent.start]
if kb_id:
cats.append([1.0])
cats = self.model.ops.asarray(cats, dtype="float32")
if len(scores) != len(cats):
@ -1172,9 +1150,6 @@ class EntityLinker(Pipe):
loss = loss / len(cats)
return loss, d_scores
def _get_links_from_doc(self, doc):
return {}
def __call__(self, doc):
kb_ids, tensors = self.predict([doc])
self.set_annotations([doc], kb_ids, tensors=tensors)

View File

@ -252,10 +252,18 @@ def test_preserving_links_ents_2(nlp):
# fmt: off
TRAIN_DATA = [
("Russ Cochran captured his first major title with his son as caddie.", {"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}}}),
("Russ Cochran his reprints include EC Comics.", {"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}}}),
("Russ Cochran has been publishing comic art.", {"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}}}),
("Russ Cochran was a member of University of Kentucky's golf team.", {"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}}}),
("Russ Cochran captured his first major title with his son as caddie.",
{"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}},
"entities": [(0, 12, "PERSON")]}),
("Russ Cochran his reprints include EC Comics.",
{"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}},
"entities": [(0, 12, "PERSON")]}),
("Russ Cochran has been publishing comic art.",
{"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}},
"entities": [(0, 12, "PERSON")]}),
("Russ Cochran was a member of University of Kentucky's golf team.",
{"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}},
"entities": [(0, 12, "PERSON")]}),
]
GOLD_entities = ["Q2146908", "Q7381115", "Q7381115", "Q2146908"]
# fmt: on

View File

@ -176,10 +176,12 @@ def test_gold_biluo_different_tokenization(en_vocab, en_tokenizer):
spaces = [True, True, True, False, False]
doc = Doc(en_vocab, words=words, spaces=spaces)
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
links = {(len("I flew to "), len("I flew to San Francisco Valley")): {"Q816843": 1.0}}
gold_words = ["I", "flew to", "San", "Francisco Valley", "."]
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
example = Example.from_dict(doc, {"words": gold_words, "entities": entities, "links": links})
assert example.get_aligned("ENT_IOB") == [2, 2, 3, 1, 2]
assert example.get_aligned("ENT_TYPE", as_string=True) == ["", "", "LOC", "LOC", ""]
assert example.get_aligned("ENT_KB_ID", as_string=True) == ["", "", "Q816843", "Q816843", ""]
# additional whitespace tokens in GoldParse words
words, spaces = get_words_and_spaces(