have gold.links correspond exactly to doc.ents

This commit is contained in:
svlandeg 2019-07-19 12:36:15 +02:00
parent e1213eaf6a
commit 21176517a7
4 changed files with 94 additions and 69 deletions

View File

@ -397,28 +397,38 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
current_doc = None
else:
sent = found_ent.sent.as_doc()
# currently feeding the gold data one entity per sentence at a time
gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char
gold_entities = {}
found_useful = False
for ent in sent.ents:
if ent.start_char == gold_start and ent.end_char == gold_end:
# add both pos and neg examples (in random order)
# this will exclude examples not in the KB
if kb:
gold_entities = {}
value_by_id = {}
candidates = kb.get_candidates(alias)
candidate_ids = [c.entity_ for c in candidates]
random.shuffle(candidate_ids)
for kb_id in candidate_ids:
entry = (gold_start, gold_end, kb_id)
found_useful = True
if kb_id != wd_id:
gold_entities[entry] = 0.0
value_by_id[kb_id] = 0.0
else:
gold_entities[entry] = 1.0
# keep all positive examples
value_by_id[kb_id] = 1.0
gold_entities[(ent.start_char, ent.end_char)] = value_by_id
# if no KB, keep all positive examples
else:
entry = (gold_start, gold_end, wd_id)
gold_entities = {entry: 1.0}
found_useful = True
value_by_id = {wd_id: 1.0}
gold_entities[(ent.start_char, ent.end_char)] = value_by_id
# currently feeding the gold data one entity per sentence at a time
# setting all other entities to empty gold dictionary
else:
gold_entities[(ent.start_char, ent.end_char)] = {}
if found_useful:
gold = GoldParse(doc=sent, links=gold_entities)
data.append((sent, gold))
total_entities += 1

View File

@ -322,10 +322,11 @@ def _measure_acc(data, el_pipe=None, error_analysis=False):
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity, value in gold.links.items():
for entity, kb_dict in gold.links.items():
start, end = entity
# only evaluating on positive examples
for gold_kb, value in kb_dict.items():
if value:
start, end, gold_kb = entity
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
for ent in doc.ents:
@ -379,8 +380,9 @@ def _measure_baselines(data, kb):
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity, value in gold.links.items():
start, end, gold_kb = entity
for entity, kb_dict in gold.links.items():
start, end = entity
for gold_kb, value in kb_dict.items():
# only evaluating on positive examples
if value:
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
@ -487,7 +489,7 @@ def run_el_toy_example(nlp):
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
"Douglas reminds us to always bring our towel, even in China or Brazil. "
"The main character in Doug's novel is the man Arthur Dent, "
"but Douglas doesn't write about George Washington or Homer Simpson."
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
)
doc = nlp(text)
print(text)

View File

@ -450,10 +450,11 @@ cdef class GoldParse:
examples of a label to have the value 0.0. Labels not in the
dictionary are treated as missing - the gradient for those labels
will be zero.
links (dict): A dict with `(start_char, end_char, kb_id)` keys,
representing the external ID of an entity in a knowledge base,
and the values being either 1.0 or 0.0, indicating positive and
negative examples, respectively.
links (dict): A dict with `(start_char, end_char)` keys,
and the values being dicts with kb_id:value entries,
representing the external IDs in a knowledge base (KB)
mapped to either 1.0 or 0.0, indicating positive and
negative examples respectively.
RETURNS (GoldParse): The newly constructed object.
"""
if words is None:

View File

@ -1076,6 +1076,7 @@ class EntityLinker(Pipe):
DOCS: TODO
"""
name = 'entity_linker'
NIL = "NIL" # string used to refer to a non-existing link
@classmethod
def Model(cls, **cfg):
@ -1151,9 +1152,10 @@ class EntityLinker(Pipe):
ents_by_offset = dict()
for ent in doc.ents:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
for entity, value in gold.links.items():
start, end, kb_id = entity
for entity, kb_dict in gold.links.items():
start, end = entity
mention = doc.text[start:end]
for kb_id, value in kb_dict.items():
entity_encoding = self.kb.get_vector(kb_id)
prior_prob = self.kb.get_prior_prob(kb_id, mention)
@ -1197,7 +1199,8 @@ class EntityLinker(Pipe):
def get_loss(self, docs, golds, scores):
cats = []
for gold in golds:
for entity, value in gold.links.items():
for entity, kb_dict in gold.links.items():
for kb_id, value in kb_dict.items():
cats.append([value])
cats = self.model.ops.asarray(cats, dtype="float32")
@ -1209,26 +1212,27 @@ class EntityLinker(Pipe):
return loss, d_scores
def __call__(self, doc):
entities, kb_ids = self.predict([doc])
self.set_annotations([doc], entities, kb_ids)
kb_ids = self.predict([doc])
self.set_annotations([doc], kb_ids)
return doc
def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in util.minibatch(stream, size=batch_size):
docs = list(docs)
entities, kb_ids = self.predict(docs)
self.set_annotations(docs, entities, kb_ids)
kb_ids = self.predict(docs)
self.set_annotations(docs, kb_ids)
yield from docs
def predict(self, docs):
""" Return the KB IDs for each entity in each doc, including NIL if there is no prediction """
self.require_model()
self.require_kb()
final_entities = []
entity_count = 0
final_kb_ids = []
if not docs:
return final_entities, final_kb_ids
return final_kb_ids
if isinstance(docs, Doc):
docs = [docs]
@ -1242,12 +1246,15 @@ class EntityLinker(Pipe):
if len(doc) > 0:
context_encoding = context_encodings[i]
for ent in doc.ents:
entity_count += 1
type_vector = [0 for i in range(len(type_to_int))]
if len(type_to_int) > 0:
type_vector[type_to_int[ent.label_]] = 1
candidates = self.kb.get_candidates(ent.text)
if candidates:
if not candidates:
final_kb_ids.append(self.NIL) # no prediction possible for this entity
else:
random.shuffle(candidates)
# this will set the prior probabilities to 0 (just like in training) if their weight is 0
@ -1266,14 +1273,19 @@ class EntityLinker(Pipe):
# TODO: thresholding
best_index = scores.argmax()
best_candidate = candidates[best_index]
final_entities.append(ent)
final_kb_ids.append(best_candidate.entity_)
return final_entities, final_kb_ids
assert len(final_kb_ids) == entity_count
def set_annotations(self, docs, entities, kb_ids=None):
for entity, kb_id in zip(entities, kb_ids):
for token in entity:
return final_kb_ids
def set_annotations(self, docs, kb_ids, tensors=None):
i=0
for doc in docs:
for ent in doc.ents:
kb_id = kb_ids[i]
i += 1
for token in ent:
token.ent_kb_id_ = kb_id
def to_disk(self, path, exclude=tuple(), **kwargs):