clean up code, remove old code, move to bin

This commit is contained in:
svlandeg 2019-06-18 13:20:40 +02:00
parent ffae7d3555
commit 0d177c1146
12 changed files with 92 additions and 787 deletions

View File

View File

@ -1,15 +1,13 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import spacy from bin.wiki_entity_linking.train_descriptions import EntityEncoder
from examples.pipeline.wiki_entity_linking.train_descriptions import EntityEncoder
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase
import csv import csv
import datetime import datetime
from . import wikipedia_processor as wp from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp
from . import wikidata_processor as wd
INPUT_DIM = 300 # dimension of pre-trained vectors INPUT_DIM = 300 # dimension of pre-trained vectors
DESC_WIDTH = 64 DESC_WIDTH = 64
@ -34,7 +32,7 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
else: else:
# read the mappings from file # read the mappings from file
title_to_id = _get_entity_to_id(entity_def_output) title_to_id = get_entity_to_id(entity_def_output)
id_to_descr = _get_id_to_description(entity_descr_output) id_to_descr = _get_id_to_description(entity_descr_output)
print() print()
@ -56,7 +54,8 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
frequency_list.append(freq) frequency_list.append(freq)
filtered_title_to_id[title] = entity filtered_title_to_id[title] = entity
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), "titles with filter frequency", min_entity_freq) print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()),
"titles with filter frequency", min_entity_freq)
print() print()
print(" * train entity encoder", datetime.datetime.now()) print(" * train entity encoder", datetime.datetime.now())
@ -101,7 +100,7 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_
descr_file.write(str(qid) + "|" + descr + "\n") descr_file.write(str(qid) + "|" + descr + "\n")
def _get_entity_to_id(entity_def_output): def get_entity_to_id(entity_def_output):
entity_to_id = dict() entity_to_id = dict()
with open(entity_def_output, 'r', encoding='utf8') as csvfile: with open(entity_def_output, 'r', encoding='utf8') as csvfile:
csvreader = csv.reader(csvfile, delimiter='|') csvreader = csv.reader(csvfile, delimiter='|')

View File

@ -55,8 +55,6 @@ class EntityEncoder:
print("Trained on", processed, "entities across", self.EPOCHS, "epochs") print("Trained on", processed, "entities across", self.EPOCHS, "epochs")
print("Final loss:", loss) print("Final loss:", loss)
# self._test_encoder()
def _train_model(self, description_list): def _train_model(self, description_list):
# TODO: when loss gets too low, a 'mean of empty slice' warning is thrown by numpy # TODO: when loss gets too low, a 'mean of empty slice' warning is thrown by numpy
@ -123,40 +121,3 @@ class EntityEncoder:
def _get_loss(golds, scores): def _get_loss(golds, scores):
loss, gradients = get_cossim_loss(scores, golds) loss, gradients = get_cossim_loss(scores, golds)
return loss, gradients return loss, gradients
def _test_encoder(self):
# Test encoder on some dummy examples
desc_A1 = "Fictional character in The Simpsons"
desc_A2 = "Simpsons - fictional human"
desc_A3 = "Fictional character in The Flintstones"
desc_A4 = "Politician from the US"
A1_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A1))])
A2_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A2))])
A3_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A3))])
A4_doc_vector = np.asarray([self._get_doc_embedding(self.nlp(desc_A4))])
loss_a1_a1, _ = get_cossim_loss(A1_doc_vector, A1_doc_vector)
loss_a1_a2, _ = get_cossim_loss(A1_doc_vector, A2_doc_vector)
loss_a1_a3, _ = get_cossim_loss(A1_doc_vector, A3_doc_vector)
loss_a1_a4, _ = get_cossim_loss(A1_doc_vector, A4_doc_vector)
print("sim doc A1 A1", loss_a1_a1)
print("sim doc A1 A2", loss_a1_a2)
print("sim doc A1 A3", loss_a1_a3)
print("sim doc A1 A4", loss_a1_a4)
A1_encoded = self.encoder(A1_doc_vector)
A2_encoded = self.encoder(A2_doc_vector)
A3_encoded = self.encoder(A3_doc_vector)
A4_encoded = self.encoder(A4_doc_vector)
loss_a1_a1, _ = get_cossim_loss(A1_encoded, A1_encoded)
loss_a1_a2, _ = get_cossim_loss(A1_encoded, A2_encoded)
loss_a1_a3, _ = get_cossim_loss(A1_encoded, A3_encoded)
loss_a1_a4, _ = get_cossim_loss(A1_encoded, A4_encoded)
print("sim encoded A1 A1", loss_a1_a1)
print("sim encoded A1 A2", loss_a1_a2)
print("sim encoded A1 A3", loss_a1_a3)
print("sim encoded A1 A4", loss_a1_a4)

View File

@ -7,7 +7,7 @@ import bz2
import datetime import datetime
from spacy.gold import GoldParse from spacy.gold import GoldParse
from . import wikipedia_processor as wp, kb_creator from bin.wiki_entity_linking import kb_creator, wikipedia_processor as wp
""" """
Process Wikipedia interlinks to generate a training dataset for the EL algorithm Process Wikipedia interlinks to generate a training dataset for the EL algorithm
@ -18,7 +18,7 @@ ENTITY_FILE = "gold_entities_1000000.csv" # use this file for faster processin
def create_training(entity_def_input, training_output): def create_training(entity_def_input, training_output):
wp_to_id = kb_creator._get_entity_to_id(entity_def_input) wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wp_to_id, training_output, limit=None) _process_wikipedia_texts(wp_to_id, training_output, limit=None)
@ -71,7 +71,8 @@ def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
elif clean_line == "</page>": elif clean_line == "</page>":
if article_id: if article_id:
try: try:
_process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(), training_output) _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(),
training_output)
except Exception as e: except Exception as e:
print("Error processing article", article_id, article_title, e) print("Error processing article", article_id, article_title, e)
else: else:

