have gold.links correspond exactly to doc.ents

This commit is contained in:
svlandeg 2019-07-19 12:36:15 +02:00
parent e1213eaf6a
commit 21176517a7
4 changed files with 94 additions and 69 deletions

View File

@ -397,28 +397,38 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
current_doc = None current_doc = None
else: else:
sent = found_ent.sent.as_doc() sent = found_ent.sent.as_doc()
# currently feeding the gold data one entity per sentence at a time
gold_start = int(start) - found_ent.sent.start_char gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char gold_end = int(end) - found_ent.sent.start_char
gold_entities = {}
found_useful = False
for ent in sent.ents:
if ent.start_char == gold_start and ent.end_char == gold_end:
# add both pos and neg examples (in random order) # add both pos and neg examples (in random order)
# this will exclude examples not in the KB # this will exclude examples not in the KB
if kb: if kb:
gold_entities = {} value_by_id = {}
candidates = kb.get_candidates(alias) candidates = kb.get_candidates(alias)
candidate_ids = [c.entity_ for c in candidates] candidate_ids = [c.entity_ for c in candidates]
random.shuffle(candidate_ids) random.shuffle(candidate_ids)
for kb_id in candidate_ids: for kb_id in candidate_ids:
entry = (gold_start, gold_end, kb_id) found_useful = True
if kb_id != wd_id: if kb_id != wd_id:
gold_entities[entry] = 0.0 value_by_id[kb_id] = 0.0
else: else:
gold_entities[entry] = 1.0 value_by_id[kb_id] = 1.0
# keep all positive examples gold_entities[(ent.start_char, ent.end_char)] = value_by_id
# if no KB, keep all positive examples
else: else:
entry = (gold_start, gold_end, wd_id) found_useful = True
gold_entities = {entry: 1.0} value_by_id = {wd_id: 1.0}
gold_entities[(ent.start_char, ent.end_char)] = value_by_id
# currently feeding the gold data one entity per sentence at a time
# setting all other entities to empty gold dictionary
else:
gold_entities[(ent.start_char, ent.end_char)] = {}
if found_useful:
gold = GoldParse(doc=sent, links=gold_entities) gold = GoldParse(doc=sent, links=gold_entities)
data.append((sent, gold)) data.append((sent, gold))
total_entities += 1 total_entities += 1

View File

@ -322,10 +322,11 @@ def _measure_acc(data, el_pipe=None, error_analysis=False):
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
try: try:
correct_entries_per_article = dict() correct_entries_per_article = dict()
for entity, value in gold.links.items(): for entity, kb_dict in gold.links.items():
start, end = entity
# only evaluating on positive examples # only evaluating on positive examples
for gold_kb, value in kb_dict.items():
if value: if value:
start, end, gold_kb = entity
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
for ent in doc.ents: for ent in doc.ents:
@ -379,8 +380,9 @@ def _measure_baselines(data, kb):
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
try: try:
correct_entries_per_article = dict() correct_entries_per_article = dict()
for entity, value in gold.links.items(): for entity, kb_dict in gold.links.items():
start, end, gold_kb = entity start, end = entity
for gold_kb, value in kb_dict.items():
# only evaluating on positive examples # only evaluating on positive examples
if value: if value:
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
@ -487,7 +489,7 @@ def run_el_toy_example(nlp):
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
"Douglas reminds us to always bring our towel, even in China or Brazil. " "Douglas reminds us to always bring our towel, even in China or Brazil. "
"The main character in Doug's novel is the man Arthur Dent, " "The main character in Doug's novel is the man Arthur Dent, "
"but Douglas doesn't write about George Washington or Homer Simpson." "but Dougledydoug doesn't write about George Washington or Homer Simpson."
) )
doc = nlp(text) doc = nlp(text)
print(text) print(text)

View File

@ -450,10 +450,11 @@ cdef class GoldParse:
examples of a label to have the value 0.0. Labels not in the examples of a label to have the value 0.0. Labels not in the
dictionary are treated as missing - the gradient for those labels dictionary are treated as missing - the gradient for those labels
will be zero. will be zero.
links (dict): A dict with `(start_char, end_char, kb_id)` keys, links (dict): A dict with `(start_char, end_char)` keys,
representing the external ID of an entity in a knowledge base, and the values being dicts with kb_id:value entries,
and the values being either 1.0 or 0.0, indicating positive and representing the external IDs in a knowledge base (KB)
negative examples, respectively. mapped to either 1.0 or 0.0, indicating positive and
negative examples respectively.
RETURNS (GoldParse): The newly constructed object. RETURNS (GoldParse): The newly constructed object.
""" """
if words is None: if words is None:

View File

