redo training data to be independent of KB and entity-level instead of doc-level

This commit is contained in:
svlandeg 2019-06-14 15:55:26 +02:00
parent 0b04d142de
commit b312f2d0e7
4 changed files with 179 additions and 182 deletions

View File

@ -1,8 +1,6 @@
# coding: utf-8 # coding: utf-8
from random import shuffle from random import shuffle
from examples.pipeline.wiki_entity_linking import kb_creator
import numpy as np import numpy as np
from spacy._ml import zero_init, create_default_optimizer from spacy._ml import zero_init, create_default_optimizer

View File

@ -19,17 +19,15 @@ Process Wikipedia interlinks to generate a training dataset for the EL algorithm
ENTITY_FILE = "gold_entities.csv" ENTITY_FILE = "gold_entities.csv"
def create_training(kb, entity_def_input, training_output): def create_training(entity_def_input, training_output):
if not kb:
raise ValueError("kb should be defined")
wp_to_id = kb_creator._get_entity_to_id(entity_def_input) wp_to_id = kb_creator._get_entity_to_id(entity_def_input)
_process_wikipedia_texts(kb, wp_to_id, training_output, limit=100000000) # TODO: full dataset _process_wikipedia_texts(wp_to_id, training_output, limit=100000000) # TODO: full dataset 100000000
def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None): def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
""" """
Read the XML wikipedia data to parse out training data: Read the XML wikipedia data to parse out training data:
raw text data + positive and negative instances raw text data + positive instances
""" """
title_regex = re.compile(r'(?<=<title>).*(?=</title>)') title_regex = re.compile(r'(?<=<title>).*(?=</title>)')
@ -43,8 +41,9 @@ def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None):
_write_training_entity(outputfile=entityfile, _write_training_entity(outputfile=entityfile,
article_id="article_id", article_id="article_id",
alias="alias", alias="alias",
entity="entity", entity="WD_id",
correct="correct") start="start",
end="end")
with bz2.open(wp.ENWIKI_DUMP, mode='rb') as file: with bz2.open(wp.ENWIKI_DUMP, mode='rb') as file:
line = file.readline() line = file.readline()
@ -75,14 +74,11 @@ def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None):
elif clean_line == "</page>": elif clean_line == "</page>":
if article_id: if article_id:
try: try:
_process_wp_text(kb, wp_to_id, entityfile, article_id, article_text.strip(), training_output) _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(), training_output)
# on a previous run, an error occurred after 46M lines and 2h
except Exception as e: except Exception as e:
print("Error processing article", article_id, article_title, e) print("Error processing article", article_id, article_title, e)
else: else:
print("Done processing a page, but couldn't find an article_id ?") print("Done processing a page, but couldn't find an article_id ?", article_title)
print(article_title)
print(article_text)
article_text = "" article_text = ""
article_title = None article_title = None
article_id = None article_id = None
@ -122,7 +118,14 @@ def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None):
text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)') text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)')
def _process_wp_text(kb, wp_to_id, entityfile, article_id, article_text, training_output): def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output):
found_entities = False
# print("Processing", article_id, article_title)
# ignore meta Wikipedia pages
if article_title.startswith("Wikipedia:"):
return
# remove the text tags # remove the text tags
text = text_regex.search(article_text).group(0) text = text_regex.search(article_text).group(0)
@ -130,67 +133,91 @@ def _process_wp_text(kb, wp_to_id, entityfile, article_id, article_text, trainin
if text.startswith("#REDIRECT"): if text.startswith("#REDIRECT"):
return return
# print("WP article", article_id, ":", article_title)
# print() # print()
# print(text) # print(text)
# get the raw text without markup etc # get the raw text without markup etc, keeping only interwiki links
clean_text = _get_clean_wp_text(text) clean_text = _get_clean_wp_text(text)
# print() # print()
# print(clean_text) # print(clean_text)
article_dict = dict() # read the text char by char to get the right offsets of the interwiki links
ambiguous_aliases = set() final_text = ""
aliases, entities, normalizations = wp.get_wp_links(text) open_read = 0
for alias, entity, norm in zip(aliases, entities, normalizations): reading_text = True
if alias not in ambiguous_aliases: reading_entity = False
entity_id = wp_to_id.get(entity) reading_mention = False
if entity_id: reading_special_case = False
# TODO: take care of these conflicts ! Currently they are being removed from the dataset entity_buffer = ""
if article_dict.get(alias) and article_dict[alias] != entity_id: mention_buffer = ""
ambiguous_aliases.add(alias) for index, letter in enumerate(clean_text):
article_dict.pop(alias) if letter == '[':
# print("Found conflicting alias", alias, "in article", article_id, article_title) open_read += 1
else: elif letter == ']':
article_dict[alias] = entity_id open_read -= 1
elif letter == '|':
if reading_text:
final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern
elif reading_entity:
reading_text = False
reading_entity = False
reading_mention = True
else:
reading_special_case = True
else:
if reading_entity:
entity_buffer += letter
elif reading_mention:
mention_buffer += letter
elif reading_text:
final_text += letter
else:
raise ValueError("Not sure at point", clean_text[index-2:index+2])
# print("found entities:") if open_read > 2:
for alias, entity in article_dict.items(): reading_special_case = True
# print(alias, "-->", entity)
candidates = kb.get_candidates(alias)
# as training data, we only store entities that are sufficiently ambiguous if open_read == 2 and reading_text:
if len(candidates) > 1: reading_text = False
_write_training_article(article_id=article_id, clean_text=clean_text, training_output=training_output) reading_entity = True
# print("alias", alias) reading_mention = False
# print all incorrect candidates # we just finished reading an entity
for c in candidates: if open_read == 0 and not reading_text:
if entity != c.entity_: if '#' in entity_buffer or entity_buffer.startswith(':'):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
if not mention_buffer:
mention_buffer = entity_buffer
start = len(final_text)
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
_write_training_entity(outputfile=entityfile, _write_training_entity(outputfile=entityfile,
article_id=article_id, article_id=article_id,
alias=alias, alias=mention_buffer,
entity=c.entity_, entity=qid,
correct="0") start=start,
end=end)
found_entities = True
final_text += mention_buffer
# print the one correct candidate entity_buffer = ""
_write_training_entity(outputfile=entityfile, mention_buffer = ""
article_id=article_id,
alias=alias,
entity=entity,
correct="1")
# print("gold entity", entity) reading_text = True
# print() reading_entity = False
reading_mention = False
reading_special_case = False
# _run_ner_depr(nlp, clean_text, article_dict) if found_entities:
# print() _write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output)
info_regex = re.compile(r'{[^{]*?}') info_regex = re.compile(r'{[^{]*?}')
interwiki_regex = re.compile(r'\[\[([^|]*?)]]') htlm_regex = re.compile(r'&lt;!--[^-]*--&gt;')
interwiki_2_regex = re.compile(r'\[\[[^|]*?\|([^|]*?)]]')
htlm_regex = re.compile(r'&lt;!--[^!]*--&gt;')
category_regex = re.compile(r'\[\[Category:[^\[]*]]') category_regex = re.compile(r'\[\[Category:[^\[]*]]')
file_regex = re.compile(r'\[\[File:[^[\]]+]]') file_regex = re.compile(r'\[\[File:[^[\]]+]]')
ref_regex = re.compile(r'&lt;ref.*?&gt;') # non-greedy ref_regex = re.compile(r'&lt;ref.*?&gt;') # non-greedy
@ -215,12 +242,6 @@ def _get_clean_wp_text(article_text):
try_again = False try_again = False
previous_length = len(clean_text) previous_length = len(clean_text)
# remove simple interwiki links (no alternative name)
clean_text = interwiki_regex.sub(r'\1', clean_text)
# remove simple interwiki links by picking the alternative name
clean_text = interwiki_2_regex.sub(r'\1', clean_text)
# remove HTML comments # remove HTML comments
clean_text = htlm_regex.sub('', clean_text) clean_text = htlm_regex.sub('', clean_text)
@ -265,43 +286,34 @@ def _write_training_article(article_id, clean_text, training_output):
outputfile.write(clean_text) outputfile.write(clean_text)
def _write_training_entity(outputfile, article_id, alias, entity, correct): def _write_training_entity(outputfile, article_id, alias, entity, start, end):
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + correct + "\n") outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n")
def read_training_entities(training_output, collect_correct=True, collect_incorrect=False): def read_training_entities(training_output):
entityfile_loc = training_output + "/" + ENTITY_FILE entityfile_loc = training_output + "/" + ENTITY_FILE
incorrect_entries_per_article = dict() entries_per_article = dict()
correct_entries_per_article = dict()
with open(entityfile_loc, mode='r', encoding='utf8') as file: with open(entityfile_loc, mode='r', encoding='utf8') as file:
for line in file: for line in file:
fields = line.replace('\n', "").split(sep='|') fields = line.replace('\n', "").split(sep='|')
article_id = fields[0] article_id = fields[0]
alias = fields[1] alias = fields[1]
entity = fields[2] wp_title = fields[2]
correct = fields[3] start = fields[3]
end = fields[4]
if correct == "1" and collect_correct: entries_by_offset = entries_per_article.get(article_id, dict())
entry_dict = correct_entries_per_article.get(article_id, dict()) entries_by_offset[start + "-" + end] = (alias, wp_title)
if alias in entry_dict:
raise ValueError("Found alias", alias, "multiple times for article", article_id, "in", ENTITY_FILE)
entry_dict[alias] = entity
correct_entries_per_article[article_id] = entry_dict
if correct == "0" and collect_incorrect: entries_per_article[article_id] = entries_by_offset
entry_dict = incorrect_entries_per_article.get(article_id, dict())
entities = entry_dict.get(alias, set())
entities.add(entity)
entry_dict[alias] = entities
incorrect_entries_per_article[article_id] = entry_dict
return correct_entries_per_article, incorrect_entries_per_article return entries_per_article
def read_training(nlp, training_dir, dev, limit, to_print): def read_training(nlp, training_dir, dev, limit, to_print):
correct_entries, incorrect_entries = read_training_entities(training_output=training_dir, # This method will provide training examples that correspond to the entity annotations found by the nlp object
collect_correct=True, entries_per_article = read_training_entities(training_output=training_dir)
collect_incorrect=True)
data = [] data = []
@ -320,36 +332,33 @@ def read_training(nlp, training_dir, dev, limit, to_print):
text = file.read() text = file.read()
article_doc = nlp(text) article_doc = nlp(text)
entries_by_offset = entries_per_article.get(article_id, dict())
gold_entities = list() gold_entities = list()
for ent in article_doc.ents:
start = ent.start_char
end = ent.end_char
# process all positive and negative entities, collect all relevant mentions in this article entity_tuple = entries_by_offset.get(str(start) + "-" + str(end), None)
for mention, entity_pos in correct_entries[article_id].items(): if entity_tuple:
# find all matches in the doc for the mentions alias, wp_title = entity_tuple
# TODO: fix this - doesn't look like all entities are found if ent.text != alias:
matcher = PhraseMatcher(nlp.vocab) print("Non-matching entity in", article_id, start, end)
patterns = list(nlp.tokenizer.pipe([mention])) else:
gold_entities.append((start, end, wp_title))
matcher.add("TerminologyList", None, *patterns) if gold_entities:
matches = matcher(article_doc) gold = GoldParse(doc=article_doc, links=gold_entities)
data.append((article_doc, gold))
# store gold entities
for match_id, start, end in matches:
gold_entities.append((start, end, entity_pos))
gold = GoldParse(doc=article_doc, links=gold_entities)
data.append((article_doc, gold))
cnt += 1 cnt += 1
except Exception as e: except Exception as e:
print("Problem parsing article", article_id) print("Problem parsing article", article_id)
print(e) print(e)
raise e
if to_print: if to_print:
print() print()
print("Processed", cnt, "training articles, dev=" + str(dev)) print("Processed", cnt, "training articles, dev=" + str(dev))
print() print()
return data return data