View File

@ -13,9 +13,12 @@ def read_wikidata_entities_json(limit=None, to_print=False):
""" Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """ """ Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """
lang = 'en' lang = 'en'
# prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
site_filter = 'enwiki' site_filter = 'enwiki'
# filter currently disabled to get ALL data
prop_filter = dict()
# prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
title_to_id = dict() title_to_id = dict()
id_to_descr = dict() id_to_descr = dict()
@ -25,6 +28,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
parse_labels = False parse_labels = False
parse_descriptions = True parse_descriptions = True
parse_aliases = False parse_aliases = False
parse_claims = False
with bz2.open(WIKIDATA_JSON, mode='rb') as file: with bz2.open(WIKIDATA_JSON, mode='rb') as file:
line = file.readline() line = file.readline()
@ -45,14 +49,15 @@ def read_wikidata_entities_json(limit=None, to_print=False):
keep = True keep = True
claims = obj["claims"] claims = obj["claims"]
# for prop, value_set in prop_filter.items(): if parse_claims:
# claim_property = claims.get(prop, None) for prop, value_set in prop_filter.items():
# if claim_property: claim_property = claims.get(prop, None)
# for cp in claim_property: if claim_property:
# cp_id = cp['mainsnak'].get('datavalue', {}).get('value', {}).get('id') for cp in claim_property:
# cp_rank = cp['rank'] cp_id = cp['mainsnak'].get('datavalue', {}).get('value', {}).get('id')
# if cp_rank != "deprecated" and cp_id in value_set: cp_rank = cp['rank']
# keep = True if cp_rank != "deprecated" and cp_id in value_set:
keep = True
if keep: if keep:
unique_id = obj["id"] unique_id = obj["id"]
@ -64,8 +69,10 @@ def read_wikidata_entities_json(limit=None, to_print=False):
# parsing all properties that refer to other entities # parsing all properties that refer to other entities
if parse_properties: if parse_properties:
for prop, claim_property in claims.items(): for prop, claim_property in claims.items():
cp_dicts = [cp['mainsnak']['datavalue'].get('value') for cp in claim_property if cp['mainsnak'].get('datavalue')] cp_dicts = [cp['mainsnak']['datavalue'].get('value') for cp in claim_property
cp_values = [cp_dict.get('id') for cp_dict in cp_dicts if isinstance(cp_dict, dict) if cp_dict.get('id') is not None] if cp['mainsnak'].get('datavalue')]
cp_values = [cp_dict.get('id') for cp_dict in cp_dicts if isinstance(cp_dict, dict)
if cp_dict.get('id') is not None]
if cp_values: if cp_values:
if to_print: if to_print:
print("prop:", prop, cp_values) print("prop:", prop, cp_values)
@ -104,7 +111,7 @@ def read_wikidata_entities_json(limit=None, to_print=False):
if lang_aliases: if lang_aliases:
for item in lang_aliases: for item in lang_aliases:
if to_print: if to_print:
print("alias (" + lang + "):", item["value"]) print("alias (" + lang + "):", item["value"])
if to_print: if to_print:
print() print()

View File

@ -26,8 +26,8 @@ wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons",
"mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki", "mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki",
"Portal", "phab", "Phabricator", "Project", "q", "quality", "rev", "Portal", "phab", "Phabricator", "Project", "q", "quality", "rev",
"s", "spcom", "Special", "species", "Strategy", "sulutil", "svn", "s", "spcom", "Special", "species", "Strategy", "sulutil", "svn",
"Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools", "tswiki", "Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools",
"User", "User talk", "v", "voy", "tswiki", "User", "User talk", "v", "voy",
"w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews", "w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews",
"Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech", "Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech",
"Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"] "Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"]

View File

@ -1,136 +0,0 @@
# coding: utf-8
from __future__ import unicode_literals
import os
import spacy
import datetime
from os import listdir
from examples.pipeline.wiki_entity_linking import training_set_creator
# requires: pip install neuralcoref --no-binary neuralcoref
# import neuralcoref
def run_kb_toy_example(kb):
for mention in ("Bush", "Douglas Adams", "Homer"):
candidates = kb.get_candidates(mention)
print("generating candidates for " + mention + " :")
for c in candidates:
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
print()
def run_el_dev(nlp, kb, training_dir, limit=None):
correct_entries_per_article, _ = training_set_creator.read_training_entities(training_output=training_dir)
predictions = list()
golds = list()
cnt = 0
for f in listdir(training_dir):
if not limit or cnt < limit:
if is_dev(f):
article_id = f.replace(".txt", "")
if cnt % 500 == 0:
print(datetime.datetime.now(), "processed", cnt, "files in the dev dataset")
cnt += 1
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
text = file.read()
doc = nlp(text)
for ent in doc.ents:
if ent.label_ == "PERSON": # TODO: expand to other types
gold_entity = correct_entries_per_article[article_id].get(ent.text, None)
# only evaluating gold entities we know, because the training data is not complete
if gold_entity:
predictions.append(ent.kb_id_)
golds.append(gold_entity)
print("Processed", cnt, "dev articles")
print()
evaluate(predictions, golds)
def is_dev(file_name):
return file_name.endswith("3.txt")
def evaluate(predictions, golds, to_print=True, times_hundred=True):
if len(predictions) != len(golds):
raise ValueError("predictions and gold entities should have the same length")
tp = 0
fp = 0
fn = 0
corrects = 0
incorrects = 0
for pred, gold in zip(predictions, golds):
is_correct = pred == gold
if is_correct:
corrects += 1
else:
incorrects += 1
if not pred:
if not is_correct: # we don't care about tn
fn += 1
elif is_correct:
tp += 1
else:
fp += 1
if to_print:
print("Evaluating", len(golds), "entities")
print("tp", tp)
print("fp", fp)
print("fn", fn)
precision = tp / (tp + fp + 0.0000001)
recall = tp / (tp + fn + 0.0000001)
if times_hundred:
precision = precision*100
recall = recall*100
fscore = 2 * recall * precision / (recall + precision + 0.0000001)
accuracy = corrects / (corrects + incorrects)
if to_print:
print("precision", round(precision, 1), "%")
print("recall", round(recall, 1), "%")
print("Fscore", round(fscore, 1), "%")
print("Accuracy", round(accuracy, 1), "%")
return precision, recall, fscore, accuracy
# TODO
def add_coref(nlp):
""" Add coreference resolution to our model """
# TODO: this doesn't work yet
# neuralcoref.add_to_pipe(nlp)
print("done adding to pipe")
doc = nlp(u'My sister has a dog. She loves him.')
print("done doc")
print(doc._.has_coref)
print(doc._.coref_clusters)
# TODO
def _run_ner_depr(nlp, clean_text, article_dict):
doc = nlp(clean_text)
for ent in doc.ents:
if ent.label_ == "PERSON": # TODO: expand to non-persons
ent_id = article_dict.get(ent.text)
if ent_id:
print(" -", ent.text, ent.label_, ent_id)
else:
print(" -", ent.text, ent.label_, '???') # TODO: investigate these cases

