mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 09:00:36 +03:00
fix entity linker
This commit is contained in:
parent
be5934b827
commit
6d73e139b0
|
@ -113,6 +113,8 @@ class Warnings(object):
|
|||
"ignored during training.")
|
||||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
W093 = ("Could not find any data to train the {name} on. Is your "
|
||||
"input data correctly formatted ?")
|
||||
W094 = ("Model '{model}' ({model_version}) specifies an under-constrained "
|
||||
"spaCy version requirement: {version}. This can lead to compatibility "
|
||||
"problems with older versions, or as new spaCy versions are "
|
||||
|
@ -556,9 +558,6 @@ class Errors(object):
|
|||
"{obj}.{attr}\nAttribute '{attr}' does not exist on {obj}.")
|
||||
E186 = ("'{tok_a}' and '{tok_b}' are different texts.")
|
||||
E187 = ("Only unicode strings are supported as labels.")
|
||||
E188 = ("Could not match the gold entity links to entities in the doc - "
|
||||
"make sure the gold EL data refers to valid results of the "
|
||||
"named entity recognizer in the `nlp` pipeline.")
|
||||
E189 = ("Each argument to `get_doc` should be of equal length.")
|
||||
E190 = ("Token head out of range in `Doc.from_array()` for token index "
|
||||
"'{index}' with value '{value}' (equivalent to relative head "
|
||||
|
|
|
@ -109,7 +109,7 @@ cdef class Example:
|
|||
prev_j = -1
|
||||
prev_value = value
|
||||
|
||||
if field in ["ENT_IOB", "ENT_TYPE"]:
|
||||
if field in ["ENT_IOB", "ENT_TYPE", "ENT_KB_ID"]:
|
||||
# Assign one-to-many NER tags
|
||||
for j, cand_j in enumerate(gold_to_cand):
|
||||
if cand_j is None:
|
||||
|
@ -196,7 +196,7 @@ def _annot2array(vocab, tok_annot, doc_annot):
|
|||
entities = doc_annot.get("entities", {})
|
||||
if value and not entities:
|
||||
raise ValueError(Errors.E981)
|
||||
ent_kb_ids = _parse_links(vocab, words, value, entities)
|
||||
ent_kb_ids = _parse_links(vocab, tok_annot["ORTH"], value, entities)
|
||||
tok_annot["ENT_KB_ID"] = ent_kb_ids
|
||||
elif key == "cats":
|
||||
pass
|
||||
|
@ -308,7 +308,6 @@ def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
|
|||
|
||||
def _parse_links(vocab, words, links, entities):
|
||||
reference = Doc(vocab, words=words)
|
||||
|
||||
starts = {token.idx: token.i for token in reference}
|
||||
ends = {token.idx + len(token): token.i for token in reference}
|
||||
ent_kb_ids = ["" for _ in reference]
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
import numpy
|
||||
import srsly
|
||||
import random
|
||||
from ast import literal_eval
|
||||
|
||||
from thinc.api import CosineDistance, to_categorical, get_array_module
|
||||
from thinc.api import set_dropout_rate, SequenceCategoricalCrossentropy
|
||||
|
@ -92,15 +91,6 @@ class Pipe(object):
|
|||
"""Modify a batch of documents, using pre-computed scores."""
|
||||
raise NotImplementedError
|
||||
|
||||
def update(self, docs, set_annotations=False, drop=0.0, sgd=None, losses=None):
|
||||
"""Learn from a batch of documents and gold-standard information,
|
||||
updating the pipe's model.
|
||||
|
||||
Delegates to predict() and get_loss().
|
||||
"""
|
||||
if set_annotations:
|
||||
docs = list(self.pipe(docs))
|
||||
|
||||
def rehearse(self, examples, sgd=None, losses=None, **config):
|
||||
pass
|
||||
|
||||
|
@ -1094,31 +1084,19 @@ class EntityLinker(Pipe):
|
|||
predictions = self.model.predict(docs)
|
||||
|
||||
for eg in examples:
|
||||
doc = eg.predicted
|
||||
ents_by_offset = dict()
|
||||
for ent in doc.ents:
|
||||
ents_by_offset[(ent.start_char, ent.end_char)] = ent
|
||||
links = self._get_links_from_doc(eg.reference)
|
||||
for entity, kb_dict in links.items():
|
||||
if isinstance(entity, str):
|
||||
entity = literal_eval(entity)
|
||||
start, end = entity
|
||||
mention = doc.text[start:end]
|
||||
|
||||
# the gold annotations should link to proper entities - if this fails, the dataset is likely corrupt
|
||||
if not (start, end) in ents_by_offset:
|
||||
raise RuntimeError(Errors.E188)
|
||||
ent = ents_by_offset[(start, end)]
|
||||
|
||||
for kb_id, value in kb_dict.items():
|
||||
# Currently only training on the positive instances - we assume there is at least 1 per doc/gold
|
||||
if value:
|
||||
try:
|
||||
sentence_docs.append(ent.sent.as_doc())
|
||||
except AttributeError:
|
||||
# Catch the exception when ent.sent is None and provide a user-friendly warning
|
||||
raise RuntimeError(Errors.E030)
|
||||
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||
for ent in eg.doc.ents:
|
||||
kb_id = kb_ids[ent.start] # KB ID of the first token is the same as the whole span
|
||||
if kb_id:
|
||||
try:
|
||||
sentence_docs.append(ent.sent.as_doc())
|
||||
except AttributeError:
|
||||
# Catch the exception when ent.sent is None and provide a user-friendly warning
|
||||
raise RuntimeError(Errors.E030)
|
||||
set_dropout_rate(self.model, drop)
|
||||
if not sentence_docs:
|
||||
warnings.warn(Warnings.W093.format(name="Entity Linker"))
|
||||
return 0.0
|
||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
||||
loss, d_scores = self.get_similarity_loss(
|
||||
scores=sentence_encodings,
|
||||
|
@ -1137,13 +1115,12 @@ class EntityLinker(Pipe):
|
|||
def get_similarity_loss(self, examples, scores):
|
||||
entity_encodings = []
|
||||
for eg in examples:
|
||||
links = self._get_links_from_doc(eg.reference)
|
||||
for entity, kb_dict in links.items():
|
||||
for kb_id, value in kb_dict.items():
|
||||
# this loss function assumes we're only using positive examples
|
||||
if value:
|
||||
entity_encoding = self.kb.get_vector(kb_id)
|
||||
entity_encodings.append(entity_encoding)
|
||||
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||
for ent in eg.doc.ents:
|
||||
kb_id = kb_ids[ent.start]
|
||||
if kb_id:
|
||||
entity_encoding = self.kb.get_vector(kb_id)
|
||||
entity_encodings.append(entity_encoding)
|
||||
|
||||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||
|
||||
|
@ -1158,10 +1135,11 @@ class EntityLinker(Pipe):
|
|||
def get_loss(self, examples, scores):
|
||||
cats = []
|
||||
for eg in examples:
|
||||
links = self._get_links_from_doc(eg.reference)
|
||||
for entity, kb_dict in links.items():
|
||||
for kb_id, value in kb_dict.items():
|
||||
cats.append([value])
|
||||
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||
for ent in eg.doc.ents:
|
||||
kb_id = kb_ids[ent.start]
|
||||
if kb_id:
|
||||
cats.append([1.0])
|
||||
|
||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||
if len(scores) != len(cats):
|
||||
|
@ -1172,9 +1150,6 @@ class EntityLinker(Pipe):
|
|||
loss = loss / len(cats)
|
||||
return loss, d_scores
|
||||
|
||||
def _get_links_from_doc(self, doc):
|
||||
return {}
|
||||
|
||||
def __call__(self, doc):
|
||||
kb_ids, tensors = self.predict([doc])
|
||||
self.set_annotations([doc], kb_ids, tensors=tensors)
|
||||
|
|
|
@ -252,10 +252,18 @@ def test_preserving_links_ents_2(nlp):
|
|||
|
||||
# fmt: off
|
||||
TRAIN_DATA = [
|
||||
("Russ Cochran captured his first major title with his son as caddie.", {"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}}}),
|
||||
("Russ Cochran his reprints include EC Comics.", {"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}}}),
|
||||
("Russ Cochran has been publishing comic art.", {"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}}}),
|
||||
("Russ Cochran was a member of University of Kentucky's golf team.", {"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}}}),
|
||||
("Russ Cochran captured his first major title with his son as caddie.",
|
||||
{"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}},
|
||||
"entities": [(0, 12, "PERSON")]}),
|
||||
("Russ Cochran his reprints include EC Comics.",
|
||||
{"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}},
|
||||
"entities": [(0, 12, "PERSON")]}),
|
||||
("Russ Cochran has been publishing comic art.",
|
||||
{"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}},
|
||||
"entities": [(0, 12, "PERSON")]}),
|
||||
("Russ Cochran was a member of University of Kentucky's golf team.",
|
||||
{"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}},
|
||||
"entities": [(0, 12, "PERSON")]}),
|
||||
]
|
||||
GOLD_entities = ["Q2146908", "Q7381115", "Q7381115", "Q2146908"]
|
||||
# fmt: on
|
||||
|
|
|
@ -176,10 +176,12 @@ def test_gold_biluo_different_tokenization(en_vocab, en_tokenizer):
|
|||
spaces = [True, True, True, False, False]
|
||||
doc = Doc(en_vocab, words=words, spaces=spaces)
|
||||
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
|
||||
links = {(len("I flew to "), len("I flew to San Francisco Valley")): {"Q816843": 1.0}}
|
||||
gold_words = ["I", "flew to", "San", "Francisco Valley", "."]
|
||||
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
|
||||
example = Example.from_dict(doc, {"words": gold_words, "entities": entities, "links": links})
|
||||
assert example.get_aligned("ENT_IOB") == [2, 2, 3, 1, 2]
|
||||
assert example.get_aligned("ENT_TYPE", as_string=True) == ["", "", "LOC", "LOC", ""]
|
||||
assert example.get_aligned("ENT_KB_ID", as_string=True) == ["", "", "Q816843", "Q816843", ""]
|
||||
|
||||
# additional whitespace tokens in GoldParse words
|
||||
words, spaces = get_words_and_spaces(
|
||||
|
|
Loading…
Reference in New Issue
Block a user