@ -1076,6 +1076,7 @@ class EntityLinker(Pipe):
DOCS: TODO DOCS: TODO
""" """
name = 'entity_linker' name = 'entity_linker'
NIL = "NIL" # string used to refer to a non-existing link
@classmethod @classmethod
def Model(cls, **cfg): def Model(cls, **cfg):
@ -1151,9 +1152,10 @@ class EntityLinker(Pipe):
ents_by_offset = dict() ents_by_offset = dict()
for ent in doc.ents: for ent in doc.ents:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
for entity, value in gold.links.items(): for entity, kb_dict in gold.links.items():
start, end, kb_id = entity start, end = entity
mention = doc.text[start:end] mention = doc.text[start:end]
for kb_id, value in kb_dict.items():
entity_encoding = self.kb.get_vector(kb_id) entity_encoding = self.kb.get_vector(kb_id)
prior_prob = self.kb.get_prior_prob(kb_id, mention) prior_prob = self.kb.get_prior_prob(kb_id, mention)
@ -1197,7 +1199,8 @@ class EntityLinker(Pipe):
def get_loss(self, docs, golds, scores): def get_loss(self, docs, golds, scores):
cats = [] cats = []
for gold in golds: for gold in golds:
for entity, value in gold.links.items(): for entity, kb_dict in gold.links.items():
for kb_id, value in kb_dict.items():
cats.append([value]) cats.append([value])
cats = self.model.ops.asarray(cats, dtype="float32") cats = self.model.ops.asarray(cats, dtype="float32")
@ -1209,26 +1212,27 @@ class EntityLinker(Pipe):
return loss, d_scores return loss, d_scores
def __call__(self, doc): def __call__(self, doc):
entities, kb_ids = self.predict([doc]) kb_ids = self.predict([doc])
self.set_annotations([doc], entities, kb_ids) self.set_annotations([doc], kb_ids)
return doc return doc
def pipe(self, stream, batch_size=128, n_threads=-1): def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in util.minibatch(stream, size=batch_size): for docs in util.minibatch(stream, size=batch_size):
docs = list(docs) docs = list(docs)
entities, kb_ids = self.predict(docs) kb_ids = self.predict(docs)
self.set_annotations(docs, entities, kb_ids) self.set_annotations(docs, kb_ids)
yield from docs yield from docs
def predict(self, docs): def predict(self, docs):
""" Return the KB IDs for each entity in each doc, including NIL if there is no prediction """
self.require_model() self.require_model()
self.require_kb() self.require_kb()
final_entities = [] entity_count = 0
final_kb_ids = [] final_kb_ids = []
if not docs: if not docs:
return final_entities, final_kb_ids return final_kb_ids
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
@ -1242,12 +1246,15 @@ class EntityLinker(Pipe):
if len(doc) > 0: if len(doc) > 0:
context_encoding = context_encodings[i] context_encoding = context_encodings[i]
for ent in doc.ents: for ent in doc.ents:
entity_count += 1
type_vector = [0 for i in range(len(type_to_int))] type_vector = [0 for i in range(len(type_to_int))]
if len(type_to_int) > 0: if len(type_to_int) > 0:
type_vector[type_to_int[ent.label_]] = 1 type_vector[type_to_int[ent.label_]] = 1
candidates = self.kb.get_candidates(ent.text) candidates = self.kb.get_candidates(ent.text)
if candidates: if not candidates:
final_kb_ids.append(self.NIL) # no prediction possible for this entity
else:
random.shuffle(candidates) random.shuffle(candidates)
# this will set the prior probabilities to 0 (just like in training) if their weight is 0 # this will set the prior probabilities to 0 (just like in training) if their weight is 0
@ -1266,14 +1273,19 @@ class EntityLinker(Pipe):
# TODO: thresholding # TODO: thresholding
best_index = scores.argmax() best_index = scores.argmax()
best_candidate = candidates[best_index] best_candidate = candidates[best_index]
final_entities.append(ent)
final_kb_ids.append(best_candidate.entity_) final_kb_ids.append(best_candidate.entity_)
return final_entities, final_kb_ids assert len(final_kb_ids) == entity_count
def set_annotations(self, docs, entities, kb_ids=None): return final_kb_ids
for entity, kb_id in zip(entities, kb_ids):
for token in entity: def set_annotations(self, docs, kb_ids, tensors=None):
i=0
for doc in docs:
for ent in doc.ents:
kb_id = kb_ids[i]
i += 1
for token in ent:
token.ent_kb_id_ = kb_id token.ent_kb_id_ = kb_id
def to_disk(self, path, exclude=tuple(), **kwargs): def to_disk(self, path, exclude=tuple(), **kwargs):