diff --git a/bin/wiki_entity_linking/training_set_creator.py b/bin/wiki_entity_linking/training_set_creator.py index eb961b9da..e5530ecc7 100644 --- a/bin/wiki_entity_linking/training_set_creator.py +++ b/bin/wiki_entity_linking/training_set_creator.py @@ -397,33 +397,43 @@ def read_training(nlp, training_dir, dev, limit, kb=None): current_doc = None else: sent = found_ent.sent.as_doc() - # currently feeding the gold data one entity per sentence at a time + gold_start = int(start) - found_ent.sent.start_char gold_end = int(end) - found_ent.sent.start_char - # add both pos and neg examples (in random order) - # this will exclude examples not in the KB - if kb: - gold_entities = {} - candidates = kb.get_candidates(alias) - candidate_ids = [c.entity_ for c in candidates] - random.shuffle(candidate_ids) - for kb_id in candidate_ids: - entry = (gold_start, gold_end, kb_id) - if kb_id != wd_id: - gold_entities[entry] = 0.0 + gold_entities = {} + found_useful = False + for ent in sent.ents: + if ent.start_char == gold_start and ent.end_char == gold_end: + # add both pos and neg examples (in random order) + # this will exclude examples not in the KB + if kb: + value_by_id = {} + candidates = kb.get_candidates(alias) + candidate_ids = [c.entity_ for c in candidates] + random.shuffle(candidate_ids) + for kb_id in candidate_ids: + found_useful = True + if kb_id != wd_id: + value_by_id[kb_id] = 0.0 + else: + value_by_id[kb_id] = 1.0 + gold_entities[(ent.start_char, ent.end_char)] = value_by_id + # if no KB, keep all positive examples else: - gold_entities[entry] = 1.0 - # keep all positive examples - else: - entry = (gold_start, gold_end, wd_id) - gold_entities = {entry: 1.0} - - gold = GoldParse(doc=sent, links=gold_entities) - data.append((sent, gold)) - total_entities += 1 - if len(data) % 2500 == 0: - print(" -read", total_entities, "entities") + found_useful = True + value_by_id = {wd_id: 1.0} + gold_entities[(ent.start_char, ent.end_char)] = value_by_id + # currently feeding the gold data one entity per sentence at a time + # setting all other entities to empty gold dictionary + else: + gold_entities[(ent.start_char, ent.end_char)] = {} + if found_useful: + gold = GoldParse(doc=sent, links=gold_entities) + data.append((sent, gold)) + total_entities += 1 + if len(data) % 2500 == 0: + print(" -read", total_entities, "entities") print(" -read", total_entities, "entities") return data diff --git a/examples/pipeline/wikidata_entity_linking.py b/examples/pipeline/wikidata_entity_linking.py index ab9aa51fd..478d35111 100644 --- a/examples/pipeline/wikidata_entity_linking.py +++ b/examples/pipeline/wikidata_entity_linking.py @@ -322,11 +322,12 @@ def _measure_acc(data, el_pipe=None, error_analysis=False): for doc, gold in zip(docs, golds): try: correct_entries_per_article = dict() - for entity, value in gold.links.items(): + for entity, kb_dict in gold.links.items(): + start, end = entity # only evaluating on positive examples - if value: - start, end, gold_kb = entity - correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb + for gold_kb, value in kb_dict.items(): + if value: + correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb for ent in doc.ents: ent_label = ent.label_ @@ -379,11 +380,12 @@ def _measure_baselines(data, kb): for doc, gold in zip(docs, golds): try: correct_entries_per_article = dict() - for entity, value in gold.links.items(): - start, end, gold_kb = entity - # only evaluating on positive examples - if value: - correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb + for entity, kb_dict in gold.links.items(): + start, end = entity + for gold_kb, value in kb_dict.items(): + # only evaluating on positive examples + if value: + correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb for ent in doc.ents: label = ent.label_ @@ -487,7 +489,7 @@ def run_el_toy_example(nlp): "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " "Douglas reminds us to always bring our towel, even in China or Brazil. " "The main character in Doug's novel is the man Arthur Dent, " - "but Douglas doesn't write about George Washington or Homer Simpson." + "but Dougledydoug doesn't write about George Washington or Homer Simpson." ) doc = nlp(text) print(text) diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 81feb55a4..5459d5424 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -450,10 +450,11 @@ cdef class GoldParse: examples of a label to have the value 0.0. Labels not in the dictionary are treated as missing - the gradient for those labels will be zero. - links (dict): A dict with `(start_char, end_char, kb_id)` keys, - representing the external ID of an entity in a knowledge base, - and the values being either 1.0 or 0.0, indicating positive and - negative examples, respectively. + links (dict): A dict with `(start_char, end_char)` keys, + and the values being dicts with kb_id:value entries, + representing the external IDs in a knowledge base (KB) + mapped to either 1.0 or 0.0, indicating positive and + negative examples respectively. RETURNS (GoldParse): The newly constructed object. """ if words is None: diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 7b6bd0ea0..a8746c73d 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1076,6 +1076,7 @@ class EntityLinker(Pipe): DOCS: TODO """ name = 'entity_linker' + NIL = "NIL" # string used to refer to a non-existing link @classmethod def Model(cls, **cfg): @@ -1151,27 +1152,28 @@ class EntityLinker(Pipe): ents_by_offset = dict() for ent in doc.ents: ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent - for entity, value in gold.links.items(): - start, end, kb_id = entity + for entity, kb_dict in gold.links.items(): + start, end = entity mention = doc.text[start:end] - entity_encoding = self.kb.get_vector(kb_id) - prior_prob = self.kb.get_prior_prob(kb_id, mention) + for kb_id, value in kb_dict.items(): + entity_encoding = self.kb.get_vector(kb_id) + prior_prob = self.kb.get_prior_prob(kb_id, mention) - gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] - assert gold_ent is not None - type_vector = [0 for i in range(len(type_to_int))] - if len(type_to_int) > 0: - type_vector[type_to_int[gold_ent.label_]] = 1 + gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] + assert gold_ent is not None + type_vector = [0 for i in range(len(type_to_int))] + if len(type_to_int) > 0: + type_vector[type_to_int[gold_ent.label_]] = 1 - # store data - entity_encodings.append(entity_encoding) - context_docs.append(doc) - type_vectors.append(type_vector) + # store data + entity_encodings.append(entity_encoding) + context_docs.append(doc) + type_vectors.append(type_vector) - if self.cfg.get("prior_weight", 1) > 0: - priors.append([prior_prob]) - else: - priors.append([0]) + if self.cfg.get("prior_weight", 1) > 0: + priors.append([prior_prob]) + else: + priors.append([0]) if len(entity_encodings) > 0: assert len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors) @@ -1197,8 +1199,9 @@ class EntityLinker(Pipe): def get_loss(self, docs, golds, scores): cats = [] for gold in golds: - for entity, value in gold.links.items(): - cats.append([value]) + for entity, kb_dict in gold.links.items(): + for kb_id, value in kb_dict.items(): + cats.append([value]) cats = self.model.ops.asarray(cats, dtype="float32") assert len(scores) == len(cats) @@ -1209,26 +1212,27 @@ class EntityLinker(Pipe): return loss, d_scores def __call__(self, doc): - entities, kb_ids = self.predict([doc]) - self.set_annotations([doc], entities, kb_ids) + kb_ids = self.predict([doc]) + self.set_annotations([doc], kb_ids) return doc def pipe(self, stream, batch_size=128, n_threads=-1): for docs in util.minibatch(stream, size=batch_size): docs = list(docs) - entities, kb_ids = self.predict(docs) - self.set_annotations(docs, entities, kb_ids) + kb_ids = self.predict(docs) + self.set_annotations(docs, kb_ids) yield from docs def predict(self, docs): + """ Return the KB IDs for each entity in each doc, including NIL if there is no prediction """ self.require_model() self.require_kb() - final_entities = [] + entity_count = 0 final_kb_ids = [] if not docs: - return final_entities, final_kb_ids + return final_kb_ids if isinstance(docs, Doc): docs = [docs] @@ -1242,12 +1246,15 @@ class EntityLinker(Pipe): if len(doc) > 0: context_encoding = context_encodings[i] for ent in doc.ents: + entity_count += 1 type_vector = [0 for i in range(len(type_to_int))] if len(type_to_int) > 0: type_vector[type_to_int[ent.label_]] = 1 candidates = self.kb.get_candidates(ent.text) - if candidates: + if not candidates: + final_kb_ids.append(self.NIL) # no prediction possible for this entity + else: random.shuffle(candidates) # this will set the prior probabilities to 0 (just like in training) if their weight is 0 @@ -1266,15 +1273,20 @@ class EntityLinker(Pipe): # TODO: thresholding best_index = scores.argmax() best_candidate = candidates[best_index] - final_entities.append(ent) final_kb_ids.append(best_candidate.entity_) - return final_entities, final_kb_ids + assert len(final_kb_ids) == entity_count - def set_annotations(self, docs, entities, kb_ids=None): - for entity, kb_id in zip(entities, kb_ids): - for token in entity: - token.ent_kb_id_ = kb_id + return final_kb_ids + + def set_annotations(self, docs, kb_ids, tensors=None): + i=0 + for doc in docs: + for ent in doc.ents: + kb_id = kb_ids[i] + i += 1 + for token in ent: + token.ent_kb_id_ = kb_id def to_disk(self, path, exclude=tuple(), **kwargs): serialize = OrderedDict()