mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 17:10:36 +03:00
fix entity linker
This commit is contained in:
parent
be5934b827
commit
6d73e139b0
|
@ -113,6 +113,8 @@ class Warnings(object):
|
||||||
"ignored during training.")
|
"ignored during training.")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
W093 = ("Could not find any data to train the {name} on. Is your "
|
||||||
|
"input data correctly formatted ?")
|
||||||
W094 = ("Model '{model}' ({model_version}) specifies an under-constrained "
|
W094 = ("Model '{model}' ({model_version}) specifies an under-constrained "
|
||||||
"spaCy version requirement: {version}. This can lead to compatibility "
|
"spaCy version requirement: {version}. This can lead to compatibility "
|
||||||
"problems with older versions, or as new spaCy versions are "
|
"problems with older versions, or as new spaCy versions are "
|
||||||
|
@ -556,9 +558,6 @@ class Errors(object):
|
||||||
"{obj}.{attr}\nAttribute '{attr}' does not exist on {obj}.")
|
"{obj}.{attr}\nAttribute '{attr}' does not exist on {obj}.")
|
||||||
E186 = ("'{tok_a}' and '{tok_b}' are different texts.")
|
E186 = ("'{tok_a}' and '{tok_b}' are different texts.")
|
||||||
E187 = ("Only unicode strings are supported as labels.")
|
E187 = ("Only unicode strings are supported as labels.")
|
||||||
E188 = ("Could not match the gold entity links to entities in the doc - "
|
|
||||||
"make sure the gold EL data refers to valid results of the "
|
|
||||||
"named entity recognizer in the `nlp` pipeline.")
|
|
||||||
E189 = ("Each argument to `get_doc` should be of equal length.")
|
E189 = ("Each argument to `get_doc` should be of equal length.")
|
||||||
E190 = ("Token head out of range in `Doc.from_array()` for token index "
|
E190 = ("Token head out of range in `Doc.from_array()` for token index "
|
||||||
"'{index}' with value '{value}' (equivalent to relative head "
|
"'{index}' with value '{value}' (equivalent to relative head "
|
||||||
|
|
|
@ -109,7 +109,7 @@ cdef class Example:
|
||||||
prev_j = -1
|
prev_j = -1
|
||||||
prev_value = value
|
prev_value = value
|
||||||
|
|
||||||
if field in ["ENT_IOB", "ENT_TYPE"]:
|
if field in ["ENT_IOB", "ENT_TYPE", "ENT_KB_ID"]:
|
||||||
# Assign one-to-many NER tags
|
# Assign one-to-many NER tags
|
||||||
for j, cand_j in enumerate(gold_to_cand):
|
for j, cand_j in enumerate(gold_to_cand):
|
||||||
if cand_j is None:
|
if cand_j is None:
|
||||||
|
@ -196,7 +196,7 @@ def _annot2array(vocab, tok_annot, doc_annot):
|
||||||
entities = doc_annot.get("entities", {})
|
entities = doc_annot.get("entities", {})
|
||||||
if value and not entities:
|
if value and not entities:
|
||||||
raise ValueError(Errors.E981)
|
raise ValueError(Errors.E981)
|
||||||
ent_kb_ids = _parse_links(vocab, words, value, entities)
|
ent_kb_ids = _parse_links(vocab, tok_annot["ORTH"], value, entities)
|
||||||
tok_annot["ENT_KB_ID"] = ent_kb_ids
|
tok_annot["ENT_KB_ID"] = ent_kb_ids
|
||||||
elif key == "cats":
|
elif key == "cats":
|
||||||
pass
|
pass
|
||||||
|
@ -308,7 +308,6 @@ def _parse_ner_tags(biluo_or_offsets, vocab, words, spaces):
|
||||||
|
|
||||||
def _parse_links(vocab, words, links, entities):
|
def _parse_links(vocab, words, links, entities):
|
||||||
reference = Doc(vocab, words=words)
|
reference = Doc(vocab, words=words)
|
||||||
|
|
||||||
starts = {token.idx: token.i for token in reference}
|
starts = {token.idx: token.i for token in reference}
|
||||||
ends = {token.idx + len(token): token.i for token in reference}
|
ends = {token.idx + len(token): token.i for token in reference}
|
||||||
ent_kb_ids = ["" for _ in reference]
|
ent_kb_ids = ["" for _ in reference]
|
||||||
|
|
|
@ -2,7 +2,6 @@
|
||||||
import numpy
|
import numpy
|
||||||
import srsly
|
import srsly
|
||||||
import random
|
import random
|
||||||
from ast import literal_eval
|
|
||||||
|
|
||||||
from thinc.api import CosineDistance, to_categorical, get_array_module
|
from thinc.api import CosineDistance, to_categorical, get_array_module
|
||||||
from thinc.api import set_dropout_rate, SequenceCategoricalCrossentropy
|
from thinc.api import set_dropout_rate, SequenceCategoricalCrossentropy
|
||||||
|
@ -92,15 +91,6 @@ class Pipe(object):
|
||||||
"""Modify a batch of documents, using pre-computed scores."""
|
"""Modify a batch of documents, using pre-computed scores."""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def update(self, docs, set_annotations=False, drop=0.0, sgd=None, losses=None):
|
|
||||||
"""Learn from a batch of documents and gold-standard information,
|
|
||||||
updating the pipe's model.
|
|
||||||
|
|
||||||
Delegates to predict() and get_loss().
|
|
||||||
"""
|
|
||||||
if set_annotations:
|
|
||||||
docs = list(self.pipe(docs))
|
|
||||||
|
|
||||||
def rehearse(self, examples, sgd=None, losses=None, **config):
|
def rehearse(self, examples, sgd=None, losses=None, **config):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@ -1094,31 +1084,19 @@ class EntityLinker(Pipe):
|
||||||
predictions = self.model.predict(docs)
|
predictions = self.model.predict(docs)
|
||||||
|
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
doc = eg.predicted
|
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||||
ents_by_offset = dict()
|
for ent in eg.doc.ents:
|
||||||
for ent in doc.ents:
|
kb_id = kb_ids[ent.start] # KB ID of the first token is the same as the whole span
|
||||||
ents_by_offset[(ent.start_char, ent.end_char)] = ent
|
if kb_id:
|
||||||
links = self._get_links_from_doc(eg.reference)
|
|
||||||
for entity, kb_dict in links.items():
|
|
||||||
if isinstance(entity, str):
|
|
||||||
entity = literal_eval(entity)
|
|
||||||
start, end = entity
|
|
||||||
mention = doc.text[start:end]
|
|
||||||
|
|
||||||
# the gold annotations should link to proper entities - if this fails, the dataset is likely corrupt
|
|
||||||
if not (start, end) in ents_by_offset:
|
|
||||||
raise RuntimeError(Errors.E188)
|
|
||||||
ent = ents_by_offset[(start, end)]
|
|
||||||
|
|
||||||
for kb_id, value in kb_dict.items():
|
|
||||||
# Currently only training on the positive instances - we assume there is at least 1 per doc/gold
|
|
||||||
if value:
|
|
||||||
try:
|
try:
|
||||||
sentence_docs.append(ent.sent.as_doc())
|
sentence_docs.append(ent.sent.as_doc())
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
# Catch the exception when ent.sent is None and provide a user-friendly warning
|
# Catch the exception when ent.sent is None and provide a user-friendly warning
|
||||||
raise RuntimeError(Errors.E030)
|
raise RuntimeError(Errors.E030)
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
|
if not sentence_docs:
|
||||||
|
warnings.warn(Warnings.W093.format(name="Entity Linker"))
|
||||||
|
return 0.0
|
||||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
||||||
loss, d_scores = self.get_similarity_loss(
|
loss, d_scores = self.get_similarity_loss(
|
||||||
scores=sentence_encodings,
|
scores=sentence_encodings,
|
||||||
|
@ -1137,11 +1115,10 @@ class EntityLinker(Pipe):
|
||||||
def get_similarity_loss(self, examples, scores):
|
def get_similarity_loss(self, examples, scores):
|
||||||
entity_encodings = []
|
entity_encodings = []
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
links = self._get_links_from_doc(eg.reference)
|
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||||
for entity, kb_dict in links.items():
|
for ent in eg.doc.ents:
|
||||||
for kb_id, value in kb_dict.items():
|
kb_id = kb_ids[ent.start]
|
||||||
# this loss function assumes we're only using positive examples
|
if kb_id:
|
||||||
if value:
|
|
||||||
entity_encoding = self.kb.get_vector(kb_id)
|
entity_encoding = self.kb.get_vector(kb_id)
|
||||||
entity_encodings.append(entity_encoding)
|
entity_encodings.append(entity_encoding)
|
||||||
|
|
||||||
|
@ -1158,10 +1135,11 @@ class EntityLinker(Pipe):
|
||||||
def get_loss(self, examples, scores):
|
def get_loss(self, examples, scores):
|
||||||
cats = []
|
cats = []
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
links = self._get_links_from_doc(eg.reference)
|
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||||
for entity, kb_dict in links.items():
|
for ent in eg.doc.ents:
|
||||||
for kb_id, value in kb_dict.items():
|
kb_id = kb_ids[ent.start]
|
||||||
cats.append([value])
|
if kb_id:
|
||||||
|
cats.append([1.0])
|
||||||
|
|
||||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||||
if len(scores) != len(cats):
|
if len(scores) != len(cats):
|
||||||
|
@ -1172,9 +1150,6 @@ class EntityLinker(Pipe):
|
||||||
loss = loss / len(cats)
|
loss = loss / len(cats)
|
||||||
return loss, d_scores
|
return loss, d_scores
|
||||||
|
|
||||||
def _get_links_from_doc(self, doc):
|
|
||||||
return {}
|
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
kb_ids, tensors = self.predict([doc])
|
kb_ids, tensors = self.predict([doc])
|
||||||
self.set_annotations([doc], kb_ids, tensors=tensors)
|
self.set_annotations([doc], kb_ids, tensors=tensors)
|
||||||
|
|
|
@ -252,10 +252,18 @@ def test_preserving_links_ents_2(nlp):
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
TRAIN_DATA = [
|
TRAIN_DATA = [
|
||||||
("Russ Cochran captured his first major title with his son as caddie.", {"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}}}),
|
("Russ Cochran captured his first major title with his son as caddie.",
|
||||||
("Russ Cochran his reprints include EC Comics.", {"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}}}),
|
{"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}},
|
||||||
("Russ Cochran has been publishing comic art.", {"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}}}),
|
"entities": [(0, 12, "PERSON")]}),
|
||||||
("Russ Cochran was a member of University of Kentucky's golf team.", {"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}}}),
|
("Russ Cochran his reprints include EC Comics.",
|
||||||
|
{"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}},
|
||||||
|
"entities": [(0, 12, "PERSON")]}),
|
||||||
|
("Russ Cochran has been publishing comic art.",
|
||||||
|
{"links": {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}},
|
||||||
|
"entities": [(0, 12, "PERSON")]}),
|
||||||
|
("Russ Cochran was a member of University of Kentucky's golf team.",
|
||||||
|
{"links": {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}},
|
||||||
|
"entities": [(0, 12, "PERSON")]}),
|
||||||
]
|
]
|
||||||
GOLD_entities = ["Q2146908", "Q7381115", "Q7381115", "Q2146908"]
|
GOLD_entities = ["Q2146908", "Q7381115", "Q7381115", "Q2146908"]
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
|
@ -176,10 +176,12 @@ def test_gold_biluo_different_tokenization(en_vocab, en_tokenizer):
|
||||||
spaces = [True, True, True, False, False]
|
spaces = [True, True, True, False, False]
|
||||||
doc = Doc(en_vocab, words=words, spaces=spaces)
|
doc = Doc(en_vocab, words=words, spaces=spaces)
|
||||||
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
|
entities = [(len("I flew to "), len("I flew to San Francisco Valley"), "LOC")]
|
||||||
|
links = {(len("I flew to "), len("I flew to San Francisco Valley")): {"Q816843": 1.0}}
|
||||||
gold_words = ["I", "flew to", "San", "Francisco Valley", "."]
|
gold_words = ["I", "flew to", "San", "Francisco Valley", "."]
|
||||||
example = Example.from_dict(doc, {"words": gold_words, "entities": entities})
|
example = Example.from_dict(doc, {"words": gold_words, "entities": entities, "links": links})
|
||||||
assert example.get_aligned("ENT_IOB") == [2, 2, 3, 1, 2]
|
assert example.get_aligned("ENT_IOB") == [2, 2, 3, 1, 2]
|
||||||
assert example.get_aligned("ENT_TYPE", as_string=True) == ["", "", "LOC", "LOC", ""]
|
assert example.get_aligned("ENT_TYPE", as_string=True) == ["", "", "LOC", "LOC", ""]
|
||||||
|
assert example.get_aligned("ENT_KB_ID", as_string=True) == ["", "", "Q816843", "Q816843", ""]
|
||||||
|
|
||||||
# additional whitespace tokens in GoldParse words
|
# additional whitespace tokens in GoldParse words
|
||||||
words, spaces = get_words_and_spaces(
|
words, spaces = get_words_and_spaces(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user