diff --git a/bin/wiki_entity_linking/kb_creator.py b/bin/wiki_entity_linking/kb_creator.py
index e8e081cef..5b25475b2 100644
--- a/bin/wiki_entity_linking/kb_creator.py
+++ b/bin/wiki_entity_linking/kb_creator.py
@@ -13,9 +13,17 @@ INPUT_DIM = 300 # dimension of pre-trained input vectors
DESC_WIDTH = 64 # dimension of output entity vectors
-def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
- entity_def_output, entity_descr_output,
- count_input, prior_prob_input, wikidata_input):
+def create_kb(
+ nlp,
+ max_entities_per_alias,
+ min_entity_freq,
+ min_occ,
+ entity_def_output,
+ entity_descr_output,
+ count_input,
+ prior_prob_input,
+ wikidata_input,
+):
# Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
@@ -28,7 +36,9 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
# write the title-ID and ID-description mappings to file
- _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr)
+ _write_entity_files(
+ entity_def_output, entity_descr_output, title_to_id, id_to_descr
+ )
else:
# read the mappings from file
@@ -54,8 +64,8 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
frequency_list.append(freq)
filtered_title_to_id[title] = entity
- print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()),
- "titles with filter frequency", min_entity_freq)
+ print(len(title_to_id.keys()), "original titles")
+ print("kept", len(filtered_title_to_id.keys()), " with frequency", min_entity_freq)
print()
print(" * train entity encoder", datetime.datetime.now())
@@ -70,14 +80,20 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
print()
print(" * adding", len(entity_list), "entities", datetime.datetime.now())
- kb.set_entities(entity_list=entity_list, prob_list=frequency_list, vector_list=embeddings)
+ kb.set_entities(
+ entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
+ )
print()
print(" * adding aliases", datetime.datetime.now())
print()
- _add_aliases(kb, title_to_id=filtered_title_to_id,
- max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
- prior_prob_input=prior_prob_input)
+ _add_aliases(
+ kb,
+ title_to_id=filtered_title_to_id,
+ max_entities_per_alias=max_entities_per_alias,
+ min_occ=min_occ,
+ prior_prob_input=prior_prob_input,
+ )
print()
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
@@ -86,13 +102,15 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
return kb
-def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr):
- with open(entity_def_output, mode='w', encoding='utf8') as id_file:
+def _write_entity_files(
+ entity_def_output, entity_descr_output, title_to_id, id_to_descr
+):
+ with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
- with open(entity_descr_output, mode='w', encoding='utf8') as descr_file:
+ with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")
@@ -100,8 +118,8 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_
def get_entity_to_id(entity_def_output):
entity_to_id = dict()
- with open(entity_def_output, 'r', encoding='utf8') as csvfile:
- csvreader = csv.reader(csvfile, delimiter='|')
+ with entity_def_output.open("r", encoding="utf8") as csvfile:
+ csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
@@ -111,8 +129,8 @@ def get_entity_to_id(entity_def_output):
def get_id_to_description(entity_descr_output):
id_to_desc = dict()
- with open(entity_descr_output, 'r', encoding='utf8') as csvfile:
- csvreader = csv.reader(csvfile, delimiter='|')
+ with entity_descr_output.open("r", encoding="utf8") as csvfile:
+ csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
@@ -125,7 +143,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
# adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count
- with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
+ with prior_prob_input.open("r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
@@ -134,7 +152,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
counts = []
entities = []
while line:
- splits = line.replace('\n', "").split(sep='|')
+ splits = line.replace("\n", "").split(sep="|")
new_alias = splits[0]
count = int(splits[1])
entity = splits[2]
@@ -153,7 +171,11 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
if selected_entities:
try:
- kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs)
+ kb.add_alias(
+ alias=previous_alias,
+ entities=selected_entities,
+ probabilities=prior_probs,
+ )
except ValueError as e:
print(e)
total_count = 0
@@ -168,4 +190,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
previous_alias = new_alias
line = prior_file.readline()
-
diff --git a/bin/wiki_entity_linking/training_set_creator.py b/bin/wiki_entity_linking/training_set_creator.py
index 5d401bb3f..b090d7659 100644
--- a/bin/wiki_entity_linking/training_set_creator.py
+++ b/bin/wiki_entity_linking/training_set_creator.py
@@ -1,7 +1,7 @@
# coding: utf-8
from __future__ import unicode_literals
-import os
+import random
import re
import bz2
import datetime
@@ -17,6 +17,10 @@ Gold-standard entities are stored in one file in standoff format (by character o
ENTITY_FILE = "gold_entities.csv"
+def now():
+ return datetime.datetime.now()
+
+
def create_training(wikipedia_input, entity_def_input, training_output):
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
@@ -27,21 +31,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
Read the XML wikipedia data to parse out training data:
raw text data + positive instances
"""
- title_regex = re.compile(r'(?<=
).*(?=)')
- id_regex = re.compile(r'(?<=)\d*(?=)')
+ title_regex = re.compile(r"(?<=).*(?=)")
+ id_regex = re.compile(r"(?<=)\d*(?=)")
read_ids = set()
entityfile_loc = training_output / ENTITY_FILE
- with open(entityfile_loc, mode="w", encoding='utf8') as entityfile:
+ with entityfile_loc.open("w", encoding="utf8") as entityfile:
# write entity training header file
- _write_training_entity(outputfile=entityfile,
- article_id="article_id",
- alias="alias",
- entity="WD_id",
- start="start",
- end="end")
+ _write_training_entity(
+ outputfile=entityfile,
+ article_id="article_id",
+ alias="alias",
+ entity="WD_id",
+ start="start",
+ end="end",
+ )
- with bz2.open(wikipedia_input, mode='rb') as file:
+ with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline()
cnt = 0
article_text = ""
@@ -51,7 +57,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
reading_revision = False
while line and (not limit or cnt < limit):
if cnt % 1000000 == 0:
- print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
+ print(now(), "processed", cnt, "lines of Wikipedia dump")
clean_line = line.strip().decode("utf-8")
if clean_line == "":
@@ -69,12 +75,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
elif clean_line == "":
if article_id:
try:
- _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(),
- training_output)
+ _process_wp_text(
+ wp_to_id,
+ entityfile,
+ article_id,
+ article_title,
+ article_text.strip(),
+ training_output,
+ )
except Exception as e:
- print("Error processing article", article_id, article_title, e)
+ print(
+ "Error processing article", article_id, article_title, e
+ )
else:
- print("Done processing a page, but couldn't find an article_id ?", article_title)
+ print(
+ "Done processing a page, but couldn't find an article_id ?",
+ article_title,
+ )
article_text = ""
article_title = None
article_id = None
@@ -98,7 +115,9 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
if ids:
article_id = ids[0]
if article_id in read_ids:
- print("Found duplicate article ID", article_id, clean_line) # This should never happen ...
+ print(
+ "Found duplicate article ID", article_id, clean_line
+ ) # This should never happen ...
read_ids.add(article_id)
# read the title of this article (outside the revision portion of the document)
@@ -111,10 +130,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
cnt += 1
-text_regex = re.compile(r'(?<=).*(?=).*(?= 2:
reading_special_case = True
@@ -175,7 +196,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
# we just finished reading an entity
if open_read == 0 and not reading_text:
- if '#' in entity_buffer or entity_buffer.startswith(':'):
+ if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
@@ -185,12 +206,14 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
- _write_training_entity(outputfile=entityfile,
- article_id=article_id,
- alias=mention_buffer,
- entity=qid,
- start=start,
- end=end)
+ _write_training_entity(
+ outputfile=entityfile,
+ article_id=article_id,
+ alias=mention_buffer,
+ entity=qid,
+ start=start,
+ end=end,
+ )
found_entities = True
final_text += mention_buffer
@@ -203,29 +226,35 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
reading_special_case = False
if found_entities:
- _write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output)
+ _write_training_article(
+ article_id=article_id,
+ clean_text=final_text,
+ training_output=training_output,
+ )
-info_regex = re.compile(r'{[^{]*?}')
-htlm_regex = re.compile(r'<!--[^-]*-->')
-category_regex = re.compile(r'\[\[Category:[^\[]*]]')
-file_regex = re.compile(r'\[\[File:[^[\]]+]]')
-ref_regex = re.compile(r'<ref.*?>') # non-greedy
-ref_2_regex = re.compile(r'</ref.*?>') # non-greedy
+info_regex = re.compile(r"{[^{]*?}")
+htlm_regex = re.compile(r"<!--[^-]*-->")
+category_regex = re.compile(r"\[\[Category:[^\[]*]]")
+file_regex = re.compile(r"\[\[File:[^[\]]+]]")
+ref_regex = re.compile(r"<ref.*?>") # non-greedy
+ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
def _get_clean_wp_text(article_text):
clean_text = article_text.strip()
# remove bolding & italic markup
- clean_text = clean_text.replace('\'\'\'', '')
- clean_text = clean_text.replace('\'\'', '')
+ clean_text = clean_text.replace("'''", "")
+ clean_text = clean_text.replace("''", "")
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
try_again = True
previous_length = len(clean_text)
while try_again:
- clean_text = info_regex.sub('', clean_text) # non-greedy match excluding a nested {
+ clean_text = info_regex.sub(
+ "", clean_text
+ ) # non-greedy match excluding a nested {
if len(clean_text) < previous_length:
try_again = True
else:
@@ -233,14 +262,14 @@ def _get_clean_wp_text(article_text):
previous_length = len(clean_text)
# remove HTML comments
- clean_text = htlm_regex.sub('', clean_text)
+ clean_text = htlm_regex.sub("", clean_text)
# remove Category and File statements
- clean_text = category_regex.sub('', clean_text)
- clean_text = file_regex.sub('', clean_text)
+ clean_text = category_regex.sub("", clean_text)
+ clean_text = file_regex.sub("", clean_text)
# remove multiple =
- while '==' in clean_text:
+ while "==" in clean_text:
clean_text = clean_text.replace("==", "=")
clean_text = clean_text.replace(". =", ".")
@@ -249,43 +278,47 @@ def _get_clean_wp_text(article_text):
clean_text = clean_text.replace(" =", "")
# remove refs (non-greedy match)
- clean_text = ref_regex.sub('', clean_text)
- clean_text = ref_2_regex.sub('', clean_text)
+ clean_text = ref_regex.sub("", clean_text)
+ clean_text = ref_2_regex.sub("", clean_text)
# remove additional wikiformatting
- clean_text = re.sub(r'<blockquote>', '', clean_text)
- clean_text = re.sub(r'</blockquote>', '', clean_text)
+ clean_text = re.sub(r"<blockquote>", "", clean_text)
+ clean_text = re.sub(r"</blockquote>", "", clean_text)
# change special characters back to normal ones
- clean_text = clean_text.replace(r'<', '<')
- clean_text = clean_text.replace(r'>', '>')
- clean_text = clean_text.replace(r'"', '"')
- clean_text = clean_text.replace(r' ', ' ')
- clean_text = clean_text.replace(r'&', '&')
+ clean_text = clean_text.replace(r"<", "<")
+ clean_text = clean_text.replace(r">", ">")
+ clean_text = clean_text.replace(r""", '"')
+ clean_text = clean_text.replace(r" ", " ")
+ clean_text = clean_text.replace(r"&", "&")
# remove multiple spaces
- while ' ' in clean_text:
- clean_text = clean_text.replace(' ', ' ')
+ while " " in clean_text:
+ clean_text = clean_text.replace(" ", " ")
return clean_text.strip()
def _write_training_article(article_id, clean_text, training_output):
- file_loc = training_output / str(article_id) + ".txt"
- with open(file_loc, mode='w', encoding='utf8') as outputfile:
+ file_loc = training_output / "{}.txt".format(article_id)
+ with file_loc.open("w", encoding="utf8") as outputfile:
outputfile.write(clean_text)
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
- outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n")
+ line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end)
+ outputfile.write(line)
def is_dev(article_id):
return article_id.endswith("3")
-def read_training(nlp, training_dir, dev, limit):
- # This method provides training examples that correspond to the entity annotations found by the nlp object
+def read_training(nlp, training_dir, dev, limit, kb=None):
+ """ This method provides training examples that correspond to the entity annotations found by the nlp object.
+ When kb is provided (for training), it will include negative training examples by using the candidate generator,
+ and it will only keep positive training examples that can be found in the KB.
+ When kb=None (for testing), it will include all positive examples only."""
entityfile_loc = training_dir / ENTITY_FILE
data = []
@@ -296,24 +329,30 @@ def read_training(nlp, training_dir, dev, limit):
skip_articles = set()
total_entities = 0
- with open(entityfile_loc, mode='r', encoding='utf8') as file:
+ with entityfile_loc.open("r", encoding="utf8") as file:
for line in file:
if not limit or len(data) < limit:
- fields = line.replace('\n', "").split(sep='|')
+ fields = line.replace("\n", "").split(sep="|")
article_id = fields[0]
alias = fields[1]
- wp_title = fields[2]
+ wd_id = fields[2]
start = fields[3]
end = fields[4]
- if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
+ if (
+ dev == is_dev(article_id)
+ and article_id != "article_id"
+ and article_id not in skip_articles
+ ):
if not current_doc or (current_article_id != article_id):
# parse the new article text
file_name = article_id + ".txt"
try:
- with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
+ training_file = training_dir / file_name
+ with training_file.open("r", encoding="utf8") as f:
text = f.read()
- if len(text) < 30000: # threshold for convenience / speed of processing
+ # threshold for convenience / speed of processing
+ if len(text) < 30000:
current_doc = nlp(text)
current_article_id = article_id
ents_by_offset = dict()
@@ -321,33 +360,69 @@ def read_training(nlp, training_dir, dev, limit):
sent_length = len(ent.sent)
# custom filtering to avoid too long or too short sentences
if 5 < sent_length < 100:
- ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
+ offset = "{}_{}".format(
+ ent.start_char, ent.end_char
+ )
+ ents_by_offset[offset] = ent
else:
skip_articles.add(article_id)
current_doc = None
except Exception as e:
print("Problem parsing article", article_id, e)
skip_articles.add(article_id)
- raise e
# repeat checking this condition in case an exception was thrown
if current_doc and (current_article_id == article_id):
- found_ent = ents_by_offset.get(start + "_" + end, None)
+ offset = "{}_{}".format(start, end)
+ found_ent = ents_by_offset.get(offset, None)
if found_ent:
if found_ent.text != alias:
skip_articles.add(article_id)
current_doc = None
else:
sent = found_ent.sent.as_doc()
- # currently feeding the gold data one entity per sentence at a time
+
gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char
- gold_entities = [(gold_start, gold_end, wp_title)]
- gold = GoldParse(doc=sent, links=gold_entities)
- data.append((sent, gold))
- total_entities += 1
- if len(data) % 2500 == 0:
- print(" -read", total_entities, "entities")
+
+ gold_entities = {}
+ found_useful = False
+ for ent in sent.ents:
+ entry = (ent.start_char, ent.end_char)
+ gold_entry = (gold_start, gold_end)
+ if entry == gold_entry:
+ # add both pos and neg examples (in random order)
+ # this will exclude examples not in the KB
+ if kb:
+ value_by_id = {}
+ candidates = kb.get_candidates(alias)
+ candidate_ids = [
+ c.entity_ for c in candidates
+ ]
+ random.shuffle(candidate_ids)
+ for kb_id in candidate_ids:
+ found_useful = True
+ if kb_id != wd_id:
+ value_by_id[kb_id] = 0.0
+ else:
+ value_by_id[kb_id] = 1.0
+ gold_entities[entry] = value_by_id
+ # if no KB, keep all positive examples
+ else:
+ found_useful = True
+ value_by_id = {wd_id: 1.0}
+
+ gold_entities[entry] = value_by_id
+ # currently feeding the gold data one entity per sentence at a time
+ # setting all other entities to empty gold dictionary
+ else:
+ gold_entities[entry] = {}
+ if found_useful:
+ gold = GoldParse(doc=sent, links=gold_entities)
+ data.append((sent, gold))
+ total_entities += 1
+ if len(data) % 2500 == 0:
+ print(" -read", total_entities, "entities")
print(" -read", total_entities, "entities")
return data
diff --git a/bin/wiki_entity_linking/wikipedia_processor.py b/bin/wiki_entity_linking/wikipedia_processor.py
index c02e472bc..80d75b013 100644
--- a/bin/wiki_entity_linking/wikipedia_processor.py
+++ b/bin/wiki_entity_linking/wikipedia_processor.py
@@ -14,22 +14,97 @@ Write these results to file for downstream KB and training data generation.
map_alias_to_link = dict()
# these will/should be matched ignoring case
-wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons",
- "d", "dbdump", "download", "Draft", "Education", "Foundation",
- "Gadget", "Gadget definition", "gerrit", "File", "Help", "Image", "Incubator",
- "m", "mail", "mailarchive", "media", "MediaWiki", "MediaWiki talk", "Mediawikiwiki",
- "MediaZilla", "Meta", "Metawikipedia", "Module",
- "mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki",
- "Portal", "phab", "Phabricator", "Project", "q", "quality", "rev",
- "s", "spcom", "Special", "species", "Strategy", "sulutil", "svn",
- "Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools",
- "tswiki", "User", "User talk", "v", "voy",
- "w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews",
- "Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech",
- "Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"]
+wiki_namespaces = [
+ "b",
+ "betawikiversity",
+ "Book",
+ "c",
+ "Category",
+ "Commons",
+ "d",
+ "dbdump",
+ "download",
+ "Draft",
+ "Education",
+ "Foundation",
+ "Gadget",
+ "Gadget definition",
+ "gerrit",
+ "File",
+ "Help",
+ "Image",
+ "Incubator",
+ "m",
+ "mail",
+ "mailarchive",
+ "media",
+ "MediaWiki",
+ "MediaWiki talk",
+ "Mediawikiwiki",
+ "MediaZilla",
+ "Meta",
+ "Metawikipedia",
+ "Module",
+ "mw",
+ "n",
+ "nost",
+ "oldwikisource",
+ "outreach",
+ "outreachwiki",
+ "otrs",
+ "OTRSwiki",
+ "Portal",
+ "phab",
+ "Phabricator",
+ "Project",
+ "q",
+ "quality",
+ "rev",
+ "s",
+ "spcom",
+ "Special",
+ "species",
+ "Strategy",
+ "sulutil",
+ "svn",
+ "Talk",
+ "Template",
+ "Template talk",
+ "Testwiki",
+ "ticket",
+ "TimedText",
+ "Toollabs",
+ "tools",
+ "tswiki",
+ "User",
+ "User talk",
+ "v",
+ "voy",
+ "w",
+ "Wikibooks",
+ "Wikidata",
+ "wikiHow",
+ "Wikinvest",
+ "wikilivres",
+ "Wikimedia",
+ "Wikinews",
+ "Wikipedia",
+ "Wikipedia talk",
+ "Wikiquote",
+ "Wikisource",
+ "Wikispecies",
+ "Wikitech",
+ "Wikiversity",
+ "Wikivoyage",
+ "wikt",
+ "wiktionary",
+ "wmf",
+ "wmania",
+ "WP",
+]
# find the links
-link_regex = re.compile(r'\[\[[^\[\]]*\]\]')
+link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
# match on interwiki links, e.g. `en:` or `:fr:`
ns_regex = r":?" + "[a-z][a-z]" + ":"
@@ -41,18 +116,22 @@ for ns in wiki_namespaces:
ns_regex = re.compile(ns_regex, re.IGNORECASE)
-def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
+def now():
+ return datetime.datetime.now()
+
+
+def read_prior_probs(wikipedia_input, prior_prob_output):
"""
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
The full file takes about 2h to parse 1100M lines.
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
"""
- with bz2.open(wikipedia_input, mode='rb') as file:
+ with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline()
cnt = 0
while line:
if cnt % 5000000 == 0:
- print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
+ print(now(), "processed", cnt, "lines of Wikipedia dump")
clean_line = line.strip().decode("utf-8")
aliases, entities, normalizations = get_wp_links(clean_line)
@@ -64,10 +143,11 @@ def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
cnt += 1
# write all aliases and their entities and count occurrences to file
- with open(prior_prob_output, mode='w', encoding='utf8') as outputfile:
+ with prior_prob_output.open("w", encoding="utf8") as outputfile:
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
- for entity, count in sorted(alias_dict.items(), key=lambda x: x[1], reverse=True):
+ s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
+ for entity, count in s_dict:
outputfile.write(alias + "|" + str(count) + "|" + entity + "\n")
@@ -140,13 +220,13 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
entity_to_count = dict()
total_count = 0
- with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
+ with prior_prob_input.open("r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
while line:
- splits = line.replace('\n', "").split(sep='|')
+ splits = line.replace("\n", "").split(sep="|")
# alias = splits[0]
count = int(splits[1])
entity = splits[2]
@@ -158,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
line = prior_file.readline()
- with open(count_output, mode='w', encoding='utf8') as entity_file:
+ with count_output.open("w", encoding="utf8") as entity_file:
entity_file.write("entity" + "|" + "count" + "\n")
for entity, count in entity_to_count.items():
entity_file.write(entity + "|" + str(count) + "\n")
@@ -171,12 +251,11 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
def get_all_frequencies(count_input):
entity_to_count = dict()
- with open(count_input, 'r', encoding='utf8') as csvfile:
- csvreader = csv.reader(csvfile, delimiter='|')
+ with count_input.open("r", encoding="utf8") as csvfile:
+ csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
entity_to_count[row[0]] = int(row[1])
return entity_to_count
-
diff --git a/examples/pipeline/dummy_entity_linking.py b/examples/pipeline/dummy_entity_linking.py
index 0e59db304..6dde616b8 100644
--- a/examples/pipeline/dummy_entity_linking.py
+++ b/examples/pipeline/dummy_entity_linking.py
@@ -14,15 +14,15 @@ def create_kb(vocab):
# adding entities
entity_0 = "Q1004791_Douglas"
print("adding entity", entity_0)
- kb.add_entity(entity=entity_0, prob=0.5, entity_vector=[0])
+ kb.add_entity(entity=entity_0, freq=0.5, entity_vector=[0])
entity_1 = "Q42_Douglas_Adams"
print("adding entity", entity_1)
- kb.add_entity(entity=entity_1, prob=0.5, entity_vector=[1])
+ kb.add_entity(entity=entity_1, freq=0.5, entity_vector=[1])
entity_2 = "Q5301561_Douglas_Haig"
print("adding entity", entity_2)
- kb.add_entity(entity=entity_2, prob=0.5, entity_vector=[2])
+ kb.add_entity(entity=entity_2, freq=0.5, entity_vector=[2])
# adding aliases
print()
diff --git a/examples/pipeline/wikidata_entity_linking.py b/examples/pipeline/wikidata_entity_linking.py
index 17c2976dd..04e5bce6d 100644
--- a/examples/pipeline/wikidata_entity_linking.py
+++ b/examples/pipeline/wikidata_entity_linking.py
@@ -1,11 +1,14 @@
# coding: utf-8
from __future__ import unicode_literals
+import os
+from os import path
import random
import datetime
from pathlib import Path
-from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp
+from bin.wiki_entity_linking import wikipedia_processor as wp
+from bin.wiki_entity_linking import training_set_creator, kb_creator
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
import spacy
@@ -17,23 +20,26 @@ Demonstrate how to build a knowledge base from WikiData and run an Entity Linkin
"""
ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
-OUTPUT_DIR = ROOT_DIR / 'wikipedia'
-TRAINING_DIR = OUTPUT_DIR / 'training_data_nel'
+OUTPUT_DIR = ROOT_DIR / "wikipedia"
+TRAINING_DIR = OUTPUT_DIR / "training_data_nel"
-PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv'
-ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv'
-ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv'
-ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv'
+PRIOR_PROB = OUTPUT_DIR / "prior_prob.csv"
+ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv"
+ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv"
+ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv"
-KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb'
-NLP_1_DIR = OUTPUT_DIR / 'nlp_1'
-NLP_2_DIR = OUTPUT_DIR / 'nlp_2'
+KB_DIR = OUTPUT_DIR / "kb_1"
+KB_FILE = "kb"
+NLP_1_DIR = OUTPUT_DIR / "nlp_1"
+NLP_2_DIR = OUTPUT_DIR / "nlp_2"
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
-WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2'
+WIKIDATA_JSON = ROOT_DIR / "wikidata" / "wikidata-20190304-all.json.bz2"
# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
-ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2'
+ENWIKI_DUMP = (
+ ROOT_DIR / "wikipedia" / "enwiki-20190320-pages-articles-multistream.xml.bz2"
+)
# KB construction parameters
MAX_CANDIDATES = 10
@@ -48,11 +54,15 @@ L2 = 1e-6
CONTEXT_WIDTH = 128
+def now():
+ return datetime.datetime.now()
+
+
def run_pipeline():
# set the appropriate booleans to define which parts of the pipeline should be re(run)
- print("START", datetime.datetime.now())
+ print("START", now())
print()
- nlp_1 = spacy.load('en_core_web_lg')
+ nlp_1 = spacy.load("en_core_web_lg")
nlp_2 = None
kb_2 = None
@@ -82,43 +92,48 @@ def run_pipeline():
# STEP 1 : create prior probabilities from WP (run only once)
if to_create_prior_probs:
- print("STEP 1: to_create_prior_probs", datetime.datetime.now())
- wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB)
+ print("STEP 1: to_create_prior_probs", now())
+ wp.read_prior_probs(ENWIKI_DUMP, PRIOR_PROB)
print()
# STEP 2 : deduce entity frequencies from WP (run only once)
if to_create_entity_counts:
- print("STEP 2: to_create_entity_counts", datetime.datetime.now())
- wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False)
+ print("STEP 2: to_create_entity_counts", now())
+ wp.write_entity_counts(PRIOR_PROB, ENTITY_COUNTS, to_print=False)
print()
# STEP 3 : create KB and write to file (run only once)
if to_create_kb:
- print("STEP 3a: to_create_kb", datetime.datetime.now())
- kb_1 = kb_creator.create_kb(nlp_1,
- max_entities_per_alias=MAX_CANDIDATES,
- min_entity_freq=MIN_ENTITY_FREQ,
- min_occ=MIN_PAIR_OCC,
- entity_def_output=ENTITY_DEFS,
- entity_descr_output=ENTITY_DESCR,
- count_input=ENTITY_COUNTS,
- prior_prob_input=PRIOR_PROB,
- wikidata_input=WIKIDATA_JSON)
+ print("STEP 3a: to_create_kb", now())
+ kb_1 = kb_creator.create_kb(
+ nlp=nlp_1,
+ max_entities_per_alias=MAX_CANDIDATES,
+ min_entity_freq=MIN_ENTITY_FREQ,
+ min_occ=MIN_PAIR_OCC,
+ entity_def_output=ENTITY_DEFS,
+ entity_descr_output=ENTITY_DESCR,
+ count_input=ENTITY_COUNTS,
+ prior_prob_input=PRIOR_PROB,
+ wikidata_input=WIKIDATA_JSON,
+ )
print("kb entities:", kb_1.get_size_entities())
print("kb aliases:", kb_1.get_size_aliases())
print()
- print("STEP 3b: write KB and NLP", datetime.datetime.now())
- kb_1.dump(KB_FILE)
+ print("STEP 3b: write KB and NLP", now())
+
+ if not path.exists(KB_DIR):
+ os.makedirs(KB_DIR)
+ kb_1.dump(KB_DIR / KB_FILE)
nlp_1.to_disk(NLP_1_DIR)
print()
# STEP 4 : read KB back in from file
if to_read_kb:
- print("STEP 4: to_read_kb", datetime.datetime.now())
+ print("STEP 4: to_read_kb", now())
nlp_2 = spacy.load(NLP_1_DIR)
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
- kb_2.load_bulk(KB_FILE)
+ kb_2.load_bulk(KB_DIR / KB_FILE)
print("kb entities:", kb_2.get_size_entities())
print("kb aliases:", kb_2.get_size_aliases())
print()
@@ -130,20 +145,26 @@ def run_pipeline():
# STEP 5: create a training dataset from WP
if create_wp_training:
- print("STEP 5: create training dataset", datetime.datetime.now())
- training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP,
- entity_def_input=ENTITY_DEFS,
- training_output=TRAINING_DIR)
+ print("STEP 5: create training dataset", now())
+ training_set_creator.create_training(
+ wikipedia_input=ENWIKI_DUMP,
+ entity_def_input=ENTITY_DEFS,
+ training_output=TRAINING_DIR,
+ )
# STEP 6: create and train the entity linking pipe
if train_pipe:
- print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
+ print("STEP 6: training Entity Linking pipe", now())
type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)}
print(" -analysing", len(type_to_int), "different entity types")
- el_pipe = nlp_2.create_pipe(name='entity_linker',
- config={"context_width": CONTEXT_WIDTH,
- "pretrained_vectors": nlp_2.vocab.vectors.name,
- "type_to_int": type_to_int})
+ el_pipe = nlp_2.create_pipe(
+ name="entity_linker",
+ config={
+ "context_width": CONTEXT_WIDTH,
+ "pretrained_vectors": nlp_2.vocab.vectors.name,
+ "type_to_int": type_to_int,
+ },
+ )
el_pipe.set_kb(kb_2)
nlp_2.add_pipe(el_pipe, last=True)
@@ -157,18 +178,22 @@ def run_pipeline():
train_limit = 5000
dev_limit = 5000
- train_data = training_set_creator.read_training(nlp=nlp_2,
- training_dir=TRAINING_DIR,
- dev=False,
- limit=train_limit)
+ # for training, get pos & neg instances that correspond to entries in the kb
+ train_data = training_set_creator.read_training(
+ nlp=nlp_2,
+ training_dir=TRAINING_DIR,
+ dev=False,
+ limit=train_limit,
+ kb=el_pipe.kb,
+ )
print("Training on", len(train_data), "articles")
print()
- dev_data = training_set_creator.read_training(nlp=nlp_2,
- training_dir=TRAINING_DIR,
- dev=True,
- limit=dev_limit)
+ # for testing, get all pos instances, whether or not they are in the kb
+ dev_data = training_set_creator.read_training(
+ nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
+ )
print("Dev testing on", len(dev_data), "articles")
print()
@@ -187,8 +212,8 @@ def run_pipeline():
try:
docs, golds = zip(*batch)
nlp_2.update(
- docs,
- golds,
+ docs=docs,
+ golds=golds,
sgd=optimizer,
drop=DROPOUT,
losses=losses,
@@ -200,48 +225,61 @@ def run_pipeline():
if batchnr > 0:
el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 1
- dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
- losses['entity_linker'] = losses['entity_linker'] / batchnr
- print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
- " / dev acc avg", round(dev_acc_context, 3))
+ dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
+ losses["entity_linker"] = losses["entity_linker"] / batchnr
+ print(
+ "Epoch, train loss",
+ itn,
+ round(losses["entity_linker"], 2),
+ " / dev acc avg",
+ round(dev_acc_context, 3),
+ )
# STEP 7: measure the performance of our trained pipe on an independent dev set
if len(dev_data) and measure_performance:
print()
- print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
+ print("STEP 7: performance measurement of Entity Linking pipe", now())
print()
- counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2)
+ counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
+ dev_data, kb_2
+ )
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
- print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
- print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
- print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
+
+ oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
+ print("dev acc oracle:", round(acc_o, 3), oracle_by_label)
+
+ random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
+ print("dev acc random:", round(acc_r, 3), random_by_label)
+
+ prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
+ print("dev acc prior:", round(acc_p, 3), prior_by_label)
# using only context
el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 0
- dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
- print("dev acc context avg:", round(dev_acc_context, 3),
- [(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
+ dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
+ context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
+ print("dev acc context avg:", round(dev_acc_context, 3), context_by_label)
# measuring combined accuracy (prior + context)
el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 1
- dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False)
- print("dev acc combo avg:", round(dev_acc_combo, 3),
- [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
+ dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
+ combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
+ print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
# STEP 8: apply the EL pipe on a toy example
if to_test_pipeline:
print()
- print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
+ print("STEP 8: applying Entity Linking to toy example", now())
print()
run_el_toy_example(nlp=nlp_2)
# STEP 9: write the NLP pipeline (including entity linker) to file
if to_write_nlp:
print()
- print("STEP 9: testing NLP IO", datetime.datetime.now())
+ print("STEP 9: testing NLP IO", now())
print()
print("writing to", NLP_2_DIR)
nlp_2.to_disk(NLP_2_DIR)
@@ -262,23 +300,22 @@ def run_pipeline():
el_pipe = nlp_3.get_pipe("entity_linker")
dev_limit = 5000
- dev_data = training_set_creator.read_training(nlp=nlp_2,
- training_dir=TRAINING_DIR,
- dev=True,
- limit=dev_limit)
+ dev_data = training_set_creator.read_training(
+ nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
+ )
print("Dev testing from file on", len(dev_data), "articles")
print()
- dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe, error_analysis=False)
- print("dev acc combo avg:", round(dev_acc_combo, 3),
- [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
+ dev_acc_combo, dev_acc_combo_dict = _measure_acc(dev_data, el_pipe)
+ combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]
+ print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
print()
- print("STOP", datetime.datetime.now())
+ print("STOP", now())
-def _measure_accuracy(data, el_pipe=None, error_analysis=False):
+def _measure_acc(data, el_pipe=None, error_analysis=False):
# If the docs in the data require further processing with an entity linker, set el_pipe
correct_by_label = dict()
incorrect_by_label = dict()
@@ -291,16 +328,21 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
- for entity in gold.links:
- start, end, gold_kb = entity
- correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
+ for entity, kb_dict in gold.links.items():
+ start, end = entity
+ # only evaluating on positive examples
+ for gold_kb, value in kb_dict.items():
+ if value:
+ offset = _offset(start, end)
+ correct_entries_per_article[offset] = gold_kb
for ent in doc.ents:
ent_label = ent.label_
pred_entity = ent.kb_id_
start = ent.start_char
end = ent.end_char
- gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
+ offset = _offset(start, end)
+ gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
if gold_entity == pred_entity:
@@ -311,28 +353,33 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
incorrect_by_label[ent_label] = incorrect + 1
if error_analysis:
print(ent.text, "in", doc)
- print("Predicted", pred_entity, "should have been", gold_entity)
+ print(
+ "Predicted",
+ pred_entity,
+ "should have been",
+ gold_entity,
+ )
print()
except Exception as e:
print("Error assessing accuracy", e)
- acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
+ acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
return acc, acc_by_label
def _measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
- counts_by_label = dict()
+ counts_d = dict()
- random_correct_by_label = dict()
- random_incorrect_by_label = dict()
+ random_correct_d = dict()
+ random_incorrect_d = dict()
- oracle_correct_by_label = dict()
- oracle_incorrect_by_label = dict()
+ oracle_correct_d = dict()
+ oracle_incorrect_d = dict()
- prior_correct_by_label = dict()
- prior_incorrect_by_label = dict()
+ prior_correct_d = dict()
+ prior_incorrect_d = dict()
docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0]
@@ -340,19 +387,24 @@ def _measure_baselines(data, kb):
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
- for entity in gold.links:
- start, end, gold_kb = entity
- correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
+ for entity, kb_dict in gold.links.items():
+ start, end = entity
+ for gold_kb, value in kb_dict.items():
+ # only evaluating on positive examples
+ if value:
+ offset = _offset(start, end)
+ correct_entries_per_article[offset] = gold_kb
for ent in doc.ents:
- ent_label = ent.label_
+ label = ent.label_
start = ent.start_char
end = ent.end_char
- gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
+ offset = _offset(start, end)
+ gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
- counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1
+ counts_d[label] = counts_d.get(label, 0) + 1
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
best_candidate = ""
@@ -370,28 +422,40 @@ def _measure_baselines(data, kb):
random_candidate = random.choice(candidates).entity_
if gold_entity == best_candidate:
- prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1
+ prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
else:
- prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1
+ prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
if gold_entity == random_candidate:
- random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1
+ random_correct_d[label] = random_correct_d.get(label, 0) + 1
else:
- random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1
+ random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
if gold_entity == oracle_candidate:
- oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1
+ oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
else:
- oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1
+ oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
except Exception as e:
print("Error assessing accuracy", e)
- acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label)
- acc_rand, acc_rand_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
- acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
+ acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
+ acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
+ acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
- return counts_by_label, acc_rand, acc_rand_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
+ return (
+ counts_d,
+ acc_rand,
+ acc_rand_d,
+ acc_prior,
+ acc_prior_d,
+ acc_oracle,
+ acc_oracle_d,
+ )
+
+
+def _offset(start, end):
+ return "{}_{}".format(start, end)
def calculate_acc(correct_by_label, incorrect_by_label):
@@ -422,15 +486,23 @@ def check_kb(kb):
print("generating candidates for " + mention + " :")
for c in candidates:
- print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
+ print(
+ " ",
+ c.prior_prob,
+ c.alias_,
+ "-->",
+ c.entity_ + " (freq=" + str(c.entity_freq) + ")",
+ )
print()
def run_el_toy_example(nlp):
- text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
- "Douglas reminds us to always bring our towel, even in China or Brazil. " \
- "The main character in Doug's novel is the man Arthur Dent, " \
- "but Douglas doesn't write about George Washington or Homer Simpson."
+ text = (
+ "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
+ "Douglas reminds us to always bring our towel, even in China or Brazil. "
+ "The main character in Doug's novel is the man Arthur Dent, "
+ "but Dougledydoug doesn't write about George Washington or Homer Simpson."
+ )
doc = nlp(text)
print(text)
for ent in doc.ents:
diff --git a/spacy/_ml.py b/spacy/_ml.py
index 4d9bb4c2b..dedd1bee5 100644
--- a/spacy/_ml.py
+++ b/spacy/_ml.py
@@ -663,15 +663,14 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False,
def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg):
- # TODO proper error
if "entity_width" not in cfg:
- raise ValueError("entity_width not found")
+ raise ValueError(Errors.E144.format(param="entity_width"))
if "context_width" not in cfg:
- raise ValueError("context_width not found")
+ raise ValueError(Errors.E144.format(param="context_width"))
conv_depth = cfg.get("conv_depth", 2)
cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3)
- pretrained_vectors = cfg.get("pretrained_vectors") # self.nlp.vocab.vectors.name
+ pretrained_vectors = cfg.get("pretrained_vectors", None)
context_width = cfg.get("context_width")
entity_width = cfg.get("entity_width")
diff --git a/spacy/errors.py b/spacy/errors.py
index ed3d6afb9..4af8b756c 100644
--- a/spacy/errors.py
+++ b/spacy/errors.py
@@ -406,6 +406,13 @@ class Errors(object):
E141 = ("Entity vectors should be of length {required} instead of the provided {found}.")
E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'")
E143 = ("Labels for component '{name}' not initialized. Did you forget to call add_label()?")
+ E144 = ("Could not find parameter `{param}` when building the entity linker model.")
+ E145 = ("Error reading `{param}` from input file.")
+ E146 = ("Could not access `{path}`.")
+ E147 = ("Unexpected error in the {method} functionality of the EntityLinker: {msg}. "
+ "This is likely a bug in spaCy, so feel free to open an issue.")
+ E148 = ("Expected {ents} KB identifiers but got {ids}. Make sure that each entity in `doc.ents` "
+ "is assigned to a KB identifier.")
@add_codes
diff --git a/spacy/gold.pxd b/spacy/gold.pxd
index 8943a155a..a3123f7fa 100644
--- a/spacy/gold.pxd
+++ b/spacy/gold.pxd
@@ -31,7 +31,7 @@ cdef class GoldParse:
cdef public list ents
cdef public dict brackets
cdef public object cats
- cdef public list links
+ cdef public dict links
cdef readonly list cand_to_gold
cdef readonly list gold_to_cand
diff --git a/spacy/gold.pyx b/spacy/gold.pyx
index 8ef1fe123..11428a776 100644
--- a/spacy/gold.pyx
+++ b/spacy/gold.pyx
@@ -468,8 +468,11 @@ cdef class GoldParse:
examples of a label to have the value 0.0. Labels not in the
dictionary are treated as missing - the gradient for those labels
will be zero.
- links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples,
- representing the external ID of an entity in a knowledge base.
+ links (dict): A dict with `(start_char, end_char)` keys,
+ and the values being dicts with kb_id:value entries,
+ representing the external IDs in a knowledge base (KB)
+ mapped to either 1.0 or 0.0, indicating positive and
+ negative examples respectively.
RETURNS (GoldParse): The newly constructed object.
"""
if words is None:
diff --git a/spacy/kb.pxd b/spacy/kb.pxd
index 40b22b275..d5aa382b1 100644
--- a/spacy/kb.pxd
+++ b/spacy/kb.pxd
@@ -79,7 +79,7 @@ cdef class KnowledgeBase:
return new_index
- cdef inline int64_t c_add_entity(self, hash_t entity_hash, float prob,
+ cdef inline int64_t c_add_entity(self, hash_t entity_hash, float freq,
int32_t vector_index, int feats_row) nogil:
"""Add an entry to the vector of entries.
After calling this method, make sure to update also the _entry_index using the return value"""
@@ -92,7 +92,7 @@ cdef class KnowledgeBase:
entry.entity_hash = entity_hash
entry.vector_index = vector_index
entry.feats_row = feats_row
- entry.prob = prob
+ entry.freq = freq
self._entries.push_back(entry)
return new_index
@@ -125,7 +125,7 @@ cdef class KnowledgeBase:
entry.entity_hash = dummy_hash
entry.vector_index = dummy_value
entry.feats_row = dummy_value
- entry.prob = dummy_value
+ entry.freq = dummy_value
# Avoid struct initializer to enable nogil
cdef vector[int64_t] dummy_entry_indices
@@ -141,7 +141,7 @@ cdef class KnowledgeBase:
self._aliases_table.push_back(alias)
cpdef load_bulk(self, loc)
- cpdef set_entities(self, entity_list, prob_list, vector_list)
+ cpdef set_entities(self, entity_list, freq_list, vector_list)
cdef class Writer:
@@ -149,7 +149,7 @@ cdef class Writer:
cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1
cdef int write_vector_element(self, float element) except -1
- cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1
+ cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1
cdef int write_alias_length(self, int64_t alias_length) except -1
cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1
@@ -162,7 +162,7 @@ cdef class Reader:
cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1
cdef int read_vector_element(self, float* element) except -1
- cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1
+ cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1
cdef int read_alias_length(self, int64_t* alias_length) except -1
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1
diff --git a/spacy/kb.pyx b/spacy/kb.pyx
index 7c2daa659..28e762653 100644
--- a/spacy/kb.pyx
+++ b/spacy/kb.pyx
@@ -94,7 +94,7 @@ cdef class KnowledgeBase:
def get_alias_strings(self):
return [self.vocab.strings[x] for x in self._alias_index]
- def add_entity(self, unicode entity, float prob, vector[float] entity_vector):
+ def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
"""
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
Return the hash of the entity ID/name at the end.
@@ -113,15 +113,15 @@ cdef class KnowledgeBase:
vector_index = self.c_add_vector(entity_vector=entity_vector)
new_index = self.c_add_entity(entity_hash=entity_hash,
- prob=prob,
+ freq=freq,
vector_index=vector_index,
feats_row=-1) # Features table currently not implemented
self._entry_index[entity_hash] = new_index
return entity_hash
- cpdef set_entities(self, entity_list, prob_list, vector_list):
- if len(entity_list) != len(prob_list) or len(entity_list) != len(vector_list):
+ cpdef set_entities(self, entity_list, freq_list, vector_list):
+ if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
raise ValueError(Errors.E140)
nr_entities = len(entity_list)
@@ -137,7 +137,7 @@ cdef class KnowledgeBase:
entity_hash = self.vocab.strings.add(entity_list[i])
entry.entity_hash = entity_hash
- entry.prob = prob_list[i]
+ entry.freq = freq_list[i]
vector_index = self.c_add_vector(entity_vector=vector_list[i])
entry.vector_index = vector_index
@@ -196,13 +196,42 @@ cdef class KnowledgeBase:
return [Candidate(kb=self,
entity_hash=self._entries[entry_index].entity_hash,
- entity_freq=self._entries[entry_index].prob,
+ entity_freq=self._entries[entry_index].freq,
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
alias_hash=alias_hash,
- prior_prob=prob)
- for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
+ prior_prob=prior_prob)
+ for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
if entry_index != 0]
+ def get_vector(self, unicode entity):
+ cdef hash_t entity_hash = self.vocab.strings[entity]
+
+ # Return an empty list if this entity is unknown in this KB
+ if entity_hash not in self._entry_index:
+ return [0] * self.entity_vector_length
+ entry_index = self._entry_index[entity_hash]
+
+ return self._vectors_table[self._entries[entry_index].vector_index]
+
+ def get_prior_prob(self, unicode entity, unicode alias):
+ """ Return the prior probability of a given alias being linked to a given entity,
+ or return 0.0 when this combination is not known in the knowledge base"""
+ cdef hash_t alias_hash = self.vocab.strings[alias]
+ cdef hash_t entity_hash = self.vocab.strings[entity]
+
+ if entity_hash not in self._entry_index or alias_hash not in self._alias_index:
+ return 0.0
+
+ alias_index = self._alias_index.get(alias_hash)
+ entry_index = self._entry_index[entity_hash]
+
+ alias_entry = self._aliases_table[alias_index]
+ for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs):
+ if self._entries[entry_index].entity_hash == entity_hash:
+ return prior_prob
+
+ return 0.0
+
def dump(self, loc):
cdef Writer writer = Writer(loc)
@@ -222,7 +251,7 @@ cdef class KnowledgeBase:
entry = self._entries[entry_index]
assert entry.entity_hash == entry_hash
assert entry_index == i
- writer.write_entry(entry.entity_hash, entry.prob, entry.vector_index)
+ writer.write_entry(entry.entity_hash, entry.freq, entry.vector_index)
i = i+1
writer.write_alias_length(self.get_size_aliases())
@@ -248,7 +277,7 @@ cdef class KnowledgeBase:
cdef hash_t entity_hash
cdef hash_t alias_hash
cdef int64_t entry_index
- cdef float prob
+ cdef float freq, prob
cdef int32_t vector_index
cdef KBEntryC entry
cdef AliasC alias
@@ -284,10 +313,10 @@ cdef class KnowledgeBase:
# index 0 is a dummy object not stored in the _entry_index and can be ignored.
i = 1
while i <= nr_entities:
- reader.read_entry(&entity_hash, &prob, &vector_index)
+ reader.read_entry(&entity_hash, &freq, &vector_index)
entry.entity_hash = entity_hash
- entry.prob = prob
+ entry.freq = freq
entry.vector_index = vector_index
entry.feats_row = -1 # Features table currently not implemented
@@ -343,7 +372,8 @@ cdef class Writer:
loc = bytes(loc)
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(bytes_loc, 'wb')
- assert self._fp != NULL
+ if not self._fp:
+ raise IOError(Errors.E146.format(path=loc))
fseek(self._fp, 0, 0)
def close(self):
@@ -357,9 +387,9 @@ cdef class Writer:
cdef int write_vector_element(self, float element) except -1:
self._write(&element, sizeof(element))
- cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1:
+ cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1:
self._write(&entry_hash, sizeof(entry_hash))
- self._write(&entry_prob, sizeof(entry_prob))
+ self._write(&entry_freq, sizeof(entry_freq))
self._write(&vector_index, sizeof(vector_index))
# Features table currently not implemented and not written to file
@@ -399,39 +429,39 @@ cdef class Reader:
if status < 1:
if feof(self._fp):
return 0 # end of file
- raise IOError("error reading header from input file")
+ raise IOError(Errors.E145.format(param="header"))
status = self._read(entity_vector_length, sizeof(int64_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
- raise IOError("error reading header from input file")
+ raise IOError(Errors.E145.format(param="vector length"))
cdef int read_vector_element(self, float* element) except -1:
status = self._read(element, sizeof(float))
if status < 1:
if feof(self._fp):
return 0 # end of file
- raise IOError("error reading entity vector from input file")
+ raise IOError(Errors.E145.format(param="vector element"))
- cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1:
+ cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1:
status = self._read(entity_hash, sizeof(hash_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
- raise IOError("error reading entity hash from input file")
+ raise IOError(Errors.E145.format(param="entity hash"))
- status = self._read(prob, sizeof(float))
+ status = self._read(freq, sizeof(float))
if status < 1:
if feof(self._fp):
return 0 # end of file
- raise IOError("error reading entity prob from input file")
+ raise IOError(Errors.E145.format(param="entity freq"))
status = self._read(vector_index, sizeof(int32_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
- raise IOError("error reading entity vector from input file")
+ raise IOError(Errors.E145.format(param="vector index"))
if feof(self._fp):
return 0
@@ -443,33 +473,33 @@ cdef class Reader:
if status < 1:
if feof(self._fp):
return 0 # end of file
- raise IOError("error reading alias length from input file")
+ raise IOError(Errors.E145.format(param="alias length"))
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1:
status = self._read(alias_hash, sizeof(hash_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
- raise IOError("error reading alias hash from input file")
+ raise IOError(Errors.E145.format(param="alias hash"))
status = self._read(candidate_length, sizeof(int64_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
- raise IOError("error reading candidate length from input file")
+ raise IOError(Errors.E145.format(param="candidate length"))
cdef int read_alias(self, int64_t* entry_index, float* prob) except -1:
status = self._read(entry_index, sizeof(int64_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
- raise IOError("error reading entry index for alias from input file")
+ raise IOError(Errors.E145.format(param="entry index"))
status = self._read(prob, sizeof(float))
if status < 1:
if feof(self._fp):
return 0 # end of file
- raise IOError("error reading prob for entity/alias from input file")
+ raise IOError(Errors.E145.format(param="prior probability"))
cdef int _read(self, void* value, size_t size) except -1:
status = fread(value, size, 1, self._fp)
diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx
index 891e8d4e3..609c4e852 100644
--- a/spacy/pipeline/pipes.pyx
+++ b/spacy/pipeline/pipes.pyx
@@ -14,7 +14,6 @@ from thinc.neural.util import to_categorical
from thinc.neural.util import get_array_module
from spacy.kb import KnowledgeBase
-from ..cli.pretrain import get_cossim_loss
from .functions import merge_subtokens
from ..tokens.doc cimport Doc
from ..syntax.nn_parser cimport Parser
@@ -1077,6 +1076,7 @@ class EntityLinker(Pipe):
DOCS: TODO
"""
name = 'entity_linker'
+ NIL = "NIL" # string used to refer to a non-existing link
@classmethod
def Model(cls, **cfg):
@@ -1093,6 +1093,8 @@ class EntityLinker(Pipe):
self.kb = None
self.cfg = dict(cfg)
self.sgd_context = None
+ if not self.cfg.get("context_width"):
+ self.cfg["context_width"] = 128
def set_kb(self, kb):
self.kb = kb
@@ -1140,7 +1142,7 @@ class EntityLinker(Pipe):
context_docs = []
entity_encodings = []
- cats = []
+
priors = []
type_vectors = []
@@ -1149,50 +1151,44 @@ class EntityLinker(Pipe):
for doc, gold in zip(docs, golds):
ents_by_offset = dict()
for ent in doc.ents:
- ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
- for entity in gold.links:
- start, end, gold_kb = entity
+ ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent
+ for entity, kb_dict in gold.links.items():
+ start, end = entity
mention = doc.text[start:end]
+ for kb_id, value in kb_dict.items():
+ entity_encoding = self.kb.get_vector(kb_id)
+ prior_prob = self.kb.get_prior_prob(kb_id, mention)
- gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
- assert gold_ent is not None
- type_vector = [0 for i in range(len(type_to_int))]
- if len(type_to_int) > 0:
- type_vector[type_to_int[gold_ent.label_]] = 1
+ gold_ent = ents_by_offset["{}_{}".format(start, end)]
+ if gold_ent is None:
+ raise RuntimeError(Errors.E147.format(method="update", msg="gold entity not found"))
- candidates = self.kb.get_candidates(mention)
- random.shuffle(candidates)
- nr_neg = 0
- for c in candidates:
- kb_id = c.entity_
- entity_encoding = c.entity_vector
+ type_vector = [0 for i in range(len(type_to_int))]
+ if len(type_to_int) > 0:
+ type_vector[type_to_int[gold_ent.label_]] = 1
+
+ # store data
entity_encodings.append(entity_encoding)
context_docs.append(doc)
type_vectors.append(type_vector)
if self.cfg.get("prior_weight", 1) > 0:
- priors.append([c.prior_prob])
+ priors.append([prior_prob])
else:
priors.append([0])
- if kb_id == gold_kb:
- cats.append([1])
- else:
- nr_neg += 1
- cats.append([0])
-
if len(entity_encodings) > 0:
- assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
+ if not (len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)):
+ raise RuntimeError(Errors.E147.format(method="update", msg="vector lengths not equal"))
- context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
+ context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i]
for i in range(len(entity_encodings))]
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
- cats = self.model.ops.asarray(cats, dtype="float32")
- loss, d_scores = self.get_loss(prediction=pred, golds=cats, docs=None)
+ loss, d_scores = self.get_loss(scores=pred, golds=golds, docs=docs)
mention_gradient = bp_mention(d_scores, sgd=sgd)
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
@@ -1203,39 +1199,45 @@ class EntityLinker(Pipe):
return loss
return 0
- def get_loss(self, docs, golds, prediction):
- d_scores = (prediction - golds)
+ def get_loss(self, docs, golds, scores):
+ cats = []
+ for gold in golds:
+ for entity, kb_dict in gold.links.items():
+ for kb_id, value in kb_dict.items():
+ cats.append([value])
+
+ cats = self.model.ops.asarray(cats, dtype="float32")
+ if len(scores) != len(cats):
+ raise RuntimeError(Errors.E147.format(method="get_loss", msg="gold entities do not match up"))
+
+ d_scores = (scores - cats)
loss = (d_scores ** 2).sum()
- loss = loss / len(golds)
+ loss = loss / len(cats)
return loss, d_scores
- def get_loss_old(self, docs, golds, scores):
- # this loss function assumes we're only using positive examples
- loss, gradients = get_cossim_loss(yh=scores, y=golds)
- loss = loss / len(golds)
- return loss, gradients
-
def __call__(self, doc):
- entities, kb_ids = self.predict([doc])
- self.set_annotations([doc], entities, kb_ids)
+ kb_ids, tensors = self.predict([doc])
+ self.set_annotations([doc], kb_ids, tensors=tensors)
return doc
def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in util.minibatch(stream, size=batch_size):
docs = list(docs)
- entities, kb_ids = self.predict(docs)
- self.set_annotations(docs, entities, kb_ids)
+ kb_ids, tensors = self.predict(docs)
+ self.set_annotations(docs, kb_ids, tensors=tensors)
yield from docs
def predict(self, docs):
+ """ Return the KB IDs for each entity in each doc, including NIL if there is no prediction """
self.require_model()
self.require_kb()
- final_entities = []
+ entity_count = 0
final_kb_ids = []
+ final_tensors = []
if not docs:
- return final_entities, final_kb_ids
+ return final_kb_ids, final_tensors
if isinstance(docs, Doc):
docs = [docs]
@@ -1247,14 +1249,19 @@ class EntityLinker(Pipe):
for i, doc in enumerate(docs):
if len(doc) > 0:
+ # currently, the context is the same for each entity in a sentence (should be refined)
context_encoding = context_encodings[i]
for ent in doc.ents:
+ entity_count += 1
type_vector = [0 for i in range(len(type_to_int))]
if len(type_to_int) > 0:
type_vector[type_to_int[ent.label_]] = 1
candidates = self.kb.get_candidates(ent.text)
- if candidates:
+ if not candidates:
+ final_kb_ids.append(self.NIL) # no prediction possible for this entity
+ final_tensors.append(context_encoding)
+ else:
random.shuffle(candidates)
# this will set the prior probabilities to 0 (just like in training) if their weight is 0
@@ -1264,7 +1271,9 @@ class EntityLinker(Pipe):
if self.cfg.get("context_weight", 1) > 0:
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
- assert len(entity_encodings) == len(prior_probs)
+ if len(entity_encodings) != len(prior_probs):
+ raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
+
mention_encodings = [list(context_encoding) + list(entity_encodings[i])
+ list(prior_probs[i]) + type_vector
for i in range(len(entity_encodings))]
@@ -1273,15 +1282,26 @@ class EntityLinker(Pipe):
# TODO: thresholding
best_index = scores.argmax()
best_candidate = candidates[best_index]
- final_entities.append(ent)
final_kb_ids.append(best_candidate.entity_)
+ final_tensors.append(context_encoding)
- return final_entities, final_kb_ids
+ if not (len(final_tensors) == len(final_kb_ids) == entity_count):
+ raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
- def set_annotations(self, docs, entities, kb_ids=None):
- for entity, kb_id in zip(entities, kb_ids):
- for token in entity:
- token.ent_kb_id_ = kb_id
+ return final_kb_ids, final_tensors
+
+ def set_annotations(self, docs, kb_ids, tensors=None):
+ count_ents = len([ent for doc in docs for ent in doc.ents])
+ if count_ents != len(kb_ids):
+ raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
+
+ i=0
+ for doc in docs:
+ for ent in doc.ents:
+ kb_id = kb_ids[i]
+ i += 1
+ for token in ent:
+ token.ent_kb_id_ = kb_id
def to_disk(self, path, exclude=tuple(), **kwargs):
serialize = OrderedDict()
diff --git a/spacy/structs.pxd b/spacy/structs.pxd
index e80b1b4d6..6c643b4cd 100644
--- a/spacy/structs.pxd
+++ b/spacy/structs.pxd
@@ -93,7 +93,7 @@ cdef struct KBEntryC:
int32_t feats_row
# log probability of entity, based on corpus frequency
- float prob
+ float freq
# Each alias struct stores a list of Entry pointers with their prior probabilities
diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py
index cafc380ba..ca6bf2b6c 100644
--- a/spacy/tests/pipeline/test_entity_linker.py
+++ b/spacy/tests/pipeline/test_entity_linker.py
@@ -13,22 +13,38 @@ def nlp():
return English()
+def assert_almost_equal(a, b):
+ delta = 0.0001
+ assert a - delta <= b <= a + delta
+
+
def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases"""
- mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
+ mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
# adding entities
- mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
- mykb.add_entity(entity='Q2', prob=0.5, entity_vector=[2])
- mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
+ mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[8, 4, 3])
+ mykb.add_entity(entity="Q2", freq=0.5, entity_vector=[2, 1, 0])
+ mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[-1, -6, 5])
# adding aliases
- mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
- mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
+ mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
+ mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the corresponding KB
- assert(mykb.get_size_entities() == 3)
- assert(mykb.get_size_aliases() == 2)
+ assert mykb.get_size_entities() == 3
+ assert mykb.get_size_aliases() == 2
+
+ # test retrieval of the entity vectors
+ assert mykb.get_vector("Q1") == [8, 4, 3]
+ assert mykb.get_vector("Q2") == [2, 1, 0]
+ assert mykb.get_vector("Q3") == [-1, -6, 5]
+
+ # test retrieval of prior probabilities
+ assert_almost_equal(mykb.get_prior_prob(entity="Q2", alias="douglas"), 0.8)
+ assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglas"), 0.2)
+ assert_almost_equal(mykb.get_prior_prob(entity="Q342", alias="douglas"), 0.0)
+ assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglassssss"), 0.0)
def test_kb_invalid_entities(nlp):
@@ -36,13 +52,15 @@ def test_kb_invalid_entities(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
- mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
- mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
- mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
+ mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
+ mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
+ mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
# adding aliases - should fail because one of the given IDs is not valid
with pytest.raises(ValueError):
- mykb.add_alias(alias='douglas', entities=['Q2', 'Q342'], probabilities=[0.8, 0.2])
+ mykb.add_alias(
+ alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]
+ )
def test_kb_invalid_probabilities(nlp):
@@ -50,13 +68,13 @@ def test_kb_invalid_probabilities(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
- mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
- mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
- mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
+ mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
+ mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
+ mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
# adding aliases - should fail because the sum of the probabilities exceeds 1
with pytest.raises(ValueError):
- mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.4])
+ mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
def test_kb_invalid_combination(nlp):
@@ -64,13 +82,15 @@ def test_kb_invalid_combination(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
- mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
- mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
- mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
+ mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
+ mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
+ mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
with pytest.raises(ValueError):
- mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.3, 0.4, 0.1])
+ mykb.add_alias(
+ alias="douglas", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1]
+ )
def test_kb_invalid_entity_vector(nlp):
@@ -78,11 +98,11 @@ def test_kb_invalid_entity_vector(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
# adding entities
- mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1, 2, 3])
+ mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1, 2, 3])
# this should fail because the kb's expected entity vector length is 3
with pytest.raises(ValueError):
- mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
+ mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
def test_candidate_generation(nlp):
@@ -90,18 +110,24 @@ def test_candidate_generation(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
- mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
- mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
- mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
+ mykb.add_entity(entity="Q1", freq=0.7, entity_vector=[1])
+ mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
+ mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
# adding aliases
- mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
- mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
+ mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
+ mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
- assert(len(mykb.get_candidates('douglas')) == 2)
- assert(len(mykb.get_candidates('adam')) == 1)
- assert(len(mykb.get_candidates('shrubbery')) == 0)
+ assert len(mykb.get_candidates("douglas")) == 2
+ assert len(mykb.get_candidates("adam")) == 1
+ assert len(mykb.get_candidates("shrubbery")) == 0
+
+ # test the content of the candidates
+ assert mykb.get_candidates("adam")[0].entity_ == "Q2"
+ assert mykb.get_candidates("adam")[0].alias_ == "adam"
+ assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 0.2)
+ assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
def test_preserving_links_asdoc(nlp):
@@ -109,24 +135,26 @@ def test_preserving_links_asdoc(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
- mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
- mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1])
+ mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
+ mykb.add_entity(entity="Q2", freq=0.8, entity_vector=[1])
# adding aliases
- mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7])
- mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6])
+ mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
+ mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
sentencizer = nlp.create_pipe("sentencizer")
nlp.add_pipe(sentencizer)
ruler = EntityRuler(nlp)
- patterns = [{"label": "GPE", "pattern": "Boston"},
- {"label": "GPE", "pattern": "Denver"}]
+ patterns = [
+ {"label": "GPE", "pattern": "Boston"},
+ {"label": "GPE", "pattern": "Denver"},
+ ]
ruler.add_patterns(patterns)
nlp.add_pipe(ruler)
- el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64})
+ el_pipe = nlp.create_pipe(name="entity_linker", config={"context_width": 64})
el_pipe.set_kb(mykb)
el_pipe.begin_training()
el_pipe.context_weight = 0
diff --git a/spacy/tests/serialize/test_serialize_kb.py b/spacy/tests/serialize/test_serialize_kb.py
index fa7253fa1..1752abda2 100644
--- a/spacy/tests/serialize/test_serialize_kb.py
+++ b/spacy/tests/serialize/test_serialize_kb.py
@@ -30,10 +30,10 @@ def test_serialize_kb_disk(en_vocab):
def _get_dummy_kb(vocab):
kb = KnowledgeBase(vocab=vocab, entity_vector_length=3)
- kb.add_entity(entity='Q53', prob=0.33, entity_vector=[0, 5, 3])
- kb.add_entity(entity='Q17', prob=0.2, entity_vector=[7, 1, 0])
- kb.add_entity(entity='Q007', prob=0.7, entity_vector=[0, 0, 7])
- kb.add_entity(entity='Q44', prob=0.4, entity_vector=[4, 4, 4])
+ kb.add_entity(entity='Q53', freq=0.33, entity_vector=[0, 5, 3])
+ kb.add_entity(entity='Q17', freq=0.2, entity_vector=[7, 1, 0])
+ kb.add_entity(entity='Q007', freq=0.7, entity_vector=[0, 0, 7])
+ kb.add_entity(entity='Q44', freq=0.4, entity_vector=[4, 4, 4])
kb.add_alias(alias='double07', entities=['Q17', 'Q007'], probabilities=[0.1, 0.9])
kb.add_alias(alias='guy', entities=['Q53', 'Q007', 'Q17', 'Q44'], probabilities=[0.3, 0.3, 0.2, 0.1])
diff --git a/spacy/tokenizer.pyx b/spacy/tokenizer.pyx
index 70a693ba1..f19f851c7 100644
--- a/spacy/tokenizer.pyx
+++ b/spacy/tokenizer.pyx
@@ -348,7 +348,7 @@ cdef class Tokenizer:
"""Add a special-case tokenization rule.
string (unicode): The string to specially tokenize.
- token_attrs (iterable): A sequence of dicts, where each dict describes
+ substrings (iterable): A sequence of dicts, where each dict describes
a token and its attributes. The `ORTH` fields of the attributes
must exactly match the string when they are concatenated.