diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py index 8753450bb..90218edda 100644 --- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py +++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py @@ -115,6 +115,7 @@ def run_pipeline(): # STEP 6: create the entity linking pipe if train_pipe: + print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) train_limit = 100 dev_limit = 20 print("Training on", train_limit, "articles") @@ -131,7 +132,7 @@ def run_pipeline(): training_dir=TRAINING_DIR, dev=True, limit=dev_limit, - to_print=False) + to_print=False) el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb, "doc_cutoff": DOC_CHAR_CUTOFF}) nlp.add_pipe(el_pipe, last=True) @@ -147,35 +148,40 @@ def run_pipeline(): with nlp.disable_pipes(*other_pipes): for batch in batches: - docs, golds = zip(*batch) - nlp.update( - docs, - golds, - drop=DROPOUT, - losses=losses, - ) + try: + docs, golds = zip(*batch) + nlp.update( + docs, + golds, + drop=DROPOUT, + losses=losses, + ) + except Exception as e: + print("Error updating batch", e) - # print(" measuring accuracy 1-1") - el_pipe.context_weight = 1 - el_pipe.prior_weight = 1 - dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe) - train_acc_1_1 = _measure_accuracy(train_data, el_pipe) + print("Epoch, train loss", itn, round(losses['entity_linker'], 2)) - # print(" measuring accuracy 0-1") - el_pipe.context_weight = 0 - el_pipe.prior_weight = 1 - dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe) - train_acc_0_1 = _measure_accuracy(train_data, el_pipe) + # baseline using only prior probabilities + el_pipe.context_weight = 0 + el_pipe.prior_weight = 1 + dev_acc_0_1 = _measure_accuracy(dev_data, el_pipe) + train_acc_0_1 = _measure_accuracy(train_data, el_pipe) - # print(" measuring accuracy 1-0") - el_pipe.context_weight = 1 - el_pipe.prior_weight = 0 - dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe) - train_acc_1_0 = _measure_accuracy(train_data, el_pipe) + # print(" measuring accuracy 1-1") + el_pipe.context_weight = 1 + el_pipe.prior_weight = 1 + dev_acc_1_1 = _measure_accuracy(dev_data, el_pipe) + train_acc_1_1 = _measure_accuracy(train_data, el_pipe) - print("Epoch, train loss, train/dev acc, 1-1, 0-1, 1-0:", itn, round(losses['entity_linker'], 2), - round(train_acc_1_1, 2), round(train_acc_0_1, 2), round(train_acc_1_0, 2), "/", - round(dev_acc_1_1, 2), round(dev_acc_0_1, 2), round(dev_acc_1_0, 2)) + # print(" measuring accuracy 1-0") + el_pipe.context_weight = 1 + el_pipe.prior_weight = 0 + dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe) + train_acc_1_0 = _measure_accuracy(train_data, el_pipe) + + print("train/dev acc, 1-1, 0-1, 1-0:" , + round(train_acc_1_1, 2), round(train_acc_0_1, 2), round(train_acc_1_0, 2), "/", + round(dev_acc_1_1, 2), round(dev_acc_0_1, 2), round(dev_acc_1_0, 2)) # test Entity Linker if to_test_pipeline: @@ -193,26 +199,29 @@ def _measure_accuracy(data, el_pipe): docs = [d for d, g in data] docs = el_pipe.pipe(docs) - golds = [g for d, g in data] for doc, gold in zip(docs, golds): - correct_entries_per_article = dict() - for entity in gold.links: - start, end, gold_kb = entity - correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb + try: + correct_entries_per_article = dict() + for entity in gold.links: + start, end, gold_kb = entity + correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb - for ent in doc.ents: - if ent.label_ == "PERSON": # TODO: expand to other types - pred_entity = ent.kb_id_ - start = ent.start - end = ent.end - gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) - if gold_entity is not None: - if gold_entity == pred_entity: - correct += 1 - else: - incorrect += 1 + for ent in doc.ents: + if ent.label_ == "PERSON": # TODO: expand to other types + pred_entity = ent.kb_id_ + start = ent.start + end = ent.end + gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) + if gold_entity is not None: + if gold_entity == pred_entity: + correct += 1 + else: + incorrect += 1 + + except Exception as e: + print("Error assessing accuracy", e) if correct == incorrect == 0: return 0 @@ -243,4 +252,4 @@ def run_el_toy_example(nlp, kb): if __name__ == "__main__": - run_pipeline() + run_pipeline() \ No newline at end of file diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 9ef9df601..deaab0a19 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1220,8 +1220,13 @@ class EntityLinker(Pipe): def predict(self, docs): self.require_model() + + if isinstance(docs, Doc): + docs = [docs] + final_entities = list() final_kb_ids = list() + for i, article_doc in enumerate(docs): doc_encoding = self.article_encoder([article_doc]) for ent in article_doc.ents: