Merge pull request #4003 from svlandeg/feature/nel-fixes

API changes for Entity linking functionality
This commit is contained in:
Ines Montani 2019-07-23 23:17:07 +02:00 committed by GitHub
commit 87fcf3141c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 704 additions and 370 deletions

View File

@ -13,9 +13,17 @@ INPUT_DIM = 300 # dimension of pre-trained input vectors
DESC_WIDTH = 64 # dimension of output entity vectors
def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
entity_def_output, entity_descr_output,
count_input, prior_prob_input, wikidata_input):
def create_kb(
nlp,
max_entities_per_alias,
min_entity_freq,
min_occ,
entity_def_output,
entity_descr_output,
count_input,
prior_prob_input,
wikidata_input,
):
# Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
@ -28,7 +36,9 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
# write the title-ID and ID-description mappings to file
_write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr)
_write_entity_files(
entity_def_output, entity_descr_output, title_to_id, id_to_descr
)
else:
# read the mappings from file
@ -54,8 +64,8 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
frequency_list.append(freq)
filtered_title_to_id[title] = entity
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()),
"titles with filter frequency", min_entity_freq)
print(len(title_to_id.keys()), "original titles")
print("kept", len(filtered_title_to_id.keys()), " with frequency", min_entity_freq)
print()
print(" * train entity encoder", datetime.datetime.now())
@ -70,14 +80,20 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
print()
print(" * adding", len(entity_list), "entities", datetime.datetime.now())
kb.set_entities(entity_list=entity_list, prob_list=frequency_list, vector_list=embeddings)
kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
)
print()
print(" * adding aliases", datetime.datetime.now())
print()
_add_aliases(kb, title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
prior_prob_input=prior_prob_input)
_add_aliases(
kb,
title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ,
prior_prob_input=prior_prob_input,
)
print()
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
@ -86,13 +102,15 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
return kb
def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr):
with open(entity_def_output, mode='w', encoding='utf8') as id_file:
def _write_entity_files(
entity_def_output, entity_descr_output, title_to_id, id_to_descr
):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
with open(entity_descr_output, mode='w', encoding='utf8') as descr_file:
with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")
@ -100,8 +118,8 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_
def get_entity_to_id(entity_def_output):
entity_to_id = dict()
with open(entity_def_output, 'r', encoding='utf8') as csvfile:
csvreader = csv.reader(csvfile, delimiter='|')
with entity_def_output.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
@ -111,8 +129,8 @@ def get_entity_to_id(entity_def_output):
def get_id_to_description(entity_descr_output):
id_to_desc = dict()
with open(entity_descr_output, 'r', encoding='utf8') as csvfile:
csvreader = csv.reader(csvfile, delimiter='|')
with entity_descr_output.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
@ -125,7 +143,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
# adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
with prior_prob_input.open("r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
@ -134,7 +152,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
counts = []
entities = []
while line:
splits = line.replace('\n', "").split(sep='|')
splits = line.replace("\n", "").split(sep="|")
new_alias = splits[0]
count = int(splits[1])
entity = splits[2]
@ -153,7 +171,11 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
if selected_entities:
try:
kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs)
kb.add_alias(
alias=previous_alias,
entities=selected_entities,
probabilities=prior_probs,
)
except ValueError as e:
print(e)
total_count = 0
@ -168,4 +190,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
previous_alias = new_alias
line = prior_file.readline()

View File

