mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
storing NEL training data in GoldParse objects
This commit is contained in:
parent
61f0e2af65
commit
a5c061f506
|
@ -1,3 +1,4 @@
|
||||||
|
# coding: utf-8
|
||||||
from random import shuffle
|
from random import shuffle
|
||||||
|
|
||||||
from examples.pipeline.wiki_entity_linking import kb_creator
|
from examples.pipeline.wiki_entity_linking import kb_creator
|
||||||
|
|
|
@ -1,11 +1,15 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import os
|
||||||
import re
|
import re
|
||||||
import csv
|
|
||||||
import bz2
|
import bz2
|
||||||
import datetime
|
import datetime
|
||||||
|
from os import listdir
|
||||||
|
|
||||||
|
from examples.pipeline.wiki_entity_linking import run_el
|
||||||
|
from spacy.gold import GoldParse
|
||||||
|
from spacy.matcher import PhraseMatcher
|
||||||
from . import wikipedia_processor as wp, kb_creator
|
from . import wikipedia_processor as wp, kb_creator
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
@ -294,5 +298,62 @@ def read_training_entities(training_output, collect_correct=True, collect_incorr
|
||||||
return correct_entries_per_article, incorrect_entries_per_article
|
return correct_entries_per_article, incorrect_entries_per_article
|
||||||
|
|
||||||
|
|
||||||
|
def read_training(nlp, training_dir, id_to_descr, doc_cutoff, dev, limit, to_print):
|
||||||
|
correct_entries, incorrect_entries = read_training_entities(training_output=training_dir,
|
||||||
|
collect_correct=True,
|
||||||
|
collect_incorrect=True)
|
||||||
|
|
||||||
|
docs = list()
|
||||||
|
golds = list()
|
||||||
|
|
||||||
|
cnt = 0
|
||||||
|
next_entity_nr = 1
|
||||||
|
files = listdir(training_dir)
|
||||||
|
for f in files:
|
||||||
|
if not limit or cnt < limit:
|
||||||
|
if dev == run_el.is_dev(f):
|
||||||
|
article_id = f.replace(".txt", "")
|
||||||
|
if cnt % 500 == 0 and to_print:
|
||||||
|
print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# parse the article text
|
||||||
|
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
|
||||||
|
text = file.read()
|
||||||
|
article_doc = nlp(text)
|
||||||
|
truncated_text = text[0:min(doc_cutoff, len(text))]
|
||||||
|
|
||||||
|
gold_entities = dict()
|
||||||
|
|
||||||
|
# process all positive and negative entities, collect all relevant mentions in this article
|
||||||
|
for mention, entity_pos in correct_entries[article_id].items():
|
||||||
|
# find all matches in the doc for the mentions
|
||||||
|
# TODO: fix this - doesn't look like all entities are found
|
||||||
|
matcher = PhraseMatcher(nlp.vocab)
|
||||||
|
patterns = list(nlp.tokenizer.pipe([mention]))
|
||||||
|
|
||||||
|
matcher.add("TerminologyList", None, *patterns)
|
||||||
|
matches = matcher(article_doc)
|
||||||
|
|
||||||
|
# store gold entities
|
||||||
|
for match_id, start, end in matches:
|
||||||
|
gold_entities[(start, end, entity_pos)] = 1.0
|
||||||
|
|
||||||
|
gold = GoldParse(doc=article_doc, cats=gold_entities)
|
||||||
|
docs.append(article_doc)
|
||||||
|
golds.append(gold)
|
||||||
|
|
||||||
|
cnt += 1
|
||||||
|
except Exception as e:
|
||||||
|
print("Problem parsing article", article_id)
|
||||||
|
print(e)
|
||||||
|
|
||||||
|
if to_print:
|
||||||
|
print()
|
||||||
|
print("Processed", cnt, "training articles, dev=" + str(dev))
|
||||||
|
print()
|
||||||
|
return docs, golds
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -23,6 +23,9 @@ VOCAB_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/vocab'
|
||||||
|
|
||||||
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
||||||
|
|
||||||
|
MAX_CANDIDATES=10
|
||||||
|
MIN_PAIR_OCC=5
|
||||||
|
DOC_CHAR_CUTOFF=300
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
print("START", datetime.datetime.now())
|
print("START", datetime.datetime.now())
|
||||||
|
@ -71,8 +74,8 @@ if __name__ == "__main__":
|
||||||
if to_create_kb:
|
if to_create_kb:
|
||||||
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
||||||
my_kb = kb_creator.create_kb(nlp,
|
my_kb = kb_creator.create_kb(nlp,
|
||||||
max_entities_per_alias=10,
|
max_entities_per_alias=MAX_CANDIDATES,
|
||||||
min_occ=5,
|
min_occ=MIN_PAIR_OCC,
|
||||||
entity_def_output=ENTITY_DEFS,
|
entity_def_output=ENTITY_DEFS,
|
||||||
entity_descr_output=ENTITY_DESCR,
|
entity_descr_output=ENTITY_DESCR,
|
||||||
count_input=ENTITY_COUNTS,
|
count_input=ENTITY_COUNTS,
|
||||||
|
@ -110,10 +113,29 @@ if __name__ == "__main__":
|
||||||
|
|
||||||
# STEP 6: create the entity linking pipe
|
# STEP 6: create the entity linking pipe
|
||||||
if train_pipe:
|
if train_pipe:
|
||||||
# TODO: the vocab objects are now different between nlp and kb - will be fixed when KB is written as part of NLP IO
|
id_to_descr = kb_creator._get_id_to_description(ENTITY_DESCR)
|
||||||
|
|
||||||
|
docs, golds = training_set_creator.read_training(nlp=nlp,
|
||||||
|
training_dir=TRAINING_DIR,
|
||||||
|
id_to_descr=id_to_descr,
|
||||||
|
doc_cutoff=DOC_CHAR_CUTOFF,
|
||||||
|
dev=False,
|
||||||
|
limit=10,
|
||||||
|
to_print=False)
|
||||||
|
|
||||||
|
# for doc, gold in zip(docs, golds):
|
||||||
|
# print("doc", doc)
|
||||||
|
# for entity, label in gold.cats.items():
|
||||||
|
# print("entity", entity, label)
|
||||||
|
# print()
|
||||||
|
|
||||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb})
|
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": my_kb})
|
||||||
nlp.add_pipe(el_pipe, last=True)
|
nlp.add_pipe(el_pipe, last=True)
|
||||||
|
|
||||||
|
other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"]
|
||||||
|
with nlp.disable_pipes(*other_pipes): # only train Entity Linking
|
||||||
|
nlp.begin_training()
|
||||||
|
|
||||||
### BELOW CODE IS DEPRECATED ###
|
### BELOW CODE IS DEPRECATED ###
|
||||||
|
|
||||||
# STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx
|
# STEP 6: apply the EL algorithm on the training dataset - TODO deprecated - code moved to pipes.pyx
|
||||||
|
|
|
@ -82,6 +82,11 @@ cdef class KnowledgeBase:
|
||||||
self.vocab.strings.add("")
|
self.vocab.strings.add("")
|
||||||
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
|
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entity_vector_length(self):
|
||||||
|
"""RETURNS (uint64): length of the entity vectors"""
|
||||||
|
return self.entity_vector_length
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.get_size_entities()
|
return self.get_size_entities()
|
||||||
|
|
||||||
|
|
|
@ -1081,8 +1081,8 @@ class EntityLinker(Pipe):
|
||||||
sent_width = cfg.get("sent_width", 64)
|
sent_width = cfg.get("sent_width", 64)
|
||||||
entity_width = cfg["kb"].entity_vector_length
|
entity_width = cfg["kb"].entity_vector_length
|
||||||
|
|
||||||
article_encoder = build_nel_encoder(in_width=embed_width, hidden_with=hidden_width, end_width=article_width)
|
article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
|
||||||
sent_encoder = build_nel_encoder(in_width=embed_width, hidden_with=hidden_width, end_width=sent_width)
|
sent_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
|
||||||
|
|
||||||
# dimension of the mention encoder needs to match the dimension of the entity encoder
|
# dimension of the mention encoder needs to match the dimension of the entity encoder
|
||||||
mention_width = article_width + sent_width
|
mention_width = article_width + sent_width
|
||||||
|
@ -1118,6 +1118,10 @@ class EntityLinker(Pipe):
|
||||||
""" docs should be a tuple of (entity_docs, article_docs, sentence_docs) TODO """
|
""" docs should be a tuple of (entity_docs, article_docs, sentence_docs) TODO """
|
||||||
self.require_model()
|
self.require_model()
|
||||||
|
|
||||||
|
if len(docs) != len(golds):
|
||||||
|
raise ValueError(Errors.E077.format(value="loss", n_docs=len(docs),
|
||||||
|
n_golds=len(golds)))
|
||||||
|
|
||||||
entity_docs, article_docs, sentence_docs = docs
|
entity_docs, article_docs, sentence_docs = docs
|
||||||
assert len(entity_docs) == len(article_docs) == len(sentence_docs)
|
assert len(entity_docs) == len(article_docs) == len(sentence_docs)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user