small tweaks and documentation

This commit is contained in:
svlandeg 2019-06-18 18:38:09 +02:00
parent 0d177c1146
commit 478305cd3f
7 changed files with 49 additions and 46 deletions

View File

@ -12,6 +12,10 @@ from thinc.neural._classes.affine import Affine
class EntityEncoder: class EntityEncoder:
"""
Train the embeddings of entity descriptions to fit a fixed-size entity vector (e.g. 64D).
This entity vector will be stored in the KB, and context vectors will be trained to be similar to them.
"""
DROP = 0 DROP = 0
EPOCHS = 5 EPOCHS = 5
@ -102,6 +106,7 @@ class EntityEncoder:
def _build_network(self, orig_width, hidden_with): def _build_network(self, orig_width, hidden_with):
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
# very simple encoder-decoder model
self.encoder = ( self.encoder = (
Affine(hidden_with, orig_width) Affine(hidden_with, orig_width)
) )

View File

@ -10,7 +10,8 @@ from spacy.gold import GoldParse
from bin.wiki_entity_linking import kb_creator, wikipedia_processor as wp from bin.wiki_entity_linking import kb_creator, wikipedia_processor as wp
""" """
Process Wikipedia interlinks to generate a training dataset for the EL algorithm Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
Gold-standard entities are stored in one file in standoff format (by character offset).
""" """
# ENTITY_FILE = "gold_entities.csv" # ENTITY_FILE = "gold_entities.csv"
@ -321,12 +322,16 @@ def read_training(nlp, training_dir, dev, limit):
current_article_id = article_id current_article_id = article_id
ents_by_offset = dict() ents_by_offset = dict()
for ent in current_doc.ents: for ent in current_doc.ents:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent sent_length = len(ent.sent)
# custom filtering to avoid too long or too short sentences
if 5 < sent_length < 100:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
else: else:
skip_articles.add(current_article_id) skip_articles.add(current_article_id)
current_doc = None current_doc = None
except Exception as e: except Exception as e:
print("Problem parsing article", article_id, e) print("Problem parsing article", article_id, e)
skip_articles.add(current_article_id)
# repeat checking this condition in case an exception was thrown # repeat checking this condition in case an exception was thrown
if current_doc and (current_article_id == article_id): if current_doc and (current_article_id == article_id):

View File

@ -10,7 +10,7 @@ WIKIDATA_JSON = 'C:/Users/Sofie/Documents/data/wikidata/wikidata-20190304-all.js
def read_wikidata_entities_json(limit=None, to_print=False): def read_wikidata_entities_json(limit=None, to_print=False):
""" Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """ # Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
lang = 'en' lang = 'en'
site_filter = 'enwiki' site_filter = 'enwiki'

View File

@ -8,6 +8,7 @@ import datetime
""" """
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions. Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
Write these results to file for downstream KB and training data generation.
""" """
@ -142,7 +143,7 @@ def _capitalize_first(text):
def write_entity_counts(prior_prob_input, count_output, to_print=False): def write_entity_counts(prior_prob_input, count_output, to_print=False):
""" Write entity counts for quick access later """ # Write entity counts for quick access later
entity_to_count = dict() entity_to_count = dict()
total_count = 0 total_count = 0

View File

@ -195,10 +195,11 @@ def run_pipeline():
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now()) print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
print() print()
acc_r, acc_r_by_label, acc_p, acc_p_by_label, acc_o, acc_o_by_label = _measure_baselines(dev_data, kb_2) counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2)
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_by_label.items()]) print("dev counts:", sorted(counts))
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_by_label.items()]) print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_by_label.items()]) print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
with el_pipe.model.use_params(optimizer.averages): with el_pipe.model.use_params(optimizer.averages):
# measuring combined accuracy (prior + context) # measuring combined accuracy (prior + context)
@ -288,6 +289,8 @@ def _measure_accuracy(data, el_pipe):
def _measure_baselines(data, kb): def _measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
counts_by_label = dict()
random_correct_by_label = dict() random_correct_by_label = dict()
random_incorrect_by_label = dict() random_incorrect_by_label = dict()
@ -315,6 +318,7 @@ def _measure_baselines(data, kb):
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None: if gold_entity is not None:
counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1
candidates = kb.get_candidates(ent.text) candidates = kb.get_candidates(ent.text)
oracle_candidate = "" oracle_candidate = ""
best_candidate = "" best_candidate = ""
@ -353,7 +357,7 @@ def _measure_baselines(data, kb):
acc_random, acc_random_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label) acc_random, acc_random_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label) acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
return acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label return counts_by_label, acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
def calculate_acc(correct_by_label, incorrect_by_label): def calculate_acc(correct_by_label, incorrect_by_label):