@ -1,7 +1,7 @@
# coding: utf-8
from __future__ import unicode_literals
import os
import random
import re
import bz2
import datetime
@ -17,6 +17,10 @@ Gold-standard entities are stored in one file in standoff format (by character o
ENTITY_FILE = "gold_entities.csv"
def now():
return datetime.datetime.now()
def create_training(wikipedia_input, entity_def_input, training_output):
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
@ -27,21 +31,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
Read the XML wikipedia data to parse out training data:
raw text data + positive instances
"""
title_regex = re.compile(r'(?<=<title>).*(?=</title>)')
id_regex = re.compile(r'(?<=<id>)\d*(?=</id>)')
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
read_ids = set()
entityfile_loc = training_output / ENTITY_FILE
with open(entityfile_loc, mode="w", encoding='utf8') as entityfile:
with entityfile_loc.open("w", encoding="utf8") as entityfile:
# write entity training header file
_write_training_entity(outputfile=entityfile,
_write_training_entity(
outputfile=entityfile,
article_id="article_id",
alias="alias",
entity="WD_id",
start="start",
end="end")
end="end",
)
with bz2.open(wikipedia_input, mode='rb') as file:
with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline()
cnt = 0
article_text = ""
@ -51,7 +57,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
reading_revision = False
while line and (not limit or cnt < limit):
if cnt % 1000000 == 0:
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
print(now(), "processed", cnt, "lines of Wikipedia dump")
clean_line = line.strip().decode("utf-8")
if clean_line == "<revision>":
@ -69,12 +75,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
elif clean_line == "</page>":
if article_id:
try:
_process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(),
training_output)
_process_wp_text(
wp_to_id,
entityfile,
article_id,
article_title,
article_text.strip(),
training_output,
)
except Exception as e:
print("Error processing article", article_id, article_title, e)
print(
"Error processing article", article_id, article_title, e
)
else:
print("Done processing a page, but couldn't find an article_id ?", article_title)
print(
"Done processing a page, but couldn't find an article_id ?",
article_title,
)
article_text = ""
article_title = None
article_id = None
@ -98,7 +115,9 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
if ids:
article_id = ids[0]
if article_id in read_ids:
print("Found duplicate article ID", article_id, clean_line) # This should never happen ...
print(
"Found duplicate article ID", article_id, clean_line
) # This should never happen ...
read_ids.add(article_id)
# read the title of this article (outside the revision portion of the document)
@ -111,10 +130,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
cnt += 1
text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)')
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output):
def _process_wp_text(
wp_to_id, entityfile, article_id, article_title, article_text, training_output
):
found_entities = False
# ignore meta Wikipedia pages
@ -141,11 +162,11 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
entity_buffer = ""
mention_buffer = ""
for index, letter in enumerate(clean_text):
if letter == '[':
if letter == "[":
open_read += 1
elif letter == ']':
elif letter == "]":
open_read -= 1
elif letter == '|':
elif letter == "|":
if reading_text:
final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern
@ -175,7 +196,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
# we just finished reading an entity
if open_read == 0 and not reading_text:
if '#' in entity_buffer or entity_buffer.startswith(':'):
if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
@ -185,12 +206,14 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
_write_training_entity(outputfile=entityfile,
_write_training_entity(
outputfile=entityfile,
article_id=article_id,
alias=mention_buffer,
entity=qid,
start=start,
end=end)
end=end,
)
found_entities = True
final_text += mention_buffer
@ -203,29 +226,35 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
reading_special_case = False
if found_entities:
_write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output)
_write_training_article(
article_id=article_id,
clean_text=final_text,
training_output=training_output,
)
info_regex = re.compile(r'{[^{]*?}')
htlm_regex = re.compile(r'&lt;!--[^-]*--&gt;')
category_regex = re.compile(r'\[\[Category:[^\[]*]]')
file_regex = re.compile(r'\[\[File:[^[\]]+]]')
ref_regex = re.compile(r'&lt;ref.*?&gt;') # non-greedy
ref_2_regex = re.compile(r'&lt;/ref.*?&gt;') # non-greedy
info_regex = re.compile(r"{[^{]*?}")
htlm_regex = re.compile(r"&lt;!--[^-]*--&gt;")
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
file_regex = re.compile(r"\[\[File:[^[\]]+]]")
ref_regex = re.compile(r"&lt;ref.*?&gt;") # non-greedy
ref_2_regex = re.compile(r"&lt;/ref.*?&gt;") # non-greedy
def _get_clean_wp_text(article_text):
clean_text = article_text.strip()
# remove bolding & italic markup
clean_text = clean_text.replace('\'\'\'', '')
clean_text = clean_text.replace('\'\'', '')
clean_text = clean_text.replace("'''", "")
clean_text = clean_text.replace("''", "")
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
try_again = True
previous_length = len(clean_text)
while try_again:
clean_text = info_regex.sub('', clean_text) # non-greedy match excluding a nested {
clean_text = info_regex.sub(
"", clean_text
) # non-greedy match excluding a nested {
if len(clean_text) < previous_length:
try_again = True
else:
@ -233,14 +262,14 @@ def _get_clean_wp_text(article_text):
previous_length = len(clean_text)
# remove HTML comments
clean_text = htlm_regex.sub('', clean_text)
clean_text = htlm_regex.sub("", clean_text)
# remove Category and File statements
clean_text = category_regex.sub('', clean_text)
clean_text = file_regex.sub('', clean_text)
clean_text = category_regex.sub("", clean_text)
clean_text = file_regex.sub("", clean_text)
# remove multiple =
while '==' in clean_text:
while "==" in clean_text:
clean_text = clean_text.replace("==", "=")
clean_text = clean_text.replace(". =", ".")
@ -249,43 +278,47 @@ def _get_clean_wp_text(article_text):
clean_text = clean_text.replace(" =", "")
# remove refs (non-greedy match)
clean_text = ref_regex.sub('', clean_text)
clean_text = ref_2_regex.sub('', clean_text)
clean_text = ref_regex.sub("", clean_text)
clean_text = ref_2_regex.sub("", clean_text)
# remove additional wikiformatting
clean_text = re.sub(r'&lt;blockquote&gt;', '', clean_text)
clean_text = re.sub(r'&lt;/blockquote&gt;', '', clean_text)
clean_text = re.sub(r"&lt;blockquote&gt;", "", clean_text)
clean_text = re.sub(r"&lt;/blockquote&gt;", "", clean_text)
# change special characters back to normal ones
clean_text = clean_text.replace(r'&lt;', '<')
clean_text = clean_text.replace(r'&gt;', '>')
clean_text = clean_text.replace(r'&quot;', '"')
clean_text = clean_text.replace(r'&amp;nbsp;', ' ')
clean_text = clean_text.replace(r'&amp;', '&')
clean_text = clean_text.replace(r"&lt;", "<")
clean_text = clean_text.replace(r"&gt;", ">")
clean_text = clean_text.replace(r"&quot;", '"')
clean_text = clean_text.replace(r"&amp;nbsp;", " ")
clean_text = clean_text.replace(r"&amp;", "&")
# remove multiple spaces
while ' ' in clean_text:
clean_text = clean_text.replace(' ', ' ')
while " " in clean_text:
clean_text = clean_text.replace(" ", " ")
return clean_text.strip()
def _write_training_article(article_id, clean_text, training_output):
file_loc = training_output / str(article_id) + ".txt"
with open(file_loc, mode='w', encoding='utf8') as outputfile:
file_loc = training_output / "{}.txt".format(article_id)
with file_loc.open("w", encoding="utf8") as outputfile:
outputfile.write(clean_text)
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n")
line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end)
outputfile.write(line)
def is_dev(article_id):
return article_id.endswith("3")
def read_training(nlp, training_dir, dev, limit):
# This method provides training examples that correspond to the entity annotations found by the nlp object
def read_training(nlp, training_dir, dev, limit, kb=None):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
When kb is provided (for training), it will include negative training examples by using the candidate generator,
and it will only keep positive training examples that can be found in the KB.
When kb=None (for testing), it will include all positive examples only."""
entityfile_loc = training_dir / ENTITY_FILE
data = []
@ -296,24 +329,30 @@ def read_training(nlp, training_dir, dev, limit):
skip_articles = set()
total_entities = 0
with open(entityfile_loc, mode='r', encoding='utf8') as file:
with entityfile_loc.open("r", encoding="utf8") as file:
for line in file:
if not limit or len(data) < limit:
fields = line.replace('\n', "").split(sep='|')
fields = line.replace("\n", "").split(sep="|")
article_id = fields[0]
alias = fields[1]
wp_title = fields[2]
wd_id = fields[2]
start = fields[3]
end = fields[4]
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
if (
dev == is_dev(article_id)
and article_id != "article_id"
and article_id not in skip_articles
):
if not current_doc or (current_article_id != article_id):
# parse the new article text
file_name = article_id + ".txt"
try:
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
training_file = training_dir / file_name
with training_file.open("r", encoding="utf8") as f:
text = f.read()
if len(text) < 30000: # threshold for convenience / speed of processing
# threshold for convenience / speed of processing
if len(text) < 30000:
current_doc = nlp(text)
current_article_id = article_id
ents_by_offset = dict()
@ -321,28 +360,64 @@ def read_training(nlp, training_dir, dev, limit):
sent_length = len(ent.sent)
# custom filtering to avoid too long or too short sentences
if 5 < sent_length < 100:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
offset = "{}_{}".format(
ent.start_char, ent.end_char
)
ents_by_offset[offset] = ent
else:
skip_articles.add(article_id)
current_doc = None
except Exception as e:
print("Problem parsing article", article_id, e)
skip_articles.add(article_id)
raise e
# repeat checking this condition in case an exception was thrown
if current_doc and (current_article_id == article_id):
found_ent = ents_by_offset.get(start + "_" + end, None)
offset = "{}_{}".format(start, end)
found_ent = ents_by_offset.get(offset, None)
if found_ent:
if found_ent.text != alias:
skip_articles.add(article_id)
current_doc = None
else:
sent = found_ent.sent.as_doc()
# currently feeding the gold data one entity per sentence at a time
gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char
gold_entities = [(gold_start, gold_end, wp_title)]
gold_entities = {}
found_useful = False
for ent in sent.ents:
entry = (ent.start_char, ent.end_char)
gold_entry = (gold_start, gold_end)
if entry == gold_entry:
# add both pos and neg examples (in random order)
# this will exclude examples not in the KB
if kb:
value_by_id = {}
candidates = kb.get_candidates(alias)
candidate_ids = [
c.entity_ for c in candidates
]
random.shuffle(candidate_ids)
for kb_id in candidate_ids:
found_useful = True
if kb_id != wd_id:
value_by_id[kb_id] = 0.0
else:
value_by_id[kb_id] = 1.0
gold_entities[entry] = value_by_id
# if no KB, keep all positive examples
else:
found_useful = True
value_by_id = {wd_id: 1.0}
gold_entities[entry] = value_by_id
# currently feeding the gold data one entity per sentence at a time
# setting all other entities to empty gold dictionary
else:
gold_entities[entry] = {}
if found_useful:
gold = GoldParse(doc=sent, links=gold_entities)
data.append((sent, gold))
total_entities += 1

View File

@ -14,22 +14,97 @@ Write these results to file for downstream KB and training data generation.
map_alias_to_link = dict()
# these will/should be matched ignoring case
wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons",
"d", "dbdump", "download", "Draft", "Education", "Foundation",
"Gadget", "Gadget definition", "gerrit", "File", "Help", "Image", "Incubator",
"m", "mail", "mailarchive", "media", "MediaWiki", "MediaWiki talk", "Mediawikiwiki",
"MediaZilla", "Meta", "Metawikipedia", "Module",
"mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki",
"Portal", "phab", "Phabricator", "Project", "q", "quality", "rev",
"s", "spcom", "Special", "species", "Strategy", "sulutil", "svn",
"Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools",
"tswiki", "User", "User talk", "v", "voy",
"w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews",
"Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech",
"Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"]
wiki_namespaces = [
"b",
"betawikiversity",
"Book",
"c",
"Category",
"Commons",
"d",
"dbdump",
"download",
"Draft",
"Education",
"Foundation",
"Gadget",
"Gadget definition",
"gerrit",
"File",
"Help",
"Image",
"Incubator",
"m",
"mail",
"mailarchive",
"media",
"MediaWiki",
"MediaWiki talk",
"Mediawikiwiki",
"MediaZilla",
"Meta",
"Metawikipedia",
"Module",
"mw",
"n",
"nost",
"oldwikisource",
"outreach",
"outreachwiki",
"otrs",
"OTRSwiki",
"Portal",
"phab",
"Phabricator",
"Project",
"q",
"quality",
"rev",
"s",
"spcom",
"Special",
"species",
"Strategy",
"sulutil",
"svn",
"Talk",
"Template",
"Template talk",
"Testwiki",
"ticket",
"TimedText",
"Toollabs",
"tools",
"tswiki",
"User",
"User talk",
"v",
"voy",
"w",
"Wikibooks",
"Wikidata",
"wikiHow",
"Wikinvest",
"wikilivres",
"Wikimedia",
"Wikinews",
"Wikipedia",
"Wikipedia talk",
"Wikiquote",
"Wikisource",
"Wikispecies",
"Wikitech",
"Wikiversity",
"Wikivoyage",
"wikt",
"wiktionary",
"wmf",
"wmania",
"WP",
]
# find the links
link_regex = re.compile(r'\[\[[^\[\]]*\]\]')
link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
# match on interwiki links, e.g. `en:` or `:fr:`
ns_regex = r":?" + "[a-z][a-z]" + ":"
@ -41,18 +116,22 @@ for ns in wiki_namespaces:
ns_regex = re.compile(ns_regex, re.IGNORECASE)
def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
def now():
return datetime.datetime.now()
def read_prior_probs(wikipedia_input, prior_prob_output):
"""
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
The full file takes about 2h to parse 1100M lines.
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
"""
with bz2.open(wikipedia_input, mode='rb') as file:
with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline()
cnt = 0
while line:
if cnt % 5000000 == 0:
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
print(now(), "processed", cnt, "lines of Wikipedia dump")
clean_line = line.strip().decode("utf-8")
aliases, entities, normalizations = get_wp_links(clean_line)
@ -64,10 +143,11 @@ def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
cnt += 1
# write all aliases and their entities and count occurrences to file
with open(prior_prob_output, mode='w', encoding='utf8') as outputfile:
with prior_prob_output.open("w", encoding="utf8") as outputfile:
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
for entity, count in sorted(alias_dict.items(), key=lambda x: x[1], reverse=True):
s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
for entity, count in s_dict:
outputfile.write(alias + "|" + str(count) + "|" + entity + "\n")
@ -140,13 +220,13 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
entity_to_count = dict()
total_count = 0
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
with prior_prob_input.open("r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
while line:
splits = line.replace('\n', "").split(sep='|')
splits = line.replace("\n", "").split(sep="|")
# alias = splits[0]
count = int(splits[1])
entity = splits[2]
@ -158,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
line = prior_file.readline()
with open(count_output, mode='w', encoding='utf8') as entity_file:
with count_output.open("w", encoding="utf8") as entity_file:
entity_file.write("entity" + "|" + "count" + "\n")
for entity, count in entity_to_count.items():
entity_file.write(entity + "|" + str(count) + "\n")
@ -171,12 +251,11 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
def get_all_frequencies(count_input):
entity_to_count = dict()
with open(count_input, 'r', encoding='utf8') as csvfile:
csvreader = csv.reader(csvfile, delimiter='|')
with count_input.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
entity_to_count[row[0]] = int(row[1])
return entity_to_count

View File

@ -14,15 +14,15 @@ def create_kb(vocab):
# adding entities
entity_0 = "Q1004791_Douglas"
print("adding entity", entity_0)
kb.add_entity(entity=entity_0, prob=0.5, entity_vector=[0])
kb.add_entity(entity=entity_0, freq=0.5, entity_vector=[0])
entity_1 = "Q42_Douglas_Adams"
print("adding entity", entity_1)
kb.add_entity(entity=entity_1, prob=0.5, entity_vector=[1])
kb.add_entity(entity=entity_1, freq=0.5, entity_vector=[1])
entity_2 = "Q5301561_Douglas_Haig"
print("adding entity", entity_2)
kb.add_entity(entity=entity_2, prob=0.5, entity_vector=[2])
kb.add_entity(entity=entity_2, freq=0.5, entity_vector=[2])
# adding aliases
print()

View File

@ -1,11 +1,14 @@
# coding: utf-8
from __future__ import unicode_literals
import os
from os import path
import random
import datetime
from pathlib import Path
from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp
from bin.wiki_entity_linking import wikipedia_processor as wp
from bin.wiki_entity_linking import training_set_creator, kb_creator
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
import spacy
@ -17,23 +20,26 @@ Demonstrate how to build a knowledge base from WikiData and run an Entity Linkin
"""
ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
OUTPUT_DIR = ROOT_DIR / 'wikipedia'
TRAINING_DIR = OUTPUT_DIR / 'training_data_nel'
OUTPUT_DIR = ROOT_DIR / "wikipedia"
TRAINING_DIR = OUTPUT_DIR / "training_data_nel"
PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv'
ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv'
ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv'
ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv'
PRIOR_PROB = OUTPUT_DIR / "prior_prob.csv"
ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv"
ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv"
ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv"
KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb'
NLP_1_DIR = OUTPUT_DIR / 'nlp_1'
NLP_2_DIR = OUTPUT_DIR / 'nlp_2'
KB_DIR = OUTPUT_DIR / "kb_1"
KB_FILE = "kb"
NLP_1_DIR = OUTPUT_DIR / "nlp_1"
NLP_2_DIR = OUTPUT_DIR / "nlp_2"
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2'
WIKIDATA_JSON = ROOT_DIR / "wikidata" / "wikidata-20190304-all.json.bz2"
# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2'
ENWIKI_DUMP = (
ROOT_DIR / "wikipedia" / "enwiki-20190320-pages-articles-multistream.xml.bz2"
)
# KB construction parameters
MAX_CANDIDATES = 10
@ -48,11 +54,15 @@ L2 = 1e-6
CONTEXT_WIDTH = 128
def now():
return datetime.datetime.now()
def run_pipeline():
# set the appropriate booleans to define which parts of the pipeline should be re(run)
print("START", datetime.datetime.now())
print("START", now())
print()
nlp_1 = spacy.load('en_core_web_lg')
nlp_1 = spacy.load("en_core_web_lg")
nlp_2 = None
kb_2 = None
@ -82,20 +92,21 @@ def run_pipeline():
# STEP 1 : create prior probabilities from WP (run only once)
if to_create_prior_probs:
print("STEP 1: to_create_prior_probs", datetime.datetime.now())
wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB)
print("STEP 1: to_create_prior_probs", now())
wp.read_prior_probs(ENWIKI_DUMP, PRIOR_PROB)
print()
# STEP 2 : deduce entity frequencies from WP (run only once)
if to_create_entity_counts:
print("STEP 2: to_create_entity_counts", datetime.datetime.now())
wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False)
print("STEP 2: to_create_entity_counts", now())
wp.write_entity_counts(PRIOR_PROB, ENTITY_COUNTS, to_print=False)
print()
# STEP 3 : create KB and write to file (run only once)
if to_create_kb:
print("STEP 3a: to_create_kb", datetime.datetime.now())
kb_1 = kb_creator.create_kb(nlp_1,
print("STEP 3a: to_create_kb", now())
kb_1 = kb_creator.create_kb(
nlp=nlp_1,
max_entities_per_alias=MAX_CANDIDATES,
min_entity_freq=MIN_ENTITY_FREQ,
min_occ=MIN_PAIR_OCC,
@ -103,22 +114,26 @@ def run_pipeline():
entity_descr_output=ENTITY_DESCR,
count_input=ENTITY_COUNTS,
prior_prob_input=PRIOR_PROB,
wikidata_input=WIKIDATA_JSON)
wikidata_input=WIKIDATA_JSON,
)
print("kb entities:", kb_1.get_size_entities())
print("kb aliases:", kb_1.get_size_aliases())
print()
print("STEP 3b: write KB and NLP", datetime.datetime.now())
kb_1.dump(KB_FILE)
print("STEP 3b: write KB and NLP", now())
if not path.exists(KB_DIR):
os.makedirs(KB_DIR)
kb_1.dump(KB_DIR / KB_FILE)
nlp_1.to_disk(NLP_1_DIR)
print()
# STEP 4 : read KB back in from file
if to_read_kb:
print("STEP 4: to_read_kb", datetime.datetime.now())
print("STEP 4: to_read_kb", now())
nlp_2 = spacy.load(NLP_1_DIR)
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
kb_2.load_bulk(KB_FILE)
kb_2.load_bulk(KB_DIR / KB_FILE)
print("kb entities:", kb_2.get_size_entities())
print("kb aliases:", kb_2.get_size_aliases())
print()
@ -130,20 +145,26 @@ def run_pipeline():
# STEP 5: create a training dataset from WP
if create_wp_training:
print("STEP 5: create training dataset", datetime.datetime.now())
training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP,
print("STEP 5: create training dataset", now())
training_set_creator.create_training(
wikipedia_input=ENWIKI_DUMP,
entity_def_input=ENTITY_DEFS,
training_output=TRAINING_DIR)
training_output=TRAINING_DIR,
)
# STEP 6: create and train the entity linking pipe
if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
print("STEP 6: training Entity Linking pipe", now())
type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)}
print(" -analysing", len(type_to_int), "different entity types")
el_pipe = nlp_2.create_pipe(name='entity_linker',
config={"context_width": CONTEXT_WIDTH,
el_pipe = nlp_2.create_pipe(
name="entity_linker",
config={
"context_width": CONTEXT_WIDTH,
"pretrained_vectors": nlp_2.vocab.vectors.name,
"type_to_int": type_to_int})
"type_to_int": type_to_int,
},
)
el_pipe.set_kb(kb_2)
nlp_2.add_pipe(el_pipe, last=True)
@ -157,18 +178,22 @@ def run_pipeline():
train_limit = 5000
dev_limit = 5000
train_data = training_set_creator.read_training(nlp=nlp_2,
# for training, get pos & neg instances that correspond to entries in the kb
train_data = training_set_creator.read_training(
nlp=nlp_2,
training_dir=TRAINING_DIR,
dev=False,
limit=train_limit)
limit=train_limit,
kb=el_pipe.kb,
)
print("Training on", len(train_data), "articles")
print()
dev_data = training_set_creator.read_training(nlp=nlp_2,
training_dir=TRAINING_DIR,
dev=True,
limit=dev_limit)
# for testing, get all pos instances, whether or not they are in the kb
dev_data = training_set_creator.read_training(
nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
)
print("Dev testing on", len(dev_data), "articles")
print()
@ -187,8 +212,8 @@ def run_pipeline():
try:
docs, golds = zip(*batch)
nlp_2.update(
docs,
golds,
docs=docs,
golds=golds,
sgd=optimizer,
drop=DROPOUT,
losses=losses,
@ -200,48 +225,61 @@ def run_pipeline():
if batchnr > 0:
el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 1
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
losses['entity_linker'] = losses['entity_linker'] / batchnr
print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
" / dev acc avg", round(dev_acc_context, 3))
dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
losses["entity_linker"] = losses["entity_linker"] / batchnr
print(
"Epoch, train loss",
itn,
round(losses["entity_linker"], 2),
" / dev acc avg",
round(dev_acc_context, 3),
)
# STEP 7: measure the performance of our trained pipe on an independent dev set
if len(dev_data) and measure_performance:
print()
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
print("STEP 7: performance measurement of Entity Linking pipe", now())
print()
counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2)
counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
dev_data, kb_2
)
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
print("dev acc oracle:", round(acc_o, 3), oracle_by_label)
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
print("dev acc random:", round(acc_r, 3), random_by_label)
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
print("dev acc prior:", round(acc_p, 3), prior_by_label)
# using only context
el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 0
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
print("dev acc context avg:", round(dev_acc_context, 3),
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
print("dev acc context avg:", round(dev_acc_context, 3), context_by_label)
# measuring combined accuracy (prior + context)
el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 1
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False)
print("dev acc combo avg:", round(dev_acc_combo, 3),
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
# STEP 8: apply the EL pipe on a toy example
if to_test_pipeline:
print()
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
print("STEP 8: applying Entity Linking to toy example", now())
print()
run_el_toy_example(nlp=nlp_2)
# STEP 9: write the NLP pipeline (including entity linker) to file
if to_write_nlp:
print()
print("STEP 9: testing NLP IO", datetime.datetime.now())
print("STEP 9: testing NLP IO", now())
print()
print("writing to", NLP_2_DIR)
nlp_2.to_disk(NLP_2_DIR)
@ -262,23 +300,22 @@ def run_pipeline():
el_pipe = nlp_3.get_pipe("entity_linker")
dev_limit = 5000
dev_data = training_set_creator.read_training(nlp=nlp_2,
training_dir=TRAINING_DIR,
dev=True,
limit=dev_limit)
dev_data = training_set_creator.read_training(
nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
)
print("Dev testing from file on", len(dev_data), "articles")
print()
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe, error_analysis=False)
print("dev acc combo avg:", round(dev_acc_combo, 3),
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
dev_acc_combo, dev_acc_combo_dict = _measure_acc(dev_data, el_pipe)
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
print()
print("STOP", datetime.datetime.now())
print("STOP", now())
def _measure_accuracy(data, el_pipe=None, error_analysis=False):
def _measure_acc(data, el_pipe=None, error_analysis=False):
# If the docs in the data require further processing with an entity linker, set el_pipe
correct_by_label = dict()
incorrect_by_label = dict()
@ -291,16 +328,21 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity in gold.links:
start, end, gold_kb = entity
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
for entity, kb_dict in gold.links.items():
start, end = entity
# only evaluating on positive examples
for gold_kb, value in kb_dict.items():
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
for ent in doc.ents:
ent_label = ent.label_
pred_entity = ent.kb_id_
start = ent.start_char
end = ent.end_char
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
if gold_entity == pred_entity:
@ -311,7 +353,12 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
incorrect_by_label[ent_label] = incorrect + 1
if error_analysis:
print(ent.text, "in", doc)
print("Predicted", pred_entity, "should have been", gold_entity)
print(
"Predicted",
pred_entity,
"should have been",
gold_entity,
)
print()
except Exception as e:
@ -323,16 +370,16 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
def _measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
counts_by_label = dict()
counts_d = dict()
random_correct_by_label = dict()
random_incorrect_by_label = dict()
random_correct_d = dict()
random_incorrect_d = dict()
oracle_correct_by_label = dict()
oracle_incorrect_by_label = dict()
oracle_correct_d = dict()
oracle_incorrect_d = dict()
prior_correct_by_label = dict()
prior_incorrect_by_label = dict()
prior_correct_d = dict()
prior_incorrect_d = dict()
docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0]
@ -340,19 +387,24 @@ def _measure_baselines(data, kb):
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity in gold.links:
start, end, gold_kb = entity
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
for entity, kb_dict in gold.links.items():
start, end = entity
for gold_kb, value in kb_dict.items():
# only evaluating on positive examples
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
for ent in doc.ents:
ent_label = ent.label_
label = ent.label_
start = ent.start_char
end = ent.end_char
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1
counts_d[label] = counts_d.get(label, 0) + 1
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
best_candidate = ""
@ -370,28 +422,40 @@ def _measure_baselines(data, kb):
random_candidate = random.choice(candidates).entity_
if gold_entity == best_candidate:
prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1
prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
else:
prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1
prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
if gold_entity == random_candidate:
random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1
random_correct_d[label] = random_correct_d.get(label, 0) + 1
else:
random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1
random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
if gold_entity == oracle_candidate:
oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1
oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
else:
oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1
oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
except Exception as e:
print("Error assessing accuracy", e)
acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label)
acc_rand, acc_rand_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
return counts_by_label, acc_rand, acc_rand_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
return (
counts_d,
acc_rand,
acc_rand_d,
acc_prior,
acc_prior_d,
acc_oracle,
acc_oracle_d,
)
def _offset(start, end):
return "{}_{}".format(start, end)
def calculate_acc(correct_by_label, incorrect_by_label):
@ -422,15 +486,23 @@ def check_kb(kb):
print("generating candidates for " + mention + " :")
for c in candidates:
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
print(
" ",
c.prior_prob,
c.alias_,
"-->",
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
)
print()
def run_el_toy_example(nlp):
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
"Douglas reminds us to always bring our towel, even in China or Brazil. " \
"The main character in Doug's novel is the man Arthur Dent, " \
"but Douglas doesn't write about George Washington or Homer Simpson."
text = (
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
"Douglas reminds us to always bring our towel, even in China or Brazil. "
"The main character in Doug's novel is the man Arthur Dent, "
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
)
doc = nlp(text)
print(text)
for ent in doc.ents:

View File

@ -663,15 +663,14 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False,
def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg):
# TODO proper error
if "entity_width" not in cfg:
raise ValueError("entity_width not found")
raise ValueError(Errors.E144.format(param="entity_width"))
if "context_width" not in cfg:
raise ValueError("context_width not found")
raise ValueError(Errors.E144.format(param="context_width"))
conv_depth = cfg.get("conv_depth", 2)
cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3)
pretrained_vectors = cfg.get("pretrained_vectors") # self.nlp.vocab.vectors.name
pretrained_vectors = cfg.get("pretrained_vectors", None)
context_width = cfg.get("context_width")
entity_width = cfg.get("entity_width")

View File

@ -406,6 +406,13 @@ class Errors(object):
E141 = ("Entity vectors should be of length {required} instead of the provided {found}.")
E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'")
E143 = ("Labels for component '{name}' not initialized. Did you forget to call add_label()?")
E144 = ("Could not find parameter `{param}` when building the entity linker model.")
E145 = ("Error reading `{param}` from input file.")
E146 = ("Could not access `{path}`.")
E147 = ("Unexpected error in the {method} functionality of the EntityLinker: {msg}. "
"This is likely a bug in spaCy, so feel free to open an issue.")
E148 = ("Expected {ents} KB identifiers but got {ids}. Make sure that each entity in `doc.ents` "
"is assigned to a KB identifier.")
@add_codes

View File

@ -31,7 +31,7 @@ cdef class GoldParse:
cdef public list ents
cdef public dict brackets
cdef public object cats
cdef public list links
cdef public dict links
cdef readonly list cand_to_gold
cdef readonly list gold_to_cand

View File

@ -468,8 +468,11 @@ cdef class GoldParse:
examples of a label to have the value 0.0. Labels not in the
dictionary are treated as missing - the gradient for those labels
will be zero.
links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples,
representing the external ID of an entity in a knowledge base.
links (dict): A dict with `(start_char, end_char)` keys,
and the values being dicts with kb_id:value entries,
representing the external IDs in a knowledge base (KB)
mapped to either 1.0 or 0.0, indicating positive and
negative examples respectively.
RETURNS (GoldParse): The newly constructed object.
"""
if words is None:

View File

@ -79,7 +79,7 @@ cdef class KnowledgeBase:
return new_index
cdef inline int64_t c_add_entity(self, hash_t entity_hash, float prob,
cdef inline int64_t c_add_entity(self, hash_t entity_hash, float freq,
int32_t vector_index, int feats_row) nogil:
"""Add an entry to the vector of entries.
After calling this method, make sure to update also the _entry_index using the return value"""
@ -92,7 +92,7 @@ cdef class KnowledgeBase:
entry.entity_hash = entity_hash
entry.vector_index = vector_index
entry.feats_row = feats_row
entry.prob = prob
entry.freq = freq
self._entries.push_back(entry)
return new_index
@ -125,7 +125,7 @@ cdef class KnowledgeBase:
entry.entity_hash = dummy_hash
entry.vector_index = dummy_value
entry.feats_row = dummy_value
entry.prob = dummy_value
entry.freq = dummy_value
# Avoid struct initializer to enable nogil
cdef vector[int64_t] dummy_entry_indices
@ -141,7 +141,7 @@ cdef class KnowledgeBase:
self._aliases_table.push_back(alias)
cpdef load_bulk(self, loc)
cpdef set_entities(self, entity_list, prob_list, vector_list)
cpdef set_entities(self, entity_list, freq_list, vector_list)
cdef class Writer:
@ -149,7 +149,7 @@ cdef class Writer:
cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1
cdef int write_vector_element(self, float element) except -1
cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1
cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1
cdef int write_alias_length(self, int64_t alias_length) except -1
cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1
@ -162,7 +162,7 @@ cdef class Reader:
cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1
cdef int read_vector_element(self, float* element) except -1
cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1
cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1
cdef int read_alias_length(self, int64_t* alias_length) except -1
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1

View File

@ -94,7 +94,7 @@ cdef class KnowledgeBase:
def get_alias_strings(self):
return [self.vocab.strings[x] for x in self._alias_index]
def add_entity(self, unicode entity, float prob, vector[float] entity_vector):
def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
"""
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
Return the hash of the entity ID/name at the end.
@ -113,15 +113,15 @@ cdef class KnowledgeBase:
vector_index = self.c_add_vector(entity_vector=entity_vector)
new_index = self.c_add_entity(entity_hash=entity_hash,
prob=prob,
freq=freq,
vector_index=vector_index,
feats_row=-1) # Features table currently not implemented
self._entry_index[entity_hash] = new_index
return entity_hash
cpdef set_entities(self, entity_list, prob_list, vector_list):
if len(entity_list) != len(prob_list) or len(entity_list) != len(vector_list):
cpdef set_entities(self, entity_list, freq_list, vector_list):
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
raise ValueError(Errors.E140)
nr_entities = len(entity_list)
@ -137,7 +137,7 @@ cdef class KnowledgeBase:
entity_hash = self.vocab.strings.add(entity_list[i])
entry.entity_hash = entity_hash
entry.prob = prob_list[i]
entry.freq = freq_list[i]
vector_index = self.c_add_vector(entity_vector=vector_list[i])
entry.vector_index = vector_index
@ -196,13 +196,42 @@ cdef class KnowledgeBase:
return [Candidate(kb=self,
entity_hash=self._entries[entry_index].entity_hash,
entity_freq=self._entries[entry_index].prob,
entity_freq=self._entries[entry_index].freq,
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
alias_hash=alias_hash,
prior_prob=prob)
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
prior_prob=prior_prob)
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
if entry_index != 0]
def get_vector(self, unicode entity):
cdef hash_t entity_hash = self.vocab.strings[entity]
# Return an empty list if this entity is unknown in this KB
if entity_hash not in self._entry_index:
return [0] * self.entity_vector_length
entry_index = self._entry_index[entity_hash]
return self._vectors_table[self._entries[entry_index].vector_index]
def get_prior_prob(self, unicode entity, unicode alias):
""" Return the prior probability of a given alias being linked to a given entity,
or return 0.0 when this combination is not known in the knowledge base"""
cdef hash_t alias_hash = self.vocab.strings[alias]
cdef hash_t entity_hash = self.vocab.strings[entity]
if entity_hash not in self._entry_index or alias_hash not in self._alias_index:
return 0.0
alias_index = <int64_t>self._alias_index.get(alias_hash)
entry_index = self._entry_index[entity_hash]
alias_entry = self._aliases_table[alias_index]
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs):
if self._entries[entry_index].entity_hash == entity_hash:
return prior_prob
return 0.0
def dump(self, loc):
cdef Writer writer = Writer(loc)
@ -222,7 +251,7 @@ cdef class KnowledgeBase:
entry = self._entries[entry_index]
assert entry.entity_hash == entry_hash
assert entry_index == i
writer.write_entry(entry.entity_hash, entry.prob, entry.vector_index)
writer.write_entry(entry.entity_hash, entry.freq, entry.vector_index)
i = i+1
writer.write_alias_length(self.get_size_aliases())
@ -248,7 +277,7 @@ cdef class KnowledgeBase:
cdef hash_t entity_hash
cdef hash_t alias_hash
cdef int64_t entry_index
cdef float prob
cdef float freq, prob
cdef int32_t vector_index
cdef KBEntryC entry
cdef AliasC alias
@ -284,10 +313,10 @@ cdef class KnowledgeBase:
# index 0 is a dummy object not stored in the _entry_index and can be ignored.
i = 1
while i <= nr_entities:
reader.read_entry(&entity_hash, &prob, &vector_index)
reader.read_entry(&entity_hash, &freq, &vector_index)
entry.entity_hash = entity_hash
entry.prob = prob
entry.freq = freq
entry.vector_index = vector_index
entry.feats_row = -1 # Features table currently not implemented
@ -343,7 +372,8 @@ cdef class Writer:
loc = bytes(loc)
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(<char*>bytes_loc, 'wb')
assert self._fp != NULL
if not self._fp:
raise IOError(Errors.E146.format(path=loc))
fseek(self._fp, 0, 0)
def close(self):
@ -357,9 +387,9 @@ cdef class Writer:
cdef int write_vector_element(self, float element) except -1:
self._write(&element, sizeof(element))
cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1:
cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1:
self._write(&entry_hash, sizeof(entry_hash))
self._write(&entry_prob, sizeof(entry_prob))
self._write(&entry_freq, sizeof(entry_freq))
self._write(&vector_index, sizeof(vector_index))
# Features table currently not implemented and not written to file
@ -399,39 +429,39 @@ cdef class Reader:
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading header from input file")
raise IOError(Errors.E145.format(param="header"))
status = self._read(entity_vector_length, sizeof(int64_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading header from input file")
raise IOError(Errors.E145.format(param="vector length"))
cdef int read_vector_element(self, float* element) except -1:
status = self._read(element, sizeof(float))
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading entity vector from input file")
raise IOError(Errors.E145.format(param="vector element"))
cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1:
cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1:
status = self._read(entity_hash, sizeof(hash_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading entity hash from input file")
raise IOError(Errors.E145.format(param="entity hash"))
status = self._read(prob, sizeof(float))
status = self._read(freq, sizeof(float))
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading entity prob from input file")
raise IOError(Errors.E145.format(param="entity freq"))
status = self._read(vector_index, sizeof(int32_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading entity vector from input file")
raise IOError(Errors.E145.format(param="vector index"))
if feof(self._fp):
return 0
@ -443,33 +473,33 @@ cdef class Reader:
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading alias length from input file")
raise IOError(Errors.E145.format(param="alias length"))
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1:
status = self._read(alias_hash, sizeof(hash_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading alias hash from input file")
raise IOError(Errors.E145.format(param="alias hash"))
status = self._read(candidate_length, sizeof(int64_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading candidate length from input file")
raise IOError(Errors.E145.format(param="candidate length"))
cdef int read_alias(self, int64_t* entry_index, float* prob) except -1:
status = self._read(entry_index, sizeof(int64_t))
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading entry index for alias from input file")
raise IOError(Errors.E145.format(param="entry index"))
status = self._read(prob, sizeof(float))
if status < 1:
if feof(self._fp):
return 0 # end of file
raise IOError("error reading prob for entity/alias from input file")
raise IOError(Errors.E145.format(param="prior probability"))
cdef int _read(self, void* value, size_t size) except -1:
status = fread(value, size, 1, self._fp)

View File

@ -14,7 +14,6 @@ from thinc.neural.util import to_categorical
from thinc.neural.util import get_array_module
from spacy.kb import KnowledgeBase
from ..cli.pretrain import get_cossim_loss
from .functions import merge_subtokens
from ..tokens.doc cimport Doc
from ..syntax.nn_parser cimport Parser
@ -1077,6 +1076,7 @@ class EntityLinker(Pipe):
DOCS: TODO
"""
name = 'entity_linker'
NIL = "NIL" # string used to refer to a non-existing link
@classmethod
def Model(cls, **cfg):
@ -1093,6 +1093,8 @@ class EntityLinker(Pipe):
self.kb = None
self.cfg = dict(cfg)
self.sgd_context = None
if not self.cfg.get("context_width"):
self.cfg["context_width"] = 128
def set_kb(self, kb):
self.kb = kb
@ -1140,7 +1142,7 @@ class EntityLinker(Pipe):
context_docs = []
entity_encodings = []
cats = []
priors = []
type_vectors = []
@ -1149,50 +1151,44 @@ class EntityLinker(Pipe):
for doc, gold in zip(docs, golds):
ents_by_offset = dict()
for ent in doc.ents:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
for entity in gold.links:
start, end, gold_kb = entity
ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent
for entity, kb_dict in gold.links.items():
start, end = entity
mention = doc.text[start:end]
for kb_id, value in kb_dict.items():
entity_encoding = self.kb.get_vector(kb_id)
prior_prob = self.kb.get_prior_prob(kb_id, mention)
gold_ent = ents_by_offset["{}_{}".format(start, end)]
if gold_ent is None:
raise RuntimeError(Errors.E147.format(method="update", msg="gold entity not found"))
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
assert gold_ent is not None
type_vector = [0 for i in range(len(type_to_int))]
if len(type_to_int) > 0:
type_vector[type_to_int[gold_ent.label_]] = 1
candidates = self.kb.get_candidates(mention)
random.shuffle(candidates)
nr_neg = 0
for c in candidates:
kb_id = c.entity_
entity_encoding = c.entity_vector
# store data
entity_encodings.append(entity_encoding)
context_docs.append(doc)
type_vectors.append(type_vector)
if self.cfg.get("prior_weight", 1) > 0:
priors.append([c.prior_prob])
priors.append([prior_prob])
else:
priors.append([0])
if kb_id == gold_kb:
cats.append([1])
else:
nr_neg += 1
cats.append([0])
if len(entity_encodings) > 0:
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
if not (len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)):
raise RuntimeError(Errors.E147.format(method="update", msg="vector lengths not equal"))
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i]
for i in range(len(entity_encodings))]
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
cats = self.model.ops.asarray(cats, dtype="float32")
loss, d_scores = self.get_loss(prediction=pred, golds=cats, docs=None)
loss, d_scores = self.get_loss(scores=pred, golds=golds, docs=docs)
mention_gradient = bp_mention(d_scores, sgd=sgd)
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
@ -1203,39 +1199,45 @@ class EntityLinker(Pipe):
return loss
return 0
def get_loss(self, docs, golds, prediction):
d_scores = (prediction - golds)
def get_loss(self, docs, golds, scores):
cats = []
for gold in golds:
for entity, kb_dict in gold.links.items():
for kb_id, value in kb_dict.items():
cats.append([value])
cats = self.model.ops.asarray(cats, dtype="float32")
if len(scores) != len(cats):
raise RuntimeError(Errors.E147.format(method="get_loss", msg="gold entities do not match up"))
d_scores = (scores - cats)
loss = (d_scores ** 2).sum()
loss = loss / len(golds)
loss = loss / len(cats)
return loss, d_scores
def get_loss_old(self, docs, golds, scores):
# this loss function assumes we're only using positive examples
loss, gradients = get_cossim_loss(yh=scores, y=golds)
loss = loss / len(golds)
return loss, gradients
def __call__(self, doc):
entities, kb_ids = self.predict([doc])
self.set_annotations([doc], entities, kb_ids)
kb_ids, tensors = self.predict([doc])
self.set_annotations([doc], kb_ids, tensors=tensors)
return doc
def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in util.minibatch(stream, size=batch_size):
docs = list(docs)
entities, kb_ids = self.predict(docs)
self.set_annotations(docs, entities, kb_ids)
kb_ids, tensors = self.predict(docs)
self.set_annotations(docs, kb_ids, tensors=tensors)
yield from docs
def predict(self, docs):
""" Return the KB IDs for each entity in each doc, including NIL if there is no prediction """
self.require_model()
self.require_kb()
final_entities = []
entity_count = 0
final_kb_ids = []
final_tensors = []
if not docs:
return final_entities, final_kb_ids
return final_kb_ids, final_tensors
if isinstance(docs, Doc):
docs = [docs]
@ -1247,14 +1249,19 @@ class EntityLinker(Pipe):
for i, doc in enumerate(docs):
if len(doc) > 0:
# currently, the context is the same for each entity in a sentence (should be refined)
context_encoding = context_encodings[i]
for ent in doc.ents:
entity_count += 1
type_vector = [0 for i in range(len(type_to_int))]
if len(type_to_int) > 0:
type_vector[type_to_int[ent.label_]] = 1
candidates = self.kb.get_candidates(ent.text)
if candidates:
if not candidates:
final_kb_ids.append(self.NIL) # no prediction possible for this entity
final_tensors.append(context_encoding)
else:
random.shuffle(candidates)
# this will set the prior probabilities to 0 (just like in training) if their weight is 0
@ -1264,7 +1271,9 @@ class EntityLinker(Pipe):
if self.cfg.get("context_weight", 1) > 0:
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
assert len(entity_encodings) == len(prior_probs)
if len(entity_encodings) != len(prior_probs):
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
mention_encodings = [list(context_encoding) + list(entity_encodings[i])
+ list(prior_probs[i]) + type_vector
for i in range(len(entity_encodings))]
@ -1273,14 +1282,25 @@ class EntityLinker(Pipe):
# TODO: thresholding
best_index = scores.argmax()
best_candidate = candidates[best_index]
final_entities.append(ent)
final_kb_ids.append(best_candidate.entity_)
final_tensors.append(context_encoding)
return final_entities, final_kb_ids
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
def set_annotations(self, docs, entities, kb_ids=None):
for entity, kb_id in zip(entities, kb_ids):
for token in entity:
return final_kb_ids, final_tensors
def set_annotations(self, docs, kb_ids, tensors=None):
count_ents = len([ent for doc in docs for ent in doc.ents])
if count_ents != len(kb_ids):
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
i=0
for doc in docs:
for ent in doc.ents:
kb_id = kb_ids[i]
i += 1
for token in ent:
token.ent_kb_id_ = kb_id
def to_disk(self, path, exclude=tuple(), **kwargs):

