diff --git a/examples/pipeline/wiki_entity_linking/train_descriptions.py b/examples/pipeline/wiki_entity_linking/train_descriptions.py
index 92859fd84..bf4bcbc3d 100644
--- a/examples/pipeline/wiki_entity_linking/train_descriptions.py
+++ b/examples/pipeline/wiki_entity_linking/train_descriptions.py
@@ -1,8 +1,6 @@
# coding: utf-8
from random import shuffle
-from examples.pipeline.wiki_entity_linking import kb_creator
-
import numpy as np
from spacy._ml import zero_init, create_default_optimizer
diff --git a/examples/pipeline/wiki_entity_linking/training_set_creator.py b/examples/pipeline/wiki_entity_linking/training_set_creator.py
index 38a86058d..fc620a1d3 100644
--- a/examples/pipeline/wiki_entity_linking/training_set_creator.py
+++ b/examples/pipeline/wiki_entity_linking/training_set_creator.py
@@ -19,17 +19,15 @@ Process Wikipedia interlinks to generate a training dataset for the EL algorithm
ENTITY_FILE = "gold_entities.csv"
-def create_training(kb, entity_def_input, training_output):
- if not kb:
- raise ValueError("kb should be defined")
+def create_training(entity_def_input, training_output):
wp_to_id = kb_creator._get_entity_to_id(entity_def_input)
- _process_wikipedia_texts(kb, wp_to_id, training_output, limit=100000000) # TODO: full dataset
+ _process_wikipedia_texts(wp_to_id, training_output, limit=100000000) # TODO: full dataset 100000000
-def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None):
+def _process_wikipedia_texts(wp_to_id, training_output, limit=None):
"""
Read the XML wikipedia data to parse out training data:
- raw text data + positive and negative instances
+ raw text data + positive instances
"""
title_regex = re.compile(r'(?<=
).*(?=)')
@@ -43,8 +41,9 @@ def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None):
_write_training_entity(outputfile=entityfile,
article_id="article_id",
alias="alias",
- entity="entity",
- correct="correct")
+ entity="WD_id",
+ start="start",
+ end="end")
with bz2.open(wp.ENWIKI_DUMP, mode='rb') as file:
line = file.readline()
@@ -75,14 +74,11 @@ def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None):
elif clean_line == "":
if article_id:
try:
- _process_wp_text(kb, wp_to_id, entityfile, article_id, article_text.strip(), training_output)
- # on a previous run, an error occurred after 46M lines and 2h
+ _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(), training_output)
except Exception as e:
print("Error processing article", article_id, article_title, e)
else:
- print("Done processing a page, but couldn't find an article_id ?")
- print(article_title)
- print(article_text)
+ print("Done processing a page, but couldn't find an article_id ?", article_title)
article_text = ""
article_title = None
article_id = None
@@ -122,7 +118,14 @@ def _process_wikipedia_texts(kb, wp_to_id, training_output, limit=None):
text_regex = re.compile(r'(?<=).*(?=", entity)
- candidates = kb.get_candidates(alias)
+ if open_read > 2:
+ reading_special_case = True
- # as training data, we only store entities that are sufficiently ambiguous
- if len(candidates) > 1:
- _write_training_article(article_id=article_id, clean_text=clean_text, training_output=training_output)
- # print("alias", alias)
+ if open_read == 2 and reading_text:
+ reading_text = False
+ reading_entity = True
+ reading_mention = False
- # print all incorrect candidates
- for c in candidates:
- if entity != c.entity_:
+ # we just finished reading an entity
+ if open_read == 0 and not reading_text:
+ if '#' in entity_buffer or entity_buffer.startswith(':'):
+ reading_special_case = True
+ # Ignore cases with nested structures like File: handles etc
+ if not reading_special_case:
+ if not mention_buffer:
+ mention_buffer = entity_buffer
+ start = len(final_text)
+ end = start + len(mention_buffer)
+ qid = wp_to_id.get(entity_buffer, None)
+ if qid:
_write_training_entity(outputfile=entityfile,
article_id=article_id,
- alias=alias,
- entity=c.entity_,
- correct="0")
+ alias=mention_buffer,
+ entity=qid,
+ start=start,
+ end=end)
+ found_entities = True
+ final_text += mention_buffer
- # print the one correct candidate
- _write_training_entity(outputfile=entityfile,
- article_id=article_id,
- alias=alias,
- entity=entity,
- correct="1")
+ entity_buffer = ""
+ mention_buffer = ""
- # print("gold entity", entity)
- # print()
+ reading_text = True
+ reading_entity = False
+ reading_mention = False
+ reading_special_case = False
- # _run_ner_depr(nlp, clean_text, article_dict)
- # print()
+ if found_entities:
+ _write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output)
info_regex = re.compile(r'{[^{]*?}')
-interwiki_regex = re.compile(r'\[\[([^|]*?)]]')
-interwiki_2_regex = re.compile(r'\[\[[^|]*?\|([^|]*?)]]')
-htlm_regex = re.compile(r'<!--[^!]*-->')
+htlm_regex = re.compile(r'<!--[^-]*-->')
category_regex = re.compile(r'\[\[Category:[^\[]*]]')
file_regex = re.compile(r'\[\[File:[^[\]]+]]')
ref_regex = re.compile(r'<ref.*?>') # non-greedy
@@ -215,12 +242,6 @@ def _get_clean_wp_text(article_text):
try_again = False
previous_length = len(clean_text)
- # remove simple interwiki links (no alternative name)
- clean_text = interwiki_regex.sub(r'\1', clean_text)
-
- # remove simple interwiki links by picking the alternative name
- clean_text = interwiki_2_regex.sub(r'\1', clean_text)
-
# remove HTML comments
clean_text = htlm_regex.sub('', clean_text)
@@ -265,43 +286,34 @@ def _write_training_article(article_id, clean_text, training_output):
outputfile.write(clean_text)
-def _write_training_entity(outputfile, article_id, alias, entity, correct):
- outputfile.write(article_id + "|" + alias + "|" + entity + "|" + correct + "\n")
+def _write_training_entity(outputfile, article_id, alias, entity, start, end):
+ outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n")
-def read_training_entities(training_output, collect_correct=True, collect_incorrect=False):
+def read_training_entities(training_output):
entityfile_loc = training_output + "/" + ENTITY_FILE
- incorrect_entries_per_article = dict()
- correct_entries_per_article = dict()
+ entries_per_article = dict()
+
with open(entityfile_loc, mode='r', encoding='utf8') as file:
for line in file:
fields = line.replace('\n', "").split(sep='|')
article_id = fields[0]
alias = fields[1]
- entity = fields[2]
- correct = fields[3]
+ wp_title = fields[2]
+ start = fields[3]
+ end = fields[4]
- if correct == "1" and collect_correct:
- entry_dict = correct_entries_per_article.get(article_id, dict())
- if alias in entry_dict:
- raise ValueError("Found alias", alias, "multiple times for article", article_id, "in", ENTITY_FILE)
- entry_dict[alias] = entity
- correct_entries_per_article[article_id] = entry_dict
+ entries_by_offset = entries_per_article.get(article_id, dict())
+ entries_by_offset[start + "-" + end] = (alias, wp_title)
- if correct == "0" and collect_incorrect:
- entry_dict = incorrect_entries_per_article.get(article_id, dict())
- entities = entry_dict.get(alias, set())
- entities.add(entity)
- entry_dict[alias] = entities
- incorrect_entries_per_article[article_id] = entry_dict
+ entries_per_article[article_id] = entries_by_offset
- return correct_entries_per_article, incorrect_entries_per_article
+ return entries_per_article
def read_training(nlp, training_dir, dev, limit, to_print):
- correct_entries, incorrect_entries = read_training_entities(training_output=training_dir,
- collect_correct=True,
- collect_incorrect=True)
+ # This method will provide training examples that correspond to the entity annotations found by the nlp object
+ entries_per_article = read_training_entities(training_output=training_dir)
data = []
@@ -320,36 +332,33 @@ def read_training(nlp, training_dir, dev, limit, to_print):
text = file.read()
article_doc = nlp(text)
+ entries_by_offset = entries_per_article.get(article_id, dict())
+
gold_entities = list()
+ for ent in article_doc.ents:
+ start = ent.start_char
+ end = ent.end_char
- # process all positive and negative entities, collect all relevant mentions in this article
- for mention, entity_pos in correct_entries[article_id].items():
- # find all matches in the doc for the mentions
- # TODO: fix this - doesn't look like all entities are found
- matcher = PhraseMatcher(nlp.vocab)
- patterns = list(nlp.tokenizer.pipe([mention]))
+ entity_tuple = entries_by_offset.get(str(start) + "-" + str(end), None)
+ if entity_tuple:
+ alias, wp_title = entity_tuple
+ if ent.text != alias:
+ print("Non-matching entity in", article_id, start, end)
+ else:
+ gold_entities.append((start, end, wp_title))
- matcher.add("TerminologyList", None, *patterns)
- matches = matcher(article_doc)
-
- # store gold entities
- for match_id, start, end in matches:
- gold_entities.append((start, end, entity_pos))
-
- gold = GoldParse(doc=article_doc, links=gold_entities)
- data.append((article_doc, gold))
+ if gold_entities:
+ gold = GoldParse(doc=article_doc, links=gold_entities)
+ data.append((article_doc, gold))
cnt += 1
except Exception as e:
print("Problem parsing article", article_id)
print(e)
+ raise e
if to_print:
print()
print("Processed", cnt, "training articles, dev=" + str(dev))
print()
return data
-
-
-
-
diff --git a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py
index d5002e26f..faea93f53 100644
--- a/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py
+++ b/examples/pipeline/wiki_entity_linking/wiki_nel_pipeline.py
@@ -30,8 +30,8 @@ TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/'
MAX_CANDIDATES = 10
MIN_PAIR_OCC = 5
-DOC_CHAR_CUTOFF = 300
-EPOCHS = 2
+DOC_SENT_CUTOFF = 2
+EPOCHS = 10
DROPOUT = 0.1
@@ -46,14 +46,14 @@ def run_pipeline():
# one-time methods to create KB and write to file
to_create_prior_probs = False
to_create_entity_counts = False
- to_create_kb = True
+ to_create_kb = False # TODO: entity_defs should also contain entities not in the KB
# read KB back in from file
- to_read_kb = True
- to_test_kb = True
+ to_read_kb = False
+ to_test_kb = False
# create training dataset
- create_wp_training = False
+ create_wp_training = True
# train the EL pipe
train_pipe = False
@@ -103,9 +103,6 @@ def run_pipeline():
# STEP 4 : read KB back in from file
if to_read_kb:
print("STEP 4: to_read_kb", datetime.datetime.now())
- # my_vocab = Vocab()
- # my_vocab.from_disk(VOCAB_DIR)
- # my_kb = KnowledgeBase(vocab=my_vocab, entity_vector_length=64)
nlp_2 = spacy.load(NLP_1_DIR)
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
kb_2.load_bulk(KB_FILE)
@@ -121,13 +118,13 @@ def run_pipeline():
# STEP 5: create a training dataset from WP
if create_wp_training:
print("STEP 5: create training dataset", datetime.datetime.now())
- training_set_creator.create_training(kb=kb_2, entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
+ training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR)
# STEP 6: create the entity linking pipe
if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
- train_limit = 10
- dev_limit = 5
+ train_limit = 50
+ dev_limit = 10
print("Training on", train_limit, "articles")
print("Dev testing on", dev_limit, "articles")
print()
@@ -144,7 +141,7 @@ def run_pipeline():
limit=dev_limit,
to_print=False)
- el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_CHAR_CUTOFF})
+ el_pipe = nlp_2.create_pipe(name='entity_linker', config={"doc_cutoff": DOC_SENT_CUTOFF})
el_pipe.set_kb(kb_2)
nlp_2.add_pipe(el_pipe, last=True)
@@ -199,10 +196,14 @@ def run_pipeline():
el_pipe.prior_weight = 0
dev_acc_1_0 = _measure_accuracy(dev_data, el_pipe)
train_acc_1_0 = _measure_accuracy(train_data, el_pipe)
-
print("train/dev acc context:", round(train_acc_1_0, 2), round(dev_acc_1_0, 2))
print()
+ # reset for follow-up tests
+ el_pipe.context_weight = 1
+ el_pipe.prior_weight = 1
+
+
if to_test_pipeline:
print()
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
@@ -215,17 +216,10 @@ def run_pipeline():
print("STEP 9: testing NLP IO", datetime.datetime.now())
print()
print("writing to", NLP_2_DIR)
- print(" vocab len nlp_2", len(nlp_2.vocab))
- print(" vocab len kb_2", len(kb_2.vocab))
nlp_2.to_disk(NLP_2_DIR)
print()
print("reading from", NLP_2_DIR)
nlp_3 = spacy.load(NLP_2_DIR)
- print(" vocab len nlp_3", len(nlp_3.vocab))
-
- for pipe_name, pipe in nlp_3.pipeline:
- if pipe_name == "entity_linker":
- print(" vocab len kb_3", len(pipe.kb.vocab))
print()
print("running toy example with NLP 2")
@@ -253,9 +247,10 @@ def _measure_accuracy(data, el_pipe):
for ent in doc.ents:
if ent.label_ == "PERSON": # TODO: expand to other types
pred_entity = ent.kb_id_
- start = ent.start
- end = ent.end
+ start = ent.start_char
+ end = ent.end_char
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
+ # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
if gold_entity == pred_entity:
correct += 1
@@ -285,7 +280,8 @@ def run_el_toy_example(nlp):
print()
# Q4426480 is her husband, Q3568763 her tutor
- text = "Ada Lovelace loved her husband William King dearly. " \
+ text = "Ada Lovelace was the countess of Lovelace. She is known for her programming work on the analytical engine."\
+ "Ada Lovelace loved her husband William King dearly. " \
"Ada Lovelace was tutored by her favorite physics tutor William King."
doc = nlp(text)
diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx
index e73ff6a0e..5d82da7ee 100644
--- a/spacy/pipeline/pipes.pyx
+++ b/spacy/pipeline/pipes.pyx
@@ -1074,6 +1074,9 @@ class EntityLinker(Pipe):
@classmethod
def Model(cls, **cfg):
+ if "entity_width" not in cfg:
+ raise ValueError("entity_width not found")
+
embed_width = cfg.get("embed_width", 300)
hidden_width = cfg.get("hidden_width", 32)
article_width = cfg.get("article_width", 128)
@@ -1095,7 +1098,10 @@ class EntityLinker(Pipe):
self.mention_encoder = True
self.kb = None
self.cfg = dict(cfg)
- self.doc_cutoff = self.cfg.get("doc_cutoff", 150)
+ self.doc_cutoff = self.cfg.get("doc_cutoff", 5)
+ self.sgd_article = None
+ self.sgd_sent = None
+ self.sgd_mention = None
def set_kb(self, kb):
self.kb = kb
@@ -1126,6 +1132,12 @@ class EntityLinker(Pipe):
self.require_model()
self.require_kb()
+ if losses is not None:
+ losses.setdefault(self.name, 0.0)
+
+ if not docs or not golds:
+ return 0
+
if len(docs) != len(golds):
raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs),
n_golds=len(golds)))
@@ -1141,21 +1153,30 @@ class EntityLinker(Pipe):
for doc, gold in zip(docs, golds):
for entity in gold.links:
start, end, gold_kb = entity
- mention = doc[start:end]
- sentence = mention.sent
- first_par = doc[0:self.doc_cutoff].as_doc()
+ mention = doc.text[start:end]
+ sent_start = 0
+ sent_end = len(doc)
+ first_par_end = len(doc)
+ for index, sent in enumerate(doc.sents):
+ if start >= sent.start_char and end <= sent.end_char:
+ sent_start = sent.start
+ sent_end = sent.end
+ if index == self.doc_cutoff-1:
+ first_par_end = sent.end
+ sentence = doc[sent_start:sent_end].as_doc()
+ first_par = doc[0:first_par_end].as_doc()
- candidates = self.kb.get_candidates(mention.text)
+ candidates = self.kb.get_candidates(mention)
for c in candidates:
kb_id = c.entity_
- # TODO: currently only training on the positive instances
+ # Currently only training on the positive instances
if kb_id == gold_kb:
prior_prob = c.prior_prob
entity_encoding = c.entity_vector
entity_encodings.append(entity_encoding)
article_docs.append(first_par)
- sentence_docs.append(sentence.as_doc())
+ sentence_docs.append(sentence)
if len(entity_encodings) > 0:
doc_encodings, bp_doc = self.article_encoder.begin_update(article_docs, drop=drop)
@@ -1168,11 +1189,6 @@ class EntityLinker(Pipe):
entity_encodings = np.asarray(entity_encodings, dtype=np.float32)
loss, d_scores = self.get_loss(scores=mention_encodings, golds=entity_encodings, docs=None)
- # print("scores", mention_encodings)
- # print("golds", entity_encodings)
- # print("loss", loss)
- # print("d_scores", d_scores)
-
mention_gradient = bp_mention(d_scores, sgd=self.sgd_mention)
# gradient : concat (doc+sent) vs. desc
@@ -1187,7 +1203,6 @@ class EntityLinker(Pipe):
bp_sent(sent_gradients, sgd=self.sgd_sent)
if losses is not None:
- losses.setdefault(self.name, 0.0)
losses[self.name] += loss
return loss
@@ -1230,16 +1245,25 @@ class EntityLinker(Pipe):
self.require_model()
self.require_kb()
- if isinstance(docs, Doc):
- docs = [docs]
-
final_entities = list()
final_kb_ids = list()
- for i, article_doc in enumerate(docs):
- if len(article_doc) > 0:
- doc_encoding = self.article_encoder([article_doc])
- for ent in article_doc.ents:
+ if not docs:
+ return final_entities, final_kb_ids
+
+ if isinstance(docs, Doc):
+ docs = [docs]
+
+ for i, doc in enumerate(docs):
+ if len(doc) > 0:
+ first_par_end = len(doc)
+ for index, sent in enumerate(doc.sents):
+ if index == self.doc_cutoff-1:
+ first_par_end = sent.end
+ first_par = doc[0:first_par_end].as_doc()
+
+ doc_encoding = self.article_encoder([first_par])
+ for ent in doc.ents:
sent_doc = ent.sent.as_doc()
if len(sent_doc) > 0:
sent_encoding = self.sent_encoder([sent_doc])
@@ -1254,7 +1278,7 @@ class EntityLinker(Pipe):
prior_prob = c.prior_prob * self.prior_weight
kb_id = c.entity_
entity_encoding = c.entity_vector
- sim = cosine(np.asarray([entity_encoding]), mention_enc_t) * self.context_weight
+ sim = float(cosine(np.asarray([entity_encoding]), mention_enc_t)) * self.context_weight
score = prior_prob + sim - (prior_prob*sim) # put weights on the different factors ?
scores.append(score)
@@ -1271,36 +1295,7 @@ class EntityLinker(Pipe):
for token in entity:
token.ent_kb_id_ = kb_id
- def to_bytes(self, exclude=tuple(), **kwargs):
- """Serialize the pipe to a bytestring.
-
- exclude (list): String names of serialization fields to exclude.
- RETURNS (bytes): The serialized object.
- """
- serialize = OrderedDict()
- serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
- serialize["kb"] = self.kb.to_bytes # TODO
- if self.mention_encoder not in (True, False, None):
- serialize["article_encoder"] = self.article_encoder.to_bytes
- serialize["sent_encoder"] = self.sent_encoder.to_bytes
- serialize["mention_encoder"] = self.mention_encoder.to_bytes
- exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
- return util.to_bytes(serialize, exclude)
-
- def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
- """Load the pipe from a bytestring."""
- deserialize = OrderedDict()
- deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
- deserialize["kb"] = lambda b: self.kb.from_bytes(b) # TODO
- deserialize["article_encoder"] = lambda b: self.article_encoder.from_bytes(b)
- deserialize["sent_encoder"] = lambda b: self.sent_encoder.from_bytes(b)
- deserialize["mention_encoder"] = lambda b: self.mention_encoder.from_bytes(b)
- exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
- util.from_bytes(bytes_data, deserialize, exclude)
- return self
-
def to_disk(self, path, exclude=tuple(), **kwargs):
- """Serialize the pipe to disk."""
serialize = OrderedDict()
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["kb"] = lambda p: self.kb.dump(p)
@@ -1312,7 +1307,6 @@ class EntityLinker(Pipe):
util.to_disk(path, serialize, exclude)
def from_disk(self, path, exclude=tuple(), **kwargs):
- """Load the pipe from disk."""
def load_article_encoder(p):
if self.article_encoder is True:
self.article_encoder, _, _ = self.Model(**self.cfg)