From e1213eaf6af1a19b00e9140105982f1a587ae4a6 Mon Sep 17 00:00:00 2001 From: svlandeg Date: Thu, 18 Jul 2019 13:35:10 +0200 Subject: [PATCH] use original gold object in get_loss function --- examples/pipeline/wikidata_entity_linking.py | 12 +++++------ spacy/pipeline/pipes.pyx | 21 ++++++++++++-------- 2 files changed, 18 insertions(+), 15 deletions(-) diff --git a/examples/pipeline/wikidata_entity_linking.py b/examples/pipeline/wikidata_entity_linking.py index 341dc94ed..ab9aa51fd 100644 --- a/examples/pipeline/wikidata_entity_linking.py +++ b/examples/pipeline/wikidata_entity_linking.py @@ -295,11 +295,7 @@ def run_pipeline(): dev_limit = 5000 dev_data = training_set_creator.read_training( - nlp=nlp_2, - training_dir=TRAINING_DIR, - dev=True, - limit=dev_limit, - kb=el_pipe.kb, + nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None ) print("Dev testing from file on", len(dev_data), "articles") @@ -383,9 +379,11 @@ def _measure_baselines(data, kb): for doc, gold in zip(docs, golds): try: correct_entries_per_article = dict() - for entity in gold.links: + for entity, value in gold.links.items(): start, end, gold_kb = entity - correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb + # only evaluating on positive examples + if value: + correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb for ent in doc.ents: label = ent.label_ diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index b3f384437..7b6bd0ea0 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1141,7 +1141,7 @@ class EntityLinker(Pipe): context_docs = [] entity_encodings = [] - cats = [] + priors = [] type_vectors = [] @@ -1173,12 +1173,9 @@ class EntityLinker(Pipe): else: priors.append([0]) - cats.append([value]) - if len(entity_encodings) > 0: - assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors) + assert len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors) - cats = self.model.ops.asarray(cats, dtype="float32") entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop) @@ -1186,7 +1183,7 @@ class EntityLinker(Pipe): for i in range(len(entity_encodings))] pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop) - loss, d_scores = self.get_loss(scores=pred, golds=cats, docs=docs) + loss, d_scores = self.get_loss(scores=pred, golds=golds, docs=docs) mention_gradient = bp_mention(d_scores, sgd=sgd) context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient] @@ -1198,9 +1195,17 @@ class EntityLinker(Pipe): return 0 def get_loss(self, docs, golds, scores): - d_scores = (scores - golds) + cats = [] + for gold in golds: + for entity, value in gold.links.items(): + cats.append([value]) + + cats = self.model.ops.asarray(cats, dtype="float32") + assert len(scores) == len(cats) + + d_scores = (scores - cats) loss = (d_scores ** 2).sum() - loss = loss / len(golds) + loss = loss / len(cats) return loss, d_scores def __call__(self, doc):