View File

@ -93,7 +93,7 @@ cdef struct KBEntryC:
int32_t feats_row
# log probability of entity, based on corpus frequency
float prob
float freq
# Each alias struct stores a list of Entry pointers with their prior probabilities

View File

@ -13,22 +13,38 @@ def nlp():
return English()
def assert_almost_equal(a, b):
delta = 0.0001
assert a - delta <= b <= a + delta
def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.5, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[8, 4, 3])
mykb.add_entity(entity="Q2", freq=0.5, entity_vector=[2, 1, 0])
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[-1, -6, 5])
# adding aliases
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the corresponding KB
assert(mykb.get_size_entities() == 3)
assert(mykb.get_size_aliases() == 2)
assert mykb.get_size_entities() == 3
assert mykb.get_size_aliases() == 2
# test retrieval of the entity vectors
assert mykb.get_vector("Q1") == [8, 4, 3]
assert mykb.get_vector("Q2") == [2, 1, 0]
assert mykb.get_vector("Q3") == [-1, -6, 5]
# test retrieval of prior probabilities
assert_almost_equal(mykb.get_prior_prob(entity="Q2", alias="douglas"), 0.8)
assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglas"), 0.2)
assert_almost_equal(mykb.get_prior_prob(entity="Q342", alias="douglas"), 0.0)
assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglassssss"), 0.0)
def test_kb_invalid_entities(nlp):
@ -36,13 +52,15 @@ def test_kb_invalid_entities(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
# adding aliases - should fail because one of the given IDs is not valid
with pytest.raises(ValueError):
mykb.add_alias(alias='douglas', entities=['Q2', 'Q342'], probabilities=[0.8, 0.2])
mykb.add_alias(
alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]
)
def test_kb_invalid_probabilities(nlp):
@ -50,13 +68,13 @@ def test_kb_invalid_probabilities(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
# adding aliases - should fail because the sum of the probabilities exceeds 1
with pytest.raises(ValueError):
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.4])
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
def test_kb_invalid_combination(nlp):
@ -64,13 +82,15 @@ def test_kb_invalid_combination(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
with pytest.raises(ValueError):
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.3, 0.4, 0.1])
mykb.add_alias(
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1]
)
def test_kb_invalid_entity_vector(nlp):
@ -78,11 +98,11 @@ def test_kb_invalid_entity_vector(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1, 2, 3])
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1, 2, 3])
# this should fail because the kb's expected entity vector length is 3
with pytest.raises(ValueError):
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
def test_candidate_generation(nlp):
@ -90,18 +110,24 @@ def test_candidate_generation(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
mykb.add_entity(entity="Q1", freq=0.7, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
# adding aliases
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
assert(len(mykb.get_candidates('douglas')) == 2)
assert(len(mykb.get_candidates('adam')) == 1)
assert(len(mykb.get_candidates('shrubbery')) == 0)
assert len(mykb.get_candidates("douglas")) == 2
assert len(mykb.get_candidates("adam")) == 1
assert len(mykb.get_candidates("shrubbery")) == 0
# test the content of the candidates
assert mykb.get_candidates("adam")[0].entity_ == "Q2"
assert mykb.get_candidates("adam")[0].alias_ == "adam"
assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 0.2)
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
def test_preserving_links_asdoc(nlp):
@ -109,24 +135,26 @@ def test_preserving_links_asdoc(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1])
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=0.8, entity_vector=[1])
# adding aliases
mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7])
mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6])
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
sentencizer = nlp.create_pipe("sentencizer")
nlp.add_pipe(sentencizer)
ruler = EntityRuler(nlp)
patterns = [{"label": "GPE", "pattern": "Boston"},
{"label": "GPE", "pattern": "Denver"}]
patterns = [
{"label": "GPE", "pattern": "Boston"},
{"label": "GPE", "pattern": "Denver"},
]
ruler.add_patterns(patterns)
nlp.add_pipe(ruler)
el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64})
el_pipe = nlp.create_pipe(name="entity_linker", config={"context_width": 64})
el_pipe.set_kb(mykb)
el_pipe.begin_training()
el_pipe.context_weight = 0

