context encoder with Tok2Vec + linking model instead of cosine

This commit is contained in:
svlandeg 2019-06-28 08:29:31 +02:00
parent dbc53b9870
commit 68a0662019
6 changed files with 73 additions and 58 deletions

View File

@ -33,7 +33,7 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
else:
# read the mappings from file
title_to_id = get_entity_to_id(entity_def_output)
id_to_descr = _get_id_to_description(entity_descr_output)
id_to_descr = get_id_to_description(entity_descr_output)
print()
print(" * _get_entity_frequencies", datetime.datetime.now())
@ -109,7 +109,7 @@ def get_entity_to_id(entity_def_output):
return entity_to_id
def _get_id_to_description(entity_descr_output):
def get_id_to_description(entity_descr_output):
id_to_desc = dict()
with open(entity_descr_output, 'r', encoding='utf8') as csvfile:
csvreader = csv.reader(csvfile, delimiter='|')

View File

@ -14,7 +14,7 @@ from thinc.neural._classes.affine import Affine
class EntityEncoder:
"""
Train the embeddings of entity descriptions to fit a fixed-size entity vector (e.g. 64D).
This entity vector will be stored in the KB, and context vectors will be trained to be similar to them.
This entity vector will be stored in the KB, for further downstream use in the entity model.
"""
DROP = 0
@ -97,7 +97,7 @@ class EntityEncoder:
else:
indices[i] = 0
word_vectors = doc.vocab.vectors.data[indices]
doc_vector = np.mean(word_vectors, axis=0) # TODO: min? max?
doc_vector = np.mean(word_vectors, axis=0)
return doc_vector
def _build_network(self, orig_width, hidden_with):

View File

