mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
storing NEL training data in GoldParse objects
This commit is contained in:
parent
61f0e2af65
commit
a5c061f506
|
@ -1,3 +1,4 @@
|
|||
# coding: utf-8
|
||||
from random import shuffle
|
||||
|
||||
from examples.pipeline.wiki_entity_linking import kb_creator
|
||||
|
|
|
@ -1,11 +1,15 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import re
|
||||
import csv
|
||||
import bz2
|
||||
import datetime
|
||||
from os import listdir
|
||||
|
||||
from examples.pipeline.wiki_entity_linking import run_el
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.matcher import PhraseMatcher
|
||||
from . import wikipedia_processor as wp, kb_creator
|
||||
|
||||
"""
|
||||
|
@ -294,5 +298,62 @@ def read_training_entities(training_output, collect_correct=True, collect_incorr
|
|||
return correct_entries_per_article, incorrect_entries_per_article
|
||||
|
||||
|
||||
def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_print):
|
||||
correct_entries, incorrect_entries = read_training_entities(training_output=training_dir,
|
||||
collect_correct=True,
|
||||
collect_incorrect=True)
|
||||
|
||||
docs = list()
|
||||
golds = list()
|
||||
|
||||
cnt = 0
|
||||
next_entity_nr = 1
|
||||
files = listdir(training_dir)
|
||||
for f in files:
|
||||
if not limit or cnt < limit:
|
||||
if dev == run_el.is_dev(f):
|
||||
article_id = f.replace(".txt", "")
|
||||
if cnt % 500 == 0 and to_print:
|
||||
print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
|
||||
|
||||
try:
|
||||
# parse the article text
|
||||
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
|
||||
text = file.read()
|
||||
article_doc = nlp(text)
|
||||
truncated_text = text[0:min(doc_cutoff, len(text))]
|
||||
|
||||
gold_entities = dict()
|
||||
|
||||
# process all positive and negative entities, collect all relevant mentions in this article
|
||||
for mention, entity_pos in correct_entries[article_id].items():
|
||||
# find all matches in the doc for the mentions
|
||||
# TODO: fix this - doesn't look like all entities are found
|
||||
matcher = PhraseMatcher(nlp.vocab)
|
||||
patterns = list(nlp.tokenizer.pipe([mention]))
|
||||
|
||||
matcher.add("TerminologyList", None, *patterns)
|
||||
matches = matcher(article_doc)
|
||||
|
||||
# store gold entities
|
||||
for match_id, start, end in matches:
|
||||
gold_entities[(start, end, entity_pos)] = 1.0
|
||||
|
||||
gold = GoldParse(doc=article_doc, cats=gold_entities)
|
||||
docs.append(article_doc)
|
||||
golds.append(gold)
|
||||
|
||||
cnt += 1
|
||||
except Exception as e:
|
||||
print("Problem parsing article", article_id)
|
||||
print(e)
|
||||
|
||||
if to_print:
|
||||
print()
|
||||
print("Processed", cnt, "training articles, dev=" + str(dev))
|
||||
print()
|
||||
return docs, golds
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -23,6 +23,9 @@ VOCAB_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/vocab'
|
|||
|
||||
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
||||
|
||||
MAX_CANDIDATES=10
|
||||
MIN_PAIR_OCC=5
|
||||
DOC_CHAR_CUTOFF=300
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("START", datetime.datetime.now())
|
||||
|
@ -71,8 +74,8 @@ if __name__ == "__main__":
|
|||
if to_create_kb:
|
||||
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
||||
my_kb = kb_creator.create_kb(nlp,
|
||||
max_entities_per_alias=10,
|
||||
min_occ=5,
|
||||
max_entities_per_alias=MAX_CANDIDATES,
|
||||
min_occ=MIN_PAIR_OCC,
|
||||
entity_def_output=ENTITY_DEFS,
|
||||
entity_descr_output=ENTITY_DESCR,
|
||||
count_input=ENTITY_COUNTS,
|
||||
|
@ -110,10 +113,29 @@ if __name__ == "__main__":
|
|||
|
||||
# STEP 6: create the entity linking pipe
|
||||
if train_pipe:
|
||||
# TODO: the vocab objects are now different between nlp and kb - will be fixed when KB is written as part of NLP IO
|
||||
id_to_descr = kb_creator._get_id_to_description(ENTITY_DESCR)
|
||||
|
||||
docs, golds = training_set_creator.read_training(nlp=nlp,
|
||||
training_dir=TRAINING_DIR,
|
||||
id_to_descr=id_to_descr,
|
||||
doc_cutoff=DOC_CHAR_CUTOFF,
|
||||
dev=False,
|
||||
limit=10,
|
||||
to_print=False)
|
||||
|
||||
# for doc, gold in zip(docs, golds):
|
||||
# print("doc", doc)
|
||||
# for entity, label in gold.cats.items():
|
||||
# print("entity", entity, label)
|
||||
# print()
|
||||
|
||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb})
|
||||
nlp.add_pipe(el_pipe, last=True)
|
||||
|
||||
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"]
|
||||
with nlp.disable_pipes(*other_pipes): # only train Entity Linking
|
||||
nlp.begin_training()
|
||||
|
||||
### BELOW CODE IS DEPRECATED ###
|
||||
|
||||
# STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx
|
||||
|
|
|
@ -82,6 +82,11 @@ cdef class KnowledgeBase:
|
|||
self.vocab.strings.add("")
|
||||
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
|
||||
|
||||
@property
|
||||
def entity_vector_length(self):
|
||||
"""RETURNS (uint64): length of the entity vectors"""
|
||||
return self.entity_vector_length
|
||||
|
||||
def __len__(self):
|
||||
return self.get_size_entities()
|
||||
|
||||
|
|
|
@ -1081,8 +1081,8 @@ class EntityLinker(Pipe):
|
|||
sent_width = cfg.get("sent_width", 64)
|
||||
entity_width = cfg["kb"].entity_vector_length
|
||||
|
||||
article_encoder = build_nel_encoder(in_width=embed_width, hidden_with=hidden_width, end_width=article_width)
|
||||
sent_encoder = build_nel_encoder(in_width=embed_width, hidden_with=hidden_width, end_width=sent_width)
|
||||
article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
|
||||
sent_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
|
||||
|
||||
# dimension of the mention encoder needs to match the dimension of the entity encoder
|
||||
mention_width = article_width + sent_width
|
||||
|
@ -1118,6 +1118,10 @@ class EntityLinker(Pipe):
|
|||
""" docs should be a tuple of (entity_docs, article_docs, sentence_docs) TODO """
|
||||
self.require_model()
|
||||
|
||||
if len(docs) != len(golds):
|
||||
raise ValueError(Errors.E077.format(value="loss", n_docs=len(docs),
|
||||
n_golds=len(golds)))
|
||||
|
||||
entity_docs, article_docs, sentence_docs = docs
|
||||
assert len(entity_docs) == len(article_docs) == len(sentence_docs)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user