mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
context encoder combining sentence and article
This commit is contained in:
parent
992fa92b66
commit
a761929fa5
|
@ -11,11 +11,11 @@ from thinc.neural._classes.convolution import ExtractWindow
|
|||
|
||||
from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator
|
||||
|
||||
from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, logistic, Tok2Vec, cosine
|
||||
from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, cosine
|
||||
|
||||
from thinc.api import chain, concatenate, flatten_add_lengths, clone, with_flatten
|
||||
from thinc.v2v import Model, Maxout, Affine, ReLu
|
||||
from thinc.t2v import Pooling, mean_pool, sum_pool
|
||||
from thinc.v2v import Model, Maxout, Affine
|
||||
from thinc.t2v import Pooling, mean_pool
|
||||
from thinc.t2t import ParametricAttention
|
||||
from thinc.misc import Residual
|
||||
from thinc.misc import LayerNorm as LN
|
||||
|
@ -30,24 +30,21 @@ from spacy.tokens import Doc
|
|||
class EL_Model:
|
||||
|
||||
PRINT_INSPECT = False
|
||||
PRINT_TRAIN = True
|
||||
PRINT_BATCH_LOSS = False
|
||||
EPS = 0.0000000005
|
||||
CUTOFF = 0.5
|
||||
|
||||
BATCH_SIZE = 5
|
||||
# UPSAMPLE = True
|
||||
|
||||
DOC_CUTOFF = 300 # number of characters from the doc context
|
||||
INPUT_DIM = 300 # dimension of pre-trained vectors
|
||||
|
||||
HIDDEN_1_WIDTH = 32
|
||||
# HIDDEN_2_WIDTH = 32 # 6
|
||||
DESC_WIDTH = 64
|
||||
ARTICLE_WIDTH = 64
|
||||
ARTICLE_WIDTH = 128
|
||||
SENT_WIDTH = 64
|
||||
|
||||
DROP = 0.1
|
||||
LEARN_RATE = 0.0001
|
||||
LEARN_RATE = 0.001
|
||||
EPOCHS = 10
|
||||
L2 = 1e-6
|
||||
|
||||
|
@ -61,13 +58,10 @@ class EL_Model:
|
|||
self._build_cnn(embed_width=self.INPUT_DIM,
|
||||
desc_width=self.DESC_WIDTH,
|
||||
article_width=self.ARTICLE_WIDTH,
|
||||
sent_width=self.SENT_WIDTH, hidden_1_width=self.HIDDEN_1_WIDTH)
|
||||
sent_width=self.SENT_WIDTH,
|
||||
hidden_1_width=self.HIDDEN_1_WIDTH)
|
||||
|
||||
def train_model(self, training_dir, entity_descr_output, trainlimit=None, devlimit=None, to_print=True):
|
||||
# raise errors instead of runtime warnings in case of int/float overflow
|
||||
# (not sure if we need this. set L2 to 0 because it throws an error otherwsise)
|
||||
# np.seterr(all='raise')
|
||||
# alternative:
|
||||
np.seterr(divide="raise", over="warn", under="ignore", invalid="raise")
|
||||
|
||||
train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts = \
|
||||
|
@ -101,21 +95,6 @@ class EL_Model:
|
|||
train_pos_count = len(train_pos_entities)
|
||||
train_neg_count = len(train_neg_entities)
|
||||
|
||||
# if self.UPSAMPLE:
|
||||
# if to_print:
|
||||
# print()
|
||||
# print("Upsampling, original training instances pos/neg:", train_pos_count, train_neg_count)
|
||||
#
|
||||
# # upsample positives to 50-50 distribution
|
||||
# while train_pos_count < train_neg_count:
|
||||
# train_ent.append(random.choice(train_pos_entities))
|
||||
# train_pos_count += 1
|
||||
#
|
||||
# upsample negatives to 50-50 distribution
|
||||
# while train_neg_count < train_pos_count:
|
||||
# train_ent.append(random.choice(train_neg_entities))
|
||||
# train_neg_count += 1
|
||||
|
||||
self._begin_training()
|
||||
|
||||
if to_print:
|
||||
|
@ -126,24 +105,25 @@ class EL_Model:
|
|||
print("Dev test on", len(dev_clusters), "entity clusters in", len(dev_art_texts), "articles")
|
||||
print("Dev instances pos/neg:", dev_pos_count, dev_neg_count)
|
||||
print()
|
||||
print(" CUTOFF", self.CUTOFF)
|
||||
print(" DOC_CUTOFF", self.DOC_CUTOFF)
|
||||
print(" INPUT_DIM", self.INPUT_DIM)
|
||||
# print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH)
|
||||
print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH)
|
||||
print(" DESC_WIDTH", self.DESC_WIDTH)
|
||||
print(" ARTICLE_WIDTH", self.ARTICLE_WIDTH)
|
||||
print(" SENT_WIDTH", self.SENT_WIDTH)
|
||||
# print(" HIDDEN_2_WIDTH", self.HIDDEN_2_WIDTH)
|
||||
print(" DROP", self.DROP)
|
||||
print(" LEARNING RATE", self.LEARN_RATE)
|
||||
print(" UPSAMPLE", self.UPSAMPLE)
|
||||
print(" BATCH SIZE", self.BATCH_SIZE)
|
||||
print()
|
||||
|
||||
self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
|
||||
print_string="dev_random", calc_random=True)
|
||||
dev_random = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
|
||||
calc_random=True)
|
||||
print("acc", "dev_random", round(dev_random, 2))
|
||||
|
||||
self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
|
||||
print_string="dev_pre", avg=True)
|
||||
dev_pre = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
|
||||
avg=True)
|
||||
print("acc", "dev_pre", round(dev_pre, 2))
|
||||
print()
|
||||
|
||||
processed = 0
|
||||
for i in range(self.EPOCHS):
|
||||
|
@ -163,45 +143,58 @@ class EL_Model:
|
|||
start = start + self.BATCH_SIZE
|
||||
stop = min(stop + self.BATCH_SIZE, len(train_clusters))
|
||||
|
||||
if self.PRINT_TRAIN:
|
||||
print()
|
||||
self._test_dev(train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts,
|
||||
print_string="train_inter_epoch " + str(i), avg=True)
|
||||
train_acc = self._test_dev(train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts, avg=True)
|
||||
dev_acc = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts, avg=True)
|
||||
|
||||
self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
|
||||
print_string="dev_inter_epoch " + str(i), avg=True)
|
||||
print(i, "acc train/dev", round(train_acc, 2), round(dev_acc, 2))
|
||||
|
||||
if to_print:
|
||||
print()
|
||||
print("Trained on", processed, "entity clusters across", self.EPOCHS, "epochs")
|
||||
|
||||
def _test_dev(self, entity_clusters, golds, descs, arts, art_texts, sents, sent_texts,
|
||||
print_string, avg=True, calc_random=False):
|
||||
|
||||
def _test_dev(self, entity_clusters, golds, descs, arts, art_texts, sents, sent_texts, avg=True, calc_random=False):
|
||||
correct = 0
|
||||
incorrect = 0
|
||||
|
||||
for cluster, entities in entity_clusters.items():
|
||||
correct_entities = [e for e in entities if golds[e]]
|
||||
incorrect_entities = [e for e in entities if not golds[e]]
|
||||
assert len(correct_entities) == 1
|
||||
if calc_random:
|
||||
for cluster, entities in entity_clusters.items():
|
||||
correct_entities = [e for e in entities if golds[e]]
|
||||
assert len(correct_entities) == 1
|
||||
|
||||
entities = list(entities)
|
||||
shuffle(entities)
|
||||
entities = list(entities)
|
||||
shuffle(entities)
|
||||
|
||||
if calc_random:
|
||||
predicted_entity = random.choice(entities)
|
||||
if predicted_entity in correct_entities:
|
||||
correct += 1
|
||||
else:
|
||||
incorrect += 1
|
||||
if calc_random:
|
||||
predicted_entity = random.choice(entities)
|
||||
if predicted_entity in correct_entities:
|
||||
correct += 1
|
||||
else:
|
||||
incorrect += 1
|
||||
|
||||
else:
|
||||
all_clusters = list()
|
||||
arts_list = list()
|
||||
sents_list = list()
|
||||
|
||||
for cluster in entity_clusters.keys():
|
||||
all_clusters.append(cluster)
|
||||
arts_list.append(art_texts[arts[cluster]])
|
||||
sents_list.append(sent_texts[sents[cluster]])
|
||||
|
||||
art_docs = list(self.nlp.pipe(arts_list))
|
||||
sent_docs = list(self.nlp.pipe(sents_list))
|
||||
|
||||
for i, cluster in enumerate(all_clusters):
|
||||
entities = entity_clusters[cluster]
|
||||
correct_entities = [e for e in entities if golds[e]]
|
||||
assert len(correct_entities) == 1
|
||||
|
||||
entities = list(entities)
|
||||
shuffle(entities)
|
||||
|
||||
else:
|
||||
desc_docs = self.nlp.pipe([descs[e] for e in entities])
|
||||
# article_texts = [art_texts[arts[e]] for e in entities]
|
||||
|
||||
sent_doc = self.nlp(sent_texts[sents[cluster]])
|
||||
article_doc = self.nlp(art_texts[arts[cluster]])
|
||||
sent_doc = sent_docs[i]
|
||||
article_doc = art_docs[i]
|
||||
|
||||
predicted_index = self._predict(article_doc=article_doc, sent_doc=sent_doc,
|
||||
desc_docs=desc_docs, avg=avg)
|
||||
|
@ -211,52 +204,56 @@ class EL_Model:
|
|||
incorrect += 1
|
||||
|
||||
if correct == incorrect == 0:
|
||||
print("acc", print_string, "NA")
|
||||
return 0
|
||||
|
||||
acc = correct / (correct + incorrect)
|
||||
print("acc", print_string, round(acc, 2))
|
||||
return acc
|
||||
|
||||
def _predict(self, article_doc, sent_doc, desc_docs, avg=True, apply_threshold=True):
|
||||
# print()
|
||||
# print("predicting article")
|
||||
|
||||
if avg:
|
||||
with self.article_encoder.use_params(self.sgd_article.averages) \
|
||||
and self.desc_encoder.use_params(self.sgd_desc.averages)\
|
||||
and self.sent_encoder.use_params(self.sgd_sent.averages):
|
||||
# doc_encoding = self.article_encoder(article_doc)
|
||||
and self.sent_encoder.use_params(self.sgd_sent.averages)\
|
||||
and self.cont_encoder.use_params(self.sgd_cont.averages):
|
||||
desc_encodings = self.desc_encoder(desc_docs)
|
||||
doc_encoding = self.article_encoder([article_doc])
|
||||
sent_encoding = self.sent_encoder([sent_doc])
|
||||
|
||||
else:
|
||||
# doc_encodings = self.article_encoder(article_docs)
|
||||
desc_encodings = self.desc_encoder(desc_docs)
|
||||
doc_encoding = self.article_encoder([article_doc])
|
||||
sent_encoding = self.sent_encoder([sent_doc])
|
||||
|
||||
sent_enc = np.transpose(sent_encoding)
|
||||
# print("desc_encodings", desc_encodings)
|
||||
# print("doc_encoding", doc_encoding)
|
||||
# print("sent_encoding", sent_encoding)
|
||||
concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
|
||||
# print("concat_encoding", concat_encoding)
|
||||
|
||||
cont_encodings = self.cont_encoder(np.asarray([concat_encoding[0]]))
|
||||
# print("cont_encodings", cont_encodings)
|
||||
context_enc = np.transpose(cont_encodings)
|
||||
# print("context_enc", context_enc)
|
||||
|
||||
highest_sim = -5
|
||||
best_i = -1
|
||||
for i, desc_enc in enumerate(desc_encodings):
|
||||
sim = cosine(desc_enc, sent_enc)
|
||||
sim = cosine(desc_enc, context_enc)
|
||||
if sim >= highest_sim:
|
||||
best_i = i
|
||||
highest_sim = sim
|
||||
|
||||
return best_i
|
||||
|
||||
def _predict_random(self, entities, apply_threshold=True):
|
||||
if not apply_threshold:
|
||||
return [float(random.uniform(0, 1)) for _ in entities]
|
||||
else:
|
||||
return [float(1.0) if random.uniform(0, 1) > self.CUTOFF else float(0.0) for _ in entities]
|
||||
|
||||
def _build_cnn(self, embed_width, desc_width, article_width, sent_width, hidden_1_width):
|
||||
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
||||
self.desc_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width,
|
||||
end_width=desc_width)
|
||||
self.article_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width,
|
||||
end_width=article_width)
|
||||
self.sent_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width,
|
||||
end_width=sent_width)
|
||||
self.desc_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width, end_width=desc_width)
|
||||
self.cont_encoder = self._context_encoder(embed_width=embed_width, article_width=article_width,
|
||||
sent_width=sent_width, hidden_width=hidden_1_width,
|
||||
end_width=desc_width)
|
||||
|
||||
|
||||
# def _encoder(self, width):
|
||||
# tok2vec = Tok2Vec(width=width, embed_size=2000, pretrained_vectors=self.nlp.vocab.vectors.name, cnn_maxout_pieces=3,
|
||||
|
@ -264,12 +261,19 @@ class EL_Model:
|
|||
#
|
||||
# return tok2vec >> flatten_add_lengths >> Pooling(mean_pool)
|
||||
|
||||
def _context_encoder(self, embed_width, article_width, sent_width, hidden_width, end_width):
|
||||
self.article_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_width, end_width=article_width)
|
||||
self.sent_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_width, end_width=sent_width)
|
||||
|
||||
model = Affine(end_width, article_width+sent_width, drop_factor=0.0)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _encoder(in_width, hidden_with, end_width):
|
||||
conv_depth = 2
|
||||
cnn_maxout_pieces = 3
|
||||
|
||||
with Model.define_operators({">>": chain}):
|
||||
with Model.define_operators({">>": chain, "**": clone}):
|
||||
convolution = Residual((ExtractWindow(nW=1) >>
|
||||
LN(Maxout(hidden_with, hidden_with * 3, pieces=cnn_maxout_pieces))))
|
||||
|
||||
|
@ -295,62 +299,75 @@ class EL_Model:
|
|||
self.sgd_sent.learn_rate = self.LEARN_RATE
|
||||
self.sgd_sent.L2 = self.L2
|
||||
|
||||
self.sgd_cont = create_default_optimizer(self.cont_encoder.ops)
|
||||
self.sgd_cont.learn_rate = self.LEARN_RATE
|
||||
self.sgd_cont.L2 = self.L2
|
||||
|
||||
self.sgd_desc = create_default_optimizer(self.desc_encoder.ops)
|
||||
self.sgd_desc.learn_rate = self.LEARN_RATE
|
||||
self.sgd_desc.L2 = self.L2
|
||||
|
||||
# self.sgd = create_default_optimizer(self.model.ops)
|
||||
# self.sgd.learn_rate = self.LEARN_RATE
|
||||
# self.sgd.L2 = self.L2
|
||||
|
||||
@staticmethod
|
||||
def get_loss(predictions, golds):
|
||||
loss, gradients = get_cossim_loss(predictions, golds)
|
||||
return loss, gradients
|
||||
|
||||
def update(self, entity_clusters, golds, descs, art_texts, arts, sent_texts, sents):
|
||||
all_clusters = list(entity_clusters.keys())
|
||||
|
||||
arts_list = list()
|
||||
sents_list = list()
|
||||
descs_list = list()
|
||||
|
||||
for cluster, entities in entity_clusters.items():
|
||||
correct_entities = [e for e in entities if golds[e]]
|
||||
incorrect_entities = [e for e in entities if not golds[e]]
|
||||
|
||||
assert len(correct_entities) == 1
|
||||
entities = list(entities)
|
||||
shuffle(entities)
|
||||
|
||||
# article_text = art_texts[arts[cluster]]
|
||||
cluster_sent = sent_texts[sents[cluster]]
|
||||
|
||||
# art_docs = self.nlp.pipe(article_text)
|
||||
sent_doc = self.nlp(cluster_sent)
|
||||
|
||||
art = art_texts[arts[cluster]]
|
||||
sent = sent_texts[sents[cluster]]
|
||||
for e in entities:
|
||||
# TODO: more appropriate loss for the whole cluster (currently only pos entities)
|
||||
if golds[e]:
|
||||
# TODO: more appropriate loss for the whole cluster (currently only pos entities)
|
||||
# TODO: speed up
|
||||
desc_doc = self.nlp(descs[e])
|
||||
arts_list.append(art)
|
||||
sents_list.append(sent)
|
||||
descs_list.append(descs[e])
|
||||
|
||||
# doc_encodings, bp_doc = self.article_encoder.begin_update(art_docs, drop=self.DROP)
|
||||
sent_encodings, bp_sent = self.sent_encoder.begin_update([sent_doc], drop=self.DROP)
|
||||
desc_encodings, bp_desc = self.desc_encoder.begin_update([desc_doc], drop=self.DROP)
|
||||
desc_docs = self.nlp.pipe(descs_list)
|
||||
desc_encodings, bp_desc = self.desc_encoder.begin_update(desc_docs, drop=self.DROP)
|
||||
|
||||
sent_encoding = sent_encodings[0]
|
||||
desc_encoding = desc_encodings[0]
|
||||
art_docs = self.nlp.pipe(arts_list)
|
||||
sent_docs = self.nlp.pipe(sents_list)
|
||||
|
||||
sent_enc = self.sent_encoder.ops.asarray([sent_encoding])
|
||||
desc_enc = self.sent_encoder.ops.asarray([desc_encoding])
|
||||
doc_encodings, bp_doc = self.article_encoder.begin_update(art_docs, drop=self.DROP)
|
||||
sent_encodings, bp_sent = self.sent_encoder.begin_update(sent_docs, drop=self.DROP)
|
||||
|
||||
# print("sent_encoding", type(sent_encoding), sent_encoding)
|
||||
# print("desc_encoding", type(desc_encoding), desc_encoding)
|
||||
# print("getting los for entity", e)
|
||||
concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
|
||||
range(len(all_clusters))]
|
||||
cont_encodings, bp_cont = self.cont_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
|
||||
|
||||
loss, gradient = self.get_loss(sent_enc, desc_enc)
|
||||
# print("sent_encodings", type(sent_encodings), sent_encodings)
|
||||
# print("desc_encodings", type(desc_encodings), desc_encodings)
|
||||
# print("doc_encodings", type(doc_encodings), doc_encodings)
|
||||
# print("getting los for", len(arts_list), "entities")
|
||||
|
||||
# print("gradient", gradient)
|
||||
# print("loss", loss)
|
||||
loss, gradient = self.get_loss(cont_encodings, desc_encodings)
|
||||
|
||||
bp_sent(gradient, sgd=self.sgd_sent)
|
||||
# bp_desc(desc_gradients, sgd=self.sgd_desc) TODO
|
||||
# print()
|
||||
# print("gradient", gradient)
|
||||
if self.PRINT_BATCH_LOSS:
|
||||
print("batch loss", loss)
|
||||
|
||||
context_gradient = bp_cont(gradient, sgd=self.sgd_cont)
|
||||
|
||||
# gradient : concat (doc+sent) vs. desc
|
||||
sent_start = self.ARTICLE_WIDTH
|
||||
sent_gradients = list()
|
||||
doc_gradients = list()
|
||||
for x in context_gradient:
|
||||
doc_gradients.append(list(x[0:sent_start]))
|
||||
sent_gradients.append(list(x[sent_start:]))
|
||||
|
||||
# print("doc_gradients", doc_gradients)
|
||||
# print("sent_gradients", sent_gradients)
|
||||
|
||||
bp_doc(doc_gradients, sgd=self.sgd_article)
|
||||
bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||
|
||||
def _get_training_data(self, training_dir, entity_descr_output, dev, limit, to_print):
|
||||
id_to_descr = kb_creator._get_id_to_description(entity_descr_output)
|
||||
|
|
|
@ -111,7 +111,7 @@ if __name__ == "__main__":
|
|||
print("STEP 6: training", datetime.datetime.now())
|
||||
my_nlp = spacy.load('en_core_web_md')
|
||||
trainer = EL_Model(kb=my_kb, nlp=my_nlp)
|
||||
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=1000, devlimit=100)
|
||||
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, trainlimit=5000, devlimit=100)
|
||||
print()
|
||||
|
||||
# STEP 7: apply the EL algorithm on the dev dataset
|
||||
|
|
Loading…
Reference in New Issue
Block a user