@ -14,8 +14,7 @@ Process Wikipedia interlinks to generate a training dataset for the EL algorithm
Gold-standard entities are stored in one file in standoff format (by character offset).
"""
# ENTITY_FILE = "gold_entities.csv"
ENTITY_FILE = "gold_entities_1000000.csv" # use this file for faster processing
ENTITY_FILE = "gold_entities.csv"
def create_training(wikipedia_input, entity_def_input, training_output):

View File

@ -42,9 +42,10 @@ MIN_PAIR_OCC = 5
# model training parameters
EPOCHS = 10
DROPOUT = 0.1
DROPOUT = 0.2
LEARN_RATE = 0.005
L2 = 1e-6
CONTEXT_WIDTH=128
def run_pipeline():
@ -136,7 +137,8 @@ def run_pipeline():
# STEP 6: create and train the entity linking pipe
if train_pipe:
el_pipe = nlp_2.create_pipe(name='entity_linker', config={})
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
el_pipe = nlp_2.create_pipe(name='entity_linker', config={"context_width": CONTEXT_WIDTH})
el_pipe.set_kb(kb_2)
nlp_2.add_pipe(el_pipe, last=True)
@ -146,9 +148,8 @@ def run_pipeline():
optimizer.learn_rate = LEARN_RATE
optimizer.L2 = L2
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
# define the size (nr of entities) of training and dev set
train_limit = 5000
train_limit = 500000
dev_limit = 5000
train_data = training_set_creator.read_training(nlp=nlp_2,

View File

@ -652,37 +652,36 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False,
return model
def build_nel_encoder(in_width, hidden_width, end_width, **cfg):
def build_nel_encoder(embed_width, hidden_width, **cfg):
# TODO proper error
if "entity_width" not in cfg:
raise ValueError("entity_width not found")
if "context_width" not in cfg:
raise ValueError("context_width not found")
conv_depth = cfg.get("conv_depth", 2)
cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3)
pretrained_vectors = cfg.get("pretrained_vectors") # self.nlp.vocab.vectors.name
tok2vec = Tok2Vec(width=hidden_width, embed_size=in_width, pretrained_vectors=pretrained_vectors,
cnn_maxout_pieces=cnn_maxout_pieces, subword_features=False, conv_depth=conv_depth, bilstm_depth=0)
context_width = cfg.get("context_width")
entity_width = cfg.get("entity_width")
with Model.define_operators({">>": chain, "**": clone}):
# convolution = Residual((ExtractWindow(nW=1) >>
# LN(Maxout(hidden_width, hidden_width * 3, pieces=cnn_maxout_pieces))))
model = Affine(1, entity_width+context_width+1, drop_factor=0.0)\
>> logistic
# encoder = SpacyVectors \
# >> with_flatten(Affine(hidden_width, in_width)) \
# >> with_flatten(LN(Maxout(hidden_width, hidden_width)) >> convolution ** conv_depth, pad=conv_depth) \
# >> flatten_add_lengths \
# >> ParametricAttention(hidden_width) \
# >> Pooling(sum_pool) \
# >> Residual(zero_init(Maxout(hidden_width, hidden_width))) \
# >> zero_init(Affine(end_width, hidden_width, drop_factor=0.0))
# context encoder
tok2vec = Tok2Vec(width=hidden_width, embed_size=embed_width, pretrained_vectors=pretrained_vectors,
cnn_maxout_pieces=cnn_maxout_pieces, subword_features=False, conv_depth=conv_depth,
bilstm_depth=0) >> flatten_add_lengths >> Pooling(mean_pool)\
>> Residual(zero_init(Maxout(hidden_width, hidden_width))) \
>> zero_init(Affine(context_width, hidden_width, drop_factor=0.0))
encoder = tok2vec >> flatten_add_lengths >> Pooling(mean_pool)\
>> Residual(zero_init(Maxout(hidden_width, hidden_width))) \
>> zero_init(Affine(end_width, hidden_width, drop_factor=0.0))
model.tok2vec = tok2vec
# TODO: ReLu or LN(Maxout) ?
# sum_pool or mean_pool ?
encoder.tok2vec = tok2vec
encoder.nO = end_width
return encoder
model.tok2vec = tok2vec
model.tok2vec.nO = context_width
model.nO = 1
return model
@layerize
def flatten(seqs, drop=0.0):

View File

@ -5,6 +5,7 @@ from __future__ import unicode_literals
import numpy
import srsly
import random
from collections import OrderedDict
from thinc.api import chain
from thinc.v2v import Affine, Maxout, Softmax
@ -229,7 +230,7 @@ class Tensorizer(Pipe):
vocab (Vocab): A `Vocab` instance. The model must share the same
`Vocab` instance with the `Doc` objects it will process.
model (Model): A `Model` instance or `True` allocate one later.
model (Model): A `Model` instance or `True` to allocate one later.
**cfg: Config parameters.
EXAMPLE:
@ -386,7 +387,7 @@ class Tagger(Pipe):
def predict(self, docs):
self.require_model()
if not any(len(doc) for doc in docs):
# Handle case where there are no tokens in any docs.
# Handle cases where there are no tokens in any docs.
n_labels = len(self.labels)
guesses = [self.model.ops.allocate((0, n_labels)) for doc in docs]
tokvecs = self.model.ops.allocate((0, self.model.tok2vec.nO))
@ -1071,22 +1072,20 @@ class EntityLinker(Pipe):
@classmethod
def Model(cls, **cfg):
if "entity_width" not in cfg:
raise ValueError("entity_width not found")
embed_width = cfg.get("embed_width", 300)
hidden_width = cfg.get("hidden_width", 128)
entity_width = cfg.get("entity_width") # this needs to correspond with the KB entity length
model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=entity_width, **cfg)
model = build_nel_encoder(embed_width=embed_width, hidden_width=hidden_width, **cfg)
return model
def __init__(self, **cfg):
self.model = True
self.kb = None
self.sgd_context = None
self.cfg = dict(cfg)
self.context_weight = cfg.get("context_weight", 1)
self.prior_weight = cfg.get("prior_weight", 1)
self.context_width = cfg.get("context_width")
def set_kb(self, kb):
self.kb = kb
@ -1107,6 +1106,7 @@ class EntityLinker(Pipe):
if self.model is True:
self.model = self.Model(**self.cfg)
self.sgd_context = self.create_optimizer()
if sgd is None:
sgd = self.create_optimizer()
@ -1132,35 +1132,55 @@ class EntityLinker(Pipe):
context_docs = []
entity_encodings = []
labels = []
for doc, gold in zip(docs, golds):
for entity in gold.links:
start, end, gold_kb = entity
mention = doc.text[start:end]
candidates = self.kb.get_candidates(mention)
random.shuffle(candidates)
nr_neg = 0
for c in candidates:
kb_id = c.entity_
# Currently only training on the positive instances
if kb_id == gold_kb:
prior_prob = c.prior_prob
entity_encoding = c.entity_vector
entity_encodings.append(entity_encoding)
context_docs.append(doc)
labels.append([1])
else: # elif nr_neg < 1:
nr_neg += 1
entity_encoding = c.entity_vector
entity_encodings.append(entity_encoding)
context_docs.append(doc)
labels.append([0])
if len(entity_encodings) > 0:
context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop)
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
loss, d_scores = self.get_loss(scores=context_encodings, golds=entity_encodings, docs=None)
bp_context(d_scores, sgd=sgd)
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) for i in range(len(entity_encodings))]
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
labels = self.model.ops.asarray(labels, dtype="float32")
loss, d_scores = self.get_loss(prediction=pred, golds=labels, docs=None)
mention_gradient = bp_mention(d_scores, sgd=sgd)
context_gradients = [list(x[0:self.context_width]) for x in mention_gradient]
bp_context(self.model.ops.asarray(context_gradients, dtype="float32"), sgd=self.sgd_context)
if losses is not None:
losses[self.name] += loss
return loss
return 0
def get_loss(self, docs, golds, scores):
def get_loss(self, docs, golds, prediction):
d_scores = (prediction - golds)
loss = (d_scores ** 2).sum()
loss = loss / len(golds)
return loss, d_scores
def get_loss_old(self, docs, golds, scores):
# this loss function assumes we're only using positive examples
loss, gradients = get_cossim_loss(yh=scores, y=golds)
loss = loss / len(golds)
@ -1191,30 +1211,26 @@ class EntityLinker(Pipe):
if isinstance(docs, Doc):
docs = [docs]
context_encodings = self.model(docs)
context_encodings = self.model.tok2vec(docs)
xp = get_array_module(context_encodings)
for i, doc in enumerate(docs):
if len(doc) > 0:
context_encoding = context_encodings[i]
context_enc_t = context_encoding.T
norm_1 = xp.linalg.norm(context_enc_t)
for ent in doc.ents:
candidates = self.kb.get_candidates(ent.text)
if candidates:
prior_probs = xp.asarray([c.prior_prob for c in candidates])
random.shuffle(candidates)
prior_probs = xp.asarray([[c.prior_prob] for c in candidates])
prior_probs *= self.prior_weight
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
norm_2 = xp.linalg.norm(entity_encodings, axis=1)
# cosine similarity
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
sims *= self.context_weight
scores = prior_probs + sims - (prior_probs*sims)
best_index = scores.argmax()
mention_encodings = [list(context_encoding) + list(entity_encodings[i]) for i in range(len(entity_encodings))]
predictions = self.model(self.model.ops.asarray(mention_encodings, dtype="float32"))
scores = (prior_probs + predictions - (xp.dot(prior_probs.T, predictions)))
# TODO: thresholding
best_index = scores.argmax()
best_candidate = candidates[best_index]
final_entities.append(ent)
final_kb_ids.append(best_candidate.entity_)