have gold.links correspond exactly to doc.ents

This commit is contained in:
svlandeg 2019-07-19 12:36:15 +02:00
parent e1213eaf6a
commit 21176517a7
4 changed files with 94 additions and 69 deletions

View File

@ -397,33 +397,43 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
current_doc = None current_doc = None
else: else:
sent = found_ent.sent.as_doc() sent = found_ent.sent.as_doc()
# currently feeding the gold data one entity per sentence at a time
gold_start = int(start) - found_ent.sent.start_char gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char gold_end = int(end) - found_ent.sent.start_char
# add both pos and neg examples (in random order) gold_entities = {}
# this will exclude examples not in the KB found_useful = False
if kb: for ent in sent.ents:
gold_entities = {} if ent.start_char == gold_start and ent.end_char == gold_end:
candidates = kb.get_candidates(alias) # add both pos and neg examples (in random order)
candidate_ids = [c.entity_ for c in candidates] # this will exclude examples not in the KB
random.shuffle(candidate_ids) if kb:
for kb_id in candidate_ids: value_by_id = {}
entry = (gold_start, gold_end, kb_id) candidates = kb.get_candidates(alias)
if kb_id != wd_id: candidate_ids = [c.entity_ for c in candidates]
gold_entities[entry] = 0.0 random.shuffle(candidate_ids)
for kb_id in candidate_ids:
found_useful = True
if kb_id != wd_id:
value_by_id[kb_id] = 0.0
else:
value_by_id[kb_id] = 1.0
gold_entities[(ent.start_char, ent.end_char)] = value_by_id
# if no KB, keep all positive examples
else: else:
gold_entities[entry] = 1.0 found_useful = True
# keep all positive examples value_by_id = {wd_id: 1.0}
else: gold_entities[(ent.start_char, ent.end_char)] = value_by_id
entry = (gold_start, gold_end, wd_id) # currently feeding the gold data one entity per sentence at a time
gold_entities = {entry: 1.0} # setting all other entities to empty gold dictionary
else:
gold = GoldParse(doc=sent, links=gold_entities) gold_entities[(ent.start_char, ent.end_char)] = {}
data.append((sent, gold)) if found_useful:
total_entities += 1 gold = GoldParse(doc=sent, links=gold_entities)
if len(data) % 2500 == 0: data.append((sent, gold))
print(" -read", total_entities, "entities") total_entities += 1
if len(data) % 2500 == 0:
print(" -read", total_entities, "entities")
print(" -read", total_entities, "entities") print(" -read", total_entities, "entities")
return data return data

View File

@ -322,11 +322,12 @@ def _measure_acc(data, el_pipe=None, error_analysis=False):
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
try: try:
correct_entries_per_article = dict() correct_entries_per_article = dict()
for entity, value in gold.links.items(): for entity, kb_dict in gold.links.items():
start, end = entity
# only evaluating on positive examples # only evaluating on positive examples
if value: for gold_kb, value in kb_dict.items():
start, end, gold_kb = entity if value:
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
for ent in doc.ents: for ent in doc.ents:
ent_label = ent.label_ ent_label = ent.label_
@ -379,11 +380,12 @@ def _measure_baselines(data, kb):
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
try: try:
correct_entries_per_article = dict() correct_entries_per_article = dict()
for entity, value in gold.links.items(): for entity, kb_dict in gold.links.items():
start, end, gold_kb = entity start, end = entity
# only evaluating on positive examples for gold_kb, value in kb_dict.items():
if value: # only evaluating on positive examples
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb if value:
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
for ent in doc.ents: for ent in doc.ents:
label = ent.label_ label = ent.label_
@ -487,7 +489,7 @@ def run_el_toy_example(nlp):
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
"Douglas reminds us to always bring our towel, even in China or Brazil. " "Douglas reminds us to always bring our towel, even in China or Brazil. "
"The main character in Doug's novel is the man Arthur Dent, " "The main character in Doug's novel is the man Arthur Dent, "
"but Douglas doesn't write about George Washington or Homer Simpson." "but Dougledydoug doesn't write about George Washington or Homer Simpson."
) )
doc = nlp(text) doc = nlp(text)
print(text) print(text)

View File

@ -450,10 +450,11 @@ cdef class GoldParse:
examples of a label to have the value 0.0. Labels not in the examples of a label to have the value 0.0. Labels not in the
dictionary are treated as missing - the gradient for those labels dictionary are treated as missing - the gradient for those labels
will be zero. will be zero.
links (dict): A dict with `(start_char, end_char, kb_id)` keys, links (dict): A dict with `(start_char, end_char)` keys,
representing the external ID of an entity in a knowledge base, and the values being dicts with kb_id:value entries,
and the values being either 1.0 or 0.0, indicating positive and representing the external IDs in a knowledge base (KB)
negative examples, respectively. mapped to either 1.0 or 0.0, indicating positive and
negative examples respectively.
RETURNS (GoldParse): The newly constructed object. RETURNS (GoldParse): The newly constructed object.
""" """
if words is None: if words is None:

View File

