mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
redo training data to be independent of KB and entity-level instead of doc-level
This commit is contained in:
parent
0b04d142de
commit
b312f2d0e7
|
@ -1,8 +1,6 @@
|
|||
# coding: utf-8
|
||||
from random import shuffle
|
||||
|
||||
from examples.pipeline.wiki_entity_linking import kb_creator
|
||||
|
||||
import numpy as np
|
||||
|
||||
from spacy._ml import zero_init, create_default_optimizer
|
||||
|
|
|
@ -19,17 +19,15 @@ Process Wikipedia interlinks to generate a training dataset for the EL algorithm
|
|||
ENTITY_FILE = "gold_entities.csv"
|
||||
|
||||
|
||||
def create_training(kb, entity_def_input, training_output):
|
||||
if not kb:
|
||||
raise ValueError("kb should be defined")
|
||||
def create_training(entity_def_input, training_output):
|
||||
wp_to_id = kb_creator._get_entity_to_id(entity_def_input)
|
||||
_process_wikipedia_texts(kb, wp_to_id, training_output, limit=100000000) # TODO: full dataset
|
||||
_process_wikipedia_texts(wp_to_id, training_output, limit=100000000) # TODO: full dataset 100000000
|
||||
|
||||
|
||||
def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None):
|
||||
def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
|
||||
"""
|
||||
Read the XML wikipedia data to parse out training data:
|
||||
raw text data + positive and negative instances
|
||||
raw text data + positive instances
|
||||
"""
|
||||
|
||||
title_regex = re.compile(r'(?<=<title>).*(?=</title>)')
|
||||
|
@ -43,8 +41,9 @@ def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None):
|
|||
_write_training_entity(outputfile=entityfile,
|
||||
article_id="article_id",
|
||||
alias="alias",
|
||||
entity="entity",
|
||||
correct="correct")
|
||||
entity="WD_id",
|
||||
start="start",
|
||||
end="end")
|
||||
|
||||
with bz2.open(wp.ENWIKI_DUMP, mode='rb') as file:
|
||||
line = file.readline()
|
||||
|
@ -75,14 +74,11 @@ def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None):
|
|||
elif clean_line == "</page>":
|
||||
if article_id:
|
||||
try:
|
||||
_process_wp_text(kb, wp_to_id, entityfile, article_id, article_text.strip(), training_output)
|
||||
# on a previous run, an error occurred after 46M lines and 2h
|
||||
_process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(), training_output)
|
||||
except Exception as e:
|
||||
print("Error processing article", article_id, article_title, e)
|
||||
else:
|
||||
print("Done processing a page, but couldn't find an article_id ?")
|
||||
print(article_title)
|
||||
print(article_text)
|
||||
print("Done processing a page, but couldn't find an article_id ?", article_title)
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
|
@ -122,7 +118,14 @@ def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None):
|
|||
text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)')
|
||||
|
||||
|
||||
def _process_wp_text(kb, wp_to_id, entityfile, article_id, article_text, training_output):
|
||||
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output):
|
||||
found_entities = False
|
||||
# print("Processing", article_id, article_title)
|
||||
|
||||
# ignore meta Wikipedia pages
|
||||
if article_title.startswith("Wikipedia:"):
|
||||
return
|
||||
|
||||
# remove the text tags
|
||||
text = text_regex.search(article_text).group(0)
|
||||
|
||||
|
@ -130,67 +133,91 @@ def _process_wp_text(kb, wp_to_id, entityfile, article_id, article_text, trainin
|
|||
if text.startswith("#REDIRECT"):
|
||||
return
|
||||
|
||||
# print("WP article", article_id, ":", article_title)
|
||||
# print()
|
||||
# print(text)
|
||||
|
||||
# get the raw text without markup etc
|
||||
# get the raw text without markup etc, keeping only interwiki links
|
||||
clean_text = _get_clean_wp_text(text)
|
||||
# print()
|
||||
# print(clean_text)
|
||||
|
||||
article_dict = dict()
|
||||
ambiguous_aliases = set()
|
||||
aliases, entities, normalizations = wp.get_wp_links(text)
|
||||
for alias, entity, norm in zip(aliases, entities, normalizations):
|
||||
if alias not in ambiguous_aliases:
|
||||
entity_id = wp_to_id.get(entity)
|
||||
if entity_id:
|
||||
# TODO: take care of these conflicts ! Currently they are being removed from the dataset
|
||||
if article_dict.get(alias) and article_dict[alias] != entity_id:
|
||||
ambiguous_aliases.add(alias)
|
||||
article_dict.pop(alias)
|
||||
# print("Found conflicting alias", alias, "in article", article_id, article_title)
|
||||
# read the text char by char to get the right offsets of the interwiki links
|
||||
final_text = ""
|
||||
open_read = 0
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
for index, letter in enumerate(clean_text):
|
||||
if letter == '[':
|
||||
open_read += 1
|
||||
elif letter == ']':
|
||||
open_read -= 1
|
||||
elif letter == '|':
|
||||
if reading_text:
|
||||
final_text += letter
|
||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||
elif reading_entity:
|
||||
reading_text = False
|
||||
reading_entity = False
|
||||
reading_mention = True
|
||||
else:
|
||||
article_dict[alias] = entity_id
|
||||
reading_special_case = True
|
||||
else:
|
||||
if reading_entity:
|
||||
entity_buffer += letter
|
||||
elif reading_mention:
|
||||
mention_buffer += letter
|
||||
elif reading_text:
|
||||
final_text += letter
|
||||
else:
|
||||
raise ValueError("Not sure at point", clean_text[index-2:index+2])
|
||||
|
||||
# print("found entities:")
|
||||
for alias, entity in article_dict.items():
|
||||
# print(alias, "-->", entity)
|
||||
candidates = kb.get_candidates(alias)
|
||||
if open_read > 2:
|
||||
reading_special_case = True
|
||||
|
||||
# as training data, we only store entities that are sufficiently ambiguous
|
||||
if len(candidates) > 1:
|
||||
_write_training_article(article_id=article_id, clean_text=clean_text, training_output=training_output)
|
||||
# print("alias", alias)
|
||||
if open_read == 2 and reading_text:
|
||||
reading_text = False
|
||||
reading_entity = True
|
||||
reading_mention = False
|
||||
|
||||
# print all incorrect candidates
|
||||
for c in candidates:
|
||||
if entity != c.entity_:
|
||||
# we just finished reading an entity
|
||||
if open_read == 0 and not reading_text:
|
||||
if '#' in entity_buffer or entity_buffer.startswith(':'):
|
||||
reading_special_case = True
|
||||
# Ignore cases with nested structures like File: handles etc
|
||||
if not reading_special_case:
|
||||
if not mention_buffer:
|
||||
mention_buffer = entity_buffer
|
||||
start = len(final_text)
|
||||
end = start + len(mention_buffer)
|
||||
qid = wp_to_id.get(entity_buffer, None)
|
||||
if qid:
|
||||
_write_training_entity(outputfile=entityfile,
|
||||
article_id=article_id,
|
||||
alias=alias,
|
||||
entity=c.entity_,
|
||||
correct="0")
|
||||
alias=mention_buffer,
|
||||
entity=qid,
|
||||
start=start,
|
||||
end=end)
|
||||
found_entities = True
|
||||
final_text += mention_buffer
|
||||
|
||||
# print the one correct candidate
|
||||
_write_training_entity(outputfile=entityfile,
|
||||
article_id=article_id,
|
||||
alias=alias,
|
||||
entity=entity,
|
||||
correct="1")
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
|
||||
# print("gold entity", entity)
|
||||
# print()
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
|
||||
# _run_ner_depr(nlp, clean_text, article_dict)
|
||||
# print()
|
||||
if found_entities:
|
||||
_write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output)
|
||||
|
||||
|
||||
info_regex = re.compile(r'{[^{]*?}')
|
||||
interwiki_regex = re.compile(r'\[\[([^|]*?)]]')
|
||||
interwiki_2_regex = re.compile(r'\[\[[^|]*?\|([^|]*?)]]')
|
||||
htlm_regex = re.compile(r'<!--[^!]*-->')
|
||||
htlm_regex = re.compile(r'<!--[^-]*-->')
|
||||
category_regex = re.compile(r'\[\[Category:[^\[]*]]')
|
||||
file_regex = re.compile(r'\[\[File:[^[\]]+]]')
|
||||
ref_regex = re.compile(r'<ref.*?>') # non-greedy
|
||||
|
@ -215,12 +242,6 @@ def _get_clean_wp_text(article_text):
|
|||
try_again = False
|
||||
previous_length = len(clean_text)
|
||||
|
||||
# remove simple interwiki links (no alternative name)
|
||||
clean_text = interwiki_regex.sub(r'\1', clean_text)
|
||||
|
||||
# remove simple interwiki links by picking the alternative name
|
||||
clean_text = interwiki_2_regex.sub(r'\1', clean_text)
|
||||
|
||||
# remove HTML comments
|
||||
clean_text = htlm_regex.sub('', clean_text)
|
||||
|
||||
|
@ -265,43 +286,34 @@ def _write_training_article(article_id, clean_text, training_output):
|
|||
outputfile.write(clean_text)
|
||||
|
||||
|
||||
def _write_training_entity(outputfile, article_id, alias, entity, correct):
|
||||
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + correct + "\n")
|
||||
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
|
||||
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n")
|
||||
|
||||
|
||||
def read_training_entities(training_output, collect_correct=True, collect_incorrect=False):
|
||||
def read_training_entities(training_output):
|
||||
entityfile_loc = training_output + "/" + ENTITY_FILE
|
||||
incorrect_entries_per_article = dict()
|
||||
correct_entries_per_article = dict()
|
||||
entries_per_article = dict()
|
||||
|
||||
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
||||
for line in file:
|
||||
fields = line.replace('\n', "").split(sep='|')
|
||||
article_id = fields[0]
|
||||
alias = fields[1]
|
||||
entity = fields[2]
|
||||
correct = fields[3]
|
||||
wp_title = fields[2]
|
||||
start = fields[3]
|
||||
end = fields[4]
|
||||
|
||||
if correct == "1" and collect_correct:
|
||||
entry_dict = correct_entries_per_article.get(article_id, dict())
|
||||
if alias in entry_dict:
|
||||
raise ValueError("Found alias", alias, "multiple times for article", article_id, "in", ENTITY_FILE)
|
||||
entry_dict[alias] = entity
|
||||
correct_entries_per_article[article_id] = entry_dict
|
||||
entries_by_offset = entries_per_article.get(article_id, dict())
|
||||
entries_by_offset[start + "-" + end] = (alias, wp_title)
|
||||
|
||||
if correct == "0" and collect_incorrect:
|
||||
entry_dict = incorrect_entries_per_article.get(article_id, dict())
|
||||
entities = entry_dict.get(alias, set())
|
||||
entities.add(entity)
|
||||
entry_dict[alias] = entities
|
||||
incorrect_entries_per_article[article_id] = entry_dict
|
||||
entries_per_article[article_id] = entries_by_offset
|
||||
|
||||
return correct_entries_per_article, incorrect_entries_per_article
|
||||
return entries_per_article
|
||||
|
||||
|
||||
def read_training(nlp, training_dir, dev, limit, to_print):
|
||||
correct_entries, incorrect_entries = read_training_entities(training_output=training_dir,
|
||||
collect_correct=True,
|
||||
collect_incorrect=True)
|
||||
# This method will provide training examples that correspond to the entity annotations found by the nlp object
|
||||
entries_per_article = read_training_entities(training_output=training_dir)
|
||||
|
||||
data = []
|
||||
|
||||
|
@ -320,22 +332,22 @@ def read_training(nlp, training_dir, dev, limit, to_print):
|
|||
text = file.read()
|
||||
article_doc = nlp(text)
|
||||
|
||||
entries_by_offset = entries_per_article.get(article_id, dict())
|
||||
|
||||
gold_entities = list()
|
||||
for ent in article_doc.ents:
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
|
||||
# process all positive and negative entities, collect all relevant mentions in this article
|
||||
for mention, entity_pos in correct_entries[article_id].items():
|
||||
# find all matches in the doc for the mentions
|
||||
# TODO: fix this - doesn't look like all entities are found
|
||||
matcher = PhraseMatcher(nlp.vocab)
|
||||
patterns = list(nlp.tokenizer.pipe([mention]))
|
||||
|
||||
matcher.add("TerminologyList", None, *patterns)
|
||||
matches = matcher(article_doc)
|
||||
|
||||
# store gold entities
|
||||
for match_id, start, end in matches:
|
||||
gold_entities.append((start, end, entity_pos))
|
||||
entity_tuple = entries_by_offset.get(str(start) + "-" + str(end), None)
|
||||
if entity_tuple:
|
||||
alias, wp_title = entity_tuple
|
||||
if ent.text != alias:
|
||||
print("Non-matching entity in", article_id, start, end)
|
||||
else:
|
||||
gold_entities.append((start, end, wp_title))
|
||||
|
||||
if gold_entities:
|
||||
gold = GoldParse(doc=article_doc, links=gold_entities)
|
||||
data.append((article_doc, gold))
|
||||
|
||||
|
@ -343,13 +355,10 @@ def read_training(nlp, training_dir, dev, limit, to_print):
|
|||
except Exception as e:
|
||||
print("Problem parsing article", article_id)
|
||||
print(e)
|
||||
raise e
|
||||
|
||||
if to_print:
|
||||
print()
|
||||
print("Processed", cnt, "training articles, dev=" + str(dev))
|
||||
print()
|
||||
return data
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
|
@ -30,8 +30,8 @@ TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
|
|||
|
||||
MAX_CANDIDATES = 10
|
||||
MIN_PAIR_OCC = 5
|
||||
DOC_CHAR_CUTOFF = 300
|
||||
EPOCHS = 2
|
||||
DOC_SENT_CUTOFF = 2
|
||||
EPOCHS = 10
|
||||
DROPOUT = 0.1
|
||||
|
||||
|
||||
|
@ -46,14 +46,14 @@ def run_pipeline():
|
|||
# one-time methods to create KB and write to file
|
||||
to_create_prior_probs = False
|
||||
to_create_entity_counts = False
|
||||
to_create_kb = True
|
||||
to_create_kb = False # TODO: entity_defs should also contain entities not in the KB
|
||||
|
||||
# read KB back in from file
|
||||
to_read_kb = True
|
||||
to_test_kb = True
|
||||
to_read_kb = False
|
||||
to_test_kb = False
|
||||
|
||||
# create training dataset
|
||||
create_wp_training = False
|
||||
create_wp_training = True
|
||||
|
||||
# train the EL pipe
|
||||
train_pipe = False
|
||||
|
@ -103,9 +103,6 @@ def run_pipeline():
|
|||
# STEP 4 : read KB back in from file
|
||||
if to_read_kb:
|
||||
print("STEP 4: to_read_kb", datetime.datetime.now())
|
||||
# my_vocab = Vocab()
|
||||
# my_vocab.from_disk(VOCAB_DIR)
|
||||
# my_kb = KnowledgeBase(vocab=my_vocab, entity_vector_length=64)
|
||||
nlp_2 = spacy.load(NLP_1_DIR)
|
||||
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
|
||||
kb_2.load_bulk(KB_FILE)
|
||||
|
@ -121,13 +118,13 @@ def run_pipeline():
|
|||
# STEP 5: create a training dataset from WP
|
||||
if create_wp_training:
|
||||
print("STEP 5: create training dataset", datetime.datetime.now())
|
||||
training_set_creator.create_training(kb=kb_2, entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
|
||||
training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
|
||||
|
||||
# STEP 6: create the entity linking pipe
|
||||
if train_pipe:
|
||||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||
train_limit = 10
|
||||
dev_limit = 5
|
||||
train_limit = 50
|
||||
dev_limit = 10
|
||||
print("Training on", train_limit, "articles")
|
||||
print("Dev testing on", dev_limit, "articles")
|
||||
print()
|
||||
|
@ -144,7 +141,7 @@ def run_pipeline():
|
|||
limit=dev_limit,
|
||||
to_print=False)
|
||||
|
||||
el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_CHAR_CUTOFF})
|
||||
el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_SENT_CUTOFF})
|
||||
el_pipe.set_kb(kb_2)
|
||||
nlp_2.add_pipe(el_pipe, last=True)
|
||||
|
||||
|
@ -199,10 +196,14 @@ def run_pipeline():
|
|||
el_pipe.prior_weight = 0
|
||||
dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe)
|
||||
train_acc_1_0 = _measure_accuracy(train_data, el_pipe)
|
||||
|
||||
print("train/dev acc context:", round(train_acc_1_0, 2), round(dev_acc_1_0, 2))
|
||||
print()
|
||||
|
||||
# reset for follow-up tests
|
||||
el_pipe.context_weight = 1
|
||||
el_pipe.prior_weight = 1
|
||||
|
||||
|
||||
if to_test_pipeline:
|
||||
print()
|
||||
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
|
||||
|
@ -215,17 +216,10 @@ def run_pipeline():
|
|||
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
||||
print()
|
||||
print("writing to", NLP_2_DIR)
|
||||
print(" vocab len nlp_2", len(nlp_2.vocab))
|
||||
print(" vocab len kb_2", len(kb_2.vocab))
|
||||
nlp_2.to_disk(NLP_2_DIR)
|
||||
print()
|
||||
print("reading from", NLP_2_DIR)
|
||||
nlp_3 = spacy.load(NLP_2_DIR)
|
||||
print(" vocab len nlp_3", len(nlp_3.vocab))
|
||||
|
||||
for pipe_name, pipe in nlp_3.pipeline:
|
||||
if pipe_name == "entity_linker":
|
||||
print(" vocab len kb_3", len(pipe.kb.vocab))
|
||||
|
||||
print()
|
||||
print("running toy example with NLP 2")
|
||||
|
@ -253,9 +247,10 @@ def _measure_accuracy(data, el_pipe):
|
|||
for ent in doc.ents:
|
||||
if ent.label_ == "PERSON": # TODO: expand to other types
|
||||
pred_entity = ent.kb_id_
|
||||
start = ent.start
|
||||
end = ent.end
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
if gold_entity == pred_entity:
|
||||
correct += 1
|
||||
|
@ -285,7 +280,8 @@ def run_el_toy_example(nlp):
|
|||
print()
|
||||
|
||||
# Q4426480 is her husband, Q3568763 her tutor
|
||||
text = "Ada Lovelace loved her husband William King dearly. " \
|
||||
text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine."\
|
||||
"Ada Lovelace loved her husband William King dearly. " \
|
||||
"Ada Lovelace was tutored by her favorite physics tutor William King."
|
||||
doc = nlp(text)
|
||||
|
||||
|
|
|
@ -1074,6 +1074,9 @@ class EntityLinker(Pipe):
|
|||
|
||||
@classmethod
|
||||
def Model(cls, **cfg):
|
||||
if "entity_width" not in cfg:
|
||||
raise ValueError("entity_width not found")
|
||||
|
||||
embed_width = cfg.get("embed_width", 300)
|
||||
hidden_width = cfg.get("hidden_width", 32)
|
||||
article_width = cfg.get("article_width", 128)
|
||||
|
@ -1095,7 +1098,10 @@ class EntityLinker(Pipe):
|
|||
self.mention_encoder = True
|
||||
self.kb = None
|
||||
self.cfg = dict(cfg)
|
||||
self.doc_cutoff = self.cfg.get("doc_cutoff", 150)
|
||||
self.doc_cutoff = self.cfg.get("doc_cutoff", 5)
|
||||
self.sgd_article = None
|
||||
self.sgd_sent = None
|
||||
self.sgd_mention = None
|
||||
|
||||
def set_kb(self, kb):
|
||||
self.kb = kb
|
||||
|
@ -1126,6 +1132,12 @@ class EntityLinker(Pipe):
|
|||
self.require_model()
|
||||
self.require_kb()
|
||||
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
|
||||
if not docs or not golds:
|
||||
return 0
|
||||
|
||||
if len(docs) != len(golds):
|
||||
raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs),
|
||||
n_golds=len(golds)))
|
||||
|
@ -1141,21 +1153,30 @@ class EntityLinker(Pipe):
|
|||
for doc, gold in zip(docs, golds):
|
||||
for entity in gold.links:
|
||||
start, end, gold_kb = entity
|
||||
mention = doc[start:end]
|
||||
sentence = mention.sent
|
||||
first_par = doc[0:self.doc_cutoff].as_doc()
|
||||
mention = doc.text[start:end]
|
||||
sent_start = 0
|
||||
sent_end = len(doc)
|
||||
first_par_end = len(doc)
|
||||
for index, sent in enumerate(doc.sents):
|
||||
if start >= sent.start_char and end <= sent.end_char:
|
||||
sent_start = sent.start
|
||||
sent_end = sent.end
|
||||
if index == self.doc_cutoff-1:
|
||||
first_par_end = sent.end
|
||||
sentence = doc[sent_start:sent_end].as_doc()
|
||||
first_par = doc[0:first_par_end].as_doc()
|
||||
|
||||
candidates = self.kb.get_candidates(mention.text)
|
||||
candidates = self.kb.get_candidates(mention)
|
||||
for c in candidates:
|
||||
kb_id = c.entity_
|
||||
# TODO: currently only training on the positive instances
|
||||
# Currently only training on the positive instances
|
||||
if kb_id == gold_kb:
|
||||
prior_prob = c.prior_prob
|
||||
entity_encoding = c.entity_vector
|
||||
|
||||
entity_encodings.append(entity_encoding)
|
||||
article_docs.append(first_par)
|
||||
sentence_docs.append(sentence.as_doc())
|
||||
sentence_docs.append(sentence)
|
||||
|
||||
if len(entity_encodings) > 0:
|
||||
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
|
||||
|
@ -1168,11 +1189,6 @@ class EntityLinker(Pipe):
|
|||
entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
|
||||
|
||||
loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
|
||||
# print("scores", mention_encodings)
|
||||
# print("golds", entity_encodings)
|
||||
# print("loss", loss)
|
||||
# print("d_scores", d_scores)
|
||||
|
||||
mention_gradient = bp_mention(d_scores, sgd=self.sgd_mention)
|
||||
|
||||
# gradient : concat (doc+sent) vs. desc
|
||||
|
@ -1187,7 +1203,6 @@ class EntityLinker(Pipe):
|
|||
bp_sent(sent_gradients, sgd=self.sgd_sent)
|
||||
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
losses[self.name] += loss
|
||||
return loss
|
||||
|
||||
|
@ -1230,16 +1245,25 @@ class EntityLinker(Pipe):
|
|||
self.require_model()
|
||||
self.require_kb()
|
||||
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
|
||||
final_entities = list()
|
||||
final_kb_ids = list()
|
||||
|
||||
for i, article_doc in enumerate(docs):
|
||||
if len(article_doc) > 0:
|
||||
doc_encoding = self.article_encoder([article_doc])
|
||||
for ent in article_doc.ents:
|
||||
if not docs:
|
||||
return final_entities, final_kb_ids
|
||||
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
|
||||
for i, doc in enumerate(docs):
|
||||
if len(doc) > 0:
|
||||
first_par_end = len(doc)
|
||||
for index, sent in enumerate(doc.sents):
|
||||
if index == self.doc_cutoff-1:
|
||||
first_par_end = sent.end
|
||||
first_par = doc[0:first_par_end].as_doc()
|
||||
|
||||
doc_encoding = self.article_encoder([first_par])
|
||||
for ent in doc.ents:
|
||||
sent_doc = ent.sent.as_doc()
|
||||
if len(sent_doc) > 0:
|
||||
sent_encoding = self.sent_encoder([sent_doc])
|
||||
|
@ -1254,7 +1278,7 @@ class EntityLinker(Pipe):
|
|||
prior_prob = c.prior_prob * self.prior_weight
|
||||
kb_id = c.entity_
|
||||
entity_encoding = c.entity_vector
|
||||
sim = cosine(np.asarray([entity_encoding]), mention_enc_t) * self.context_weight
|
||||
sim = float(cosine(np.asarray([entity_encoding]), mention_enc_t)) * self.context_weight
|
||||
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
|
||||
scores.append(score)
|
||||
|
||||
|
@ -1271,36 +1295,7 @@ class EntityLinker(Pipe):
|
|||
for token in entity:
|
||||
token.ent_kb_id_ = kb_id
|
||||
|
||||
def to_bytes(self, exclude=tuple(), **kwargs):
|
||||
"""Serialize the pipe to a bytestring.
|
||||
|
||||
exclude (list): String names of serialization fields to exclude.
|
||||
RETURNS (bytes): The serialized object.
|
||||
"""
|
||||
serialize = OrderedDict()
|
||||
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
||||
serialize["kb"] = self.kb.to_bytes # TODO
|
||||
if self.mention_encoder not in (True, False, None):
|
||||
serialize["article_encoder"] = self.article_encoder.to_bytes
|
||||
serialize["sent_encoder"] = self.sent_encoder.to_bytes
|
||||
serialize["mention_encoder"] = self.mention_encoder.to_bytes
|
||||
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
|
||||
"""Load the pipe from a bytestring."""
|
||||
deserialize = OrderedDict()
|
||||
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
||||
deserialize["kb"] = lambda b: self.kb.from_bytes(b) # TODO
|
||||
deserialize["article_encoder"] = lambda b: self.article_encoder.from_bytes(b)
|
||||
deserialize["sent_encoder"] = lambda b: self.sent_encoder.from_bytes(b)
|
||||
deserialize["mention_encoder"] = lambda b: self.mention_encoder.from_bytes(b)
|
||||
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
||||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||
"""Serialize the pipe to disk."""
|
||||
serialize = OrderedDict()
|
||||
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||
serialize["kb"] = lambda p: self.kb.dump(p)
|
||||
|
@ -1312,7 +1307,6 @@ class EntityLinker(Pipe):
|
|||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, exclude=tuple(), **kwargs):
|
||||
"""Load the pipe from disk."""
|
||||
def load_article_encoder(p):
|
||||
if self.article_encoder is True:
|
||||
self.article_encoder, _, _ = self.Model(**self.cfg)
|
||||
|
|
Loading…
Reference in New Issue
Block a user