Merge branch 'master' into spacy.io

This commit is contained in:
Ines Montani 2019-07-25 12:14:18 +02:00
commit 4361da2bba
23 changed files with 922 additions and 389 deletions

View File

@ -13,9 +13,17 @@ INPUT_DIM = 300 # dimension of pre-trained input vectors
DESC_WIDTH = 64 # dimension of output entity vectors DESC_WIDTH = 64 # dimension of output entity vectors
def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, def create_kb(
entity_def_output, entity_descr_output, nlp,
count_input, prior_prob_input, wikidata_input): max_entities_per_alias,
min_entity_freq,
min_occ,
entity_def_output,
entity_descr_output,
count_input,
prior_prob_input,
wikidata_input,
):
# Create the knowledge base from Wikidata entries # Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH) kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
@ -28,7 +36,9 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input) title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
# write the title-ID and ID-description mappings to file # write the title-ID and ID-description mappings to file
_write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr) _write_entity_files(
entity_def_output, entity_descr_output, title_to_id, id_to_descr
)
else: else:
# read the mappings from file # read the mappings from file
@ -54,8 +64,8 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
frequency_list.append(freq) frequency_list.append(freq)
filtered_title_to_id[title] = entity filtered_title_to_id[title] = entity
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), print(len(title_to_id.keys()), "original titles")
"titles with filter frequency", min_entity_freq) print("kept", len(filtered_title_to_id.keys()), " with frequency", min_entity_freq)
print() print()
print(" * train entity encoder", datetime.datetime.now()) print(" * train entity encoder", datetime.datetime.now())
@ -70,14 +80,20 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
print() print()
print(" * adding", len(entity_list), "entities", datetime.datetime.now()) print(" * adding", len(entity_list), "entities", datetime.datetime.now())
kb.set_entities(entity_list=entity_list, prob_list=frequency_list, vector_list=embeddings) kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
)
print() print()
print(" * adding aliases", datetime.datetime.now()) print(" * adding aliases", datetime.datetime.now())
print() print()
_add_aliases(kb, title_to_id=filtered_title_to_id, _add_aliases(
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ, kb,
prior_prob_input=prior_prob_input) title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ,
prior_prob_input=prior_prob_input,
)
print() print()
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases()) print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
@ -86,13 +102,15 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
return kb return kb
def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr): def _write_entity_files(
with open(entity_def_output, mode='w', encoding='utf8') as id_file: entity_def_output, entity_descr_output, title_to_id, id_to_descr
):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n") id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items(): for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n") id_file.write(title + "|" + str(qid) + "\n")
with open(entity_descr_output, mode='w', encoding='utf8') as descr_file: with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n") descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items(): for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n") descr_file.write(str(qid) + "|" + descr + "\n")
@ -100,8 +118,8 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_
def get_entity_to_id(entity_def_output): def get_entity_to_id(entity_def_output):
entity_to_id = dict() entity_to_id = dict()
with open(entity_def_output, 'r', encoding='utf8') as csvfile: with entity_def_output.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter='|') csvreader = csv.reader(csvfile, delimiter="|")
# skip header # skip header
next(csvreader) next(csvreader)
for row in csvreader: for row in csvreader:
@ -111,8 +129,8 @@ def get_entity_to_id(entity_def_output):
def get_id_to_description(entity_descr_output): def get_id_to_description(entity_descr_output):
id_to_desc = dict() id_to_desc = dict()
with open(entity_descr_output, 'r', encoding='utf8') as csvfile: with entity_descr_output.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter='|') csvreader = csv.reader(csvfile, delimiter="|")
# skip header # skip header
next(csvreader) next(csvreader)
for row in csvreader: for row in csvreader:
@ -125,7 +143,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
# adding aliases with prior probabilities # adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count # we can read this file sequentially, it's sorted by alias, and then by count
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file: with prior_prob_input.open("r", encoding="utf8") as prior_file:
# skip header # skip header
prior_file.readline() prior_file.readline()
line = prior_file.readline() line = prior_file.readline()
@ -134,7 +152,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
counts = [] counts = []
entities = [] entities = []
while line: while line:
splits = line.replace('\n', "").split(sep='|') splits = line.replace("\n", "").split(sep="|")
new_alias = splits[0] new_alias = splits[0]
count = int(splits[1]) count = int(splits[1])
entity = splits[2] entity = splits[2]
@ -153,7 +171,11 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
if selected_entities: if selected_entities:
try: try:
kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs) kb.add_alias(
alias=previous_alias,
entities=selected_entities,
probabilities=prior_probs,
)
except ValueError as e: except ValueError as e:
print(e) print(e)
total_count = 0 total_count = 0
@ -168,4 +190,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
previous_alias = new_alias previous_alias = new_alias
line = prior_file.readline() line = prior_file.readline()

View File

