storing NEL training data in GoldParse objects

This commit is contained in:
svlandeg 2019-06-07 12:58:42 +02:00
parent 61f0e2af65
commit a5c061f506
5 changed files with 99 additions and 6 deletions

View File

@ -1,3 +1,4 @@
# coding: utf-8
from random import shuffle
from examples.pipeline.wiki_entity_linking import kb_creator

View File

@ -1,11 +1,15 @@
# coding: utf-8
from __future__ import unicode_literals
import os
import re
import csv
import bz2
import datetime
from os import listdir
from examples.pipeline.wiki_entity_linking import run_el
from spacy.gold import GoldParse
from spacy.matcher import PhraseMatcher
from . import wikipedia_processor as wp, kb_creator
"""
@ -294,5 +298,62 @@ def read_training_entities(training_output, collect_correct=True, collect_incorr
return correct_entries_per_article, incorrect_entries_per_article
def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_print):
correct_entries, incorrect_entries = read_training_entities(training_output=training_dir,
collect_correct=True,
collect_incorrect=True)
docs = list()
golds = list()
cnt = 0
next_entity_nr = 1
files = listdir(training_dir)
for f in files:
if not limit or cnt < limit:
if dev == run_el.is_dev(f):
article_id = f.replace(".txt", "")
if cnt % 500 == 0 and to_print:
print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
try:
# parse the article text
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
text = file.read()
article_doc = nlp(text)
truncated_text = text[0:min(doc_cutoff, len(text))]
gold_entities = dict()
# process all positive and negative entities, collect all relevant mentions in this article
for mention, entity_pos in correct_entries[article_id].items():
# find all matches in the doc for the mentions
# TODO: fix this - doesn't look like all entities are found
matcher = PhraseMatcher(nlp.vocab)
patterns = list(nlp.tokenizer.pipe([mention]))
matcher.add("TerminologyList", None, *patterns)
matches = matcher(article_doc)
# store gold entities
for match_id, start, end in matches:
gold_entities[(start, end, entity_pos)] = 1.0
gold = GoldParse(doc=article_doc, cats=gold_entities)
docs.append(article_doc)
golds.append(gold)
cnt += 1
except Exception as e:
print("Problem parsing article", article_id)
print(e)
if to_print:
print()
print("Processed", cnt, "training articles, dev=" + str(dev))
print()
return docs, golds

View File

@ -23,6 +23,9 @@ VOCAB_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/vocab'
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
MAX_CANDIDATES=10
MIN_PAIR_OCC=5
DOC_CHAR_CUTOFF=300
if __name__ == "__main__":
print("START", datetime.datetime.now())
@ -71,8 +74,8 @@ if __name__ == "__main__":
if to_create_kb:
print("STEP 3a: to_create_kb", datetime.datetime.now())
my_kb = kb_creator.create_kb(nlp,
max_entities_per_alias=10,
min_occ=5,
max_entities_per_alias=MAX_CANDIDATES,
min_occ=MIN_PAIR_OCC,
entity_def_output=ENTITY_DEFS,
entity_descr_output=ENTITY_DESCR,
count_input=ENTITY_COUNTS,
@ -110,10 +113,29 @@ if __name__ == "__main__":
# STEP 6: create the entity linking pipe
if train_pipe:
# TODO: the vocab objects are now different between nlp and kb - will be fixed when KB is written as part of NLP IO
id_to_descr = kb_creator._get_id_to_description(ENTITY_DESCR)
docs, golds = training_set_creator.read_training(nlp=nlp,
training_dir=TRAINING_DIR,
id_to_descr=id_to_descr,
doc_cutoff=DOC_CHAR_CUTOFF,
dev=False,
limit=10,
to_print=False)
# for doc, gold in zip(docs, golds):
# print("doc", doc)
# for entity, label in gold.cats.items():
# print("entity", entity, label)
# print()
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb})
nlp.add_pipe(el_pipe, last=True)
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"]
with nlp.disable_pipes(*other_pipes): # only train Entity Linking
nlp.begin_training()
### BELOW CODE IS DEPRECATED ###
# STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx

View File

@ -82,6 +82,11 @@ cdef class KnowledgeBase:
self.vocab.strings.add("")
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
@property
def entity_vector_length(self):
"""RETURNS (uint64): length of the entity vectors"""
return self.entity_vector_length
def __len__(self):
return self.get_size_entities()

View File

@ -1081,8 +1081,8 @@ class EntityLinker(Pipe):
sent_width = cfg.get("sent_width", 64)
entity_width = cfg["kb"].entity_vector_length
article_encoder = build_nel_encoder(in_width=embed_width, hidden_with=hidden_width, end_width=article_width)
sent_encoder = build_nel_encoder(in_width=embed_width, hidden_with=hidden_width, end_width=sent_width)
article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
sent_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
# dimension of the mention encoder needs to match the dimension of the entity encoder
mention_width = article_width + sent_width
@ -1118,6 +1118,10 @@ class EntityLinker(Pipe):
""" docs should be a tuple of (entity_docs, article_docs, sentence_docs) TODO """
self.require_model()
if len(docs) != len(golds):
raise ValueError(Errors.E077.format(value="loss", n_docs=len(docs),
n_golds=len(golds)))
entity_docs, article_docs, sentence_docs = docs
assert len(entity_docs) == len(article_docs) == len(sentence_docs)