@ -1076,6 +1076,7 @@ class EntityLinker(Pipe):
DOCS: TODO DOCS: TODO
""" """
name = 'entity_linker' name = 'entity_linker'
NIL = "NIL" # string used to refer to a non-existing link
@classmethod @classmethod
def Model(cls, **cfg): def Model(cls, **cfg):
@ -1151,27 +1152,28 @@ class EntityLinker(Pipe):
ents_by_offset = dict() ents_by_offset = dict()
for ent in doc.ents: for ent in doc.ents:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
for entity, value in gold.links.items(): for entity, kb_dict in gold.links.items():
start, end, kb_id = entity start, end = entity
mention = doc.text[start:end] mention = doc.text[start:end]
entity_encoding = self.kb.get_vector(kb_id) for kb_id, value in kb_dict.items():
prior_prob = self.kb.get_prior_prob(kb_id, mention) entity_encoding = self.kb.get_vector(kb_id)
prior_prob = self.kb.get_prior_prob(kb_id, mention)
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
assert gold_ent is not None assert gold_ent is not None
type_vector = [0 for i in range(len(type_to_int))] type_vector = [0 for i in range(len(type_to_int))]
if len(type_to_int) > 0: if len(type_to_int) > 0:
type_vector[type_to_int[gold_ent.label_]] = 1 type_vector[type_to_int[gold_ent.label_]] = 1
# store data # store data
entity_encodings.append(entity_encoding) entity_encodings.append(entity_encoding)
context_docs.append(doc) context_docs.append(doc)
type_vectors.append(type_vector) type_vectors.append(type_vector)
if self.cfg.get("prior_weight", 1) > 0: if self.cfg.get("prior_weight", 1) > 0:
priors.append([prior_prob]) priors.append([prior_prob])
else: else:
priors.append([0]) priors.append([0])
if len(entity_encodings) > 0: if len(entity_encodings) > 0:
assert len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors) assert len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)
@ -1197,8 +1199,9 @@ class EntityLinker(Pipe):
def get_loss(self, docs, golds, scores): def get_loss(self, docs, golds, scores):
cats = [] cats = []
for gold in golds: for gold in golds:
for entity, value in gold.links.items(): for entity, kb_dict in gold.links.items():
cats.append([value]) for kb_id, value in kb_dict.items():
cats.append([value])
cats = self.model.ops.asarray(cats, dtype="float32") cats = self.model.ops.asarray(cats, dtype="float32")
assert len(scores) == len(cats) assert len(scores) == len(cats)
@ -1209,26 +1212,27 @@ class EntityLinker(Pipe):
return loss, d_scores return loss, d_scores
def __call__(self, doc): def __call__(self, doc):
entities, kb_ids = self.predict([doc]) kb_ids = self.predict([doc])
self.set_annotations([doc], entities, kb_ids) self.set_annotations([doc], kb_ids)
return doc return doc
def pipe(self, stream, batch_size=128, n_threads=-1): def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in util.minibatch(stream, size=batch_size): for docs in util.minibatch(stream, size=batch_size):
docs = list(docs) docs = list(docs)
entities, kb_ids = self.predict(docs) kb_ids = self.predict(docs)
self.set_annotations(docs, entities, kb_ids) self.set_annotations(docs, kb_ids)
yield from docs yield from docs
def predict(self, docs): def predict(self, docs):
""" Return the KB IDs for each entity in each doc, including NIL if there is no prediction """
self.require_model() self.require_model()
self.require_kb() self.require_kb()
final_entities = [] entity_count = 0
final_kb_ids = [] final_kb_ids = []
if not docs: if not docs:
return final_entities, final_kb_ids return final_kb_ids
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
@ -1242,12 +1246,15 @@ class EntityLinker(Pipe):
if len(doc) > 0: if len(doc) > 0:
context_encoding = context_encodings[i] context_encoding = context_encodings[i]
for ent in doc.ents: for ent in doc.ents:
entity_count += 1
type_vector = [0 for i in range(len(type_to_int))] type_vector = [0 for i in range(len(type_to_int))]
if len(type_to_int) > 0: if len(type_to_int) > 0:
type_vector[type_to_int[ent.label_]] = 1 type_vector[type_to_int[ent.label_]] = 1
candidates = self.kb.get_candidates(ent.text) candidates = self.kb.get_candidates(ent.text)
if candidates: if not candidates:
final_kb_ids.append(self.NIL) # no prediction possible for this entity
else:
random.shuffle(candidates) random.shuffle(candidates)
# this will set the prior probabilities to 0 (just like in training) if their weight is 0 # this will set the prior probabilities to 0 (just like in training) if their weight is 0
@ -1266,15 +1273,20 @@ class EntityLinker(Pipe):
# TODO: thresholding # TODO: thresholding
best_index = scores.argmax() best_index = scores.argmax()
best_candidate = candidates[best_index] best_candidate = candidates[best_index]
final_entities.append(ent)
final_kb_ids.append(best_candidate.entity_) final_kb_ids.append(best_candidate.entity_)
return final_entities, final_kb_ids assert len(final_kb_ids) == entity_count
def set_annotations(self, docs, entities, kb_ids=None): return final_kb_ids
for entity, kb_id in zip(entities, kb_ids):
for token in entity: def set_annotations(self, docs, kb_ids, tensors=None):
token.ent_kb_id_ = kb_id i=0
for doc in docs:
for ent in doc.ents:
kb_id = kb_ids[i]
i += 1
for token in ent:
token.ent_kb_id_ = kb_id
def to_disk(self, path, exclude=tuple(), **kwargs): def to_disk(self, path, exclude=tuple(), **kwargs):
serialize = OrderedDict() serialize = OrderedDict()