mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
small tweaks and documentation
This commit is contained in:
parent
0d177c1146
commit
478305cd3f
|
@ -12,6 +12,10 @@ from thinc.neural._classes.affine import Affine
|
|||
|
||||
|
||||
class EntityEncoder:
|
||||
"""
|
||||
Train the embeddings of entity descriptions to fit a fixed-size entity vector (e.g. 64D).
|
||||
This entity vector will be stored in the KB, and context vectors will be trained to be similar to them.
|
||||
"""
|
||||
|
||||
DROP = 0
|
||||
EPOCHS = 5
|
||||
|
@ -102,6 +106,7 @@ class EntityEncoder:
|
|||
|
||||
def _build_network(self, orig_width, hidden_with):
|
||||
with Model.define_operators({">>": chain}):
|
||||
# very simple encoder-decoder model
|
||||
self.encoder = (
|
||||
Affine(hidden_with, orig_width)
|
||||
)
|
||||
|
|
|
@ -10,7 +10,8 @@ from spacy.gold import GoldParse
|
|||
from bin.wiki_entity_linking import kb_creator, wikipedia_processor as wp
|
||||
|
||||
"""
|
||||
Process Wikipedia interlinks to generate a training dataset for the EL algorithm
|
||||
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
|
||||
Gold-standard entities are stored in one file in standoff format (by character offset).
|
||||
"""
|
||||
|
||||
# ENTITY_FILE = "gold_entities.csv"
|
||||
|
@ -321,12 +322,16 @@ def read_training(nlp, training_dir, dev, limit):
|
|||
current_article_id = article_id
|
||||
ents_by_offset = dict()
|
||||
for ent in current_doc.ents:
|
||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
||||
sent_length = len(ent.sent)
|
||||
# custom filtering to avoid too long or too short sentences
|
||||
if 5 < sent_length < 100:
|
||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
||||
else:
|
||||
skip_articles.add(current_article_id)
|
||||
current_doc = None
|
||||
except Exception as e:
|
||||
print("Problem parsing article", article_id, e)
|
||||
skip_articles.add(current_article_id)
|
||||
|
||||
# repeat checking this condition in case an exception was thrown
|
||||
if current_doc and (current_article_id == article_id):
|
||||
|
|
|
@ -10,7 +10,7 @@ WIKIDATA_JSON = 'C:/Users/Sofie/Documents/data/wikidata/wikidata-20190304-all.js
|
|||
|
||||
|
||||
def read_wikidata_entities_json(limit=None, to_print=False):
|
||||
""" Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. """
|
||||
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
|
||||
|
||||
lang = 'en'
|
||||
site_filter = 'enwiki'
|
||||
|
|
|
@ -8,6 +8,7 @@ import datetime
|
|||
|
||||
"""
|
||||
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
|
||||
Write these results to file for downstream KB and training data generation.
|
||||
"""
|
||||
|
||||
|
||||
|
@ -142,7 +143,7 @@ def _capitalize_first(text):
|
|||
|
||||
|
||||
def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||
""" Write entity counts for quick access later """
|
||||
# Write entity counts for quick access later
|
||||
entity_to_count = dict()
|
||||
total_count = 0
|
||||
|
||||
|
|
|
@ -195,10 +195,11 @@ def run_pipeline():
|
|||
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
||||
print()
|
||||
|
||||
acc_r, acc_r_by_label, acc_p, acc_p_by_label, acc_o, acc_o_by_label = _measure_baselines(dev_data, kb_2)
|
||||
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_by_label.items()])
|
||||
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_by_label.items()])
|
||||
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_by_label.items()])
|
||||
counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2)
|
||||
print("dev counts:", sorted(counts))
|
||||
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
|
||||
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
|
||||
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
|
||||
|
||||
with el_pipe.model.use_params(optimizer.averages):
|
||||
# measuring combined accuracy (prior + context)
|
||||
|
@ -288,6 +289,8 @@ def _measure_accuracy(data, el_pipe):
|
|||
|
||||
def _measure_baselines(data, kb):
|
||||
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
||||
counts_by_label = dict()
|
||||
|
||||
random_correct_by_label = dict()
|
||||
random_incorrect_by_label = dict()
|
||||
|
||||
|
@ -315,6 +318,7 @@ def _measure_baselines(data, kb):
|
|||
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1
|
||||
candidates = kb.get_candidates(ent.text)
|
||||
oracle_candidate = ""
|
||||
best_candidate = ""
|
||||
|
@ -353,7 +357,7 @@ def _measure_baselines(data, kb):
|
|||
acc_random, acc_random_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
|
||||
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
|
||||
|
||||
return acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
|
||||
return counts_by_label, acc_random, acc_random_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
|
||||
|
||||
|
||||
def calculate_acc(correct_by_label, incorrect_by_label):
|
||||
|
|
|
@ -11,7 +11,7 @@ from copy import copy, deepcopy
|
|||
from thinc.neural import Model
|
||||
import srsly
|
||||
|
||||
from spacy.kb import KnowledgeBase
|
||||
from .kb import KnowledgeBase
|
||||
from .tokenizer import Tokenizer
|
||||
from .vocab import Vocab
|
||||
from .lemmatizer import Lemmatizer
|
||||
|
|
|
@ -14,7 +14,6 @@ from thinc.misc import LayerNorm
|
|||
from thinc.neural.util import to_categorical
|
||||
from thinc.neural.util import get_array_module
|
||||
|
||||
from spacy.kb import KnowledgeBase
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..syntax.nn_parser cimport Parser
|
||||
from ..syntax.ner cimport BiluoPushDown
|
||||
|
@ -1081,9 +1080,9 @@ class EntityLinker(Pipe):
|
|||
hidden_width = cfg.get("hidden_width", 128)
|
||||
|
||||
# no default because this needs to correspond with the KB entity length
|
||||
sent_width = cfg.get("entity_width")
|
||||
entity_width = cfg.get("entity_width")
|
||||
|
||||
model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=sent_width, **cfg)
|
||||
model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=entity_width, **cfg)
|
||||
|
||||
return model
|
||||
|
||||
|
@ -1135,21 +1134,13 @@ class EntityLinker(Pipe):
|
|||
docs = [docs]
|
||||
golds = [golds]
|
||||
|
||||
# article_docs = list()
|
||||
sentence_docs = list()
|
||||
context_docs = list()
|
||||
entity_encodings = list()
|
||||
|
||||
for doc, gold in zip(docs, golds):
|
||||
for entity in gold.links:
|
||||
start, end, gold_kb = entity
|
||||
mention = doc.text[start:end]
|
||||
sent_start = 0
|
||||
sent_end = len(doc)
|
||||
for index, sent in enumerate(doc.sents):
|
||||
if start >= sent.start_char and end <= sent.end_char:
|
||||
sent_start = sent.start
|
||||
sent_end = sent.end
|
||||
sentence = doc[sent_start:sent_end].as_doc()
|
||||
|
||||
candidates = self.kb.get_candidates(mention)
|
||||
for c in candidates:
|
||||
|
@ -1159,14 +1150,14 @@ class EntityLinker(Pipe):
|
|||
prior_prob = c.prior_prob
|
||||
entity_encoding = c.entity_vector
|
||||
entity_encodings.append(entity_encoding)
|
||||
sentence_docs.append(sentence)
|
||||
context_docs.append(doc)
|
||||
|
||||
if len(entity_encodings) > 0:
|
||||
sent_encodings, bp_sent = self.model.begin_update(sentence_docs, drop=drop)
|
||||
context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop)
|
||||
entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
|
||||
|
||||
loss, d_scores = self.get_loss(scores=sent_encodings, golds=entity_encodings, docs=None)
|
||||
bp_sent(d_scores, sgd=sgd)
|
||||
loss, d_scores = self.get_loss(scores=context_encodings, golds=entity_encodings, docs=None)
|
||||
bp_context(d_scores, sgd=sgd)
|
||||
|
||||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
|
@ -1222,28 +1213,25 @@ class EntityLinker(Pipe):
|
|||
|
||||
for i, doc in enumerate(docs):
|
||||
if len(doc) > 0:
|
||||
context_encoding = self.model([doc])
|
||||
context_enc_t = np.transpose(context_encoding)
|
||||
for ent in doc.ents:
|
||||
sent_doc = ent.sent.as_doc()
|
||||
if len(sent_doc) > 0:
|
||||
sent_encoding = self.model([sent_doc])
|
||||
sent_enc_t = np.transpose(sent_encoding)
|
||||
candidates = self.kb.get_candidates(ent.text)
|
||||
if candidates:
|
||||
scores = list()
|
||||
for c in candidates:
|
||||
prior_prob = c.prior_prob * self.prior_weight
|
||||
kb_id = c.entity_
|
||||
entity_encoding = c.entity_vector
|
||||
sim = float(cosine(np.asarray([entity_encoding]), context_enc_t)) * self.context_weight
|
||||
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
|
||||
scores.append(score)
|
||||
|
||||
candidates = self.kb.get_candidates(ent.text)
|
||||
if candidates:
|
||||
scores = list()
|
||||
for c in candidates:
|
||||
prior_prob = c.prior_prob * self.prior_weight
|
||||
kb_id = c.entity_
|
||||
entity_encoding = c.entity_vector
|
||||
sim = float(cosine(np.asarray([entity_encoding]), sent_enc_t)) * self.context_weight
|
||||
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
|
||||
scores.append(score)
|
||||
|
||||
# TODO: thresholding
|
||||
best_index = scores.index(max(scores))
|
||||
best_candidate = candidates[best_index]
|
||||
final_entities.append(ent)
|
||||
final_kb_ids.append(best_candidate.entity_)
|
||||
# TODO: thresholding
|
||||
best_index = scores.index(max(scores))
|
||||
best_candidate = candidates[best_index]
|
||||
final_entities.append(ent)
|
||||
final_kb_ids.append(best_candidate.entity_)
|
||||
|
||||
return final_entities, final_kb_ids
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user