View File

@ -30,8 +30,8 @@ TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
MAX_CANDIDATES = 10 MAX_CANDIDATES = 10
MIN_PAIR_OCC = 5 MIN_PAIR_OCC = 5
DOC_CHAR_CUTOFF = 300 DOC_SENT_CUTOFF = 2
EPOCHS = 2 EPOCHS = 10
DROPOUT = 0.1 DROPOUT = 0.1
@ -46,14 +46,14 @@ def run_pipeline():
# one-time methods to create KB and write to file # one-time methods to create KB and write to file
to_create_prior_probs = False to_create_prior_probs = False
to_create_entity_counts = False to_create_entity_counts = False
to_create_kb = True to_create_kb = False # TODO: entity_defs should also contain entities not in the KB
# read KB back in from file # read KB back in from file
to_read_kb = True to_read_kb = False
to_test_kb = True to_test_kb = False
# create training dataset # create training dataset
create_wp_training = False create_wp_training = True
# train the EL pipe # train the EL pipe
train_pipe = False train_pipe = False
@ -103,9 +103,6 @@ def run_pipeline():
# STEP 4 : read KB back in from file # STEP 4 : read KB back in from file
if to_read_kb: if to_read_kb:
print("STEP 4: to_read_kb", datetime.datetime.now()) print("STEP 4: to_read_kb", datetime.datetime.now())
# my_vocab = Vocab()
# my_vocab.from_disk(VOCAB_DIR)
# my_kb = KnowledgeBase(vocab=my_vocab, entity_vector_length=64)
nlp_2 = spacy.load(NLP_1_DIR) nlp_2 = spacy.load(NLP_1_DIR)
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH) kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
kb_2.load_bulk(KB_FILE) kb_2.load_bulk(KB_FILE)
@ -121,13 +118,13 @@ def run_pipeline():
# STEP 5: create a training dataset from WP # STEP 5: create a training dataset from WP
if create_wp_training: if create_wp_training:
print("STEP 5: create training dataset", datetime.datetime.now()) print("STEP 5: create training dataset", datetime.datetime.now())
training_set_creator.create_training(kb=kb_2, entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR) training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
# STEP 6: create the entity linking pipe # STEP 6: create the entity linking pipe
if train_pipe: if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
train_limit = 10 train_limit = 50
dev_limit = 5 dev_limit = 10
print("Training on", train_limit, "articles") print("Training on", train_limit, "articles")
print("Dev testing on", dev_limit, "articles") print("Dev testing on", dev_limit, "articles")
print() print()
@ -144,7 +141,7 @@ def run_pipeline():
limit=dev_limit, limit=dev_limit,
to_print=False) to_print=False)
el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_CHAR_CUTOFF}) el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_SENT_CUTOFF})
el_pipe.set_kb(kb_2) el_pipe.set_kb(kb_2)
nlp_2.add_pipe(el_pipe, last=True) nlp_2.add_pipe(el_pipe, last=True)
@ -199,10 +196,14 @@ def run_pipeline():
el_pipe.prior_weight = 0 el_pipe.prior_weight = 0
dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe) dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe)
train_acc_1_0 = _measure_accuracy(train_data, el_pipe) train_acc_1_0 = _measure_accuracy(train_data, el_pipe)
print("train/dev acc context:", round(train_acc_1_0, 2), round(dev_acc_1_0, 2)) print("train/dev acc context:", round(train_acc_1_0, 2), round(dev_acc_1_0, 2))
print() print()
# reset for follow-up tests
el_pipe.context_weight = 1
el_pipe.prior_weight = 1
if to_test_pipeline: if to_test_pipeline:
print() print()
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now()) print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
@ -215,17 +216,10 @@ def run_pipeline():
print("STEP 9: testing NLP IO", datetime.datetime.now()) print("STEP 9: testing NLP IO", datetime.datetime.now())
print() print()
print("writing to", NLP_2_DIR) print("writing to", NLP_2_DIR)
print(" vocab len nlp_2", len(nlp_2.vocab))
print(" vocab len kb_2", len(kb_2.vocab))
nlp_2.to_disk(NLP_2_DIR) nlp_2.to_disk(NLP_2_DIR)
print() print()
print("reading from", NLP_2_DIR) print("reading from", NLP_2_DIR)
nlp_3 = spacy.load(NLP_2_DIR) nlp_3 = spacy.load(NLP_2_DIR)
print(" vocab len nlp_3", len(nlp_3.vocab))
for pipe_name, pipe in nlp_3.pipeline:
if pipe_name == "entity_linker":
print(" vocab len kb_3", len(pipe.kb.vocab))
print() print()
print("running toy example with NLP 2") print("running toy example with NLP 2")
@ -253,9 +247,10 @@ def _measure_accuracy(data, el_pipe):
for ent in doc.ents: for ent in doc.ents:
if ent.label_ == "PERSON": # TODO: expand to other types if ent.label_ == "PERSON": # TODO: expand to other types
pred_entity = ent.kb_id_ pred_entity = ent.kb_id_
start = ent.start start = ent.start_char
end = ent.end end = ent.end_char
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None: if gold_entity is not None:
if gold_entity == pred_entity: if gold_entity == pred_entity:
correct += 1 correct += 1
@ -285,7 +280,8 @@ def run_el_toy_example(nlp):
print() print()
# Q4426480 is her husband, Q3568763 her tutor # Q4426480 is her husband, Q3568763 her tutor
text = "Ada Lovelace loved her husband William King dearly. " \ text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine."\
"Ada Lovelace loved her husband William King dearly. " \
"Ada Lovelace was tutored by her favorite physics tutor William King." "Ada Lovelace was tutored by her favorite physics tutor William King."
doc = nlp(text) doc = nlp(text)

