mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
pretraining description vectors and storing them in the KB
This commit is contained in:
parent
5c723c32c3
commit
d8b435ceff
|
@ -2,6 +2,7 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import spacy
|
||||
from examples.pipeline.wiki_entity_linking.train_descriptions import EntityEncoder
|
||||
from spacy.kb import KnowledgeBase
|
||||
|
||||
import csv
|
||||
|
@ -10,25 +11,47 @@ import datetime
|
|||
from . import wikipedia_processor as wp
|
||||
from . import wikidata_processor as wd
|
||||
|
||||
INPUT_DIM = 300 # dimension of pre-trained vectors
|
||||
DESC_WIDTH = 64
|
||||
|
||||
def create_kb(vocab, max_entities_per_alias, min_occ,
|
||||
def create_kb(nlp, max_entities_per_alias, min_occ,
|
||||
entity_def_output, entity_descr_output,
|
||||
count_input, prior_prob_input,
|
||||
to_print=False, write_entity_defs=True):
|
||||
count_input, prior_prob_input, to_print=False):
|
||||
""" Create the knowledge base from Wikidata entries """
|
||||
kb = KnowledgeBase(vocab=vocab, entity_vector_length=64) # TODO: entity vectors !
|
||||
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
|
||||
|
||||
print()
|
||||
print("1. _read_wikidata_entities", datetime.datetime.now())
|
||||
print()
|
||||
title_to_id, id_to_descr = wd.read_wikidata_entities_json(limit=None)
|
||||
# disable parts of the pipeline when rerunning
|
||||
read_raw_data = False
|
||||
|
||||
# write the title-ID and ID-description mappings to file
|
||||
if write_entity_defs:
|
||||
if read_raw_data:
|
||||
print()
|
||||
print("1. _read_wikidata_entities", datetime.datetime.now())
|
||||
print()
|
||||
title_to_id, id_to_descr = wd.read_wikidata_entities_json(limit=None)
|
||||
|
||||
# write the title-ID and ID-description mappings to file
|
||||
_write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr)
|
||||
|
||||
else:
|
||||
# read the mappings from file
|
||||
title_to_id = _get_entity_to_id(entity_def_output)
|
||||
id_to_descr = _get_id_to_description(entity_descr_output)
|
||||
|
||||
title_list = list(title_to_id.keys())
|
||||
|
||||
# TODO: remove this filter (just for quicker testing of code)
|
||||
title_list = title_list[0:34200]
|
||||
title_to_id = {t: title_to_id[t] for t in title_list}
|
||||
|
||||
# print("title_list", len(title_list), title_list[0:3])
|
||||
|
||||
entity_list = [title_to_id[x] for x in title_list]
|
||||
# print("entity_list", len(entity_list), entity_list[0:3])
|
||||
|
||||
# TODO: should we remove entities from the KB where there is no description ?
|
||||
description_list = [id_to_descr.get(x, "No description defined") for x in entity_list]
|
||||
# print("description_list", len(description_list), description_list[0:3])
|
||||
|
||||
|
||||
print()
|
||||
print("2. _get_entity_frequencies", datetime.datetime.now())
|
||||
|
@ -36,13 +59,27 @@ def create_kb(vocab, max_entities_per_alias, min_occ,
|
|||
entity_frequencies = wp.get_entity_frequencies(count_input=count_input, entities=title_list)
|
||||
|
||||
print()
|
||||
print("3. adding", len(entity_list), "entities", datetime.datetime.now())
|
||||
print("3. train entity encoder", datetime.datetime.now())
|
||||
print()
|
||||
# TODO: vector_list !
|
||||
kb.set_entities(entity_list=entity_list, prob_list=entity_frequencies, vector_list=None)
|
||||
|
||||
encoder = EntityEncoder(nlp, INPUT_DIM, DESC_WIDTH)
|
||||
encoder.train(description_list=description_list, to_print=True)
|
||||
print()
|
||||
|
||||
print("4. get entity embeddings", datetime.datetime.now())
|
||||
print()
|
||||
embeddings = encoder.apply_encoder(description_list)
|
||||
# print("descriptions", description_list[0:3])
|
||||
# print("embeddings", len(embeddings), embeddings[0:3])
|
||||
#print("embeddings[0]", len(embeddings[0]), embeddings[0][0:3])
|
||||
|
||||
print()
|
||||
print("4. adding aliases", datetime.datetime.now())
|
||||
print("5. adding", len(entity_list), "entities", datetime.datetime.now())
|
||||
print()
|
||||
kb.set_entities(entity_list=entity_list, prob_list=entity_frequencies, vector_list=embeddings)
|
||||
|
||||
print()
|
||||
print("6. adding aliases", datetime.datetime.now())
|
||||
print()
|
||||
_add_aliases(kb, title_to_id=title_to_id,
|
||||
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
|
||||
|
@ -67,7 +104,6 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_
|
|||
for qid, descr in id_to_descr.items():
|
||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||
|
||||
|
||||
def _get_entity_to_id(entity_def_output):
|
||||
entity_to_id = dict()
|
||||
with open(entity_def_output, 'r', encoding='utf8') as csvfile:
|
||||
|
@ -99,11 +135,11 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
|||
print("wp titles:", wp_titles)
|
||||
|
||||
# adding aliases with prior probabilities
|
||||
# we can read this file sequentially, it's sorted by alias, and then by count
|
||||
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
|
||||
# skip header
|
||||
prior_file.readline()
|
||||
line = prior_file.readline()
|
||||
# we can read this file sequentially, it's sorted by alias, and then by count
|
||||
previous_alias = None
|
||||
total_count = 0
|
||||
counts = list()
|
||||
|
|
|
@ -12,6 +12,15 @@ from examples.pipeline.wiki_entity_linking import training_set_creator
|
|||
# import neuralcoref
|
||||
|
||||
|
||||
def run_kb_toy_example(kb):
|
||||
for mention in ("Bush", "President", "Homer"):
|
||||
candidates = kb.get_candidates(mention)
|
||||
|
||||
print("generating candidates for " + mention + " :")
|
||||
for c in candidates:
|
||||
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
|
||||
print()
|
||||
|
||||
def run_el_toy_example(nlp, kb):
|
||||
_prepare_pipeline(nlp, kb)
|
||||
|
||||
|
|
|
@ -14,72 +14,83 @@ from thinc.neural._classes.affine import Affine
|
|||
|
||||
class EntityEncoder:
|
||||
|
||||
INPUT_DIM = 300 # dimension of pre-trained vectors
|
||||
DESC_WIDTH = 64
|
||||
|
||||
DROP = 0
|
||||
EPOCHS = 5
|
||||
STOP_THRESHOLD = 0.1
|
||||
STOP_THRESHOLD = 0.9 # 0.1
|
||||
|
||||
BATCH_SIZE = 1000
|
||||
|
||||
def __init__(self, kb, nlp):
|
||||
def __init__(self, nlp, input_dim, desc_width):
|
||||
self.nlp = nlp
|
||||
self.kb = kb
|
||||
self.input_dim = input_dim
|
||||
self.desc_width = desc_width
|
||||
|
||||
def run(self, entity_descr_output):
|
||||
id_to_descr = kb_creator._get_id_to_description(entity_descr_output)
|
||||
def apply_encoder(self, description_list):
|
||||
if self.encoder is None:
|
||||
raise ValueError("Can not apply encoder before training it")
|
||||
|
||||
processed, loss = self._train_model(entity_descr_output, id_to_descr)
|
||||
print("Trained on", processed, "entities across", self.EPOCHS, "epochs")
|
||||
print("Final loss:", loss)
|
||||
print()
|
||||
print("Encoding", len(description_list), "entities")
|
||||
|
||||
# TODO: apply and write to file afterwards !
|
||||
# self._apply_encoder(id_to_descr)
|
||||
batch_size = 10000
|
||||
|
||||
self._test_encoder()
|
||||
start = 0
|
||||
stop = min(batch_size, len(description_list))
|
||||
encodings = []
|
||||
|
||||
def _train_model(self, entity_descr_output, id_to_descr):
|
||||
while start < len(description_list):
|
||||
docs = list(self.nlp.pipe(description_list[start:stop]))
|
||||
doc_embeddings = [self._get_doc_embedding(doc) for doc in docs]
|
||||
enc = self.encoder(np.asarray(doc_embeddings))
|
||||
encodings.extend(enc.tolist())
|
||||
|
||||
start = start + batch_size
|
||||
stop = min(stop + batch_size, len(description_list))
|
||||
print("encoded :", len(encodings))
|
||||
|
||||
return encodings
|
||||
|
||||
def train(self, description_list, to_print=False):
|
||||
processed, loss = self._train_model(description_list)
|
||||
|
||||
if to_print:
|
||||
print("Trained on", processed, "entities across", self.EPOCHS, "epochs")
|
||||
print("Final loss:", loss)
|
||||
|
||||
# self._test_encoder()
|
||||
|
||||
def _train_model(self, description_list):
|
||||
# TODO: when loss gets too low, a 'mean of empty slice' warning is thrown by numpy
|
||||
|
||||
self._build_network(self.INPUT_DIM, self.DESC_WIDTH)
|
||||
self._build_network(self.input_dim, self.desc_width)
|
||||
|
||||
processed = 0
|
||||
loss = 1
|
||||
descriptions = description_list.copy() # copy this list so that shuffling does not affect other functions
|
||||
|
||||
for i in range(self.EPOCHS):
|
||||
entity_keys = list(id_to_descr.keys())
|
||||
shuffle(entity_keys)
|
||||
shuffle(descriptions)
|
||||
|
||||
batch_nr = 0
|
||||
start = 0
|
||||
stop = min(self.BATCH_SIZE, len(entity_keys))
|
||||
stop = min(self.BATCH_SIZE, len(descriptions))
|
||||
|
||||
while loss > self.STOP_THRESHOLD and start < len(entity_keys):
|
||||
while loss > self.STOP_THRESHOLD and start < len(descriptions):
|
||||
batch = []
|
||||
for e in entity_keys[start:stop]:
|
||||
descr = id_to_descr[e]
|
||||
for descr in descriptions[start:stop]:
|
||||
doc = self.nlp(descr)
|
||||
doc_vector = self._get_doc_embedding(doc)
|
||||
batch.append(doc_vector)
|
||||
|
||||
loss = self.update(batch)
|
||||
loss = self._update(batch)
|
||||
print(i, batch_nr, loss)
|
||||
processed += len(batch)
|
||||
|
||||
batch_nr += 1
|
||||
start = start + self.BATCH_SIZE
|
||||
stop = min(stop + self.BATCH_SIZE, len(entity_keys))
|
||||
stop = min(stop + self.BATCH_SIZE, len(descriptions))
|
||||
|
||||
return processed, loss
|
||||
|
||||
def _apply_encoder(self, id_to_descr):
|
||||
for id, descr in id_to_descr.items():
|
||||
doc = self.nlp(descr)
|
||||
doc_vector = self._get_doc_embedding(doc)
|
||||
encoding = self.encoder(np.asarray([doc_vector]))
|
||||
|
||||
@staticmethod
|
||||
def _get_doc_embedding(doc):
|
||||
indices = np.zeros((len(doc),), dtype="i")
|
||||
|
@ -101,16 +112,16 @@ class EntityEncoder:
|
|||
|
||||
self.sgd = create_default_optimizer(self.model.ops)
|
||||
|
||||
def update(self, vectors):
|
||||
def _update(self, vectors):
|
||||
predictions, bp_model = self.model.begin_update(np.asarray(vectors), drop=self.DROP)
|
||||
|
||||
loss, d_scores = self.get_loss(scores=predictions, golds=np.asarray(vectors))
|
||||
loss, d_scores = self._get_loss(scores=predictions, golds=np.asarray(vectors))
|
||||
bp_model(d_scores, sgd=self.sgd)
|
||||
|
||||
return loss / len(vectors)
|
||||
|
||||
@staticmethod
|
||||
def get_loss(golds, scores):
|
||||
def _get_loss(golds, scores):
|
||||
loss, gradients = get_cossim_loss(scores, golds)
|
||||
return loss, gradients
|
||||
|
||||
|
|
|
@ -2,7 +2,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el
|
||||
from examples.pipeline.wiki_entity_linking.train_descriptions import EntityEncoder
|
||||
from examples.pipeline.wiki_entity_linking.train_el import EL_Model
|
||||
|
||||
import spacy
|
||||
|
@ -28,6 +27,7 @@ TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
|||
if __name__ == "__main__":
|
||||
print("START", datetime.datetime.now())
|
||||
print()
|
||||
nlp = spacy.load('en_core_web_lg')
|
||||
my_kb = None
|
||||
|
||||
# one-time methods to create KB and write to file
|
||||
|
@ -37,10 +37,7 @@ if __name__ == "__main__":
|
|||
|
||||
# read KB back in from file
|
||||
to_read_kb = True
|
||||
to_test_kb = False
|
||||
|
||||
# run entity description pre-training
|
||||
run_desc_training = True
|
||||
to_test_kb = True
|
||||
|
||||
# create training dataset
|
||||
create_wp_training = False
|
||||
|
@ -51,6 +48,8 @@ if __name__ == "__main__":
|
|||
# apply named entity linking to the dev dataset
|
||||
apply_to_dev = False
|
||||
|
||||
to_test_pipeline = False
|
||||
|
||||
# STEP 1 : create prior probabilities from WP
|
||||
# run only once !
|
||||
if to_create_prior_probs:
|
||||
|
@ -69,9 +68,7 @@ if __name__ == "__main__":
|
|||
# run only once !
|
||||
if to_create_kb:
|
||||
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
||||
my_nlp = spacy.load('en_core_web_sm')
|
||||
my_vocab = my_nlp.vocab
|
||||
my_kb = kb_creator.create_kb(my_vocab,
|
||||
my_kb = kb_creator.create_kb(nlp,
|
||||
max_entities_per_alias=10,
|
||||
min_occ=5,
|
||||
entity_def_output=ENTITY_DEFS,
|
||||
|
@ -85,7 +82,7 @@ if __name__ == "__main__":
|
|||
|
||||
print("STEP 3b: write KB", datetime.datetime.now())
|
||||
my_kb.dump(KB_FILE)
|
||||
my_vocab.to_disk(VOCAB_DIR)
|
||||
nlp.vocab.to_disk(VOCAB_DIR)
|
||||
print()
|
||||
|
||||
# STEP 4 : read KB back in from file
|
||||
|
@ -101,18 +98,9 @@ if __name__ == "__main__":
|
|||
|
||||
# test KB
|
||||
if to_test_kb:
|
||||
my_nlp = spacy.load('en_core_web_sm')
|
||||
run_el.run_el_toy_example(kb=my_kb, nlp=my_nlp)
|
||||
run_el.run_kb_toy_example(kb=my_kb)
|
||||
print()
|
||||
|
||||
# STEP 4b : read KB back in from file, create entity descriptions
|
||||
# TODO: write back to file
|
||||
if run_desc_training:
|
||||
print("STEP 4b: training entity descriptions", datetime.datetime.now())
|
||||
my_nlp = spacy.load('en_core_web_md')
|
||||
EntityEncoder(my_kb, my_nlp).run(entity_descr_output=ENTITY_DESCR)
|
||||
print()
|
||||
|
||||
# STEP 5: create a training dataset from WP
|
||||
if create_wp_training:
|
||||
print("STEP 5: create training dataset", datetime.datetime.now())
|
||||
|
@ -121,15 +109,18 @@ if __name__ == "__main__":
|
|||
# STEP 6: apply the EL algorithm on the training dataset
|
||||
if run_el_training:
|
||||
print("STEP 6: training", datetime.datetime.now())
|
||||
my_nlp = spacy.load('en_core_web_md')
|
||||
trainer = EL_Model(kb=my_kb, nlp=my_nlp)
|
||||
trainer = EL_Model(kb=my_kb, nlp=nlp)
|
||||
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=10000, devlimit=500)
|
||||
print()
|
||||
|
||||
# STEP 7: apply the EL algorithm on the dev dataset
|
||||
# STEP 7: apply the EL algorithm on the dev dataset (TODO: overlaps with code from run_el_training ?)
|
||||
if apply_to_dev:
|
||||
my_nlp = spacy.load('en_core_web_md')
|
||||
run_el.run_el_dev(kb=my_kb, nlp=my_nlp, training_dir=TRAINING_DIR, limit=2000)
|
||||
run_el.run_el_dev(kb=my_kb, nlp=nlp, training_dir=TRAINING_DIR, limit=2000)
|
||||
print()
|
||||
|
||||
# test KB
|
||||
if to_test_pipeline:
|
||||
run_el.run_el_toy_example(kb=my_kb, nlp=nlp)
|
||||
print()
|
||||
|
||||
# TODO coreference resolution
|
||||
|
|
14
spacy/kb.pyx
14
spacy/kb.pyx
|
@ -124,6 +124,14 @@ cdef class KnowledgeBase:
|
|||
return entity_hash
|
||||
|
||||
cpdef set_entities(self, entity_list, prob_list, vector_list):
|
||||
if len(entity_list) != len(prob_list):
|
||||
# TODO: proper error
|
||||
raise ValueError("Entity list and prob list should have the same length")
|
||||
|
||||
if len(entity_list) != len(vector_list):
|
||||
# TODO: proper error
|
||||
raise ValueError("Entity list and vector list should have the same length")
|
||||
|
||||
nr_entities = len(entity_list)
|
||||
self._entry_index = PreshMap(nr_entities+1)
|
||||
self._entries = entry_vec(nr_entities+1)
|
||||
|
@ -131,12 +139,12 @@ cdef class KnowledgeBase:
|
|||
i = 0
|
||||
cdef EntryC entry
|
||||
while i < nr_entities:
|
||||
entity_vector = entity_list[i]
|
||||
entity_vector = vector_list[i]
|
||||
if len(entity_vector) != self.entity_vector_length:
|
||||
# TODO: proper error
|
||||
raise ValueError("Entity vector length should have been", self.entity_vector_length)
|
||||
raise ValueError("Entity vector is", len(entity_vector), "length but should have been", self.entity_vector_length)
|
||||
|
||||
entity_hash = self.vocab.strings.add(entity_vector)
|
||||
entity_hash = self.vocab.strings.add(entity_list[i])
|
||||
entry.entity_hash = entity_hash
|
||||
entry.prob = prob_list[i]
|
||||
|
||||
|
|
|
@ -117,7 +117,7 @@ class Language(object):
|
|||
"tagger": lambda nlp, **cfg: Tagger(nlp.vocab, **cfg),
|
||||
"parser": lambda nlp, **cfg: DependencyParser(nlp.vocab, **cfg),
|
||||
"ner": lambda nlp, **cfg: EntityRecognizer(nlp.vocab, **cfg),
|
||||
"entity_linker": lambda nlp, **cfg: EntityLinker(nlp.vocab, **cfg),
|
||||
"entity_linker": lambda nlp, **cfg: EntityLinker(**cfg),
|
||||
"similarity": lambda nlp, **cfg: SimilarityHook(nlp.vocab, **cfg),
|
||||
"textcat": lambda nlp, **cfg: TextCategorizer(nlp.vocab, **cfg),
|
||||
"sentencizer": lambda nlp, **cfg: Sentencizer(**cfg),
|
||||
|
|
Loading…
Reference in New Issue
Block a user