context encoder combining sentence and article

This commit is contained in:
svlandeg 2019-05-28 18:14:49 +02:00
parent 992fa92b66
commit a761929fa5
2 changed files with 138 additions and 121 deletions

View File

@ -11,11 +11,11 @@ from thinc.neural._classes.convolution import ExtractWindow
from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator
from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, logistic, Tok2Vec, cosine from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, cosine
from thinc.api import chain, concatenate, flatten_add_lengths, clone, with_flatten from thinc.api import chain, concatenate, flatten_add_lengths, clone, with_flatten
from thinc.v2v import Model, Maxout, Affine, ReLu from thinc.v2v import Model, Maxout, Affine
from thinc.t2v import Pooling, mean_pool, sum_pool from thinc.t2v import Pooling, mean_pool
from thinc.t2t import ParametricAttention from thinc.t2t import ParametricAttention
from thinc.misc import Residual from thinc.misc import Residual
from thinc.misc import LayerNorm as LN from thinc.misc import LayerNorm as LN
@ -30,24 +30,21 @@ from spacy.tokens import Doc
class EL_Model: class EL_Model:
PRINT_INSPECT = False PRINT_INSPECT = False
PRINT_TRAIN = True PRINT_BATCH_LOSS = False
EPS = 0.0000000005 EPS = 0.0000000005
CUTOFF = 0.5
BATCH_SIZE = 5 BATCH_SIZE = 5
# UPSAMPLE = True
DOC_CUTOFF = 300 # number of characters from the doc context DOC_CUTOFF = 300 # number of characters from the doc context
INPUT_DIM = 300 # dimension of pre-trained vectors INPUT_DIM = 300 # dimension of pre-trained vectors
HIDDEN_1_WIDTH = 32 HIDDEN_1_WIDTH = 32
# HIDDEN_2_WIDTH = 32 # 6
DESC_WIDTH = 64 DESC_WIDTH = 64
ARTICLE_WIDTH = 64 ARTICLE_WIDTH = 128
SENT_WIDTH = 64 SENT_WIDTH = 64
DROP = 0.1 DROP = 0.1
LEARN_RATE = 0.0001 LEARN_RATE = 0.001
EPOCHS = 10 EPOCHS = 10
L2 = 1e-6 L2 = 1e-6
@ -61,13 +58,10 @@ class EL_Model:
self._build_cnn(embed_width=self.INPUT_DIM, self._build_cnn(embed_width=self.INPUT_DIM,
desc_width=self.DESC_WIDTH, desc_width=self.DESC_WIDTH,
article_width=self.ARTICLE_WIDTH, article_width=self.ARTICLE_WIDTH,
sent_width=self.SENT_WIDTH, hidden_1_width=self.HIDDEN_1_WIDTH) sent_width=self.SENT_WIDTH,
hidden_1_width=self.HIDDEN_1_WIDTH)
def train_model(self, training_dir, entity_descr_output, trainlimit=None, devlimit=None, to_print=True): def train_model(self, training_dir, entity_descr_output, trainlimit=None, devlimit=None, to_print=True):
# raise errors instead of runtime warnings in case of int/float overflow
# (not sure if we need this. set L2 to 0 because it throws an error otherwsise)
# np.seterr(all='raise')
# alternative:
np.seterr(divide="raise", over="warn", under="ignore", invalid="raise") np.seterr(divide="raise", over="warn", under="ignore", invalid="raise")
train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts = \ train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts = \
@ -101,21 +95,6 @@ class EL_Model:
train_pos_count = len(train_pos_entities) train_pos_count = len(train_pos_entities)
train_neg_count = len(train_neg_entities) train_neg_count = len(train_neg_entities)
# if self.UPSAMPLE:
# if to_print:
# print()
# print("Upsampling, original training instances pos/neg:", train_pos_count, train_neg_count)
#
# # upsample positives to 50-50 distribution
# while train_pos_count < train_neg_count:
# train_ent.append(random.choice(train_pos_entities))
# train_pos_count += 1
#
# upsample negatives to 50-50 distribution
# while train_neg_count < train_pos_count:
# train_ent.append(random.choice(train_neg_entities))
# train_neg_count += 1
self._begin_training() self._begin_training()
if to_print: if to_print:
@ -126,24 +105,25 @@ class EL_Model:
print("Dev test on", len(dev_clusters), "entity clusters in", len(dev_art_texts), "articles") print("Dev test on", len(dev_clusters), "entity clusters in", len(dev_art_texts), "articles")
print("Dev instances pos/neg:", dev_pos_count, dev_neg_count) print("Dev instances pos/neg:", dev_pos_count, dev_neg_count)
print() print()
print(" CUTOFF", self.CUTOFF)
print(" DOC_CUTOFF", self.DOC_CUTOFF) print(" DOC_CUTOFF", self.DOC_CUTOFF)
print(" INPUT_DIM", self.INPUT_DIM) print(" INPUT_DIM", self.INPUT_DIM)
# print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH) print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH)
print(" DESC_WIDTH", self.DESC_WIDTH) print(" DESC_WIDTH", self.DESC_WIDTH)
print(" ARTICLE_WIDTH", self.ARTICLE_WIDTH) print(" ARTICLE_WIDTH", self.ARTICLE_WIDTH)
print(" SENT_WIDTH", self.SENT_WIDTH) print(" SENT_WIDTH", self.SENT_WIDTH)
# print(" HIDDEN_2_WIDTH", self.HIDDEN_2_WIDTH)
print(" DROP", self.DROP) print(" DROP", self.DROP)
print(" LEARNING RATE", self.LEARN_RATE) print(" LEARNING RATE", self.LEARN_RATE)
print(" UPSAMPLE", self.UPSAMPLE) print(" BATCH SIZE", self.BATCH_SIZE)
print() print()
self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts, dev_random = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
print_string="dev_random", calc_random=True) calc_random=True)
print("acc", "dev_random", round(dev_random, 2))
self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts, dev_pre = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
print_string="dev_pre", avg=True) avg=True)
print("acc", "dev_pre", round(dev_pre, 2))
print()
processed = 0 processed = 0
for i in range(self.EPOCHS): for i in range(self.EPOCHS):
@ -163,45 +143,58 @@ class EL_Model:
start = start + self.BATCH_SIZE start = start + self.BATCH_SIZE
stop = min(stop + self.BATCH_SIZE, len(train_clusters)) stop = min(stop + self.BATCH_SIZE, len(train_clusters))
if self.PRINT_TRAIN: train_acc = self._test_dev(train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts, avg=True)
print() dev_acc = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts, avg=True)
self._test_dev(train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts,
print_string="train_inter_epoch " + str(i), avg=True)
self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts, print(i, "acc train/dev", round(train_acc, 2), round(dev_acc, 2))
print_string="dev_inter_epoch " + str(i), avg=True)
if to_print: if to_print:
print() print()
print("Trained on", processed, "entity clusters across", self.EPOCHS, "epochs") print("Trained on", processed, "entity clusters across", self.EPOCHS, "epochs")
def _test_dev(self, entity_clusters, golds, descs, arts, art_texts, sents, sent_texts, def _test_dev(self, entity_clusters, golds, descs, arts, art_texts, sents, sent_texts, avg=True, calc_random=False):
print_string, avg=True, calc_random=False):
correct = 0 correct = 0
incorrect = 0 incorrect = 0
for cluster, entities in entity_clusters.items(): if calc_random:
correct_entities = [e for e in entities if golds[e]] for cluster, entities in entity_clusters.items():
incorrect_entities = [e for e in entities if not golds[e]] correct_entities = [e for e in entities if golds[e]]
assert len(correct_entities) == 1 assert len(correct_entities) == 1
entities = list(entities) entities = list(entities)
shuffle(entities) shuffle(entities)
if calc_random: if calc_random:
predicted_entity = random.choice(entities) predicted_entity = random.choice(entities)
if predicted_entity in correct_entities: if predicted_entity in correct_entities:
correct += 1 correct += 1
else: else:
incorrect += 1 incorrect += 1
else:
all_clusters = list()
arts_list = list()
sents_list = list()
for cluster in entity_clusters.keys():
all_clusters.append(cluster)
arts_list.append(art_texts[arts[cluster]])
sents_list.append(sent_texts[sents[cluster]])
art_docs = list(self.nlp.pipe(arts_list))
sent_docs = list(self.nlp.pipe(sents_list))
for i, cluster in enumerate(all_clusters):
entities = entity_clusters[cluster]
correct_entities = [e for e in entities if golds[e]]
assert len(correct_entities) == 1
entities = list(entities)
shuffle(entities)
else:
desc_docs = self.nlp.pipe([descs[e] for e in entities]) desc_docs = self.nlp.pipe([descs[e] for e in entities])
# article_texts = [art_texts[arts[e]] for e in entities] sent_doc = sent_docs[i]
article_doc = art_docs[i]
sent_doc = self.nlp(sent_texts[sents[cluster]])
article_doc = self.nlp(art_texts[arts[cluster]])
predicted_index = self._predict(article_doc=article_doc, sent_doc=sent_doc, predicted_index = self._predict(article_doc=article_doc, sent_doc=sent_doc,
desc_docs=desc_docs, avg=avg) desc_docs=desc_docs, avg=avg)
@ -211,52 +204,56 @@ class EL_Model:
incorrect += 1 incorrect += 1
if correct == incorrect == 0: if correct == incorrect == 0:
print("acc", print_string, "NA")
return 0 return 0
acc = correct / (correct + incorrect) acc = correct / (correct + incorrect)
print("acc", print_string, round(acc, 2))
return acc return acc
def _predict(self, article_doc, sent_doc, desc_docs, avg=True, apply_threshold=True): def _predict(self, article_doc, sent_doc, desc_docs, avg=True, apply_threshold=True):
# print()
# print("predicting article")
if avg: if avg:
with self.article_encoder.use_params(self.sgd_article.averages) \ with self.article_encoder.use_params(self.sgd_article.averages) \
and self.desc_encoder.use_params(self.sgd_desc.averages)\ and self.desc_encoder.use_params(self.sgd_desc.averages)\
and self.sent_encoder.use_params(self.sgd_sent.averages): and self.sent_encoder.use_params(self.sgd_sent.averages)\
# doc_encoding = self.article_encoder(article_doc) and self.cont_encoder.use_params(self.sgd_cont.averages):
desc_encodings = self.desc_encoder(desc_docs) desc_encodings = self.desc_encoder(desc_docs)
doc_encoding = self.article_encoder([article_doc])
sent_encoding = self.sent_encoder([sent_doc]) sent_encoding = self.sent_encoder([sent_doc])
else: else:
# doc_encodings = self.article_encoder(article_docs)
desc_encodings = self.desc_encoder(desc_docs) desc_encodings = self.desc_encoder(desc_docs)
doc_encoding = self.article_encoder([article_doc])
sent_encoding = self.sent_encoder([sent_doc]) sent_encoding = self.sent_encoder([sent_doc])
sent_enc = np.transpose(sent_encoding) # print("desc_encodings", desc_encodings)
# print("doc_encoding", doc_encoding)
# print("sent_encoding", sent_encoding)
concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
# print("concat_encoding", concat_encoding)
cont_encodings = self.cont_encoder(np.asarray([concat_encoding[0]]))
# print("cont_encodings", cont_encodings)
context_enc = np.transpose(cont_encodings)
# print("context_enc", context_enc)
highest_sim = -5 highest_sim = -5
best_i = -1 best_i = -1
for i, desc_enc in enumerate(desc_encodings): for i, desc_enc in enumerate(desc_encodings):
sim = cosine(desc_enc, sent_enc) sim = cosine(desc_enc, context_enc)
if sim >= highest_sim: if sim >= highest_sim:
best_i = i best_i = i
highest_sim = sim highest_sim = sim
return best_i return best_i
def _predict_random(self, entities, apply_threshold=True):
if not apply_threshold:
return [float(random.uniform(0, 1)) for _ in entities]
else:
return [float(1.0) if random.uniform(0, 1) > self.CUTOFF else float(0.0) for _ in entities]
def _build_cnn(self, embed_width, desc_width, article_width, sent_width, hidden_1_width): def _build_cnn(self, embed_width, desc_width, article_width, sent_width, hidden_1_width):
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}): self.desc_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width, end_width=desc_width)
self.desc_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width, self.cont_encoder = self._context_encoder(embed_width=embed_width, article_width=article_width,
end_width=desc_width) sent_width=sent_width, hidden_width=hidden_1_width,
self.article_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width, end_width=desc_width)
end_width=article_width)
self.sent_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width,
end_width=sent_width)
# def _encoder(self, width): # def _encoder(self, width):
# tok2vec = Tok2Vec(width=width, embed_size=2000, pretrained_vectors=self.nlp.vocab.vectors.name, cnn_maxout_pieces=3, # tok2vec = Tok2Vec(width=width, embed_size=2000, pretrained_vectors=self.nlp.vocab.vectors.name, cnn_maxout_pieces=3,
@ -264,12 +261,19 @@ class EL_Model:
# #
# return tok2vec >> flatten_add_lengths >> Pooling(mean_pool) # return tok2vec >> flatten_add_lengths >> Pooling(mean_pool)
def _context_encoder(self, embed_width, article_width, sent_width, hidden_width, end_width):
self.article_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_width, end_width=article_width)
self.sent_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_width, end_width=sent_width)
model = Affine(end_width, article_width+sent_width, drop_factor=0.0)
return model
@staticmethod @staticmethod
def _encoder(in_width, hidden_with, end_width): def _encoder(in_width, hidden_with, end_width):
conv_depth = 2 conv_depth = 2
cnn_maxout_pieces = 3 cnn_maxout_pieces = 3
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain, "**": clone}):
convolution = Residual((ExtractWindow(nW=1) >> convolution = Residual((ExtractWindow(nW=1) >>
LN(Maxout(hidden_with, hidden_with * 3, pieces=cnn_maxout_pieces)))) LN(Maxout(hidden_with, hidden_with * 3, pieces=cnn_maxout_pieces))))
@ -295,62 +299,75 @@ class EL_Model:
self.sgd_sent.learn_rate = self.LEARN_RATE self.sgd_sent.learn_rate = self.LEARN_RATE
self.sgd_sent.L2 = self.L2 self.sgd_sent.L2 = self.L2
self.sgd_cont = create_default_optimizer(self.cont_encoder.ops)
self.sgd_cont.learn_rate = self.LEARN_RATE
self.sgd_cont.L2 = self.L2
self.sgd_desc = create_default_optimizer(self.desc_encoder.ops) self.sgd_desc = create_default_optimizer(self.desc_encoder.ops)
self.sgd_desc.learn_rate = self.LEARN_RATE self.sgd_desc.learn_rate = self.LEARN_RATE
self.sgd_desc.L2 = self.L2 self.sgd_desc.L2 = self.L2
# self.sgd = create_default_optimizer(self.model.ops)
# self.sgd.learn_rate = self.LEARN_RATE
# self.sgd.L2 = self.L2
@staticmethod @staticmethod
def get_loss(predictions, golds): def get_loss(predictions, golds):
loss, gradients = get_cossim_loss(predictions, golds) loss, gradients = get_cossim_loss(predictions, golds)
return loss, gradients return loss, gradients
def update(self, entity_clusters, golds, descs, art_texts, arts, sent_texts, sents): def update(self, entity_clusters, golds, descs, art_texts, arts, sent_texts, sents):
all_clusters = list(entity_clusters.keys())
arts_list = list()
sents_list = list()
descs_list = list()
for cluster, entities in entity_clusters.items(): for cluster, entities in entity_clusters.items():
correct_entities = [e for e in entities if golds[e]] art = art_texts[arts[cluster]]
incorrect_entities = [e for e in entities if not golds[e]] sent = sent_texts[sents[cluster]]
assert len(correct_entities) == 1
entities = list(entities)
shuffle(entities)
# article_text = art_texts[arts[cluster]]
cluster_sent = sent_texts[sents[cluster]]
# art_docs = self.nlp.pipe(article_text)
sent_doc = self.nlp(cluster_sent)
for e in entities: for e in entities:
# TODO: more appropriate loss for the whole cluster (currently only pos entities)
if golds[e]: if golds[e]:
# TODO: more appropriate loss for the whole cluster (currently only pos entities) arts_list.append(art)
# TODO: speed up sents_list.append(sent)
desc_doc = self.nlp(descs[e]) descs_list.append(descs[e])
# doc_encodings, bp_doc = self.article_encoder.begin_update(art_docs, drop=self.DROP) desc_docs = self.nlp.pipe(descs_list)
sent_encodings, bp_sent = self.sent_encoder.begin_update([sent_doc], drop=self.DROP) desc_encodings, bp_desc = self.desc_encoder.begin_update(desc_docs, drop=self.DROP)
desc_encodings, bp_desc = self.desc_encoder.begin_update([desc_doc], drop=self.DROP)
sent_encoding = sent_encodings[0] art_docs = self.nlp.pipe(arts_list)
desc_encoding = desc_encodings[0] sent_docs = self.nlp.pipe(sents_list)
sent_enc = self.sent_encoder.ops.asarray([sent_encoding]) doc_encodings, bp_doc = self.article_encoder.begin_update(art_docs, drop=self.DROP)
desc_enc = self.sent_encoder.ops.asarray([desc_encoding]) sent_encodings, bp_sent = self.sent_encoder.begin_update(sent_docs, drop=self.DROP)
# print("sent_encoding", type(sent_encoding), sent_encoding) concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
# print("desc_encoding", type(desc_encoding), desc_encoding) range(len(all_clusters))]
# print("getting los for entity", e) cont_encodings, bp_cont = self.cont_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
loss, gradient = self.get_loss(sent_enc, desc_enc) # print("sent_encodings", type(sent_encodings), sent_encodings)
# print("desc_encodings", type(desc_encodings), desc_encodings)
# print("doc_encodings", type(doc_encodings), doc_encodings)
# print("getting los for", len(arts_list), "entities")
# print("gradient", gradient) loss, gradient = self.get_loss(cont_encodings, desc_encodings)
# print("loss", loss)
bp_sent(gradient, sgd=self.sgd_sent) # print("gradient", gradient)
# bp_desc(desc_gradients, sgd=self.sgd_desc) TODO if self.PRINT_BATCH_LOSS:
# print() print("batch loss", loss)
context_gradient = bp_cont(gradient, sgd=self.sgd_cont)
# gradient : concat (doc+sent) vs. desc
sent_start = self.ARTICLE_WIDTH
sent_gradients = list()
doc_gradients = list()
for x in context_gradient:
doc_gradients.append(list(x[0:sent_start]))
sent_gradients.append(list(x[sent_start:]))
# print("doc_gradients", doc_gradients)
# print("sent_gradients", sent_gradients)
bp_doc(doc_gradients, sgd=self.sgd_article)
bp_sent(sent_gradients, sgd=self.sgd_sent)
def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print): def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print):
id_to_descr = kb_creator._get_id_to_description(entity_descr_output) id_to_descr = kb_creator._get_id_to_description(entity_descr_output)

View File

@ -111,7 +111,7 @@ if __name__ == "__main__":
print("STEP 6: training", datetime.datetime.now()) print("STEP 6: training", datetime.datetime.now())
my_nlp = spacy.load('en_core_web_md') my_nlp = spacy.load('en_core_web_md')
trainer = EL_Model(kb=my_kb, nlp=my_nlp) trainer = EL_Model(kb=my_kb, nlp=my_nlp)
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=1000, devlimit=100) trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=5000, devlimit=100)
print() print()
# STEP 7: apply the EL algorithm on the dev dataset # STEP 7: apply the EL algorithm on the dev dataset