mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-05 22:20:34 +03:00
train and predict per article (saving time for doc encoding)
This commit is contained in:
parent
3b81b00954
commit
4142e8dd1b
|
@ -46,11 +46,11 @@ class EL_Model():
|
||||||
dev_instances, dev_pos, dev_neg, dev_doc = self._get_training_data(training_dir,
|
dev_instances, dev_pos, dev_neg, dev_doc = self._get_training_data(training_dir,
|
||||||
entity_descr_output,
|
entity_descr_output,
|
||||||
True,
|
True,
|
||||||
limit, to_print)
|
limit / 10, to_print)
|
||||||
|
|
||||||
if to_print:
|
if to_print:
|
||||||
print("Training on", len(train_instances), "instance clusters")
|
print("Training on", len(train_instances.values()), "articles")
|
||||||
print("Dev test on", len(dev_instances), "instance clusters")
|
print("Dev test on", len(dev_instances.values()), "articles")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
self.sgd_entity = self.begin_training(self.entity_encoder)
|
self.sgd_entity = self.begin_training(self.entity_encoder)
|
||||||
|
@ -60,49 +60,51 @@ class EL_Model():
|
||||||
|
|
||||||
losses = {}
|
losses = {}
|
||||||
|
|
||||||
for inst_cluster in train_instances:
|
instance_count = 0
|
||||||
pos_ex = train_pos.get(inst_cluster)
|
|
||||||
neg_exs = train_neg.get(inst_cluster, [])
|
for article_id, inst_cluster_set in train_instances.items():
|
||||||
|
article_doc = train_doc[article_id]
|
||||||
|
pos_ex_list = list()
|
||||||
|
neg_exs_list = list()
|
||||||
|
for inst_cluster in inst_cluster_set:
|
||||||
|
instance_count += 1
|
||||||
|
pos_ex_list.append(train_pos.get(inst_cluster))
|
||||||
|
neg_exs_list.append(train_neg.get(inst_cluster, []))
|
||||||
|
|
||||||
|
self.update(article_doc, pos_ex_list, neg_exs_list, losses=losses)
|
||||||
|
p, r, fscore = self._test_dev(dev_instances, dev_pos, dev_neg, dev_doc)
|
||||||
|
print(round(fscore, 1))
|
||||||
|
|
||||||
|
if to_print:
|
||||||
|
print("Trained on", instance_count, "instance clusters")
|
||||||
|
|
||||||
if pos_ex and neg_exs:
|
|
||||||
article = inst_cluster.split(sep="_")[0]
|
|
||||||
entity_id = inst_cluster.split(sep="_")[1]
|
|
||||||
article_doc = train_doc[article]
|
|
||||||
self.update(article_doc, pos_ex, neg_exs, losses=losses)
|
|
||||||
p, r, fscore = self._test_dev(dev_instances, dev_pos, dev_neg, dev_doc)
|
|
||||||
print(round(fscore, 1))
|
|
||||||
# TODO
|
|
||||||
# elif not pos_ex:
|
|
||||||
# print("Weird. Couldn't find pos example for", inst_cluster)
|
|
||||||
# elif not neg_exs:
|
|
||||||
# print("Weird. Couldn't find neg examples for", inst_cluster)
|
|
||||||
|
|
||||||
def _test_dev(self, dev_instances, dev_pos, dev_neg, dev_doc):
|
def _test_dev(self, dev_instances, dev_pos, dev_neg, dev_doc):
|
||||||
predictions = list()
|
predictions = list()
|
||||||
golds = list()
|
golds = list()
|
||||||
|
|
||||||
for inst_cluster in dev_instances:
|
for article_id, inst_cluster_set in dev_instances.items():
|
||||||
pos_ex = dev_pos.get(inst_cluster)
|
for inst_cluster in inst_cluster_set:
|
||||||
neg_exs = dev_neg.get(inst_cluster, [])
|
pos_ex = dev_pos.get(inst_cluster)
|
||||||
ex_to_id = dict()
|
neg_exs = dev_neg.get(inst_cluster, [])
|
||||||
|
ex_to_id = dict()
|
||||||
|
|
||||||
if pos_ex and neg_exs:
|
if pos_ex and neg_exs:
|
||||||
ex_to_id[pos_ex] = pos_ex._.entity_id
|
ex_to_id[pos_ex] = pos_ex._.entity_id
|
||||||
for neg_ex in neg_exs:
|
for neg_ex in neg_exs:
|
||||||
ex_to_id[neg_ex] = neg_ex._.entity_id
|
ex_to_id[neg_ex] = neg_ex._.entity_id
|
||||||
|
|
||||||
article = inst_cluster.split(sep="_")[0]
|
article = inst_cluster.split(sep="_")[0]
|
||||||
entity_id = inst_cluster.split(sep="_")[1]
|
entity_id = inst_cluster.split(sep="_")[1]
|
||||||
article_doc = dev_doc[article]
|
article_doc = dev_doc[article]
|
||||||
|
|
||||||
examples = list(neg_exs)
|
examples = list(neg_exs)
|
||||||
examples.append(pos_ex)
|
examples.append(pos_ex)
|
||||||
shuffle(examples)
|
shuffle(examples)
|
||||||
|
|
||||||
best_entity, lowest_mse = self._predict(examples, article_doc)
|
|
||||||
predictions.append(ex_to_id[best_entity])
|
|
||||||
golds.append(ex_to_id[pos_ex])
|
|
||||||
|
|
||||||
|
best_entity, lowest_mse = self._predict(examples, article_doc)
|
||||||
|
predictions.append(ex_to_id[best_entity])
|
||||||
|
golds.append(ex_to_id[pos_ex])
|
||||||
|
|
||||||
# TODO: use lowest_mse and combine with prior probability
|
# TODO: use lowest_mse and combine with prior probability
|
||||||
p, r, F = run_el.evaluate(predictions, golds, to_print=False)
|
p, r, F = run_el.evaluate(predictions, golds, to_print=False)
|
||||||
|
@ -161,60 +163,79 @@ class EL_Model():
|
||||||
sgd = create_default_optimizer(model.ops)
|
sgd = create_default_optimizer(model.ops)
|
||||||
return sgd
|
return sgd
|
||||||
|
|
||||||
def update(self, article_doc, true_entity, false_entities, drop=0., losses=None):
|
def update(self, article_doc, true_entity_list, false_entities_list, drop=0., losses=None):
|
||||||
|
# TODO: one call only to begin_update ?
|
||||||
|
|
||||||
|
entity_diffs = None
|
||||||
|
doc_diffs = None
|
||||||
|
|
||||||
doc_encoding, article_bp = self.article_encoder.begin_update([article_doc], drop=drop)
|
doc_encoding, article_bp = self.article_encoder.begin_update([article_doc], drop=drop)
|
||||||
|
|
||||||
true_entity_encoding, true_entity_bp = self.entity_encoder.begin_update([true_entity], drop=drop)
|
for i, true_entity in enumerate(true_entity_list):
|
||||||
# print("encoding dim", len(true_entity_encoding[0]))
|
false_entities = false_entities_list[i]
|
||||||
|
|
||||||
consensus_encoding = self._calculate_consensus(doc_encoding, true_entity_encoding)
|
true_entity_encoding, true_entity_bp = self.entity_encoder.begin_update([true_entity], drop=drop)
|
||||||
consensus_encoding_t = consensus_encoding.transpose()
|
# print("encoding dim", len(true_entity_encoding[0]))
|
||||||
|
|
||||||
doc_mse, doc_diffs = self._calculate_similarity(doc_encoding, consensus_encoding)
|
consensus_encoding = self._calculate_consensus(doc_encoding, true_entity_encoding)
|
||||||
|
# consensus_encoding_t = consensus_encoding.transpose()
|
||||||
|
|
||||||
entity_mses = list()
|
doc_mse, doc_diff = self._calculate_similarity(doc_encoding, consensus_encoding)
|
||||||
|
|
||||||
true_mse, true_diffs = self._calculate_similarity(true_entity_encoding, consensus_encoding)
|
entity_mses = list()
|
||||||
# print("true_mse", true_mse)
|
|
||||||
# print("true_diffs", true_diffs)
|
|
||||||
entity_mses.append(true_mse)
|
|
||||||
# true_exp = np.exp(true_entity_encoding.dot(consensus_encoding_t))
|
|
||||||
# print("true_exp", true_exp)
|
|
||||||
|
|
||||||
# false_exp_sum = 0
|
true_mse, true_diffs = self._calculate_similarity(true_entity_encoding, consensus_encoding)
|
||||||
|
# print("true_mse", true_mse)
|
||||||
|
# print("true_diffs", true_diffs)
|
||||||
|
entity_mses.append(true_mse)
|
||||||
|
# true_exp = np.exp(true_entity_encoding.dot(consensus_encoding_t))
|
||||||
|
# print("true_exp", true_exp)
|
||||||
|
|
||||||
for false_entity in false_entities:
|
# false_exp_sum = 0
|
||||||
false_entity_encoding, false_entity_bp = self.entity_encoder.begin_update([false_entity], drop=drop)
|
|
||||||
false_mse, false_diffs = self._calculate_similarity(false_entity_encoding, consensus_encoding)
|
|
||||||
# print("false_mse", false_mse)
|
|
||||||
# false_exp = np.exp(false_entity_encoding.dot(consensus_encoding_t))
|
|
||||||
# print("false_exp", false_exp)
|
|
||||||
# print("false_diffs", false_diffs)
|
|
||||||
entity_mses.append(false_mse)
|
|
||||||
# if false_mse > true_mse:
|
|
||||||
# true_diffs = true_diffs - false_diffs ???
|
|
||||||
# false_exp_sum += false_exp
|
|
||||||
|
|
||||||
# prob = true_exp / false_exp_sum
|
if doc_diffs is not None:
|
||||||
# print("prob", prob)
|
doc_diffs += doc_diff
|
||||||
|
entity_diffs += true_diffs
|
||||||
|
else:
|
||||||
|
doc_diffs = doc_diff
|
||||||
|
entity_diffs = true_diffs
|
||||||
|
|
||||||
entity_mses = sorted(entity_mses)
|
for false_entity in false_entities:
|
||||||
# mse_sum = sum(entity_mses)
|
false_entity_encoding, false_entity_bp = self.entity_encoder.begin_update([false_entity], drop=drop)
|
||||||
# entity_probs = [1 - x/mse_sum for x in entity_mses]
|
false_mse, false_diffs = self._calculate_similarity(false_entity_encoding, consensus_encoding)
|
||||||
# print("entity_mses", entity_mses)
|
# print("false_mse", false_mse)
|
||||||
# print("entity_probs", entity_probs)
|
# false_exp = np.exp(false_entity_encoding.dot(consensus_encoding_t))
|
||||||
true_index = entity_mses.index(true_mse)
|
# print("false_exp", false_exp)
|
||||||
# print("true index", true_index)
|
# print("false_diffs", false_diffs)
|
||||||
# print("true prob", entity_probs[true_index])
|
entity_mses.append(false_mse)
|
||||||
|
# if false_mse > true_mse:
|
||||||
|
# true_diffs = true_diffs - false_diffs ???
|
||||||
|
# false_exp_sum += false_exp
|
||||||
|
|
||||||
# print("training loss", true_mse)
|
# prob = true_exp / false_exp_sum
|
||||||
|
# print("prob", prob)
|
||||||
|
|
||||||
# print()
|
entity_mses = sorted(entity_mses)
|
||||||
|
# mse_sum = sum(entity_mses)
|
||||||
|
# entity_probs = [1 - x/mse_sum for x in entity_mses]
|
||||||
|
# print("entity_mses", entity_mses)
|
||||||
|
# print("entity_probs", entity_probs)
|
||||||
|
true_index = entity_mses.index(true_mse)
|
||||||
|
# print("true index", true_index)
|
||||||
|
# print("true prob", entity_probs[true_index])
|
||||||
|
|
||||||
|
# print("training loss", true_mse)
|
||||||
|
|
||||||
|
# print()
|
||||||
|
|
||||||
# TODO: proper backpropagation taking ranking of elements into account ?
|
# TODO: proper backpropagation taking ranking of elements into account ?
|
||||||
# TODO backpropagation also for negative examples
|
# TODO backpropagation also for negative examples
|
||||||
true_entity_bp(true_diffs, sgd=self.sgd_entity)
|
|
||||||
article_bp(doc_diffs, sgd=self.sgd_article)
|
if doc_diffs is not None:
|
||||||
|
doc_diffs = doc_diffs / len(true_entity_list)
|
||||||
|
|
||||||
|
true_entity_bp(entity_diffs, sgd=self.sgd_entity)
|
||||||
|
article_bp(doc_diffs, sgd=self.sgd_article)
|
||||||
|
|
||||||
|
|
||||||
# TODO delete ?
|
# TODO delete ?
|
||||||
|
@ -268,7 +289,7 @@ class EL_Model():
|
||||||
collect_incorrect=True)
|
collect_incorrect=True)
|
||||||
|
|
||||||
|
|
||||||
instances = list()
|
instance_by_doc = dict()
|
||||||
local_vectors = list() # TODO: local vectors
|
local_vectors = list() # TODO: local vectors
|
||||||
doc_by_article = dict()
|
doc_by_article = dict()
|
||||||
pos_entities = dict()
|
pos_entities = dict()
|
||||||
|
@ -280,18 +301,19 @@ class EL_Model():
|
||||||
if dev == run_el.is_dev(f):
|
if dev == run_el.is_dev(f):
|
||||||
article_id = f.replace(".txt", "")
|
article_id = f.replace(".txt", "")
|
||||||
if cnt % 500 == 0 and to_print:
|
if cnt % 500 == 0 and to_print:
|
||||||
print(datetime.datetime.now(), "processed", cnt, "files in the dev dataset")
|
print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
|
||||||
cnt += 1
|
cnt += 1
|
||||||
if article_id not in doc_by_article:
|
if article_id not in doc_by_article:
|
||||||
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
|
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
|
||||||
text = file.read()
|
text = file.read()
|
||||||
doc = self.nlp(text)
|
doc = self.nlp(text)
|
||||||
doc_by_article[article_id] = doc
|
doc_by_article[article_id] = doc
|
||||||
|
instance_by_doc[article_id] = set()
|
||||||
|
|
||||||
for mention, entity_pos in correct_entries[article_id].items():
|
for mention, entity_pos in correct_entries[article_id].items():
|
||||||
descr = id_to_descr.get(entity_pos)
|
descr = id_to_descr.get(entity_pos)
|
||||||
if descr:
|
if descr:
|
||||||
instances.append(article_id + "_" + mention)
|
instance_by_doc[article_id].add(article_id + "_" + mention)
|
||||||
doc_descr = self.nlp(descr)
|
doc_descr = self.nlp(descr)
|
||||||
doc_descr._.entity_id = entity_pos
|
doc_descr._.entity_id = entity_pos
|
||||||
pos_entities[article_id + "_" + mention] = doc_descr
|
pos_entities[article_id + "_" + mention] = doc_descr
|
||||||
|
@ -308,6 +330,6 @@ class EL_Model():
|
||||||
|
|
||||||
if to_print:
|
if to_print:
|
||||||
print()
|
print()
|
||||||
print("Processed", cnt, "dev articles")
|
print("Processed", cnt, "training articles, dev=" + str(dev))
|
||||||
print()
|
print()
|
||||||
return instances, pos_entities, neg_entities, doc_by_article
|
return instance_by_doc, pos_entities, neg_entities, doc_by_article
|
||||||
|
|
|
@ -111,7 +111,7 @@ if __name__ == "__main__":
|
||||||
print("STEP 6: training ", datetime.datetime.now())
|
print("STEP 6: training ", datetime.datetime.now())
|
||||||
my_nlp = spacy.load('en_core_web_md')
|
my_nlp = spacy.load('en_core_web_md')
|
||||||
trainer = EL_Model(kb=my_kb, nlp=my_nlp)
|
trainer = EL_Model(kb=my_kb, nlp=my_nlp)
|
||||||
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, limit=50)
|
trainer.train_model(training_dir=TRAINING_DIR, entity_descr_output=ENTITY_DESCR, limit=500)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# STEP 7: apply the EL algorithm on the dev dataset
|
# STEP 7: apply the EL algorithm on the dev dataset
|
||||||
|
|
Loading…
Reference in New Issue
Block a user