View File

@ -11,7 +11,7 @@ from copy import copy, deepcopy
from thinc.neural import Model from thinc.neural import Model
import srsly import srsly
from spacy.kb import KnowledgeBase from .kb import KnowledgeBase
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .vocab import Vocab from .vocab import Vocab
from .lemmatizer import Lemmatizer from .lemmatizer import Lemmatizer

View File

@ -14,7 +14,6 @@ from thinc.misc import LayerNorm
from thinc.neural.util import to_categorical from thinc.neural.util import to_categorical
from thinc.neural.util import get_array_module from thinc.neural.util import get_array_module
from spacy.kb import KnowledgeBase
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..syntax.nn_parser cimport Parser from ..syntax.nn_parser cimport Parser
from ..syntax.ner cimport BiluoPushDown from ..syntax.ner cimport BiluoPushDown
@ -1081,9 +1080,9 @@ class EntityLinker(Pipe):
hidden_width = cfg.get("hidden_width", 128) hidden_width = cfg.get("hidden_width", 128)
# no default because this needs to correspond with the KB entity length # no default because this needs to correspond with the KB entity length
sent_width = cfg.get("entity_width") entity_width = cfg.get("entity_width")
model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg) model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=entity_width, **cfg)
return model return model
@ -1135,21 +1134,13 @@ class EntityLinker(Pipe):
docs = [docs] docs = [docs]
golds = [golds] golds = [golds]
# article_docs = list() context_docs = list()
sentence_docs = list()
entity_encodings = list() entity_encodings = list()
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
for entity in gold.links: for entity in gold.links:
start, end, gold_kb = entity start, end, gold_kb = entity
mention = doc.text[start:end] mention = doc.text[start:end]
sent_start = 0
sent_end = len(doc)
for index, sent in enumerate(doc.sents):
if start >= sent.start_char and end <= sent.end_char:
sent_start = sent.start
sent_end = sent.end
sentence = doc[sent_start:sent_end].as_doc()
candidates = self.kb.get_candidates(mention) candidates = self.kb.get_candidates(mention)
for c in candidates: for c in candidates:
@ -1159,14 +1150,14 @@ class EntityLinker(Pipe):
prior_prob = c.prior_prob prior_prob = c.prior_prob
entity_encoding = c.entity_vector entity_encoding = c.entity_vector
entity_encodings.append(entity_encoding) entity_encodings.append(entity_encoding)
sentence_docs.append(sentence) context_docs.append(doc)
if len(entity_encodings) > 0: if len(entity_encodings) > 0:
sent_encodings, bp_sent = self.model.begin_update(sentence_docs, drop=drop) context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop)
entity_encodings = np.asarray(entity_encodings, dtype=np.float32) entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
loss, d_scores = self.get_loss(scores=sent_encodings, golds=entity_encodings, docs=None) loss, d_scores = self.get_loss(scores=context_encodings, golds=entity_encodings, docs=None)
bp_sent(d_scores, sgd=sgd) bp_context(d_scores, sgd=sgd)
if losses is not None: if losses is not None:
losses[self.name] += loss losses[self.name] += loss
@ -1222,28 +1213,25 @@ class EntityLinker(Pipe):
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if len(doc) > 0: if len(doc) > 0:
context_encoding = self.model([doc])
context_enc_t = np.transpose(context_encoding)
for ent in doc.ents: for ent in doc.ents:
sent_doc = ent.sent.as_doc() candidates = self.kb.get_candidates(ent.text)
if len(sent_doc) > 0: if candidates:
sent_encoding = self.model([sent_doc]) scores = list()
sent_enc_t = np.transpose(sent_encoding) for c in candidates:
prior_prob = c.prior_prob * self.prior_weight
kb_id = c.entity_
entity_encoding = c.entity_vector
sim = float(cosine(np.asarray([entity_encoding]), context_enc_t)) * self.context_weight
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
scores.append(score)
candidates = self.kb.get_candidates(ent.text) # TODO: thresholding
if candidates: best_index = scores.index(max(scores))
scores = list() best_candidate = candidates[best_index]
for c in candidates: final_entities.append(ent)
prior_prob = c.prior_prob * self.prior_weight final_kb_ids.append(best_candidate.entity_)
kb_id = c.entity_
entity_encoding = c.entity_vector
sim = float(cosine(np.asarray([entity_encoding]), sent_enc_t)) * self.context_weight
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
scores.append(score)
# TODO: thresholding
best_index = scores.index(max(scores))
best_candidate = candidates[best_index]
final_entities.append(ent)
final_kb_ids.append(best_candidate.entity_)
return final_entities, final_kb_ids return final_entities, final_kb_ids