View File

@ -1,490 +0,0 @@
# coding: utf-8
from __future__ import unicode_literals
import os
import datetime
from os import listdir
import numpy as np
import random
from random import shuffle
from thinc.neural._classes.convolution import ExtractWindow
from thinc.neural.util import get_array_module
from examples.pipeline.wiki_entity_linking import run_el, training_set_creator, kb_creator
from spacy._ml import SpacyVectors, create_default_optimizer, zero_init, cosine
from thinc.api import chain, concatenate, flatten_add_lengths, clone, with_flatten
from thinc.v2v import Model, Maxout, Affine
from thinc.t2v import Pooling, mean_pool
from thinc.t2t import ParametricAttention
from thinc.misc import Residual
from thinc.misc import LayerNorm as LN
# from spacy.cli.pretrain import get_cossim_loss
from spacy.matcher import PhraseMatcher
class EL_Model:
PRINT_INSPECT = False
PRINT_BATCH_LOSS = False
EPS = 0.0000000005
BATCH_SIZE = 100
DOC_CUTOFF = 300 # number of characters from the doc context
INPUT_DIM = 300 # dimension of pre-trained vectors
HIDDEN_1_WIDTH = 32
DESC_WIDTH = 64
ARTICLE_WIDTH = 128
SENT_WIDTH = 64
DROP = 0.4
LEARN_RATE = 0.005
EPOCHS = 10
L2 = 1e-6
name = "entity_linker"
def __init__(self, kb, nlp):
run_el._prepare_pipeline(nlp, kb)
self.nlp = nlp
self.kb = kb
self._build_cnn(embed_width=self.INPUT_DIM,
desc_width=self.DESC_WIDTH,
article_width=self.ARTICLE_WIDTH,
sent_width=self.SENT_WIDTH,
hidden_1_width=self.HIDDEN_1_WIDTH)
def train_model(self, training_dir, entity_descr_output, trainlimit=None, devlimit=None, to_print=True):
np.seterr(divide="raise", over="warn", under="ignore", invalid="raise")
id_to_descr = kb_creator._get_id_to_description(entity_descr_output)
train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts = \
self._get_training_data(training_dir, id_to_descr, False, trainlimit, to_print=False)
train_clusters = list(train_ent.keys())
dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts = \
self._get_training_data(training_dir, id_to_descr, True, devlimit, to_print=False)
dev_clusters = list(dev_ent.keys())
dev_pos_count = len([g for g in dev_gold.values() if g])
dev_neg_count = len([g for g in dev_gold.values() if not g])
# inspect data
if self.PRINT_INSPECT:
for cluster, entities in train_ent.items():
print()
for entity in entities:
print("entity", entity)
print("gold", train_gold[entity])
print("desc", train_desc[entity])
print("sentence ID", train_sent[entity])
print("sentence text", train_sent_texts[train_sent[entity]])
print("article ID", train_art[entity])
print("article text", train_art_texts[train_art[entity]])
print()
train_pos_entities = [k for k, v in train_gold.items() if v]
train_neg_entities = [k for k, v in train_gold.items() if not v]
train_pos_count = len(train_pos_entities)
train_neg_count = len(train_neg_entities)
self._begin_training()
if to_print:
print()
print("Training on", len(train_clusters), "entity clusters in", len(train_art_texts), "articles")
print("Training instances pos/neg:", train_pos_count, train_neg_count)
print()
print("Dev test on", len(dev_clusters), "entity clusters in", len(dev_art_texts), "articles")
print("Dev instances pos/neg:", dev_pos_count, dev_neg_count)
print()
print(" DOC_CUTOFF", self.DOC_CUTOFF)
print(" INPUT_DIM", self.INPUT_DIM)
print(" HIDDEN_1_WIDTH", self.HIDDEN_1_WIDTH)
print(" DESC_WIDTH", self.DESC_WIDTH)
print(" ARTICLE_WIDTH", self.ARTICLE_WIDTH)
print(" SENT_WIDTH", self.SENT_WIDTH)
print(" DROP", self.DROP)
print(" LEARNING RATE", self.LEARN_RATE)
print(" BATCH SIZE", self.BATCH_SIZE)
print()
dev_random = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
calc_random=True)
print("acc", "dev_random", round(dev_random, 2))
dev_pre = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts,
avg=True)
print("acc", "dev_pre", round(dev_pre, 2))
print()
processed = 0
for i in range(self.EPOCHS):
shuffle(train_clusters)
start = 0
stop = min(self.BATCH_SIZE, len(train_clusters))
while start < len(train_clusters):
next_batch = {c: train_ent[c] for c in train_clusters[start:stop]}
processed += len(next_batch.keys())
self.update(entity_clusters=next_batch, golds=train_gold, descs=train_desc,
art_texts=train_art_texts, arts=train_art,
sent_texts=train_sent_texts, sents=train_sent)
start = start + self.BATCH_SIZE
stop = min(stop + self.BATCH_SIZE, len(train_clusters))
train_acc = self._test_dev(train_ent, train_gold, train_desc, train_art, train_art_texts, train_sent, train_sent_texts, avg=True)
dev_acc = self._test_dev(dev_ent, dev_gold, dev_desc, dev_art, dev_art_texts, dev_sent, dev_sent_texts, avg=True)
print(i, "acc train/dev", round(train_acc, 2), round(dev_acc, 2))
if to_print:
print()
print("Trained on", processed, "entity clusters across", self.EPOCHS, "epochs")
def _test_dev(self, entity_clusters, golds, descs, arts, art_texts, sents, sent_texts, avg=True, calc_random=False):
correct = 0
incorrect = 0
if calc_random:
for cluster, entities in entity_clusters.items():
correct_entities = [e for e in entities if golds[e]]
assert len(correct_entities) == 1
entities = list(entities)
shuffle(entities)
if calc_random:
predicted_entity = random.choice(entities)
if predicted_entity in correct_entities:
correct += 1
else:
incorrect += 1
else:
all_clusters = list()
arts_list = list()
sents_list = list()
for cluster in entity_clusters.keys():
all_clusters.append(cluster)
arts_list.append(art_texts[arts[cluster]])
sents_list.append(sent_texts[sents[cluster]])
art_docs = list(self.nlp.pipe(arts_list))
sent_docs = list(self.nlp.pipe(sents_list))
for i, cluster in enumerate(all_clusters):
entities = entity_clusters[cluster]
correct_entities = [e for e in entities if golds[e]]
assert len(correct_entities) == 1
entities = list(entities)
shuffle(entities)
desc_docs = self.nlp.pipe([descs[e] for e in entities])
sent_doc = sent_docs[i]
article_doc = art_docs[i]
predicted_index = self._predict(article_doc=article_doc, sent_doc=sent_doc,
desc_docs=desc_docs, avg=avg)
if entities[predicted_index] in correct_entities:
correct += 1
else:
incorrect += 1
if correct == incorrect == 0:
return 0
acc = correct / (correct + incorrect)
return acc
def _predict(self, article_doc, sent_doc, desc_docs, avg=True, apply_threshold=True):
if avg:
with self.article_encoder.use_params(self.sgd_article.averages) \
and self.desc_encoder.use_params(self.sgd_desc.averages)\
and self.sent_encoder.use_params(self.sgd_sent.averages):
desc_encodings = self.desc_encoder(desc_docs)
doc_encoding = self.article_encoder([article_doc])
sent_encoding = self.sent_encoder([sent_doc])
else:
desc_encodings = self.desc_encoder(desc_docs)
doc_encoding = self.article_encoder([article_doc])
sent_encoding = self.sent_encoder([sent_doc])
concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
if avg:
with self.cont_encoder.use_params(self.sgd_cont.averages):
cont_encodings = self.cont_encoder(np.asarray([concat_encoding[0]]))
else:
cont_encodings = self.cont_encoder(np.asarray([concat_encoding[0]]))
context_enc = np.transpose(cont_encodings)
highest_sim = -5
best_i = -1
for i, desc_enc in enumerate(desc_encodings):
sim = cosine(desc_enc, context_enc)
if sim >= highest_sim:
best_i = i
highest_sim = sim
return best_i
def _build_cnn(self, embed_width, desc_width, article_width, sent_width, hidden_1_width):
self.desc_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_1_width, end_width=desc_width)
self.cont_encoder = self._context_encoder(embed_width=embed_width, article_width=article_width,
sent_width=sent_width, hidden_width=hidden_1_width,
end_width=desc_width)
# def _encoder(self, width):
# tok2vec = Tok2Vec(width=width, embed_size=2000, pretrained_vectors=self.nlp.vocab.vectors.name, cnn_maxout_pieces=3,
# subword_features=False, conv_depth=4, bilstm_depth=0)
#
# return tok2vec >> flatten_add_lengths >> Pooling(mean_pool)
def _context_encoder(self, embed_width, article_width, sent_width, hidden_width, end_width):
self.article_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_width, end_width=article_width)
self.sent_encoder = self._encoder(in_width=embed_width, hidden_with=hidden_width, end_width=sent_width)
model = Affine(end_width, article_width+sent_width, drop_factor=0.0)
return model
@staticmethod
def _encoder(in_width, hidden_with, end_width):
conv_depth = 2
cnn_maxout_pieces = 3
with Model.define_operators({">>": chain, "**": clone}):
convolution = Residual((ExtractWindow(nW=1) >>
LN(Maxout(hidden_with, hidden_with * 3, pieces=cnn_maxout_pieces))))
encoder = SpacyVectors \
>> with_flatten(LN(Maxout(hidden_with, in_width)) >> convolution ** conv_depth, pad=conv_depth) \
>> flatten_add_lengths \
>> ParametricAttention(hidden_with)\
>> Pooling(mean_pool) \
>> Residual(zero_init(Maxout(hidden_with, hidden_with))) \
>> zero_init(Affine(end_width, hidden_with, drop_factor=0.0))
# TODO: ReLu or LN(Maxout) ?
# sum_pool or mean_pool ?
return encoder
def _begin_training(self):
self.sgd_article = create_default_optimizer(self.article_encoder.ops)
self.sgd_article.learn_rate = self.LEARN_RATE
self.sgd_article.L2 = self.L2
self.sgd_sent = create_default_optimizer(self.sent_encoder.ops)
self.sgd_sent.learn_rate = self.LEARN_RATE
self.sgd_sent.L2 = self.L2
self.sgd_cont = create_default_optimizer(self.cont_encoder.ops)
self.sgd_cont.learn_rate = self.LEARN_RATE
self.sgd_cont.L2 = self.L2
self.sgd_desc = create_default_optimizer(self.desc_encoder.ops)
self.sgd_desc.learn_rate = self.LEARN_RATE
self.sgd_desc.L2 = self.L2
def get_loss(self, pred, gold, targets):
loss, gradients = self.get_cossim_loss(pred, gold, targets)
return loss, gradients
def get_cossim_loss(self, yh, y, t):
# Add a small constant to avoid 0 vectors
# print()
# print("yh", yh)
# print("y", y)
# print("t", t)
yh = yh + 1e-8
y = y + 1e-8
# https://math.stackexchange.com/questions/1923613/partial-derivative-of-cosine-similarity
xp = get_array_module(yh)
norm_yh = xp.linalg.norm(yh, axis=1, keepdims=True)
norm_y = xp.linalg.norm(y, axis=1, keepdims=True)
mul_norms = norm_yh * norm_y
cos = (yh * y).sum(axis=1, keepdims=True) / mul_norms
# print("cos", cos)
d_yh = (y / mul_norms) - (cos * (yh / norm_yh ** 2))
# print("abs", xp.abs(cos - t))
loss = xp.abs(cos - t).sum()
# print("loss", loss)
# print("d_yh", d_yh)
inverse = np.asarray([int(t[i][0]) * d_yh[i] for i in range(len(t))])
# print("inverse", inverse)
return loss, -inverse
def update(self, entity_clusters, golds, descs, art_texts, arts, sent_texts, sents):
arts_list = list()
sents_list = list()
descs_list = list()
targets = list()
for cluster, entities in entity_clusters.items():
art = art_texts[arts[cluster]]
sent = sent_texts[sents[cluster]]
for e in entities:
if golds[e]:
arts_list.append(art)
sents_list.append(sent)
descs_list.append(descs[e])
targets.append([1])
# else:
# arts_list.append(art)
# sents_list.append(sent)
# descs_list.append(descs[e])
# targets.append([-1])
desc_docs = self.nlp.pipe(descs_list)
desc_encodings, bp_desc = self.desc_encoder.begin_update(desc_docs, drop=self.DROP)
art_docs = self.nlp.pipe(arts_list)
sent_docs = self.nlp.pipe(sents_list)
doc_encodings, bp_doc = self.article_encoder.begin_update(art_docs, drop=self.DROP)
sent_encodings, bp_sent = self.sent_encoder.begin_update(sent_docs, drop=self.DROP)
concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in
range(len(targets))]
cont_encodings, bp_cont = self.cont_encoder.begin_update(np.asarray(concat_encodings), drop=self.DROP)
loss, cont_gradient = self.get_loss(cont_encodings, desc_encodings, targets)
# loss, desc_gradient = self.get_loss(desc_encodings, cont_encodings, targets)
# cont_gradient = cont_gradient / 2
# desc_gradient = desc_gradient / 2
# bp_desc(desc_gradient, sgd=self.sgd_desc)
if self.PRINT_BATCH_LOSS:
print("batch loss", loss)
context_gradient = bp_cont(cont_gradient, sgd=self.sgd_cont)
# gradient : concat (doc+sent) vs. desc
sent_start = self.ARTICLE_WIDTH
sent_gradients = list()
doc_gradients = list()
for x in context_gradient:
doc_gradients.append(list(x[0:sent_start]))
sent_gradients.append(list(x[sent_start:]))
bp_doc(doc_gradients, sgd=self.sgd_article)
bp_sent(sent_gradients, sgd=self.sgd_sent)
def _get_training_data(self, training_dir, id_to_descr, dev, limit, to_print):
correct_entries, incorrect_entries = training_set_creator.read_training_entities(training_output=training_dir)
entities_by_cluster = dict()
gold_by_entity = dict()
desc_by_entity = dict()
article_by_cluster = dict()
text_by_article = dict()
sentence_by_cluster = dict()
text_by_sentence = dict()
sentence_by_text = dict()
cnt = 0
next_entity_nr = 1
next_sent_nr = 1
files = listdir(training_dir)
shuffle(files)
for f in files:
if not limit or cnt < limit:
if dev == run_el.is_dev(f):
article_id = f.replace(".txt", "")
if cnt % 500 == 0 and to_print:
print(datetime.datetime.now(), "processed", cnt, "files in the training dataset")
try:
# parse the article text
with open(os.path.join(training_dir, f), mode="r", encoding='utf8') as file:
text = file.read()
article_doc = self.nlp(text)
truncated_text = text[0:min(self.DOC_CUTOFF, len(text))]
text_by_article[article_id] = truncated_text
# process all positive and negative entities, collect all relevant mentions in this article
for mention, entity_pos in correct_entries[article_id].items():
cluster = article_id + "_" + mention
descr = id_to_descr.get(entity_pos)
entities = set()
if descr:
entity = "E_" + str(next_entity_nr) + "_" + cluster
next_entity_nr += 1
gold_by_entity[entity] = 1
desc_by_entity[entity] = descr
entities.add(entity)
entity_negs = incorrect_entries[article_id][mention]
for entity_neg in entity_negs:
descr = id_to_descr.get(entity_neg)
if descr:
entity = "E_" + str(next_entity_nr) + "_" + cluster
next_entity_nr += 1
gold_by_entity[entity] = 0
desc_by_entity[entity] = descr
entities.add(entity)
found_matches = 0
if len(entities) > 1:
entities_by_cluster[cluster] = entities
# find all matches in the doc for the mentions
# TODO: fix this - doesn't look like all entities are found
matcher = PhraseMatcher(self.nlp.vocab)
patterns = list(self.nlp.tokenizer.pipe([mention]))
matcher.add("TerminologyList", None, *patterns)
matches = matcher(article_doc)
# store sentences
for match_id, start, end in matches:
span = article_doc[start:end]
if mention == span.text:
found_matches += 1
sent_text = span.sent.text
sent_nr = sentence_by_text.get(sent_text, None)
if sent_nr is None:
sent_nr = "S_" + str(next_sent_nr) + article_id
next_sent_nr += 1
text_by_sentence[sent_nr] = sent_text
sentence_by_text[sent_text] = sent_nr
article_by_cluster[cluster] = article_id
sentence_by_cluster[cluster] = sent_nr
if found_matches == 0:
# print("Could not find neg instances or sentence matches for", mention, "in", article_id)
entities_by_cluster.pop(cluster, None)
article_by_cluster.pop(cluster, None)
sentence_by_cluster.pop(cluster, None)
for entity in entities:
gold_by_entity.pop(entity, None)
desc_by_entity.pop(entity, None)
cnt += 1
except:
print("Problem parsing article", article_id)
if to_print:
print()
print("Processed", cnt, "training articles, dev=" + str(dev))
print()
return entities_by_cluster, gold_by_entity, desc_by_entity, article_by_cluster, text_by_article, \
sentence_by_cluster, text_by_sentence

