small tweaks and documentation

This commit is contained in:
svlandeg 2019-06-18 18:38:09 +02:00
parent 0d177c1146
commit 478305cd3f
7 changed files with 49 additions and 46 deletions

View File

@ -12,6 +12,10 @@ from thinc.neural._classes.affine import Affine
class EntityEncoder:
"""
Train the embeddings of entity descriptions to fit a fixed-size entity vector (e.g. 64D).
This entity vector will be stored in the KB, and context vectors will be trained to be similar to them.
"""
DROP = 0
EPOCHS = 5
@ -102,6 +106,7 @@ class EntityEncoder:
def _build_network(self, orig_width, hidden_with):
with Model.define_operators({">>": chain}):
# very simple encoder-decoder model
self.encoder = (
Affine(hidden_with, orig_width)
)

View File

@ -10,7 +10,8 @@ from spacy.gold import GoldParse
from bin.wiki_entity_linking import kb_creator, wikipedia_processor as wp
"""
Process Wikipedia interlinks to generate a training dataset for the EL algorithm
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
Gold-standard entities are stored in one file in standoff format (by character offset).
"""
# ENTITY_FILE = "gold_entities.csv"
@ -321,12 +322,16 @@ def read_training(nlp, training_dir, dev, limit):
current_article_id = article_id
ents_by_offset = dict()
for ent in current_doc.ents:
sent_length = len(ent.sent)
# custom filtering to avoid too long or too short sentences
if 5 < sent_length < 100:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
else:
skip_articles.add(current_article_id)
current_doc = None
except Exception as e:
print("Problem parsing article", article_id, e)
skip_articles.add(current_article_id)
# repeat checking this condition in case an exception was thrown
if current_doc and (current_article_id == article_id):

View File

@ -10,7 +10,7 @@ WIKIDATA_JSON = 'C:/Users/Sofie/Documents/data/wikidata/wikidata-20190304-all.js
def read_wikidata_entities_json(limit=None, to_print=False):
""" Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
lang = 'en'
site_filter = 'enwiki'

View File

@ -8,6 +8,7 @@ import datetime
"""
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
Write these results to file for downstream KB and training data generation.
"""
@ -142,7 +143,7 @@ def _capitalize_first(text):
def write_entity_counts(prior_prob_input, count_output, to_print=False):
""" Write entity counts for quick access later """
# Write entity counts for quick access later
entity_to_count = dict()
total_count = 0

View File

@ -195,10 +195,11 @@ def run_pipeline():
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
print()
acc_r, acc_r_by_label, acc_p, acc_p_by_label, acc_o, acc_o_by_label = _measure_baselines(dev_data, kb_2)
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_by_label.items()])
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_by_label.items()])
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_by_label.items()])
counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2)
print("dev counts:", sorted(counts))
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
with el_pipe.model.use_params(optimizer.averages):
# measuring combined accuracy (prior + context)
@ -288,6 +289,8 @@ def _measure_accuracy(data, el_pipe):
def _measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
counts_by_label = dict()
random_correct_by_label = dict()
random_incorrect_by_label = dict()
@ -315,6 +318,7 @@ def _measure_baselines(data, kb):
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
best_candidate = ""
@ -353,7 +357,7 @@ def _measure_baselines(data, kb):
acc_random, acc_random_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
return acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
return counts_by_label, acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
def calculate_acc(correct_by_label, incorrect_by_label):

View File

@ -11,7 +11,7 @@ from copy import copy, deepcopy
from thinc.neural import Model
import srsly
from spacy.kb import KnowledgeBase
from .kb import KnowledgeBase
from .tokenizer import Tokenizer
from .vocab import Vocab
from .lemmatizer import Lemmatizer

View File

@ -14,7 +14,6 @@ from thinc.misc import LayerNorm
from thinc.neural.util import to_categorical
from thinc.neural.util import get_array_module
from spacy.kb import KnowledgeBase
from ..tokens.doc cimport Doc
from ..syntax.nn_parser cimport Parser
from ..syntax.ner cimport BiluoPushDown
@ -1081,9 +1080,9 @@ class EntityLinker(Pipe):
hidden_width = cfg.get("hidden_width", 128)
# no default because this needs to correspond with the KB entity length
sent_width = cfg.get("entity_width")
entity_width = cfg.get("entity_width")
model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=entity_width, **cfg)
return model
@ -1135,21 +1134,13 @@ class EntityLinker(Pipe):
docs = [docs]
golds = [golds]
# article_docs = list()
sentence_docs = list()
context_docs = list()
entity_encodings = list()
for doc, gold in zip(docs, golds):
for entity in gold.links:
start, end, gold_kb = entity
mention = doc.text[start:end]
sent_start = 0
sent_end = len(doc)
for index, sent in enumerate(doc.sents):
if start >= sent.start_char and end <= sent.end_char:
sent_start = sent.start
sent_end = sent.end
sentence = doc[sent_start:sent_end].as_doc()
candidates = self.kb.get_candidates(mention)
for c in candidates:
@ -1159,14 +1150,14 @@ class EntityLinker(Pipe):
prior_prob = c.prior_prob
entity_encoding = c.entity_vector
entity_encodings.append(entity_encoding)
sentence_docs.append(sentence)
context_docs.append(doc)
if len(entity_encodings) > 0:
sent_encodings, bp_sent = self.model.begin_update(sentence_docs, drop=drop)
context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop)
entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
loss, d_scores = self.get_loss(scores=sent_encodings, golds=entity_encodings, docs=None)
bp_sent(d_scores, sgd=sgd)
loss, d_scores = self.get_loss(scores=context_encodings, golds=entity_encodings, docs=None)
bp_context(d_scores, sgd=sgd)
if losses is not None:
losses[self.name] += loss
@ -1222,12 +1213,9 @@ class EntityLinker(Pipe):
for i, doc in enumerate(docs):
if len(doc) > 0:
context_encoding = self.model([doc])
context_enc_t = np.transpose(context_encoding)
for ent in doc.ents:
sent_doc = ent.sent.as_doc()
if len(sent_doc) > 0:
sent_encoding = self.model([sent_doc])
sent_enc_t = np.transpose(sent_encoding)
candidates = self.kb.get_candidates(ent.text)
if candidates:
scores = list()
@ -1235,7 +1223,7 @@ class EntityLinker(Pipe):
prior_prob = c.prior_prob * self.prior_weight
kb_id = c.entity_
entity_encoding = c.entity_vector
sim = float(cosine(np.asarray([entity_encoding]), sent_enc_t)) * self.context_weight
sim = float(cosine(np.asarray([entity_encoding]), context_enc_t)) * self.context_weight
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
scores.append(score)