@ -1,7 +1,7 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import os import random
import re import re
import bz2 import bz2
import datetime import datetime
@ -17,6 +17,10 @@ Gold-standard entities are stored in one file in standoff format (by character o
ENTITY_FILE = "gold_entities.csv" ENTITY_FILE = "gold_entities.csv"
def now():
return datetime.datetime.now()
def create_training(wikipedia_input, entity_def_input, training_output): def create_training(wikipedia_input, entity_def_input, training_output):
wp_to_id = kb_creator.get_entity_to_id(entity_def_input) wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None) _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
@ -27,21 +31,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
Read the XML wikipedia data to parse out training data: Read the XML wikipedia data to parse out training data:
raw text data + positive instances raw text data + positive instances
""" """
title_regex = re.compile(r'(?<=<title>).*(?=</title>)') title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
id_regex = re.compile(r'(?<=<id>)\d*(?=</id>)') id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
read_ids = set() read_ids = set()
entityfile_loc = training_output / ENTITY_FILE entityfile_loc = training_output / ENTITY_FILE
with open(entityfile_loc, mode="w", encoding='utf8') as entityfile: with entityfile_loc.open("w", encoding="utf8") as entityfile:
# write entity training header file # write entity training header file
_write_training_entity(outputfile=entityfile, _write_training_entity(
outputfile=entityfile,
article_id="article_id", article_id="article_id",
alias="alias", alias="alias",
entity="WD_id", entity="WD_id",
start="start", start="start",
end="end") end="end",
)
with bz2.open(wikipedia_input, mode='rb') as file: with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline() line = file.readline()
cnt = 0 cnt = 0
article_text = "" article_text = ""
@ -51,7 +57,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
reading_revision = False reading_revision = False
while line and (not limit or cnt < limit): while line and (not limit or cnt < limit):
if cnt % 1000000 == 0: if cnt % 1000000 == 0:
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump") print(now(), "processed", cnt, "lines of Wikipedia dump")
clean_line = line.strip().decode("utf-8") clean_line = line.strip().decode("utf-8")
if clean_line == "<revision>": if clean_line == "<revision>":
@ -69,12 +75,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
elif clean_line == "</page>": elif clean_line == "</page>":
if article_id: if article_id:
try: try:
_process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(), _process_wp_text(
training_output) wp_to_id,
entityfile,
article_id,
article_title,
article_text.strip(),
training_output,
)
except Exception as e: except Exception as e:
print("Error processing article", article_id, article_title, e) print(
"Error processing article", article_id, article_title, e
)
else: else:
print("Done processing a page, but couldn't find an article_id ?", article_title) print(
"Done processing a page, but couldn't find an article_id ?",
article_title,
)
article_text = "" article_text = ""
article_title = None article_title = None
article_id = None article_id = None
@ -98,7 +115,9 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
if ids: if ids:
article_id = ids[0] article_id = ids[0]
if article_id in read_ids: if article_id in read_ids:
print("Found duplicate article ID", article_id, clean_line) # This should never happen ... print(
"Found duplicate article ID", article_id, clean_line
) # This should never happen ...
read_ids.add(article_id) read_ids.add(article_id)
# read the title of this article (outside the revision portion of the document) # read the title of this article (outside the revision portion of the document)
@ -111,10 +130,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
cnt += 1 cnt += 1
text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)') text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output): def _process_wp_text(
wp_to_id, entityfile, article_id, article_title, article_text, training_output
):
found_entities = False found_entities = False
# ignore meta Wikipedia pages # ignore meta Wikipedia pages
@ -141,11 +162,11 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
entity_buffer = "" entity_buffer = ""
mention_buffer = "" mention_buffer = ""
for index, letter in enumerate(clean_text): for index, letter in enumerate(clean_text):
if letter == '[': if letter == "[":
open_read += 1 open_read += 1
elif letter == ']': elif letter == "]":
open_read -= 1 open_read -= 1
elif letter == '|': elif letter == "|":
if reading_text: if reading_text:
final_text += letter final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern # switch from reading entity to mention in the [[entity|mention]] pattern
@ -163,7 +184,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
elif reading_text: elif reading_text:
final_text += letter final_text += letter
else: else:
raise ValueError("Not sure at point", clean_text[index-2:index+2]) raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
if open_read > 2: if open_read > 2:
reading_special_case = True reading_special_case = True
@ -175,7 +196,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
# we just finished reading an entity # we just finished reading an entity
if open_read == 0 and not reading_text: if open_read == 0 and not reading_text:
if '#' in entity_buffer or entity_buffer.startswith(':'): if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True reading_special_case = True
# Ignore cases with nested structures like File: handles etc # Ignore cases with nested structures like File: handles etc
if not reading_special_case: if not reading_special_case:
@ -185,12 +206,14 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
end = start + len(mention_buffer) end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None) qid = wp_to_id.get(entity_buffer, None)
if qid: if qid:
_write_training_entity(outputfile=entityfile, _write_training_entity(
outputfile=entityfile,
article_id=article_id, article_id=article_id,
alias=mention_buffer, alias=mention_buffer,
entity=qid, entity=qid,
start=start, start=start,
end=end) end=end,
)
found_entities = True found_entities = True
final_text += mention_buffer final_text += mention_buffer
@ -203,29 +226,35 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
reading_special_case = False reading_special_case = False
if found_entities: if found_entities:
_write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output) _write_training_article(
article_id=article_id,
clean_text=final_text,
training_output=training_output,
)
info_regex = re.compile(r'{[^{]*?}') info_regex = re.compile(r"{[^{]*?}")
htlm_regex = re.compile(r'&lt;!--[^-]*--&gt;') htlm_regex = re.compile(r"&lt;!--[^-]*--&gt;")
category_regex = re.compile(r'\[\[Category:[^\[]*]]') category_regex = re.compile(r"\[\[Category:[^\[]*]]")
file_regex = re.compile(r'\[\[File:[^[\]]+]]') file_regex = re.compile(r"\[\[File:[^[\]]+]]")
ref_regex = re.compile(r'&lt;ref.*?&gt;') # non-greedy ref_regex = re.compile(r"&lt;ref.*?&gt;") # non-greedy
ref_2_regex = re.compile(r'&lt;/ref.*?&gt;') # non-greedy ref_2_regex = re.compile(r"&lt;/ref.*?&gt;") # non-greedy
def _get_clean_wp_text(article_text): def _get_clean_wp_text(article_text):
clean_text = article_text.strip() clean_text = article_text.strip()
# remove bolding & italic markup # remove bolding & italic markup
clean_text = clean_text.replace('\'\'\'', '') clean_text = clean_text.replace("'''", "")
clean_text = clean_text.replace('\'\'', '') clean_text = clean_text.replace("''", "")
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating # remove nested {{info}} statements by removing the inner/smallest ones first and iterating
try_again = True try_again = True
previous_length = len(clean_text) previous_length = len(clean_text)
while try_again: while try_again:
clean_text = info_regex.sub('', clean_text) # non-greedy match excluding a nested { clean_text = info_regex.sub(
"", clean_text
) # non-greedy match excluding a nested {
if len(clean_text) < previous_length: if len(clean_text) < previous_length:
try_again = True try_again = True
else: else:
@ -233,14 +262,14 @@ def _get_clean_wp_text(article_text):
previous_length = len(clean_text) previous_length = len(clean_text)
# remove HTML comments # remove HTML comments
clean_text = htlm_regex.sub('', clean_text) clean_text = htlm_regex.sub("", clean_text)
# remove Category and File statements # remove Category and File statements
clean_text = category_regex.sub('', clean_text) clean_text = category_regex.sub("", clean_text)
clean_text = file_regex.sub('', clean_text) clean_text = file_regex.sub("", clean_text)
# remove multiple = # remove multiple =
while '==' in clean_text: while "==" in clean_text:
clean_text = clean_text.replace("==", "=") clean_text = clean_text.replace("==", "=")
clean_text = clean_text.replace(". =", ".") clean_text = clean_text.replace(". =", ".")
@ -249,43 +278,47 @@ def _get_clean_wp_text(article_text):
clean_text = clean_text.replace(" =", "") clean_text = clean_text.replace(" =", "")
# remove refs (non-greedy match) # remove refs (non-greedy match)
clean_text = ref_regex.sub('', clean_text) clean_text = ref_regex.sub("", clean_text)
clean_text = ref_2_regex.sub('', clean_text) clean_text = ref_2_regex.sub("", clean_text)
# remove additional wikiformatting # remove additional wikiformatting
clean_text = re.sub(r'&lt;blockquote&gt;', '', clean_text) clean_text = re.sub(r"&lt;blockquote&gt;", "", clean_text)
clean_text = re.sub(r'&lt;/blockquote&gt;', '', clean_text) clean_text = re.sub(r"&lt;/blockquote&gt;", "", clean_text)
# change special characters back to normal ones # change special characters back to normal ones
clean_text = clean_text.replace(r'&lt;', '<') clean_text = clean_text.replace(r"&lt;", "<")
clean_text = clean_text.replace(r'&gt;', '>') clean_text = clean_text.replace(r"&gt;", ">")
clean_text = clean_text.replace(r'&quot;', '"') clean_text = clean_text.replace(r"&quot;", '"')
clean_text = clean_text.replace(r'&amp;nbsp;', ' ') clean_text = clean_text.replace(r"&amp;nbsp;", " ")
clean_text = clean_text.replace(r'&amp;', '&') clean_text = clean_text.replace(r"&amp;", "&")
# remove multiple spaces # remove multiple spaces
while ' ' in clean_text: while " " in clean_text:
clean_text = clean_text.replace(' ', ' ') clean_text = clean_text.replace(" ", " ")
return clean_text.strip() return clean_text.strip()
def _write_training_article(article_id, clean_text, training_output): def _write_training_article(article_id, clean_text, training_output):
file_loc = training_output / str(article_id) + ".txt" file_loc = training_output / "{}.txt".format(article_id)
with open(file_loc, mode='w', encoding='utf8') as outputfile: with file_loc.open("w", encoding="utf8") as outputfile:
outputfile.write(clean_text) outputfile.write(clean_text)
def _write_training_entity(outputfile, article_id, alias, entity, start, end): def _write_training_entity(outputfile, article_id, alias, entity, start, end):
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n") line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end)
outputfile.write(line)
def is_dev(article_id): def is_dev(article_id):
return article_id.endswith("3") return article_id.endswith("3")
def read_training(nlp, training_dir, dev, limit): def read_training(nlp, training_dir, dev, limit, kb=None):
# This method provides training examples that correspond to the entity annotations found by the nlp object """ This method provides training examples that correspond to the entity annotations found by the nlp object.
When kb is provided (for training), it will include negative training examples by using the candidate generator,
and it will only keep positive training examples that can be found in the KB.
When kb=None (for testing), it will include all positive examples only."""
entityfile_loc = training_dir / ENTITY_FILE entityfile_loc = training_dir / ENTITY_FILE
data = [] data = []
@ -296,24 +329,30 @@ def read_training(nlp, training_dir, dev, limit):
skip_articles = set() skip_articles = set()
total_entities = 0 total_entities = 0
with open(entityfile_loc, mode='r', encoding='utf8') as file: with entityfile_loc.open("r", encoding="utf8") as file:
for line in file: for line in file:
if not limit or len(data) < limit: if not limit or len(data) < limit:
fields = line.replace('\n', "").split(sep='|') fields = line.replace("\n", "").split(sep="|")
article_id = fields[0] article_id = fields[0]
alias = fields[1] alias = fields[1]
wp_title = fields[2] wd_id = fields[2]
start = fields[3] start = fields[3]
end = fields[4] end = fields[4]
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles: if (
dev == is_dev(article_id)
and article_id != "article_id"
and article_id not in skip_articles
):
if not current_doc or (current_article_id != article_id): if not current_doc or (current_article_id != article_id):
# parse the new article text # parse the new article text
file_name = article_id + ".txt" file_name = article_id + ".txt"
try: try:
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f: training_file = training_dir / file_name
with training_file.open("r", encoding="utf8") as f:
text = f.read() text = f.read()
if len(text) < 30000: # threshold for convenience / speed of processing # threshold for convenience / speed of processing
if len(text) < 30000:
current_doc = nlp(text) current_doc = nlp(text)
current_article_id = article_id current_article_id = article_id
ents_by_offset = dict() ents_by_offset = dict()
@ -321,28 +360,64 @@ def read_training(nlp, training_dir, dev, limit):
sent_length = len(ent.sent) sent_length = len(ent.sent)
# custom filtering to avoid too long or too short sentences # custom filtering to avoid too long or too short sentences
if 5 < sent_length < 100: if 5 < sent_length < 100:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent offset = "{}_{}".format(
ent.start_char, ent.end_char
)
ents_by_offset[offset] = ent
else: else:
skip_articles.add(article_id) skip_articles.add(article_id)
current_doc = None current_doc = None
except Exception as e: except Exception as e:
print("Problem parsing article", article_id, e) print("Problem parsing article", article_id, e)
skip_articles.add(article_id) skip_articles.add(article_id)
raise e
# repeat checking this condition in case an exception was thrown # repeat checking this condition in case an exception was thrown
if current_doc and (current_article_id == article_id): if current_doc and (current_article_id == article_id):
found_ent = ents_by_offset.get(start + "_" + end, None) offset = "{}_{}".format(start, end)
found_ent = ents_by_offset.get(offset, None)
if found_ent: if found_ent:
if found_ent.text != alias: if found_ent.text != alias:
skip_articles.add(article_id) skip_articles.add(article_id)
current_doc = None current_doc = None
else: else:
sent = found_ent.sent.as_doc() sent = found_ent.sent.as_doc()
# currently feeding the gold data one entity per sentence at a time
gold_start = int(start) - found_ent.sent.start_char gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char gold_end = int(end) - found_ent.sent.start_char
gold_entities = [(gold_start, gold_end, wp_title)]
gold_entities = {}
found_useful = False
for ent in sent.ents:
entry = (ent.start_char, ent.end_char)
gold_entry = (gold_start, gold_end)
if entry == gold_entry:
# add both pos and neg examples (in random order)
# this will exclude examples not in the KB
if kb:
value_by_id = {}
candidates = kb.get_candidates(alias)
candidate_ids = [
c.entity_ for c in candidates
]
random.shuffle(candidate_ids)
for kb_id in candidate_ids:
found_useful = True
if kb_id != wd_id:
value_by_id[kb_id] = 0.0
else:
value_by_id[kb_id] = 1.0
gold_entities[entry] = value_by_id
# if no KB, keep all positive examples
else:
found_useful = True
value_by_id = {wd_id: 1.0}
gold_entities[entry] = value_by_id
# currently feeding the gold data one entity per sentence at a time
# setting all other entities to empty gold dictionary
else:
gold_entities[entry] = {}
if found_useful:
gold = GoldParse(doc=sent, links=gold_entities) gold = GoldParse(doc=sent, links=gold_entities)
data.append((sent, gold)) data.append((sent, gold))
total_entities += 1 total_entities += 1

View File

@ -14,22 +14,97 @@ Write these results to file for downstream KB and training data generation.
map_alias_to_link = dict() map_alias_to_link = dict()
# these will/should be matched ignoring case # these will/should be matched ignoring case
wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons", wiki_namespaces = [
"d", "dbdump", "download", "Draft", "Education", "Foundation", "b",
"Gadget", "Gadget definition", "gerrit", "File", "Help", "Image", "Incubator", "betawikiversity",
"m", "mail", "mailarchive", "media", "MediaWiki", "MediaWiki talk", "Mediawikiwiki", "Book",
"MediaZilla", "Meta", "Metawikipedia", "Module", "c",
"mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki", "Category",
"Portal", "phab", "Phabricator", "Project", "q", "quality", "rev", "Commons",
"s", "spcom", "Special", "species", "Strategy", "sulutil", "svn", "d",
"Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools", "dbdump",
"tswiki", "User", "User talk", "v", "voy", "download",
"w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews", "Draft",
"Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech", "Education",
"Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"] "Foundation",
"Gadget",
"Gadget definition",
"gerrit",
"File",
"Help",
"Image",
"Incubator",
"m",
"mail",
"mailarchive",
"media",
"MediaWiki",
"MediaWiki talk",
"Mediawikiwiki",
"MediaZilla",
"Meta",
"Metawikipedia",
"Module",
"mw",
"n",
"nost",
"oldwikisource",
"outreach",
"outreachwiki",
"otrs",
"OTRSwiki",
"Portal",
"phab",
"Phabricator",
"Project",
"q",
"quality",
"rev",
"s",
"spcom",
"Special",
"species",
"Strategy",
"sulutil",
"svn",
"Talk",
"Template",
"Template talk",
"Testwiki",
"ticket",
"TimedText",
"Toollabs",
"tools",
"tswiki",
"User",
"User talk",
"v",
"voy",
"w",
"Wikibooks",
"Wikidata",
"wikiHow",
"Wikinvest",
"wikilivres",
"Wikimedia",
"Wikinews",
"Wikipedia",
"Wikipedia talk",
"Wikiquote",
"Wikisource",
"Wikispecies",
"Wikitech",
"Wikiversity",
"Wikivoyage",
"wikt",
"wiktionary",
"wmf",
"wmania",
"WP",
]
# find the links # find the links
link_regex = re.compile(r'\[\[[^\[\]]*\]\]') link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
# match on interwiki links, e.g. `en:` or `:fr:` # match on interwiki links, e.g. `en:` or `:fr:`
ns_regex = r":?" + "[a-z][a-z]" + ":" ns_regex = r":?" + "[a-z][a-z]" + ":"
@ -41,18 +116,22 @@ for ns in wiki_namespaces:
ns_regex = re.compile(ns_regex, re.IGNORECASE) ns_regex = re.compile(ns_regex, re.IGNORECASE)
def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output): def now():
return datetime.datetime.now()
def read_prior_probs(wikipedia_input, prior_prob_output):
""" """
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities. Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
The full file takes about 2h to parse 1100M lines. The full file takes about 2h to parse 1100M lines.
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from. It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
""" """
with bz2.open(wikipedia_input, mode='rb') as file: with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline() line = file.readline()
cnt = 0 cnt = 0
while line: while line:
if cnt % 5000000 == 0: if cnt % 5000000 == 0:
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump") print(now(), "processed", cnt, "lines of Wikipedia dump")
clean_line = line.strip().decode("utf-8") clean_line = line.strip().decode("utf-8")
aliases, entities, normalizations = get_wp_links(clean_line) aliases, entities, normalizations = get_wp_links(clean_line)
@ -64,10 +143,11 @@ def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
cnt += 1 cnt += 1
# write all aliases and their entities and count occurrences to file # write all aliases and their entities and count occurrences to file
with open(prior_prob_output, mode='w', encoding='utf8') as outputfile: with prior_prob_output.open("w", encoding="utf8") as outputfile:
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n") outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]): for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
for entity, count in sorted(alias_dict.items(), key=lambda x: x[1], reverse=True): s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
for entity, count in s_dict:
outputfile.write(alias + "|" + str(count) + "|" + entity + "\n") outputfile.write(alias + "|" + str(count) + "|" + entity + "\n")
@ -140,13 +220,13 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
entity_to_count = dict() entity_to_count = dict()
total_count = 0 total_count = 0
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file: with prior_prob_input.open("r", encoding="utf8") as prior_file:
# skip header # skip header
prior_file.readline() prior_file.readline()
line = prior_file.readline() line = prior_file.readline()
while line: while line:
splits = line.replace('\n', "").split(sep='|') splits = line.replace("\n", "").split(sep="|")
# alias = splits[0] # alias = splits[0]
count = int(splits[1]) count = int(splits[1])
entity = splits[2] entity = splits[2]
@ -158,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
line = prior_file.readline() line = prior_file.readline()
with open(count_output, mode='w', encoding='utf8') as entity_file: with count_output.open("w", encoding="utf8") as entity_file:
entity_file.write("entity" + "|" + "count" + "\n") entity_file.write("entity" + "|" + "count" + "\n")
for entity, count in entity_to_count.items(): for entity, count in entity_to_count.items():
entity_file.write(entity + "|" + str(count) + "\n") entity_file.write(entity + "|" + str(count) + "\n")
@ -171,12 +251,11 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
def get_all_frequencies(count_input): def get_all_frequencies(count_input):
entity_to_count = dict() entity_to_count = dict()
with open(count_input, 'r', encoding='utf8') as csvfile: with count_input.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter='|') csvreader = csv.reader(csvfile, delimiter="|")
# skip header # skip header
next(csvreader) next(csvreader)
for row in csvreader: for row in csvreader:
entity_to_count[row[0]] = int(row[1]) entity_to_count[row[0]] = int(row[1])
return entity_to_count return entity_to_count

View File

@ -14,15 +14,15 @@ def create_kb(vocab):
# adding entities # adding entities
entity_0 = "Q1004791_Douglas" entity_0 = "Q1004791_Douglas"
print("adding entity", entity_0) print("adding entity", entity_0)
kb.add_entity(entity=entity_0, prob=0.5, entity_vector=[0]) kb.add_entity(entity=entity_0, freq=0.5, entity_vector=[0])
entity_1 = "Q42_Douglas_Adams" entity_1 = "Q42_Douglas_Adams"
print("adding entity", entity_1) print("adding entity", entity_1)
kb.add_entity(entity=entity_1, prob=0.5, entity_vector=[1]) kb.add_entity(entity=entity_1, freq=0.5, entity_vector=[1])
entity_2 = "Q5301561_Douglas_Haig" entity_2 = "Q5301561_Douglas_Haig"
print("adding entity", entity_2) print("adding entity", entity_2)
kb.add_entity(entity=entity_2, prob=0.5, entity_vector=[2]) kb.add_entity(entity=entity_2, freq=0.5, entity_vector=[2])
# adding aliases # adding aliases
print() print()

View File

@ -1,11 +1,14 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import os
from os import path
import random import random
import datetime import datetime
from pathlib import Path from pathlib import Path
from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp from bin.wiki_entity_linking import wikipedia_processor as wp
from bin.wiki_entity_linking import training_set_creator, kb_creator
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
import spacy import spacy
@ -17,23 +20,26 @@ Demonstrate how to build a knowledge base from WikiData and run an Entity Linkin
""" """
ROOT_DIR = Path("C:/Users/Sofie/Documents/data/") ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
OUTPUT_DIR = ROOT_DIR / 'wikipedia' OUTPUT_DIR = ROOT_DIR / "wikipedia"
TRAINING_DIR = OUTPUT_DIR / 'training_data_nel' TRAINING_DIR = OUTPUT_DIR / "training_data_nel"
PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv' PRIOR_PROB = OUTPUT_DIR / "prior_prob.csv"
ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv' ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv"
ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv' ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv"
ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv' ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv"
KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb' KB_DIR = OUTPUT_DIR / "kb_1"
NLP_1_DIR = OUTPUT_DIR / 'nlp_1' KB_FILE = "kb"
NLP_2_DIR = OUTPUT_DIR / 'nlp_2' NLP_1_DIR = OUTPUT_DIR / "nlp_1"
NLP_2_DIR = OUTPUT_DIR / "nlp_2"
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2' WIKIDATA_JSON = ROOT_DIR / "wikidata" / "wikidata-20190304-all.json.bz2"
# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/ # get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2' ENWIKI_DUMP = (
ROOT_DIR / "wikipedia" / "enwiki-20190320-pages-articles-multistream.xml.bz2"
)
# KB construction parameters # KB construction parameters
MAX_CANDIDATES = 10 MAX_CANDIDATES = 10
@ -48,11 +54,15 @@ L2 = 1e-6
CONTEXT_WIDTH = 128 CONTEXT_WIDTH = 128
def now():
return datetime.datetime.now()
def run_pipeline(): def run_pipeline():
# set the appropriate booleans to define which parts of the pipeline should be re(run) # set the appropriate booleans to define which parts of the pipeline should be re(run)
print("START", datetime.datetime.now()) print("START", now())
print() print()
nlp_1 = spacy.load('en_core_web_lg') nlp_1 = spacy.load("en_core_web_lg")
nlp_2 = None nlp_2 = None
kb_2 = None kb_2 = None
@ -82,20 +92,21 @@ def run_pipeline():
# STEP 1 : create prior probabilities from WP (run only once) # STEP 1 : create prior probabilities from WP (run only once)
if to_create_prior_probs: if to_create_prior_probs:
print("STEP 1: to_create_prior_probs", datetime.datetime.now()) print("STEP 1: to_create_prior_probs", now())
wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB) wp.read_prior_probs(ENWIKI_DUMP, PRIOR_PROB)
print() print()
# STEP 2 : deduce entity frequencies from WP (run only once) # STEP 2 : deduce entity frequencies from WP (run only once)
if to_create_entity_counts: if to_create_entity_counts:
print("STEP 2: to_create_entity_counts", datetime.datetime.now()) print("STEP 2: to_create_entity_counts", now())
wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False) wp.write_entity_counts(PRIOR_PROB, ENTITY_COUNTS, to_print=False)
print() print()
# STEP 3 : create KB and write to file (run only once) # STEP 3 : create KB and write to file (run only once)
if to_create_kb: if to_create_kb:
print("STEP 3a: to_create_kb", datetime.datetime.now()) print("STEP 3a: to_create_kb", now())
kb_1 = kb_creator.create_kb(nlp_1, kb_1 = kb_creator.create_kb(
nlp=nlp_1,
max_entities_per_alias=MAX_CANDIDATES, max_entities_per_alias=MAX_CANDIDATES,
min_entity_freq=MIN_ENTITY_FREQ, min_entity_freq=MIN_ENTITY_FREQ,
min_occ=MIN_PAIR_OCC, min_occ=MIN_PAIR_OCC,
@ -103,22 +114,26 @@ def run_pipeline():
entity_descr_output=ENTITY_DESCR, entity_descr_output=ENTITY_DESCR,
count_input=ENTITY_COUNTS, count_input=ENTITY_COUNTS,
prior_prob_input=PRIOR_PROB, prior_prob_input=PRIOR_PROB,
wikidata_input=WIKIDATA_JSON) wikidata_input=WIKIDATA_JSON,
)
print("kb entities:", kb_1.get_size_entities()) print("kb entities:", kb_1.get_size_entities())
print("kb aliases:", kb_1.get_size_aliases()) print("kb aliases:", kb_1.get_size_aliases())
print() print()
print("STEP 3b: write KB and NLP", datetime.datetime.now()) print("STEP 3b: write KB and NLP", now())
kb_1.dump(KB_FILE)
if not path.exists(KB_DIR):
os.makedirs(KB_DIR)
kb_1.dump(KB_DIR / KB_FILE)
nlp_1.to_disk(NLP_1_DIR) nlp_1.to_disk(NLP_1_DIR)
print() print()
# STEP 4 : read KB back in from file # STEP 4 : read KB back in from file
if to_read_kb: if to_read_kb:
print("STEP 4: to_read_kb", datetime.datetime.now()) print("STEP 4: to_read_kb", now())
nlp_2 = spacy.load(NLP_1_DIR) nlp_2 = spacy.load(NLP_1_DIR)
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH) kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
kb_2.load_bulk(KB_FILE) kb_2.load_bulk(KB_DIR / KB_FILE)
print("kb entities:", kb_2.get_size_entities()) print("kb entities:", kb_2.get_size_entities())
print("kb aliases:", kb_2.get_size_aliases()) print("kb aliases:", kb_2.get_size_aliases())
print() print()
@ -130,20 +145,26 @@ def run_pipeline():
# STEP 5: create a training dataset from WP # STEP 5: create a training dataset from WP
if create_wp_training: if create_wp_training:
print("STEP 5: create training dataset", datetime.datetime.now()) print("STEP 5: create training dataset", now())
training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP, training_set_creator.create_training(
wikipedia_input=ENWIKI_DUMP,
entity_def_input=ENTITY_DEFS, entity_def_input=ENTITY_DEFS,
training_output=TRAINING_DIR) training_output=TRAINING_DIR,
)
# STEP 6: create and train the entity linking pipe # STEP 6: create and train the entity linking pipe
if train_pipe: if train_pipe:
print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) print("STEP 6: training Entity Linking pipe", now())
type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)} type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)}
print(" -analysing", len(type_to_int), "different entity types") print(" -analysing", len(type_to_int), "different entity types")
el_pipe = nlp_2.create_pipe(name='entity_linker', el_pipe = nlp_2.create_pipe(
config={"context_width": CONTEXT_WIDTH, name="entity_linker",
config={
"context_width": CONTEXT_WIDTH,
"pretrained_vectors": nlp_2.vocab.vectors.name, "pretrained_vectors": nlp_2.vocab.vectors.name,
"type_to_int": type_to_int}) "type_to_int": type_to_int,
},
)
el_pipe.set_kb(kb_2) el_pipe.set_kb(kb_2)
nlp_2.add_pipe(el_pipe, last=True) nlp_2.add_pipe(el_pipe, last=True)
@ -157,18 +178,22 @@ def run_pipeline():
train_limit = 5000 train_limit = 5000
dev_limit = 5000 dev_limit = 5000
train_data = training_set_creator.read_training(nlp=nlp_2, # for training, get pos & neg instances that correspond to entries in the kb
train_data = training_set_creator.read_training(
nlp=nlp_2,
training_dir=TRAINING_DIR, training_dir=TRAINING_DIR,
dev=False, dev=False,
limit=train_limit) limit=train_limit,
kb=el_pipe.kb,
)
print("Training on", len(train_data), "articles") print("Training on", len(train_data), "articles")
print() print()
dev_data = training_set_creator.read_training(nlp=nlp_2, # for testing, get all pos instances, whether or not they are in the kb
training_dir=TRAINING_DIR, dev_data = training_set_creator.read_training(
dev=True, nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
limit=dev_limit) )
print("Dev testing on", len(dev_data), "articles") print("Dev testing on", len(dev_data), "articles")
print() print()
@ -187,8 +212,8 @@ def run_pipeline():
try: try:
docs, golds = zip(*batch) docs, golds = zip(*batch)
nlp_2.update( nlp_2.update(
docs, docs=docs,
golds, golds=golds,
sgd=optimizer, sgd=optimizer,
drop=DROPOUT, drop=DROPOUT,
losses=losses, losses=losses,
@ -200,48 +225,61 @@ def run_pipeline():
if batchnr > 0: if batchnr > 0:
el_pipe.cfg["context_weight"] = 1 el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 1 el_pipe.cfg["prior_weight"] = 1
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe) dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
losses['entity_linker'] = losses['entity_linker'] / batchnr losses["entity_linker"] = losses["entity_linker"] / batchnr
print("Epoch, train loss", itn, round(losses['entity_linker'], 2), print(
" / dev acc avg", round(dev_acc_context, 3)) "Epoch, train loss",
itn,
round(losses["entity_linker"], 2),
" / dev acc avg",
round(dev_acc_context, 3),
)
# STEP 7: measure the performance of our trained pipe on an independent dev set # STEP 7: measure the performance of our trained pipe on an independent dev set
if len(dev_data) and measure_performance: if len(dev_data) and measure_performance:
print() print()
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now()) print("STEP 7: performance measurement of Entity Linking pipe", now())
print() print()
counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2) counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
dev_data, kb_2
)
print("dev counts:", sorted(counts.items(), key=lambda x: x[0])) print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()]) oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()]) print("dev acc oracle:", round(acc_o, 3), oracle_by_label)
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
print("dev acc random:", round(acc_r, 3), random_by_label)
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
print("dev acc prior:", round(acc_p, 3), prior_by_label)
# using only context # using only context
el_pipe.cfg["context_weight"] = 1 el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 0 el_pipe.cfg["prior_weight"] = 0
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe) dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
print("dev acc context avg:", round(dev_acc_context, 3), context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()]) print("dev acc context avg:", round(dev_acc_context, 3), context_by_label)
# measuring combined accuracy (prior + context) # measuring combined accuracy (prior + context)
el_pipe.cfg["context_weight"] = 1 el_pipe.cfg["context_weight"] = 1
el_pipe.cfg["prior_weight"] = 1 el_pipe.cfg["prior_weight"] = 1
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False) dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]) print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
# STEP 8: apply the EL pipe on a toy example # STEP 8: apply the EL pipe on a toy example
if to_test_pipeline: if to_test_pipeline:
print() print()
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now()) print("STEP 8: applying Entity Linking to toy example", now())
print() print()
run_el_toy_example(nlp=nlp_2) run_el_toy_example(nlp=nlp_2)
# STEP 9: write the NLP pipeline (including entity linker) to file # STEP 9: write the NLP pipeline (including entity linker) to file
if to_write_nlp: if to_write_nlp:
print() print()
print("STEP 9: testing NLP IO", datetime.datetime.now()) print("STEP 9: testing NLP IO", now())
print() print()
print("writing to", NLP_2_DIR) print("writing to", NLP_2_DIR)
nlp_2.to_disk(NLP_2_DIR) nlp_2.to_disk(NLP_2_DIR)
@ -262,23 +300,22 @@ def run_pipeline():
el_pipe = nlp_3.get_pipe("entity_linker") el_pipe = nlp_3.get_pipe("entity_linker")
dev_limit = 5000 dev_limit = 5000
dev_data = training_set_creator.read_training(nlp=nlp_2, dev_data = training_set_creator.read_training(
training_dir=TRAINING_DIR, nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
dev=True, )
limit=dev_limit)
print("Dev testing from file on", len(dev_data), "articles") print("Dev testing from file on", len(dev_data), "articles")
print() print()
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe, error_analysis=False) dev_acc_combo, dev_acc_combo_dict = _measure_acc(dev_data, el_pipe)
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]) print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
print() print()
print("STOP", datetime.datetime.now()) print("STOP", now())
def _measure_accuracy(data, el_pipe=None, error_analysis=False): def _measure_acc(data, el_pipe=None, error_analysis=False):
# If the docs in the data require further processing with an entity linker, set el_pipe # If the docs in the data require further processing with an entity linker, set el_pipe
correct_by_label = dict() correct_by_label = dict()
incorrect_by_label = dict() incorrect_by_label = dict()
@ -291,16 +328,21 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
try: try:
correct_entries_per_article = dict() correct_entries_per_article = dict()
for entity in gold.links: for entity, kb_dict in gold.links.items():
start, end, gold_kb = entity start, end = entity
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb # only evaluating on positive examples
for gold_kb, value in kb_dict.items():
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
for ent in doc.ents: for ent in doc.ents:
ent_label = ent.label_ ent_label = ent.label_
pred_entity = ent.kb_id_ pred_entity = ent.kb_id_
start = ent.start_char start = ent.start_char
end = ent.end_char end = ent.end_char
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None: if gold_entity is not None:
if gold_entity == pred_entity: if gold_entity == pred_entity:
@ -311,7 +353,12 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
incorrect_by_label[ent_label] = incorrect + 1 incorrect_by_label[ent_label] = incorrect + 1
if error_analysis: if error_analysis:
print(ent.text, "in", doc) print(ent.text, "in", doc)
print("Predicted", pred_entity, "should have been", gold_entity) print(
"Predicted",
pred_entity,
"should have been",
gold_entity,
)
print() print()
except Exception as e: except Exception as e:
@ -323,16 +370,16 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
def _measure_baselines(data, kb): def _measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
counts_by_label = dict() counts_d = dict()
random_correct_by_label = dict() random_correct_d = dict()
random_incorrect_by_label = dict() random_incorrect_d = dict()
oracle_correct_by_label = dict() oracle_correct_d = dict()
oracle_incorrect_by_label = dict() oracle_incorrect_d = dict()
prior_correct_by_label = dict() prior_correct_d = dict()
prior_incorrect_by_label = dict() prior_incorrect_d = dict()
docs = [d for d, g in data if len(d) > 0] docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0] golds = [g for d, g in data if len(d) > 0]
@ -340,19 +387,24 @@ def _measure_baselines(data, kb):
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
try: try:
correct_entries_per_article = dict() correct_entries_per_article = dict()
for entity in gold.links: for entity, kb_dict in gold.links.items():
start, end, gold_kb = entity start, end = entity
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb for gold_kb, value in kb_dict.items():
# only evaluating on positive examples
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
for ent in doc.ents: for ent in doc.ents:
ent_label = ent.label_ label = ent.label_
start = ent.start_char start = ent.start_char
end = ent.end_char end = ent.end_char
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None: if gold_entity is not None:
counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1 counts_d[label] = counts_d.get(label, 0) + 1
candidates = kb.get_candidates(ent.text) candidates = kb.get_candidates(ent.text)
oracle_candidate = "" oracle_candidate = ""
best_candidate = "" best_candidate = ""
@ -370,28 +422,40 @@ def _measure_baselines(data, kb):
random_candidate = random.choice(candidates).entity_ random_candidate = random.choice(candidates).entity_
if gold_entity == best_candidate: if gold_entity == best_candidate:
prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1 prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
else: else:
prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1 prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
if gold_entity == random_candidate: if gold_entity == random_candidate:
random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1 random_correct_d[label] = random_correct_d.get(label, 0) + 1
else: else:
random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1 random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
if gold_entity == oracle_candidate: if gold_entity == oracle_candidate:
oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1 oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
else: else:
oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1 oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
except Exception as e: except Exception as e:
print("Error assessing accuracy", e) print("Error assessing accuracy", e)
acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label) acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
acc_rand, acc_rand_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label) acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label) acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
return counts_by_label, acc_rand, acc_rand_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label return (
counts_d,
acc_rand,
acc_rand_d,
acc_prior,
acc_prior_d,
acc_oracle,
acc_oracle_d,
)
def _offset(start, end):
return "{}_{}".format(start, end)
def calculate_acc(correct_by_label, incorrect_by_label): def calculate_acc(correct_by_label, incorrect_by_label):
@ -422,15 +486,23 @@ def check_kb(kb):
print("generating candidates for " + mention + " :") print("generating candidates for " + mention + " :")
for c in candidates: for c in candidates:
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")") print(
" ",
c.prior_prob,
c.alias_,
"-->",
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
)
print() print()
def run_el_toy_example(nlp): def run_el_toy_example(nlp):
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \ text = (
"Douglas reminds us to always bring our towel, even in China or Brazil. " \ "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
"The main character in Doug's novel is the man Arthur Dent, " \ "Douglas reminds us to always bring our towel, even in China or Brazil. "
"but Douglas doesn't write about George Washington or Homer Simpson." "The main character in Doug's novel is the man Arthur Dent, "
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
)
doc = nlp(text) doc = nlp(text)
print(text) print(text)
for ent in doc.ents: for ent in doc.ents:

View File

@ -663,15 +663,14 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False,
def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg): def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg):
# TODO proper error
if "entity_width" not in cfg: if "entity_width" not in cfg:
raise ValueError("entity_width not found") raise ValueError(Errors.E144.format(param="entity_width"))
if "context_width" not in cfg: if "context_width" not in cfg:
raise ValueError("context_width not found") raise ValueError(Errors.E144.format(param="context_width"))
conv_depth = cfg.get("conv_depth", 2) conv_depth = cfg.get("conv_depth", 2)
cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3) cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3)
pretrained_vectors = cfg.get("pretrained_vectors") # self.nlp.vocab.vectors.name pretrained_vectors = cfg.get("pretrained_vectors", None)
context_width = cfg.get("context_width") context_width = cfg.get("context_width")
entity_width = cfg.get("entity_width") entity_width = cfg.get("entity_width")

View File

@ -406,7 +406,15 @@ class Errors(object):
E141 = ("Entity vectors should be of length {required} instead of the provided {found}.") E141 = ("Entity vectors should be of length {required} instead of the provided {found}.")
E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'") E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'")
E143 = ("Labels for component '{name}' not initialized. Did you forget to call add_label()?") E143 = ("Labels for component '{name}' not initialized. Did you forget to call add_label()?")
E144 = ("Could not find parameter `{param}` when building the entity linker model.")
E145 = ("Error reading `{param}` from input file.")
E146 = ("Could not access `{path}`.")
E147 = ("Unexpected error in the {method} functionality of the EntityLinker: {msg}. "
"This is likely a bug in spaCy, so feel free to open an issue.")
E148 = ("Expected {ents} KB identifiers but got {ids}. Make sure that each entity in `doc.ents` "
"is assigned to a KB identifier.")
E149 = ("Error deserializing model. Check that the config used to create the "
"component matches the model being loaded.")
@add_codes @add_codes
class TempErrors(object): class TempErrors(object):

View File

@ -31,7 +31,7 @@ cdef class GoldParse:
cdef public list ents cdef public list ents
cdef public dict brackets cdef public dict brackets
cdef public object cats cdef public object cats
cdef public list links cdef public dict links
cdef readonly list cand_to_gold cdef readonly list cand_to_gold
cdef readonly list gold_to_cand cdef readonly list gold_to_cand

View File

@ -468,8 +468,11 @@ cdef class GoldParse:
examples of a label to have the value 0.0. Labels not in the examples of a label to have the value 0.0. Labels not in the
dictionary are treated as missing - the gradient for those labels dictionary are treated as missing - the gradient for those labels
will be zero. will be zero.
links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples, links (dict): A dict with `(start_char, end_char)` keys,
representing the external ID of an entity in a knowledge base. and the values being dicts with kb_id:value entries,
representing the external IDs in a knowledge base (KB)
mapped to either 1.0 or 0.0, indicating positive and
negative examples respectively.
RETURNS (GoldParse): The newly constructed object. RETURNS (GoldParse): The newly constructed object.
""" """
if words is None: if words is None:

View File

@ -79,7 +79,7 @@ cdef class KnowledgeBase:
return new_index return new_index
cdef inline int64_t c_add_entity(self, hash_t entity_hash, float prob, cdef inline int64_t c_add_entity(self, hash_t entity_hash, float freq,
int32_t vector_index, int feats_row) nogil: int32_t vector_index, int feats_row) nogil:
"""Add an entry to the vector of entries. """Add an entry to the vector of entries.
After calling this method, make sure to update also the _entry_index using the return value""" After calling this method, make sure to update also the _entry_index using the return value"""
@ -92,7 +92,7 @@ cdef class KnowledgeBase:
entry.entity_hash = entity_hash entry.entity_hash = entity_hash
entry.vector_index = vector_index entry.vector_index = vector_index
entry.feats_row = feats_row entry.feats_row = feats_row
entry.prob = prob entry.freq = freq
self._entries.push_back(entry) self._entries.push_back(entry)
return new_index return new_index
@ -125,7 +125,7 @@ cdef class KnowledgeBase:
entry.entity_hash = dummy_hash entry.entity_hash = dummy_hash
entry.vector_index = dummy_value entry.vector_index = dummy_value
entry.feats_row = dummy_value entry.feats_row = dummy_value
entry.prob = dummy_value entry.freq = dummy_value
# Avoid struct initializer to enable nogil # Avoid struct initializer to enable nogil
cdef vector[int64_t] dummy_entry_indices cdef vector[int64_t] dummy_entry_indices
@ -141,7 +141,7 @@ cdef class KnowledgeBase:
self._aliases_table.push_back(alias) self._aliases_table.push_back(alias)
cpdef load_bulk(self, loc) cpdef load_bulk(self, loc)
cpdef set_entities(self, entity_list, prob_list, vector_list) cpdef set_entities(self, entity_list, freq_list, vector_list)
cdef class Writer: cdef class Writer:
@ -149,7 +149,7 @@ cdef class Writer:
cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1 cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1
cdef int write_vector_element(self, float element) except -1 cdef int write_vector_element(self, float element) except -1
cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1 cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1
cdef int write_alias_length(self, int64_t alias_length) except -1 cdef int write_alias_length(self, int64_t alias_length) except -1
cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1 cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1
@ -162,7 +162,7 @@ cdef class Reader:
cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1 cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1
cdef int read_vector_element(self, float* element) except -1 cdef int read_vector_element(self, float* element) except -1
cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1 cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1
cdef int read_alias_length(self, int64_t* alias_length) except -1 cdef int read_alias_length(self, int64_t* alias_length) except -1
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1 cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1

View File

@ -94,7 +94,7 @@ cdef class KnowledgeBase:
def get_alias_strings(self): def get_alias_strings(self):
return [self.vocab.strings[x] for x in self._alias_index] return [self.vocab.strings[x] for x in self._alias_index]
def add_entity(self, unicode entity, float prob, vector[float] entity_vector): def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
""" """
Add an entity to the KB, optionally specifying its log probability based on corpus frequency Add an entity to the KB, optionally specifying its log probability based on corpus frequency
Return the hash of the entity ID/name at the end. Return the hash of the entity ID/name at the end.
@ -113,15 +113,15 @@ cdef class KnowledgeBase:
vector_index = self.c_add_vector(entity_vector=entity_vector) vector_index = self.c_add_vector(entity_vector=entity_vector)
new_index = self.c_add_entity(entity_hash=entity_hash, new_index = self.c_add_entity(entity_hash=entity_hash,
prob=prob, freq=freq,
vector_index=vector_index, vector_index=vector_index,
feats_row=-1) # Features table currently not implemented feats_row=-1) # Features table currently not implemented
self._entry_index[entity_hash] = new_index self._entry_index[entity_hash] = new_index
return entity_hash return entity_hash
cpdef set_entities(self, entity_list, prob_list, vector_list): cpdef set_entities(self, entity_list, freq_list, vector_list):
if len(entity_list) != len(prob_list) or len(entity_list) != len(vector_list): if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
raise ValueError(Errors.E140) raise ValueError(Errors.E140)
nr_entities = len(entity_list) nr_entities = len(entity_list)
@ -137,7 +137,7 @@ cdef class KnowledgeBase:
entity_hash = self.vocab.strings.add(entity_list[i]) entity_hash = self.vocab.strings.add(entity_list[i])
entry.entity_hash = entity_hash entry.entity_hash = entity_hash
entry.prob = prob_list[i] entry.freq = freq_list[i]
vector_index = self.c_add_vector(entity_vector=vector_list[i]) vector_index = self.c_add_vector(entity_vector=vector_list[i])
entry.vector_index = vector_index entry.vector_index = vector_index
@ -196,13 +196,42 @@ cdef class KnowledgeBase:
return [Candidate(kb=self, return [Candidate(kb=self,
entity_hash=self._entries[entry_index].entity_hash, entity_hash=self._entries[entry_index].entity_hash,
entity_freq=self._entries[entry_index].prob, entity_freq=self._entries[entry_index].freq,
entity_vector=self._vectors_table[self._entries[entry_index].vector_index], entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
alias_hash=alias_hash, alias_hash=alias_hash,
prior_prob=prob) prior_prob=prior_prob)
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs) for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
if entry_index != 0] if entry_index != 0]
def get_vector(self, unicode entity):
cdef hash_t entity_hash = self.vocab.strings[entity]
# Return an empty list if this entity is unknown in this KB
if entity_hash not in self._entry_index:
return [0] * self.entity_vector_length
entry_index = self._entry_index[entity_hash]
return self._vectors_table[self._entries[entry_index].vector_index]
def get_prior_prob(self, unicode entity, unicode alias):
""" Return the prior probability of a given alias being linked to a given entity,
or return 0.0 when this combination is not known in the knowledge base"""
cdef hash_t alias_hash = self.vocab.strings[alias]
cdef hash_t entity_hash = self.vocab.strings[entity]
if entity_hash not in self._entry_index or alias_hash not in self._alias_index:
return 0.0
alias_index = <int64_t>self._alias_index.get(alias_hash)
entry_index = self._entry_index[entity_hash]
alias_entry = self._aliases_table[alias_index]
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs):
if self._entries[entry_index].entity_hash == entity_hash:
return prior_prob
return 0.0
def dump(self, loc): def dump(self, loc):
cdef Writer writer = Writer(loc) cdef Writer writer = Writer(loc)
@ -222,7 +251,7 @@ cdef class KnowledgeBase:
entry = self._entries[entry_index] entry = self._entries[entry_index]
assert entry.entity_hash == entry_hash assert entry.entity_hash == entry_hash
assert entry_index == i assert entry_index == i
writer.write_entry(entry.entity_hash, entry.prob, entry.vector_index) writer.write_entry(entry.entity_hash, entry.freq, entry.vector_index)
i = i+1 i = i+1
writer.write_alias_length(self.get_size_aliases()) writer.write_alias_length(self.get_size_aliases())
@ -248,7 +277,7 @@ cdef class KnowledgeBase:
cdef hash_t entity_hash cdef hash_t entity_hash
cdef hash_t alias_hash cdef hash_t alias_hash
cdef int64_t entry_index cdef int64_t entry_index
cdef float prob cdef float freq, prob
cdef int32_t vector_index cdef int32_t vector_index
cdef KBEntryC entry cdef KBEntryC entry
cdef AliasC alias cdef AliasC alias
@ -284,10 +313,10 @@ cdef class KnowledgeBase:
# index 0 is a dummy object not stored in the _entry_index and can be ignored. # index 0 is a dummy object not stored in the _entry_index and can be ignored.
i = 1 i = 1
while i <= nr_entities: while i <= nr_entities:
reader.read_entry(&entity_hash, &prob, &vector_index) reader.read_entry(&entity_hash, &freq, &vector_index)
entry.entity_hash = entity_hash entry.entity_hash = entity_hash
entry.prob = prob entry.freq = freq
entry.vector_index = vector_index entry.vector_index = vector_index
entry.feats_row = -1 # Features table currently not implemented entry.feats_row = -1 # Features table currently not implemented
@ -343,7 +372,8 @@ cdef class Writer:
loc = bytes(loc) loc = bytes(loc)
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self._fp = fopen(<char*>bytes_loc, 'wb') self._fp = fopen(<char*>bytes_loc, 'wb')
assert self._fp != NULL if not self._fp:
raise IOError(Errors.E146.format(path=loc))
fseek(self._fp, 0, 0) fseek(self._fp, 0, 0)
def close(self): def close(self):
@ -357,9 +387,9 @@ cdef class Writer:
cdef int write_vector_element(self, float element) except -1: cdef int write_vector_element(self, float element) except -1:
self._write(&element, sizeof(element)) self._write(&element, sizeof(element))
cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1: cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1:
self._write(&entry_hash, sizeof(entry_hash)) self._write(&entry_hash, sizeof(entry_hash))
self._write(&entry_prob, sizeof(entry_prob)) self._write(&entry_freq, sizeof(entry_freq))
self._write(&vector_index, sizeof(vector_index)) self._write(&vector_index, sizeof(vector_index))
# Features table currently not implemented and not written to file # Features table currently not implemented and not written to file
@ -399,39 +429,39 @@ cdef class Reader:
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):
return 0 # end of file return 0 # end of file
raise IOError("error reading header from input file") raise IOError(Errors.E145.format(param="header"))
status = self._read(entity_vector_length, sizeof(int64_t)) status = self._read(entity_vector_length, sizeof(int64_t))
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):
return 0 # end of file return 0 # end of file
raise IOError("error reading header from input file") raise IOError(Errors.E145.format(param="vector length"))
cdef int read_vector_element(self, float* element) except -1: cdef int read_vector_element(self, float* element) except -1:
status = self._read(element, sizeof(float)) status = self._read(element, sizeof(float))
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):
return 0 # end of file return 0 # end of file
raise IOError("error reading entity vector from input file") raise IOError(Errors.E145.format(param="vector element"))
cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1: cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1:
status = self._read(entity_hash, sizeof(hash_t)) status = self._read(entity_hash, sizeof(hash_t))
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):
return 0 # end of file return 0 # end of file
raise IOError("error reading entity hash from input file") raise IOError(Errors.E145.format(param="entity hash"))
status = self._read(prob, sizeof(float)) status = self._read(freq, sizeof(float))
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):
return 0 # end of file return 0 # end of file
raise IOError("error reading entity prob from input file") raise IOError(Errors.E145.format(param="entity freq"))
status = self._read(vector_index, sizeof(int32_t)) status = self._read(vector_index, sizeof(int32_t))
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):
return 0 # end of file return 0 # end of file
raise IOError("error reading entity vector from input file") raise IOError(Errors.E145.format(param="vector index"))
if feof(self._fp): if feof(self._fp):
return 0 return 0
@ -443,33 +473,33 @@ cdef class Reader:
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):
return 0 # end of file return 0 # end of file
raise IOError("error reading alias length from input file") raise IOError(Errors.E145.format(param="alias length"))
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1: cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1:
status = self._read(alias_hash, sizeof(hash_t)) status = self._read(alias_hash, sizeof(hash_t))
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):
return 0 # end of file return 0 # end of file
raise IOError("error reading alias hash from input file") raise IOError(Errors.E145.format(param="alias hash"))
status = self._read(candidate_length, sizeof(int64_t)) status = self._read(candidate_length, sizeof(int64_t))
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):
return 0 # end of file return 0 # end of file
raise IOError("error reading candidate length from input file") raise IOError(Errors.E145.format(param="candidate length"))
cdef int read_alias(self, int64_t* entry_index, float* prob) except -1: cdef int read_alias(self, int64_t* entry_index, float* prob) except -1:
status = self._read(entry_index, sizeof(int64_t)) status = self._read(entry_index, sizeof(int64_t))
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):
return 0 # end of file return 0 # end of file
raise IOError("error reading entry index for alias from input file") raise IOError(Errors.E145.format(param="entry index"))
status = self._read(prob, sizeof(float)) status = self._read(prob, sizeof(float))
if status < 1: if status < 1:
if feof(self._fp): if feof(self._fp):
return 0 # end of file return 0 # end of file
raise IOError("error reading prob for entity/alias from input file") raise IOError(Errors.E145.format(param="prior probability"))
cdef int _read(self, void* value, size_t size) except -1: cdef int _read(self, void* value, size_t size) except -1:
status = fread(value, size, 1, self._fp) status = fread(value, size, 1, self._fp)

View File

@ -12,10 +12,6 @@ from ...language import Language
from ...attrs import LANG, NORM from ...attrs import LANG, NORM
from ...util import update_exc, add_lookups from ...util import update_exc, add_lookups
# Borrowing french syntax parser because both languages use
# universal dependencies for tagging/parsing.
# Read here for more:
# https://github.com/explosion/spaCy/pull/1882#issuecomment-361409573
from .syntax_iterators import SYNTAX_ITERATORS from .syntax_iterators import SYNTAX_ITERATORS

View File

@ -14,7 +14,6 @@ from thinc.neural.util import to_categorical
from thinc.neural.util import get_array_module from thinc.neural.util import get_array_module
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase
from ..cli.pretrain import get_cossim_loss
from .functions import merge_subtokens from .functions import merge_subtokens
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..syntax.nn_parser cimport Parser from ..syntax.nn_parser cimport Parser
@ -168,7 +167,10 @@ class Pipe(object):
self.cfg["pretrained_vectors"] = self.vocab.vectors.name self.cfg["pretrained_vectors"] = self.vocab.vectors.name
if self.model is True: if self.model is True:
self.model = self.Model(**self.cfg) self.model = self.Model(**self.cfg)
try:
self.model.from_bytes(b) self.model.from_bytes(b)
except AttributeError:
raise ValueError(Errors.E149)
deserialize = OrderedDict() deserialize = OrderedDict()
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
@ -197,7 +199,10 @@ class Pipe(object):
self.cfg["pretrained_vectors"] = self.vocab.vectors.name self.cfg["pretrained_vectors"] = self.vocab.vectors.name
if self.model is True: if self.model is True:
self.model = self.Model(**self.cfg) self.model = self.Model(**self.cfg)
try:
self.model.from_bytes(p.open("rb").read()) self.model.from_bytes(p.open("rb").read())
except AttributeError:
raise ValueError(Errors.E149)
deserialize = OrderedDict() deserialize = OrderedDict()
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p)) deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
@ -563,7 +568,10 @@ class Tagger(Pipe):
"token_vector_width", "token_vector_width",
self.cfg.get("token_vector_width", 96)) self.cfg.get("token_vector_width", 96))
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
try:
self.model.from_bytes(b) self.model.from_bytes(b)
except AttributeError:
raise ValueError(Errors.E149)
def load_tag_map(b): def load_tag_map(b):
tag_map = srsly.msgpack_loads(b) tag_map = srsly.msgpack_loads(b)
@ -601,7 +609,10 @@ class Tagger(Pipe):
if self.model is True: if self.model is True:
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg) self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
with p.open("rb") as file_: with p.open("rb") as file_:
try:
self.model.from_bytes(file_.read()) self.model.from_bytes(file_.read())
except AttributeError:
raise ValueError(Errors.E149)
def load_tag_map(p): def load_tag_map(p):
tag_map = srsly.read_msgpack(p) tag_map = srsly.read_msgpack(p)
@ -1077,6 +1088,7 @@ class EntityLinker(Pipe):
DOCS: TODO DOCS: TODO
""" """
name = 'entity_linker' name = 'entity_linker'
NIL = "NIL" # string used to refer to a non-existing link
@classmethod @classmethod
def Model(cls, **cfg): def Model(cls, **cfg):
@ -1093,6 +1105,8 @@ class EntityLinker(Pipe):
self.kb = None self.kb = None
self.cfg = dict(cfg) self.cfg = dict(cfg)
self.sgd_context = None self.sgd_context = None
if not self.cfg.get("context_width"):
self.cfg["context_width"] = 128
def set_kb(self, kb): def set_kb(self, kb):
self.kb = kb self.kb = kb
@ -1140,7 +1154,7 @@ class EntityLinker(Pipe):
context_docs = [] context_docs = []
entity_encodings = [] entity_encodings = []
cats = []
priors = [] priors = []
type_vectors = [] type_vectors = []
@ -1149,50 +1163,44 @@ class EntityLinker(Pipe):
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
ents_by_offset = dict() ents_by_offset = dict()
for ent in doc.ents: for ent in doc.ents:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent
for entity in gold.links: for entity, kb_dict in gold.links.items():
start, end, gold_kb = entity start, end = entity
mention = doc.text[start:end] mention = doc.text[start:end]
for kb_id, value in kb_dict.items():
entity_encoding = self.kb.get_vector(kb_id)
prior_prob = self.kb.get_prior_prob(kb_id, mention)
gold_ent = ents_by_offset["{}_{}".format(start, end)]
if gold_ent is None:
raise RuntimeError(Errors.E147.format(method="update", msg="gold entity not found"))
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
assert gold_ent is not None
type_vector = [0 for i in range(len(type_to_int))] type_vector = [0 for i in range(len(type_to_int))]
if len(type_to_int) > 0: if len(type_to_int) > 0:
type_vector[type_to_int[gold_ent.label_]] = 1 type_vector[type_to_int[gold_ent.label_]] = 1
candidates = self.kb.get_candidates(mention) # store data
random.shuffle(candidates)
nr_neg = 0
for c in candidates:
kb_id = c.entity_
entity_encoding = c.entity_vector
entity_encodings.append(entity_encoding) entity_encodings.append(entity_encoding)
context_docs.append(doc) context_docs.append(doc)
type_vectors.append(type_vector) type_vectors.append(type_vector)
if self.cfg.get("prior_weight", 1) > 0: if self.cfg.get("prior_weight", 1) > 0:
priors.append([c.prior_prob]) priors.append([prior_prob])
else: else:
priors.append([0]) priors.append([0])
if kb_id == gold_kb:
cats.append([1])
else:
nr_neg += 1
cats.append([0])
if len(entity_encodings) > 0: if len(entity_encodings) > 0:
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors) if not (len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)):
raise RuntimeError(Errors.E147.format(method="update", msg="vector lengths not equal"))
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i] mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i]
for i in range(len(entity_encodings))] for i in range(len(entity_encodings))]
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop) pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
cats = self.model.ops.asarray(cats, dtype="float32")
loss, d_scores = self.get_loss(prediction=pred, golds=cats, docs=None) loss, d_scores = self.get_loss(scores=pred, golds=golds, docs=docs)
mention_gradient = bp_mention(d_scores, sgd=sgd) mention_gradient = bp_mention(d_scores, sgd=sgd)
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient] context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
@ -1203,39 +1211,45 @@ class EntityLinker(Pipe):
return loss return loss
return 0 return 0
def get_loss(self, docs, golds, prediction): def get_loss(self, docs, golds, scores):
d_scores = (prediction - golds) cats = []
for gold in golds:
for entity, kb_dict in gold.links.items():
for kb_id, value in kb_dict.items():
cats.append([value])
cats = self.model.ops.asarray(cats, dtype="float32")
if len(scores) != len(cats):
raise RuntimeError(Errors.E147.format(method="get_loss", msg="gold entities do not match up"))
d_scores = (scores - cats)
loss = (d_scores ** 2).sum() loss = (d_scores ** 2).sum()
loss = loss / len(golds) loss = loss / len(cats)
return loss, d_scores return loss, d_scores
def get_loss_old(self, docs, golds, scores):
# this loss function assumes we're only using positive examples
loss, gradients = get_cossim_loss(yh=scores, y=golds)
loss = loss / len(golds)
return loss, gradients
def __call__(self, doc): def __call__(self, doc):
entities, kb_ids = self.predict([doc]) kb_ids, tensors = self.predict([doc])
self.set_annotations([doc], entities, kb_ids) self.set_annotations([doc], kb_ids, tensors=tensors)
return doc return doc
def pipe(self, stream, batch_size=128, n_threads=-1): def pipe(self, stream, batch_size=128, n_threads=-1):
for docs in util.minibatch(stream, size=batch_size): for docs in util.minibatch(stream, size=batch_size):
docs = list(docs) docs = list(docs)
entities, kb_ids = self.predict(docs) kb_ids, tensors = self.predict(docs)
self.set_annotations(docs, entities, kb_ids) self.set_annotations(docs, kb_ids, tensors=tensors)
yield from docs yield from docs
def predict(self, docs): def predict(self, docs):
""" Return the KB IDs for each entity in each doc, including NIL if there is no prediction """
self.require_model() self.require_model()
self.require_kb() self.require_kb()
final_entities = [] entity_count = 0
final_kb_ids = [] final_kb_ids = []
final_tensors = []
if not docs: if not docs:
return final_entities, final_kb_ids return final_kb_ids, final_tensors
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
@ -1247,14 +1261,19 @@ class EntityLinker(Pipe):
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if len(doc) > 0: if len(doc) > 0:
# currently, the context is the same for each entity in a sentence (should be refined)
context_encoding = context_encodings[i] context_encoding = context_encodings[i]
for ent in doc.ents: for ent in doc.ents:
entity_count += 1
type_vector = [0 for i in range(len(type_to_int))] type_vector = [0 for i in range(len(type_to_int))]
if len(type_to_int) > 0: if len(type_to_int) > 0:
type_vector[type_to_int[ent.label_]] = 1 type_vector[type_to_int[ent.label_]] = 1
candidates = self.kb.get_candidates(ent.text) candidates = self.kb.get_candidates(ent.text)
if candidates: if not candidates:
final_kb_ids.append(self.NIL) # no prediction possible for this entity
final_tensors.append(context_encoding)
else:
random.shuffle(candidates) random.shuffle(candidates)
# this will set the prior probabilities to 0 (just like in training) if their weight is 0 # this will set the prior probabilities to 0 (just like in training) if their weight is 0
@ -1264,7 +1283,9 @@ class EntityLinker(Pipe):
if self.cfg.get("context_weight", 1) > 0: if self.cfg.get("context_weight", 1) > 0:
entity_encodings = xp.asarray([c.entity_vector for c in candidates]) entity_encodings = xp.asarray([c.entity_vector for c in candidates])
assert len(entity_encodings) == len(prior_probs) if len(entity_encodings) != len(prior_probs):
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
mention_encodings = [list(context_encoding) + list(entity_encodings[i]) mention_encodings = [list(context_encoding) + list(entity_encodings[i])
+ list(prior_probs[i]) + type_vector + list(prior_probs[i]) + type_vector
for i in range(len(entity_encodings))] for i in range(len(entity_encodings))]
@ -1273,14 +1294,25 @@ class EntityLinker(Pipe):
# TODO: thresholding # TODO: thresholding
best_index = scores.argmax() best_index = scores.argmax()
best_candidate = candidates[best_index] best_candidate = candidates[best_index]
final_entities.append(ent)
final_kb_ids.append(best_candidate.entity_) final_kb_ids.append(best_candidate.entity_)
final_tensors.append(context_encoding)
return final_entities, final_kb_ids if not (len(final_tensors) == len(final_kb_ids) == entity_count):
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
def set_annotations(self, docs, entities, kb_ids=None): return final_kb_ids, final_tensors
for entity, kb_id in zip(entities, kb_ids):
for token in entity: def set_annotations(self, docs, kb_ids, tensors=None):
count_ents = len([ent for doc in docs for ent in doc.ents])
if count_ents != len(kb_ids):
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
i=0
for doc in docs:
for ent in doc.ents:
kb_id = kb_ids[i]
i += 1
for token in ent:
token.ent_kb_id_ = kb_id token.ent_kb_id_ = kb_id
def to_disk(self, path, exclude=tuple(), **kwargs): def to_disk(self, path, exclude=tuple(), **kwargs):
@ -1297,7 +1329,10 @@ class EntityLinker(Pipe):
def load_model(p): def load_model(p):
if self.model is True: if self.model is True:
self.model = self.Model(**self.cfg) self.model = self.Model(**self.cfg)
try:
self.model.from_bytes(p.open("rb").read()) self.model.from_bytes(p.open("rb").read())
except AttributeError:
raise ValueError(Errors.E149)
def load_kb(p): def load_kb(p):
kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"]) kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"])

View File

@ -93,7 +93,7 @@ cdef struct KBEntryC:
int32_t feats_row int32_t feats_row
# log probability of entity, based on corpus frequency # log probability of entity, based on corpus frequency
float prob float freq
# Each alias struct stores a list of Entry pointers with their prior probabilities # Each alias struct stores a list of Entry pointers with their prior probabilities

View File

@ -631,7 +631,10 @@ cdef class Parser:
cfg = {} cfg = {}
with (path / 'model').open('rb') as file_: with (path / 'model').open('rb') as file_:
bytes_data = file_.read() bytes_data = file_.read()
try:
self.model.from_bytes(bytes_data) self.model.from_bytes(bytes_data)
except AttributeError:
raise ValueError(Errors.E149)
self.cfg.update(cfg) self.cfg.update(cfg)
return self return self
@ -663,6 +666,9 @@ cdef class Parser:
else: else:
cfg = {} cfg = {}
if 'model' in msg: if 'model' in msg:
try:
self.model.from_bytes(msg['model']) self.model.from_bytes(msg['model'])
except AttributeError:
raise ValueError(Errors.E149)
self.cfg.update(cfg) self.cfg.update(cfg)
return self return self

View File

@ -13,22 +13,38 @@ def nlp():
return English() return English()
def assert_almost_equal(a, b):
delta = 0.0001
assert a - delta <= b <= a + delta
def test_kb_valid_entities(nlp): def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases""" """Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[8, 4, 3])
mykb.add_entity(entity='Q2', prob=0.5, entity_vector=[2]) mykb.add_entity(entity="Q2", freq=0.5, entity_vector=[2, 1, 0])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[-1, -6, 5])
# adding aliases # adding aliases
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2]) mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9]) mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the corresponding KB # test the size of the corresponding KB
assert(mykb.get_size_entities() == 3) assert mykb.get_size_entities() == 3
assert(mykb.get_size_aliases() == 2) assert mykb.get_size_aliases() == 2
# test retrieval of the entity vectors
assert mykb.get_vector("Q1") == [8, 4, 3]
assert mykb.get_vector("Q2") == [2, 1, 0]
assert mykb.get_vector("Q3") == [-1, -6, 5]
# test retrieval of prior probabilities
assert_almost_equal(mykb.get_prior_prob(entity="Q2", alias="douglas"), 0.8)
assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglas"), 0.2)
assert_almost_equal(mykb.get_prior_prob(entity="Q342", alias="douglas"), 0.0)
assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglassssss"), 0.0)
def test_kb_invalid_entities(nlp): def test_kb_invalid_entities(nlp):
@ -36,13 +52,15 @@ def test_kb_invalid_entities(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
# adding aliases - should fail because one of the given IDs is not valid # adding aliases - should fail because one of the given IDs is not valid
with pytest.raises(ValueError): with pytest.raises(ValueError):
mykb.add_alias(alias='douglas', entities=['Q2', 'Q342'], probabilities=[0.8, 0.2]) mykb.add_alias(
alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]
)
def test_kb_invalid_probabilities(nlp): def test_kb_invalid_probabilities(nlp):
@ -50,13 +68,13 @@ def test_kb_invalid_probabilities(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
# adding aliases - should fail because the sum of the probabilities exceeds 1 # adding aliases - should fail because the sum of the probabilities exceeds 1
with pytest.raises(ValueError): with pytest.raises(ValueError):
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.4]) mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
def test_kb_invalid_combination(nlp): def test_kb_invalid_combination(nlp):
@ -64,13 +82,15 @@ def test_kb_invalid_combination(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
# adding aliases - should fail because the entities and probabilities vectors are not of equal length # adding aliases - should fail because the entities and probabilities vectors are not of equal length
with pytest.raises(ValueError): with pytest.raises(ValueError):
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.3, 0.4, 0.1]) mykb.add_alias(
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1]
)
def test_kb_invalid_entity_vector(nlp): def test_kb_invalid_entity_vector(nlp):
@ -78,11 +98,11 @@ def test_kb_invalid_entity_vector(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1, 2, 3]) mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1, 2, 3])
# this should fail because the kb's expected entity vector length is 3 # this should fail because the kb's expected entity vector length is 3
with pytest.raises(ValueError): with pytest.raises(ValueError):
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
def test_candidate_generation(nlp): def test_candidate_generation(nlp):
@ -90,18 +110,24 @@ def test_candidate_generation(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=0.7, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
# adding aliases # adding aliases
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2]) mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9]) mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates # test the size of the relevant candidates
assert(len(mykb.get_candidates('douglas')) == 2) assert len(mykb.get_candidates("douglas")) == 2
assert(len(mykb.get_candidates('adam')) == 1) assert len(mykb.get_candidates("adam")) == 1
assert(len(mykb.get_candidates('shrubbery')) == 0) assert len(mykb.get_candidates("shrubbery")) == 0
# test the content of the candidates
assert mykb.get_candidates("adam")[0].entity_ == "Q2"
assert mykb.get_candidates("adam")[0].alias_ == "adam"
assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 0.2)
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
def test_preserving_links_asdoc(nlp): def test_preserving_links_asdoc(nlp):
@ -109,24 +135,26 @@ def test_preserving_links_asdoc(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1]) mykb.add_entity(entity="Q2", freq=0.8, entity_vector=[1])
# adding aliases # adding aliases
mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7]) mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6]) mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained) # set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
sentencizer = nlp.create_pipe("sentencizer") sentencizer = nlp.create_pipe("sentencizer")
nlp.add_pipe(sentencizer) nlp.add_pipe(sentencizer)
ruler = EntityRuler(nlp) ruler = EntityRuler(nlp)
patterns = [{"label": "GPE", "pattern": "Boston"}, patterns = [
{"label": "GPE", "pattern": "Denver"}] {"label": "GPE", "pattern": "Boston"},
{"label": "GPE", "pattern": "Denver"},
]
ruler.add_patterns(patterns) ruler.add_patterns(patterns)
nlp.add_pipe(ruler) nlp.add_pipe(ruler)
el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64}) el_pipe = nlp.create_pipe(name="entity_linker", config={"context_width": 64})
el_pipe.set_kb(mykb) el_pipe.set_kb(mykb)
el_pipe.begin_training() el_pipe.begin_training()
el_pipe.context_weight = 0 el_pipe.context_weight = 0

View File

@ -0,0 +1,112 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from ..util import get_doc
@pytest.fixture
def doc(en_tokenizer):
text = "He jests at scars, that never felt a wound."
heads = [1, 6, -1, -1, 3, 2, 1, 0, 1, -2, -3]
deps = [
"nsubj",
"ccomp",
"prep",
"pobj",
"punct",
"nsubj",
"neg",
"ROOT",
"det",
"dobj",
"punct",
]
tokens = en_tokenizer(text)
return get_doc(tokens.vocab, words=[t.text for t in tokens], heads=heads, deps=deps)
def test_issue3962(doc):
""" Ensure that as_doc does not result in out-of-bound access of tokens.
This is achieved by setting the head to itself if it would lie out of the span otherwise."""
span2 = doc[1:5] # "jests at scars ,"
doc2 = span2.as_doc()
doc2_json = doc2.to_json()
assert doc2_json
assert doc2[0].head.text == "jests" # head set to itself, being the new artificial root
assert doc2[0].dep_ == "dep"
assert doc2[1].head.text == "jests"
assert doc2[1].dep_ == "prep"
assert doc2[2].head.text == "at"
assert doc2[2].dep_ == "pobj"
assert doc2[3].head.text == "jests" # head set to the new artificial root
assert doc2[3].dep_ == "dep"
# We should still have 1 sentence
assert len(list(doc2.sents)) == 1
span3 = doc[6:9] # "never felt a"
doc3 = span3.as_doc()
doc3_json = doc3.to_json()
assert doc3_json
assert doc3[0].head.text == "felt"
assert doc3[0].dep_ == "neg"
assert doc3[1].head.text == "felt"
assert doc3[1].dep_ == "ROOT"
assert doc3[2].head.text == "felt" # head set to ancestor
assert doc3[2].dep_ == "dep"
# We should still have 1 sentence as "a" can be attached to "felt" instead of "wound"
assert len(list(doc3.sents)) == 1
@pytest.fixture
def two_sent_doc(en_tokenizer):
text = "He jests at scars. They never felt a wound."
heads = [1, 0, -1, -1, -3, 2, 1, 0, 1, -2, -3]
deps = [
"nsubj",
"ROOT",
"prep",
"pobj",
"punct",
"nsubj",
"neg",
"ROOT",
"det",
"dobj",
"punct",
]
tokens = en_tokenizer(text)
return get_doc(tokens.vocab, words=[t.text for t in tokens], heads=heads, deps=deps)
def test_issue3962_long(two_sent_doc):
""" Ensure that as_doc does not result in out-of-bound access of tokens.
This is achieved by setting the head to itself if it would lie out of the span otherwise."""
span2 = two_sent_doc[1:7] # "jests at scars. They never"
doc2 = span2.as_doc()
doc2_json = doc2.to_json()
assert doc2_json
assert doc2[0].head.text == "jests" # head set to itself, being the new artificial root (in sentence 1)
assert doc2[0].dep_ == "ROOT"
assert doc2[1].head.text == "jests"
assert doc2[1].dep_ == "prep"
assert doc2[2].head.text == "at"
assert doc2[2].dep_ == "pobj"
assert doc2[3].head.text == "jests"
assert doc2[3].dep_ == "punct"
assert doc2[4].head.text == "They" # head set to itself, being the new artificial root (in sentence 2)
assert doc2[4].dep_ == "dep"
assert doc2[4].head.text == "They" # head set to the new artificial head (in sentence 2)
assert doc2[4].dep_ == "dep"
# We should still have 2 sentences
sents = list(doc2.sents)
assert len(sents) == 2
assert sents[0].text == "jests at scars ."
assert sents[1].text == "They never"

View File

@ -0,0 +1,28 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.matcher import PhraseMatcher
from spacy.tokens import Doc
@pytest.mark.xfail
def test_issue4002(en_vocab):
"""Test that the PhraseMatcher can match on overwritten NORM attributes.
"""
matcher = PhraseMatcher(en_vocab, attr="NORM")
pattern1 = Doc(en_vocab, words=["c", "d"])
assert [t.norm_ for t in pattern1] == ["c", "d"]
matcher.add("TEST", None, pattern1)
doc = Doc(en_vocab, words=["a", "b", "c", "d"])
assert [t.norm_ for t in doc] == ["a", "b", "c", "d"]
matches = matcher(doc)
assert len(matches) == 1
matcher = PhraseMatcher(en_vocab, attr="NORM")
pattern2 = Doc(en_vocab, words=["1", "2"])
pattern2[0].norm_ = "c"
pattern2[1].norm_ = "d"
assert [t.norm_ for t in pattern2] == ["c", "d"]
matcher.add("TEST", None, pattern2)
matches = matcher(doc)
assert len(matches) == 1

View File

@ -30,10 +30,10 @@ def test_serialize_kb_disk(en_vocab):
def _get_dummy_kb(vocab): def _get_dummy_kb(vocab):
kb = KnowledgeBase(vocab=vocab, entity_vector_length=3) kb = KnowledgeBase(vocab=vocab, entity_vector_length=3)
kb.add_entity(entity='Q53', prob=0.33, entity_vector=[0, 5, 3]) kb.add_entity(entity='Q53', freq=0.33, entity_vector=[0, 5, 3])
kb.add_entity(entity='Q17', prob=0.2, entity_vector=[7, 1, 0]) kb.add_entity(entity='Q17', freq=0.2, entity_vector=[7, 1, 0])
kb.add_entity(entity='Q007', prob=0.7, entity_vector=[0, 0, 7]) kb.add_entity(entity='Q007', freq=0.7, entity_vector=[0, 0, 7])
kb.add_entity(entity='Q44', prob=0.4, entity_vector=[4, 4, 4]) kb.add_entity(entity='Q44', freq=0.4, entity_vector=[4, 4, 4])
kb.add_alias(alias='double07', entities=['Q17', 'Q007'], probabilities=[0.1, 0.9]) kb.add_alias(alias='double07', entities=['Q17', 'Q007'], probabilities=[0.1, 0.9])
kb.add_alias(alias='guy', entities=['Q53', 'Q007', 'Q17', 'Q44'], probabilities=[0.3, 0.3, 0.2, 0.1]) kb.add_alias(alias='guy', entities=['Q53', 'Q007', 'Q17', 'Q44'], probabilities=[0.3, 0.3, 0.2, 0.1])

View File

@ -348,7 +348,7 @@ cdef class Tokenizer:
"""Add a special-case tokenization rule. """Add a special-case tokenization rule.
string (unicode): The string to specially tokenize. string (unicode): The string to specially tokenize.
token_attrs (iterable): A sequence of dicts, where each dict describes substrings (iterable): A sequence of dicts, where each dict describes
a token and its attributes. The `ORTH` fields of the attributes a token and its attributes. The `ORTH` fields of the attributes
must exactly match the string when they are concatenated. must exactly match the string when they are concatenated.

View File

@ -794,7 +794,7 @@ cdef class Doc:
if array[i, col] != 0: if array[i, col] != 0:
self.vocab.morphology.assign_tag(&tokens[i], array[i, col]) self.vocab.morphology.assign_tag(&tokens[i], array[i, col])
# Now load the data # Now load the data
for i in range(self.length): for i in range(length):
token = &self.c[i] token = &self.c[i]
for j in range(n_attrs): for j in range(n_attrs):
if attr_ids[j] != TAG: if attr_ids[j] != TAG:
@ -804,7 +804,7 @@ cdef class Doc:
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs) self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)
# If document is parsed, set children # If document is parsed, set children
if self.is_parsed: if self.is_parsed:
set_children_from_heads(self.c, self.length) set_children_from_heads(self.c, length)
return self return self
def get_lca_matrix(self): def get_lca_matrix(self):

View File

@ -17,6 +17,7 @@ from ..attrs cimport attr_id_t
from ..parts_of_speech cimport univ_pos_t from ..parts_of_speech cimport univ_pos_t
from ..attrs cimport * from ..attrs cimport *
from ..lexeme cimport Lexeme from ..lexeme cimport Lexeme
from ..symbols cimport dep
from ..util import normalize_slice from ..util import normalize_slice
from ..compat import is_config, basestring_ from ..compat import is_config, basestring_
@ -206,7 +207,6 @@ cdef class Span:
DOCS: https://spacy.io/api/span#as_doc DOCS: https://spacy.io/api/span#as_doc
""" """
# TODO: Fix!
words = [t.text for t in self] words = [t.text for t in self]
spaces = [bool(t.whitespace_) for t in self] spaces = [bool(t.whitespace_) for t in self]
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces) cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
@ -220,7 +220,9 @@ cdef class Span:
else: else:
array_head.append(SENT_START) array_head.append(SENT_START)
array = self.doc.to_array(array_head) array = self.doc.to_array(array_head)
doc.from_array(array_head, array[self.start : self.end]) array = array[self.start : self.end]
self._fix_dep_copy(array_head, array)
doc.from_array(array_head, array)
doc.noun_chunks_iterator = self.doc.noun_chunks_iterator doc.noun_chunks_iterator = self.doc.noun_chunks_iterator
doc.user_hooks = self.doc.user_hooks doc.user_hooks = self.doc.user_hooks
doc.user_span_hooks = self.doc.user_span_hooks doc.user_span_hooks = self.doc.user_span_hooks
@ -235,6 +237,44 @@ cdef class Span:
doc.cats[cat_label] = value doc.cats[cat_label] = value
return doc return doc
def _fix_dep_copy(self, attrs, array):
""" Rewire dependency links to make sure their heads fall into the span
while still keeping the correct number of sentences. """
cdef int length = len(array)
cdef attr_t value
cdef int i, head_col, ancestor_i
old_to_new_root = dict()
if HEAD in attrs:
head_col = attrs.index(HEAD)
for i in range(length):
# if the HEAD refers to a token outside this span, find a more appropriate ancestor
token = self[i]
ancestor_i = token.head.i - self.start # span offset
if ancestor_i not in range(length):
if DEP in attrs:
array[i, attrs.index(DEP)] = dep
# try finding an ancestor within this span
ancestors = token.ancestors
for ancestor in ancestors:
ancestor_i = ancestor.i - self.start
if ancestor_i in range(length):
array[i, head_col] = ancestor_i - i
# if there is no appropriate ancestor, define a new artificial root
value = array[i, head_col]
if (i+value) not in range(length):
new_root = old_to_new_root.get(ancestor_i, None)
if new_root is not None:
# take the same artificial root as a previous token from the same sentence
array[i, head_col] = new_root - i
else:
# set this token as the new artificial root
array[i, head_col] = 0
old_to_new_root[ancestor_i] = i
return array
def merge(self, *args, **attributes): def merge(self, *args, **attributes):
"""Retokenize the document, such that the span is merged into a single """Retokenize the document, such that the span is merged into a single
token. token.
@ -500,7 +540,7 @@ cdef class Span:
if "root" in self.doc.user_span_hooks: if "root" in self.doc.user_span_hooks:
return self.doc.user_span_hooks["root"](self) return self.doc.user_span_hooks["root"](self)
# This should probably be called 'head', and the other one called # This should probably be called 'head', and the other one called
# 'gov'. But we went with 'head' elsehwhere, and now we're stuck =/ # 'gov'. But we went with 'head' elsewhere, and now we're stuck =/
cdef int i cdef int i
# First, we scan through the Span, and check whether there's a word # First, we scan through the Span, and check whether there's a word
# with head==0, i.e. a sentence root. If so, we can return it. The # with head==0, i.e. a sentence root. If so, we can return it. The

View File

@ -45,10 +45,11 @@ Whether the provided syntactic annotations form a projective dependency tree.
| Name | Type | Description | | Name | Type | Description |
| --------------------------------- | ---- | -------------------------------------------------------------------------------------------------------------------------------------------------------- | | --------------------------------- | ---- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `words` | list | The words. |
| `tags` | list | The part-of-speech tag annotations. | | `tags` | list | The part-of-speech tag annotations. |
| `heads` | list | The syntactic head annotations. | | `heads` | list | The syntactic head annotations. |
| `labels` | list | The syntactic relation-type annotations. | | `labels` | list | The syntactic relation-type annotations. |
| `ents` | list | The named entity annotations. | | `ner` | list | The named entity annotations as BILUO tags. |
| `cand_to_gold` | list | The alignment from candidate tokenization to gold tokenization. | | `cand_to_gold` | list | The alignment from candidate tokenization to gold tokenization. |
| `gold_to_cand` | list | The alignment from gold tokenization to candidate tokenization. | | `gold_to_cand` | list | The alignment from gold tokenization to candidate tokenization. |
| `cats` <Tag variant="new">2</Tag> | list | Entries in the list should be either a label, or a `(start, end, label)` triple. The tuple form is used for categories applied to spans of the document. | | `cats` <Tag variant="new">2</Tag> | list | Entries in the list should be either a label, or a `(start, end, label)` triple. The tuple form is used for categories applied to spans of the document. |