View File

@ -1074,6 +1074,9 @@ class EntityLinker(Pipe):
@classmethod @classmethod
def Model(cls, **cfg): def Model(cls, **cfg):
if "entity_width" not in cfg:
raise ValueError("entity_width not found")
embed_width = cfg.get("embed_width", 300) embed_width = cfg.get("embed_width", 300)
hidden_width = cfg.get("hidden_width", 32) hidden_width = cfg.get("hidden_width", 32)
article_width = cfg.get("article_width", 128) article_width = cfg.get("article_width", 128)
@ -1095,7 +1098,10 @@ class EntityLinker(Pipe):
self.mention_encoder = True self.mention_encoder = True
self.kb = None self.kb = None
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.doc_cutoff = self.cfg.get("doc_cutoff", 150) self.doc_cutoff = self.cfg.get("doc_cutoff", 5)
self.sgd_article = None
self.sgd_sent = None
self.sgd_mention = None
def set_kb(self, kb): def set_kb(self, kb):
self.kb = kb self.kb = kb
@ -1126,6 +1132,12 @@ class EntityLinker(Pipe):
self.require_model() self.require_model()
self.require_kb() self.require_kb()
if losses is not None:
losses.setdefault(self.name, 0.0)
if not docs or not golds:
return 0
if len(docs) != len(golds): if len(docs) != len(golds):
raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs), raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs),
n_golds=len(golds))) n_golds=len(golds)))
@ -1141,21 +1153,30 @@ class EntityLinker(Pipe):
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
for entity in gold.links: for entity in gold.links:
start, end, gold_kb = entity start, end, gold_kb = entity
mention = doc[start:end] mention = doc.text[start:end]
sentence = mention.sent sent_start = 0
first_par = doc[0:self.doc_cutoff].as_doc() sent_end = len(doc)
first_par_end = len(doc)
for index, sent in enumerate(doc.sents):
if start >= sent.start_char and end <= sent.end_char:
sent_start = sent.start
sent_end = sent.end
if index == self.doc_cutoff-1:
first_par_end = sent.end
sentence = doc[sent_start:sent_end].as_doc()
first_par = doc[0:first_par_end].as_doc()
candidates = self.kb.get_candidates(mention.text) candidates = self.kb.get_candidates(mention)
for c in candidates: for c in candidates:
kb_id = c.entity_ kb_id = c.entity_
# TODO: currently only training on the positive instances # Currently only training on the positive instances
if kb_id == gold_kb: if kb_id == gold_kb:
prior_prob = c.prior_prob prior_prob = c.prior_prob
entity_encoding = c.entity_vector entity_encoding = c.entity_vector
entity_encodings.append(entity_encoding) entity_encodings.append(entity_encoding)
article_docs.append(first_par) article_docs.append(first_par)
sentence_docs.append(sentence.as_doc()) sentence_docs.append(sentence)
if len(entity_encodings) > 0: if len(entity_encodings) > 0:
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop) doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
@ -1168,11 +1189,6 @@ class EntityLinker(Pipe):
entity_encodings = np.asarray(entity_encodings, dtype=np.float32) entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None) loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
# print("scores", mention_encodings)
# print("golds", entity_encodings)
# print("loss", loss)
# print("d_scores", d_scores)
mention_gradient = bp_mention(d_scores, sgd=self.sgd_mention) mention_gradient = bp_mention(d_scores, sgd=self.sgd_mention)
# gradient : concat (doc+sent) vs. desc # gradient : concat (doc+sent) vs. desc
@ -1187,7 +1203,6 @@ class EntityLinker(Pipe):
bp_sent(sent_gradients, sgd=self.sgd_sent) bp_sent(sent_gradients, sgd=self.sgd_sent)
if losses is not None: if losses is not None:
losses.setdefault(self.name, 0.0)
losses[self.name] += loss losses[self.name] += loss
return loss return loss
@ -1230,16 +1245,25 @@ class EntityLinker(Pipe):
self.require_model() self.require_model()
self.require_kb() self.require_kb()
if isinstance(docs, Doc):
docs = [docs]
final_entities = list() final_entities = list()
final_kb_ids = list() final_kb_ids = list()
for i, article_doc in enumerate(docs): if not docs:
if len(article_doc) > 0: return final_entities, final_kb_ids
doc_encoding = self.article_encoder([article_doc])
for ent in article_doc.ents: if isinstance(docs, Doc):
docs = [docs]
for i, doc in enumerate(docs):
if len(doc) > 0:
first_par_end = len(doc)
for index, sent in enumerate(doc.sents):
if index == self.doc_cutoff-1:
first_par_end = sent.end
first_par = doc[0:first_par_end].as_doc()
doc_encoding = self.article_encoder([first_par])
for ent in doc.ents:
sent_doc = ent.sent.as_doc() sent_doc = ent.sent.as_doc()
if len(sent_doc) > 0: if len(sent_doc) > 0:
sent_encoding = self.sent_encoder([sent_doc]) sent_encoding = self.sent_encoder([sent_doc])
@ -1254,7 +1278,7 @@ class EntityLinker(Pipe):
prior_prob = c.prior_prob * self.prior_weight prior_prob = c.prior_prob * self.prior_weight
kb_id = c.entity_ kb_id = c.entity_
entity_encoding = c.entity_vector entity_encoding = c.entity_vector
sim = cosine(np.asarray([entity_encoding]), mention_enc_t) * self.context_weight sim = float(cosine(np.asarray([entity_encoding]), mention_enc_t)) * self.context_weight
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ? score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
scores.append(score) scores.append(score)
@ -1271,36 +1295,7 @@ class EntityLinker(Pipe):
for token in entity: for token in entity:
token.ent_kb_id_ = kb_id token.ent_kb_id_ = kb_id
def to_bytes(self, exclude=tuple(), **kwargs):
"""Serialize the pipe to a bytestring.
exclude (list): String names of serialization fields to exclude.
RETURNS (bytes): The serialized object.
"""
serialize = OrderedDict()
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
serialize["kb"] = self.kb.to_bytes # TODO
if self.mention_encoder not in (True, False, None):
serialize["article_encoder"] = self.article_encoder.to_bytes
serialize["sent_encoder"] = self.sent_encoder.to_bytes
serialize["mention_encoder"] = self.mention_encoder.to_bytes
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
"""Load the pipe from a bytestring."""
deserialize = OrderedDict()
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
deserialize["kb"] = lambda b: self.kb.from_bytes(b) # TODO
deserialize["article_encoder"] = lambda b: self.article_encoder.from_bytes(b)
deserialize["sent_encoder"] = lambda b: self.sent_encoder.from_bytes(b)
deserialize["mention_encoder"] = lambda b: self.mention_encoder.from_bytes(b)
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
util.from_bytes(bytes_data, deserialize, exclude)
return self
def to_disk(self, path, exclude=tuple(), **kwargs): def to_disk(self, path, exclude=tuple(), **kwargs):
"""Serialize the pipe to disk."""
serialize = OrderedDict() serialize = OrderedDict()
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["kb"] = lambda p: self.kb.dump(p) serialize["kb"] = lambda p: self.kb.dump(p)
@ -1312,7 +1307,6 @@ class EntityLinker(Pipe):
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
def from_disk(self, path, exclude=tuple(), **kwargs): def from_disk(self, path, exclude=tuple(), **kwargs):
"""Load the pipe from disk."""
def load_article_encoder(p): def load_article_encoder(p):
if self.article_encoder is True: if self.article_encoder is True:
self.article_encoder, _, _ = self.Model(**self.cfg) self.article_encoder, _, _ = self.Model(**self.cfg)