View File

@ -5,8 +5,8 @@ import random
from spacy.util import minibatch, compounding from spacy.util import minibatch, compounding
from examples.pipeline.wiki_entity_linking import wikipedia_processor as wp, kb_creator, training_set_creator, run_el from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp
from examples.pipeline.wiki_entity_linking.kb_creator import DESC_WIDTH from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
import spacy import spacy
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase
@ -30,9 +30,11 @@ TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
MAX_CANDIDATES = 10 MAX_CANDIDATES = 10
MIN_ENTITY_FREQ = 20 MIN_ENTITY_FREQ = 20
MIN_PAIR_OCC = 5 MIN_PAIR_OCC = 5
DOC_SENT_CUTOFF = 2
EPOCHS = 10 EPOCHS = 10
DROPOUT = 0.1 DROPOUT = 0.1
LEARN_RATE = 0.005
L2 = 1e-6
def run_pipeline(): def run_pipeline():
@ -40,7 +42,6 @@ def run_pipeline():
print() print()
nlp_1 = spacy.load('en_core_web_lg') nlp_1 = spacy.load('en_core_web_lg')
nlp_2 = None nlp_2 = None
kb_1 = None
kb_2 = None kb_2 = None
# one-time methods to create KB and write to file # one-time methods to create KB and write to file
@ -114,7 +115,7 @@ def run_pipeline():
# test KB # test KB
if to_test_kb: if to_test_kb:
test_kb(kb_2) check_kb(kb_2)
print() print()
# STEP 5: create a training dataset from WP # STEP 5: create a training dataset from WP
@ -122,19 +123,21 @@ def run_pipeline():
print("STEP 5: create training dataset", datetime.datetime.now()) print("STEP 5: create training dataset", datetime.datetime.now())
training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR) training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
# STEP 6: create the entity linking pipe # STEP 6: create and train the entity linking pipe
el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_SENT_CUTOFF}) el_pipe = nlp_2.create_pipe(name='entity_linker', config={})
el_pipe.set_kb(kb_2) el_pipe.set_kb(kb_2)
nlp_2.add_pipe(el_pipe, last=True) nlp_2.add_pipe(el_pipe, last=True)
other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"] other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"]
with nlp_2.disable_pipes(*other_pipes): # only train Entity Linking with nlp_2.disable_pipes(*other_pipes): # only train Entity Linking
nlp_2.begin_training() optimizer = nlp_2.begin_training()
optimizer.learn_rate = LEARN_RATE
optimizer.L2 = L2
if train_pipe: if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
train_limit = 25000 train_limit = 25000
dev_limit = 1000 dev_limit = 5000
train_data = training_set_creator.read_training(nlp=nlp_2, train_data = training_set_creator.read_training(nlp=nlp_2,
training_dir=TRAINING_DIR, training_dir=TRAINING_DIR,
@ -144,6 +147,14 @@ def run_pipeline():
print("Training on", len(train_data), "articles") print("Training on", len(train_data), "articles")
print() print()
dev_data = training_set_creator.read_training(nlp=nlp_2,
training_dir=TRAINING_DIR,
dev=True,
limit=dev_limit)
print("Dev testing on", len(dev_data), "articles")
print()
if not train_data: if not train_data:
print("Did not find any training data") print("Did not find any training data")
@ -161,53 +172,55 @@ def run_pipeline():
nlp_2.update( nlp_2.update(
docs, docs,
golds, golds,
sgd=optimizer,
drop=DROPOUT, drop=DROPOUT,
losses=losses, losses=losses,
) )
batchnr += 1 batchnr += 1
except Exception as e: except Exception as e:
print("Error updating batch:", e) print("Error updating batch:", e)
raise(e)
if batchnr > 0: if batchnr > 0:
losses['entity_linker'] = losses['entity_linker'] / batchnr with el_pipe.model.use_params(optimizer.averages):
print("Epoch, train loss", itn, round(losses['entity_linker'], 2)) el_pipe.context_weight = 1
el_pipe.prior_weight = 0
dev_data = training_set_creator.read_training(nlp=nlp_2, dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
training_dir=TRAINING_DIR, losses['entity_linker'] = losses['entity_linker'] / batchnr
dev=True, print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
limit=dev_limit) " / dev acc context avg", round(dev_acc_context, 3))
print()
print("Dev testing on", len(dev_data), "articles")
# STEP 7: measure the performance of our trained pipe on an independent dev set
if len(dev_data) and measure_performance: if len(dev_data) and measure_performance:
print() print()
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now()) print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
print() print()
acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label = _measure_baselines(dev_data, kb_2) acc_r, acc_r_by_label, acc_p, acc_p_by_label, acc_o, acc_o_by_label = _measure_baselines(dev_data, kb_2)
print("dev acc oracle:", round(acc_oracle, 3), [(x, round(y, 3)) for x, y in acc_oracle_by_label.items()]) print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_by_label.items()])
print("dev acc random:", round(acc_random, 3), [(x, round(y, 3)) for x, y in acc_random_by_label.items()]) print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_by_label.items()])
print("dev acc prior:", round(acc_prior, 3), [(x, round(y, 3)) for x, y in acc_prior_by_label.items()]) print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_by_label.items()])
# print(" measuring accuracy 1-1") with el_pipe.model.use_params(optimizer.averages):
el_pipe.context_weight = 1 # measuring combined accuracy (prior + context)
el_pipe.prior_weight = 1 el_pipe.context_weight = 1
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe) el_pipe.prior_weight = 1
print("dev acc combo:", round(dev_acc_combo, 3), [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]) dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe)
print("dev acc combo avg:", round(dev_acc_combo, 3),
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
# using only context # using only context
el_pipe.context_weight = 1 el_pipe.context_weight = 1
el_pipe.prior_weight = 0 el_pipe.prior_weight = 0
dev_acc_context, dev_acc_1_0_dict = _measure_accuracy(dev_data, el_pipe) dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
print("dev acc context:", round(dev_acc_context, 3), [(x, round(y, 3)) for x, y in dev_acc_1_0_dict.items()]) print("dev acc context avg:", round(dev_acc_context, 3),
print() [(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
print()
# reset for follow-up tests # reset for follow-up tests
el_pipe.context_weight = 1 el_pipe.context_weight = 1
el_pipe.prior_weight = 1 el_pipe.prior_weight = 1
# STEP 8: apply the EL pipe on a toy example
if to_test_pipeline: if to_test_pipeline:
print() print()
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now()) print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
@ -215,6 +228,7 @@ def run_pipeline():
run_el_toy_example(nlp=nlp_2) run_el_toy_example(nlp=nlp_2)
print() print()
# STEP 9: write the NLP pipeline (including entity linker) to file
if to_write_nlp: if to_write_nlp:
print() print()
print("STEP 9: testing NLP IO", datetime.datetime.now()) print("STEP 9: testing NLP IO", datetime.datetime.now())
@ -225,6 +239,7 @@ def run_pipeline():
print("reading from", NLP_2_DIR) print("reading from", NLP_2_DIR)
nlp_3 = spacy.load(NLP_2_DIR) nlp_3 = spacy.load(NLP_2_DIR)
# verify that the IO has gone correctly
if to_read_nlp: if to_read_nlp:
print() print()
print("running toy example with NLP 2") print("running toy example with NLP 2")
@ -272,6 +287,7 @@ def _measure_accuracy(data, el_pipe):
def _measure_baselines(data, kb): def _measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
random_correct_by_label = dict() random_correct_by_label = dict()
random_incorrect_by_label = dict() random_incorrect_by_label = dict()
@ -362,7 +378,7 @@ def calculate_acc(correct_by_label, incorrect_by_label):
return acc, acc_by_label return acc, acc_by_label
def test_kb(kb): def check_kb(kb):
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"): for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
candidates = kb.get_candidates(mention) candidates = kb.get_candidates(mention)
@ -384,7 +400,7 @@ def run_el_toy_example(nlp):
print() print()
# Q4426480 is her husband # Q4426480 is her husband
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine. "\ text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\
"She loved her husband William King dearly. " "She loved her husband William King dearly. "
doc = nlp(text) doc = nlp(text)
print(text) print(text)
@ -393,7 +409,7 @@ def run_el_toy_example(nlp):
print() print()
# Q3568763 is her tutor # Q3568763 is her tutor
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine. "\ text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\
"She was tutored by her favorite physics tutor William King." "She was tutored by her favorite physics tutor William King."
doc = nlp(text) doc = nlp(text)
print(text) print(text)

View File

@ -661,10 +661,11 @@ def build_nel_encoder(in_width, hidden_width, end_width, **cfg):
LN(Maxout(hidden_width, hidden_width * 3, pieces=cnn_maxout_pieces)))) LN(Maxout(hidden_width, hidden_width * 3, pieces=cnn_maxout_pieces))))
encoder = SpacyVectors \ encoder = SpacyVectors \
>> with_flatten(LN(Maxout(hidden_width, in_width)) >> convolution ** conv_depth, pad=conv_depth) \ >> with_flatten(Affine(hidden_width, in_width))\
>> with_flatten(LN(Maxout(hidden_width, hidden_width)) >> convolution ** conv_depth, pad=conv_depth) \
>> flatten_add_lengths \ >> flatten_add_lengths \
>> ParametricAttention(hidden_width) \ >> ParametricAttention(hidden_width) \
>> Pooling(mean_pool) \ >> Pooling(sum_pool) \
>> Residual(zero_init(Maxout(hidden_width, hidden_width))) \ >> Residual(zero_init(Maxout(hidden_width, hidden_width))) \
>> zero_init(Affine(end_width, hidden_width, drop_factor=0.0)) >> zero_init(Affine(end_width, hidden_width, drop_factor=0.0))