View File

@ -30,10 +30,10 @@ def test_serialize_kb_disk(en_vocab):
def _get_dummy_kb(vocab):
kb = KnowledgeBase(vocab=vocab, entity_vector_length=3)
kb.add_entity(entity='Q53', prob=0.33, entity_vector=[0, 5, 3])
kb.add_entity(entity='Q17', prob=0.2, entity_vector=[7, 1, 0])
kb.add_entity(entity='Q007', prob=0.7, entity_vector=[0, 0, 7])
kb.add_entity(entity='Q44', prob=0.4, entity_vector=[4, 4, 4])
kb.add_entity(entity='Q53', freq=0.33, entity_vector=[0, 5, 3])
kb.add_entity(entity='Q17', freq=0.2, entity_vector=[7, 1, 0])
kb.add_entity(entity='Q007', freq=0.7, entity_vector=[0, 0, 7])
kb.add_entity(entity='Q44', freq=0.4, entity_vector=[4, 4, 4])
kb.add_alias(alias='double07', entities=['Q17', 'Q007'], probabilities=[0.1, 0.9])
kb.add_alias(alias='guy', entities=['Q53', 'Q007', 'Q17', 'Q44'], probabilities=[0.3, 0.3, 0.2, 0.1])

View File

@ -348,7 +348,7 @@ cdef class Tokenizer:
"""Add a special-case tokenization rule.
string (unicode): The string to specially tokenize.
token_attrs (iterable): A sequence of dicts, where each dict describes
substrings (iterable): A sequence of dicts, where each dict describes
a token and its attributes. The `ORTH` fields of the attributes
must exactly match the string when they are concatenated.