View File

@ -1078,33 +1078,19 @@ class EntityLinker(Pipe):
raise ValueError("entity_width not found") raise ValueError("entity_width not found")
embed_width = cfg.get("embed_width", 300) embed_width = cfg.get("embed_width", 300)
hidden_width = cfg.get("hidden_width", 32) hidden_width = cfg.get("hidden_width", 128)
entity_width = cfg.get("entity_width") # no default because this needs to correspond with the KB
sent_width = entity_width # no default because this needs to correspond with the KB entity length
sent_width = cfg.get("entity_width")
model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg) model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
# dimension of the mention encoder needs to match the dimension of the entity encoder
# article_width = cfg.get("article_width", 128)
# sent_width = cfg.get("sent_width", 64)
# article_encoder = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=article_width, **cfg)
# mention_width = article_width + sent_width
# mention_encoder = Affine(entity_width, mention_width, drop_factor=0.0)
# return article_encoder, sent_encoder, mention_encoder
return model return model
def __init__(self, **cfg): def __init__(self, **cfg):
# self.article_encoder = True
# self.sent_encoder = True
# self.mention_encoder = True
self.model = True self.model = True
self.kb = None self.kb = None
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.doc_cutoff = self.cfg.get("doc_cutoff", 5)
# self.sgd_article = None
# self.sgd_sent = None
# self.sgd_mention = None
def set_kb(self, kb): def set_kb(self, kb):
self.kb = kb self.kb = kb
@ -1131,13 +1117,6 @@ class EntityLinker(Pipe):
sgd = self.create_optimizer() sgd = self.create_optimizer()
return sgd return sgd
# if self.mention_encoder is True:
# self.article_encoder, self.sent_encoder, self.mention_encoder = self.Model(**self.cfg)
# self.sgd_article = create_default_optimizer(self.article_encoder.ops)
# self.sgd_sent = create_default_optimizer(self.sent_encoder.ops)
# self.sgd_mention = create_default_optimizer(self.mention_encoder.ops)
# return self.sgd_article
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None): def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
self.require_model() self.require_model()
self.require_kb() self.require_kb()
@ -1166,15 +1145,11 @@ class EntityLinker(Pipe):
mention = doc.text[start:end] mention = doc.text[start:end]
sent_start = 0 sent_start = 0
sent_end = len(doc) sent_end = len(doc)
first_par_end = len(doc)
for index, sent in enumerate(doc.sents): for index, sent in enumerate(doc.sents):
if start >= sent.start_char and end <= sent.end_char: if start >= sent.start_char and end <= sent.end_char:
sent_start = sent.start sent_start = sent.start
sent_end = sent.end sent_end = sent.end
if index == self.doc_cutoff-1:
first_par_end = sent.end
sentence = doc[sent_start:sent_end].as_doc() sentence = doc[sent_start:sent_end].as_doc()
first_par = doc[0:first_par_end].as_doc()
candidates = self.kb.get_candidates(mention) candidates = self.kb.get_candidates(mention)
for c in candidates: for c in candidates:
@ -1184,32 +1159,15 @@ class EntityLinker(Pipe):
prior_prob = c.prior_prob prior_prob = c.prior_prob
entity_encoding = c.entity_vector entity_encoding = c.entity_vector
entity_encodings.append(entity_encoding) entity_encodings.append(entity_encoding)
# article_docs.append(first_par)
sentence_docs.append(sentence) sentence_docs.append(sentence)
if len(entity_encodings) > 0: if len(entity_encodings) > 0:
# doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
# sent_encodings, bp_sent = self.sent_encoder.begin_update(sentence_docs, drop=drop)
# concat_encodings = [list(doc_encodings[i]) + list(sent_encodings[i]) for i in range(len(article_docs))]
# mention_encodings, bp_mention = self.mention_encoder.begin_update(np.asarray(concat_encodings), drop=drop)
sent_encodings, bp_sent = self.model.begin_update(sentence_docs, drop=drop) sent_encodings, bp_sent = self.model.begin_update(sentence_docs, drop=drop)
entity_encodings = np.asarray(entity_encodings, dtype=np.float32) entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
loss, d_scores = self.get_loss(scores=sent_encodings, golds=entity_encodings, docs=None) loss, d_scores = self.get_loss(scores=sent_encodings, golds=entity_encodings, docs=None)
bp_sent(d_scores, sgd=sgd) bp_sent(d_scores, sgd=sgd)
# gradient : concat (doc+sent) vs. desc
# sent_start = self.article_encoder.nO
# sent_gradients = list()
# doc_gradients = list()
# for x in mention_gradient:
# doc_gradients.append(list(x[0:sent_start]))
# sent_gradients.append(list(x[sent_start:]))
# bp_doc(doc_gradients, sgd=self.sgd_article)
# bp_sent(sent_gradients, sgd=self.sgd_sent)
if losses is not None: if losses is not None:
losses[self.name] += loss losses[self.name] += loss
return loss return loss
@ -1264,21 +1222,9 @@ class EntityLinker(Pipe):
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if len(doc) > 0: if len(doc) > 0:
first_par_end = len(doc)
for index, sent in enumerate(doc.sents):
if index == self.doc_cutoff-1:
first_par_end = sent.end
first_par = doc[0:first_par_end].as_doc()
# doc_encoding = self.article_encoder([first_par])
for ent in doc.ents: for ent in doc.ents:
sent_doc = ent.sent.as_doc() sent_doc = ent.sent.as_doc()
if len(sent_doc) > 0: if len(sent_doc) > 0:
# sent_encoding = self.sent_encoder([sent_doc])
# concat_encoding = [list(doc_encoding[0]) + list(sent_encoding[0])]
# mention_encoding = self.mention_encoder(np.asarray([concat_encoding[0]]))
# mention_enc_t = np.transpose(mention_encoding)
sent_encoding = self.model([sent_doc]) sent_encoding = self.model([sent_doc])
sent_enc_t = np.transpose(sent_encoding) sent_enc_t = np.transpose(sent_encoding)