mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Merge pull request #4003 from svlandeg/feature/nel-fixes
API changes for Entity linking functionality
This commit is contained in:
commit
87fcf3141c
|
@ -13,9 +13,17 @@ INPUT_DIM = 300 # dimension of pre-trained input vectors
|
|||
DESC_WIDTH = 64 # dimension of output entity vectors
|
||||
|
||||
|
||||
def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
||||
entity_def_output, entity_descr_output,
|
||||
count_input, prior_prob_input, wikidata_input):
|
||||
def create_kb(
|
||||
nlp,
|
||||
max_entities_per_alias,
|
||||
min_entity_freq,
|
||||
min_occ,
|
||||
entity_def_output,
|
||||
entity_descr_output,
|
||||
count_input,
|
||||
prior_prob_input,
|
||||
wikidata_input,
|
||||
):
|
||||
# Create the knowledge base from Wikidata entries
|
||||
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
|
||||
|
||||
|
@ -28,7 +36,9 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
|||
title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
|
||||
|
||||
# write the title-ID and ID-description mappings to file
|
||||
_write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr)
|
||||
_write_entity_files(
|
||||
entity_def_output, entity_descr_output, title_to_id, id_to_descr
|
||||
)
|
||||
|
||||
else:
|
||||
# read the mappings from file
|
||||
|
@ -54,8 +64,8 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
|||
frequency_list.append(freq)
|
||||
filtered_title_to_id[title] = entity
|
||||
|
||||
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()),
|
||||
"titles with filter frequency", min_entity_freq)
|
||||
print(len(title_to_id.keys()), "original titles")
|
||||
print("kept", len(filtered_title_to_id.keys()), " with frequency", min_entity_freq)
|
||||
|
||||
print()
|
||||
print(" * train entity encoder", datetime.datetime.now())
|
||||
|
@ -70,14 +80,20 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
|||
|
||||
print()
|
||||
print(" * adding", len(entity_list), "entities", datetime.datetime.now())
|
||||
kb.set_entities(entity_list=entity_list, prob_list=frequency_list, vector_list=embeddings)
|
||||
kb.set_entities(
|
||||
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
||||
)
|
||||
|
||||
print()
|
||||
print(" * adding aliases", datetime.datetime.now())
|
||||
print()
|
||||
_add_aliases(kb, title_to_id=filtered_title_to_id,
|
||||
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
|
||||
prior_prob_input=prior_prob_input)
|
||||
_add_aliases(
|
||||
kb,
|
||||
title_to_id=filtered_title_to_id,
|
||||
max_entities_per_alias=max_entities_per_alias,
|
||||
min_occ=min_occ,
|
||||
prior_prob_input=prior_prob_input,
|
||||
)
|
||||
|
||||
print()
|
||||
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
|
||||
|
@ -86,13 +102,15 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
|||
return kb
|
||||
|
||||
|
||||
def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr):
|
||||
with open(entity_def_output, mode='w', encoding='utf8') as id_file:
|
||||
def _write_entity_files(
|
||||
entity_def_output, entity_descr_output, title_to_id, id_to_descr
|
||||
):
|
||||
with entity_def_output.open("w", encoding="utf8") as id_file:
|
||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||
for title, qid in title_to_id.items():
|
||||
id_file.write(title + "|" + str(qid) + "\n")
|
||||
|
||||
with open(entity_descr_output, mode='w', encoding='utf8') as descr_file:
|
||||
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||
for qid, descr in id_to_descr.items():
|
||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||
|
@ -100,8 +118,8 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_
|
|||
|
||||
def get_entity_to_id(entity_def_output):
|
||||
entity_to_id = dict()
|
||||
with open(entity_def_output, 'r', encoding='utf8') as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter='|')
|
||||
with entity_def_output.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
|
@ -111,8 +129,8 @@ def get_entity_to_id(entity_def_output):
|
|||
|
||||
def get_id_to_description(entity_descr_output):
|
||||
id_to_desc = dict()
|
||||
with open(entity_descr_output, 'r', encoding='utf8') as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter='|')
|
||||
with entity_descr_output.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
|
@ -125,7 +143,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
|||
|
||||
# adding aliases with prior probabilities
|
||||
# we can read this file sequentially, it's sorted by alias, and then by count
|
||||
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
|
||||
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||
# skip header
|
||||
prior_file.readline()
|
||||
line = prior_file.readline()
|
||||
|
@ -134,7 +152,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
|||
counts = []
|
||||
entities = []
|
||||
while line:
|
||||
splits = line.replace('\n', "").split(sep='|')
|
||||
splits = line.replace("\n", "").split(sep="|")
|
||||
new_alias = splits[0]
|
||||
count = int(splits[1])
|
||||
entity = splits[2]
|
||||
|
@ -153,7 +171,11 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
|||
|
||||
if selected_entities:
|
||||
try:
|
||||
kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs)
|
||||
kb.add_alias(
|
||||
alias=previous_alias,
|
||||
entities=selected_entities,
|
||||
probabilities=prior_probs,
|
||||
)
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
total_count = 0
|
||||
|
@ -168,4 +190,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
|||
previous_alias = new_alias
|
||||
|
||||
line = prior_file.readline()
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import bz2
|
||||
import datetime
|
||||
|
@ -17,6 +17,10 @@ Gold-standard entities are stored in one file in standoff format (by character o
|
|||
ENTITY_FILE = "gold_entities.csv"
|
||||
|
||||
|
||||
def now():
|
||||
return datetime.datetime.now()
|
||||
|
||||
|
||||
def create_training(wikipedia_input, entity_def_input, training_output):
|
||||
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
||||
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
|
||||
|
@ -27,21 +31,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
Read the XML wikipedia data to parse out training data:
|
||||
raw text data + positive instances
|
||||
"""
|
||||
title_regex = re.compile(r'(?<=<title>).*(?=</title>)')
|
||||
id_regex = re.compile(r'(?<=<id>)\d*(?=</id>)')
|
||||
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
|
||||
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
||||
|
||||
read_ids = set()
|
||||
entityfile_loc = training_output / ENTITY_FILE
|
||||
with open(entityfile_loc, mode="w", encoding='utf8') as entityfile:
|
||||
with entityfile_loc.open("w", encoding="utf8") as entityfile:
|
||||
# write entity training header file
|
||||
_write_training_entity(outputfile=entityfile,
|
||||
article_id="article_id",
|
||||
alias="alias",
|
||||
entity="WD_id",
|
||||
start="start",
|
||||
end="end")
|
||||
_write_training_entity(
|
||||
outputfile=entityfile,
|
||||
article_id="article_id",
|
||||
alias="alias",
|
||||
entity="WD_id",
|
||||
start="start",
|
||||
end="end",
|
||||
)
|
||||
|
||||
with bz2.open(wikipedia_input, mode='rb') as file:
|
||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||
line = file.readline()
|
||||
cnt = 0
|
||||
article_text = ""
|
||||
|
@ -51,7 +57,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
reading_revision = False
|
||||
while line and (not limit or cnt < limit):
|
||||
if cnt % 1000000 == 0:
|
||||
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
|
||||
print(now(), "processed", cnt, "lines of Wikipedia dump")
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
|
||||
if clean_line == "<revision>":
|
||||
|
@ -69,12 +75,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
elif clean_line == "</page>":
|
||||
if article_id:
|
||||
try:
|
||||
_process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(),
|
||||
training_output)
|
||||
_process_wp_text(
|
||||
wp_to_id,
|
||||
entityfile,
|
||||
article_id,
|
||||
article_title,
|
||||
article_text.strip(),
|
||||
training_output,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error processing article", article_id, article_title, e)
|
||||
print(
|
||||
"Error processing article", article_id, article_title, e
|
||||
)
|
||||
else:
|
||||
print("Done processing a page, but couldn't find an article_id ?", article_title)
|
||||
print(
|
||||
"Done processing a page, but couldn't find an article_id ?",
|
||||
article_title,
|
||||
)
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
|
@ -98,7 +115,9 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
if ids:
|
||||
article_id = ids[0]
|
||||
if article_id in read_ids:
|
||||
print("Found duplicate article ID", article_id, clean_line) # This should never happen ...
|
||||
print(
|
||||
"Found duplicate article ID", article_id, clean_line
|
||||
) # This should never happen ...
|
||||
read_ids.add(article_id)
|
||||
|
||||
# read the title of this article (outside the revision portion of the document)
|
||||
|
@ -111,10 +130,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
cnt += 1
|
||||
|
||||
|
||||
text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)')
|
||||
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
||||
|
||||
|
||||
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output):
|
||||
def _process_wp_text(
|
||||
wp_to_id, entityfile, article_id, article_title, article_text, training_output
|
||||
):
|
||||
found_entities = False
|
||||
|
||||
# ignore meta Wikipedia pages
|
||||
|
@ -141,11 +162,11 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
|||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
for index, letter in enumerate(clean_text):
|
||||
if letter == '[':
|
||||
if letter == "[":
|
||||
open_read += 1
|
||||
elif letter == ']':
|
||||
elif letter == "]":
|
||||
open_read -= 1
|
||||
elif letter == '|':
|
||||
elif letter == "|":
|
||||
if reading_text:
|
||||
final_text += letter
|
||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||
|
@ -163,7 +184,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
|||
elif reading_text:
|
||||
final_text += letter
|
||||
else:
|
||||
raise ValueError("Not sure at point", clean_text[index-2:index+2])
|
||||
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
|
||||
|
||||
if open_read > 2:
|
||||
reading_special_case = True
|
||||
|
@ -175,7 +196,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
|||
|
||||
# we just finished reading an entity
|
||||
if open_read == 0 and not reading_text:
|
||||
if '#' in entity_buffer or entity_buffer.startswith(':'):
|
||||
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
||||
reading_special_case = True
|
||||
# Ignore cases with nested structures like File: handles etc
|
||||
if not reading_special_case:
|
||||
|
@ -185,12 +206,14 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
|||
end = start + len(mention_buffer)
|
||||
qid = wp_to_id.get(entity_buffer, None)
|
||||
if qid:
|
||||
_write_training_entity(outputfile=entityfile,
|
||||
article_id=article_id,
|
||||
alias=mention_buffer,
|
||||
entity=qid,
|
||||
start=start,
|
||||
end=end)
|
||||
_write_training_entity(
|
||||
outputfile=entityfile,
|
||||
article_id=article_id,
|
||||
alias=mention_buffer,
|
||||
entity=qid,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
found_entities = True
|
||||
final_text += mention_buffer
|
||||
|
||||
|
@ -203,29 +226,35 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
|||
reading_special_case = False
|
||||
|
||||
if found_entities:
|
||||
_write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output)
|
||||
_write_training_article(
|
||||
article_id=article_id,
|
||||
clean_text=final_text,
|
||||
training_output=training_output,
|
||||
)
|
||||
|
||||
|
||||
info_regex = re.compile(r'{[^{]*?}')
|
||||
htlm_regex = re.compile(r'<!--[^-]*-->')
|
||||
category_regex = re.compile(r'\[\[Category:[^\[]*]]')
|
||||
file_regex = re.compile(r'\[\[File:[^[\]]+]]')
|
||||
ref_regex = re.compile(r'<ref.*?>') # non-greedy
|
||||
ref_2_regex = re.compile(r'</ref.*?>') # non-greedy
|
||||
info_regex = re.compile(r"{[^{]*?}")
|
||||
htlm_regex = re.compile(r"<!--[^-]*-->")
|
||||
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
|
||||
file_regex = re.compile(r"\[\[File:[^[\]]+]]")
|
||||
ref_regex = re.compile(r"<ref.*?>") # non-greedy
|
||||
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
||||
|
||||
|
||||
def _get_clean_wp_text(article_text):
|
||||
clean_text = article_text.strip()
|
||||
|
||||
# remove bolding & italic markup
|
||||
clean_text = clean_text.replace('\'\'\'', '')
|
||||
clean_text = clean_text.replace('\'\'', '')
|
||||
clean_text = clean_text.replace("'''", "")
|
||||
clean_text = clean_text.replace("''", "")
|
||||
|
||||
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
||||
try_again = True
|
||||
previous_length = len(clean_text)
|
||||
while try_again:
|
||||
clean_text = info_regex.sub('', clean_text) # non-greedy match excluding a nested {
|
||||
clean_text = info_regex.sub(
|
||||
"", clean_text
|
||||
) # non-greedy match excluding a nested {
|
||||
if len(clean_text) < previous_length:
|
||||
try_again = True
|
||||
else:
|
||||
|
@ -233,14 +262,14 @@ def _get_clean_wp_text(article_text):
|
|||
previous_length = len(clean_text)
|
||||
|
||||
# remove HTML comments
|
||||
clean_text = htlm_regex.sub('', clean_text)
|
||||
clean_text = htlm_regex.sub("", clean_text)
|
||||
|
||||
# remove Category and File statements
|
||||
clean_text = category_regex.sub('', clean_text)
|
||||
clean_text = file_regex.sub('', clean_text)
|
||||
clean_text = category_regex.sub("", clean_text)
|
||||
clean_text = file_regex.sub("", clean_text)
|
||||
|
||||
# remove multiple =
|
||||
while '==' in clean_text:
|
||||
while "==" in clean_text:
|
||||
clean_text = clean_text.replace("==", "=")
|
||||
|
||||
clean_text = clean_text.replace(". =", ".")
|
||||
|
@ -249,43 +278,47 @@ def _get_clean_wp_text(article_text):
|
|||
clean_text = clean_text.replace(" =", "")
|
||||
|
||||
# remove refs (non-greedy match)
|
||||
clean_text = ref_regex.sub('', clean_text)
|
||||
clean_text = ref_2_regex.sub('', clean_text)
|
||||
clean_text = ref_regex.sub("", clean_text)
|
||||
clean_text = ref_2_regex.sub("", clean_text)
|
||||
|
||||
# remove additional wikiformatting
|
||||
clean_text = re.sub(r'<blockquote>', '', clean_text)
|
||||
clean_text = re.sub(r'</blockquote>', '', clean_text)
|
||||
clean_text = re.sub(r"<blockquote>", "", clean_text)
|
||||
clean_text = re.sub(r"</blockquote>", "", clean_text)
|
||||
|
||||
# change special characters back to normal ones
|
||||
clean_text = clean_text.replace(r'<', '<')
|
||||
clean_text = clean_text.replace(r'>', '>')
|
||||
clean_text = clean_text.replace(r'"', '"')
|
||||
clean_text = clean_text.replace(r'&nbsp;', ' ')
|
||||
clean_text = clean_text.replace(r'&', '&')
|
||||
clean_text = clean_text.replace(r"<", "<")
|
||||
clean_text = clean_text.replace(r">", ">")
|
||||
clean_text = clean_text.replace(r""", '"')
|
||||
clean_text = clean_text.replace(r"&nbsp;", " ")
|
||||
clean_text = clean_text.replace(r"&", "&")
|
||||
|
||||
# remove multiple spaces
|
||||
while ' ' in clean_text:
|
||||
clean_text = clean_text.replace(' ', ' ')
|
||||
while " " in clean_text:
|
||||
clean_text = clean_text.replace(" ", " ")
|
||||
|
||||
return clean_text.strip()
|
||||
|
||||
|
||||
def _write_training_article(article_id, clean_text, training_output):
|
||||
file_loc = training_output / str(article_id) + ".txt"
|
||||
with open(file_loc, mode='w', encoding='utf8') as outputfile:
|
||||
file_loc = training_output / "{}.txt".format(article_id)
|
||||
with file_loc.open("w", encoding="utf8") as outputfile:
|
||||
outputfile.write(clean_text)
|
||||
|
||||
|
||||
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
|
||||
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n")
|
||||
line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end)
|
||||
outputfile.write(line)
|
||||
|
||||
|
||||
def is_dev(article_id):
|
||||
return article_id.endswith("3")
|
||||
|
||||
|
||||
def read_training(nlp, training_dir, dev, limit):
|
||||
# This method provides training examples that correspond to the entity annotations found by the nlp object
|
||||
def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||
When kb is provided (for training), it will include negative training examples by using the candidate generator,
|
||||
and it will only keep positive training examples that can be found in the KB.
|
||||
When kb=None (for testing), it will include all positive examples only."""
|
||||
entityfile_loc = training_dir / ENTITY_FILE
|
||||
data = []
|
||||
|
||||
|
@ -296,24 +329,30 @@ def read_training(nlp, training_dir, dev, limit):
|
|||
skip_articles = set()
|
||||
total_entities = 0
|
||||
|
||||
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
||||
with entityfile_loc.open("r", encoding="utf8") as file:
|
||||
for line in file:
|
||||
if not limit or len(data) < limit:
|
||||
fields = line.replace('\n', "").split(sep='|')
|
||||
fields = line.replace("\n", "").split(sep="|")
|
||||
article_id = fields[0]
|
||||
alias = fields[1]
|
||||
wp_title = fields[2]
|
||||
wd_id = fields[2]
|
||||
start = fields[3]
|
||||
end = fields[4]
|
||||
|
||||
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
|
||||
if (
|
||||
dev == is_dev(article_id)
|
||||
and article_id != "article_id"
|
||||
and article_id not in skip_articles
|
||||
):
|
||||
if not current_doc or (current_article_id != article_id):
|
||||
# parse the new article text
|
||||
file_name = article_id + ".txt"
|
||||
try:
|
||||
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
|
||||
training_file = training_dir / file_name
|
||||
with training_file.open("r", encoding="utf8") as f:
|
||||
text = f.read()
|
||||
if len(text) < 30000: # threshold for convenience / speed of processing
|
||||
# threshold for convenience / speed of processing
|
||||
if len(text) < 30000:
|
||||
current_doc = nlp(text)
|
||||
current_article_id = article_id
|
||||
ents_by_offset = dict()
|
||||
|
@ -321,33 +360,69 @@ def read_training(nlp, training_dir, dev, limit):
|
|||
sent_length = len(ent.sent)
|
||||
# custom filtering to avoid too long or too short sentences
|
||||
if 5 < sent_length < 100:
|
||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
||||
offset = "{}_{}".format(
|
||||
ent.start_char, ent.end_char
|
||||
)
|
||||
ents_by_offset[offset] = ent
|
||||
else:
|
||||
skip_articles.add(article_id)
|
||||
current_doc = None
|
||||
except Exception as e:
|
||||
print("Problem parsing article", article_id, e)
|
||||
skip_articles.add(article_id)
|
||||
raise e
|
||||
|
||||
# repeat checking this condition in case an exception was thrown
|
||||
if current_doc and (current_article_id == article_id):
|
||||
found_ent = ents_by_offset.get(start + "_" + end, None)
|
||||
offset = "{}_{}".format(start, end)
|
||||
found_ent = ents_by_offset.get(offset, None)
|
||||
if found_ent:
|
||||
if found_ent.text != alias:
|
||||
skip_articles.add(article_id)
|
||||
current_doc = None
|
||||
else:
|
||||
sent = found_ent.sent.as_doc()
|
||||
# currently feeding the gold data one entity per sentence at a time
|
||||
|
||||
gold_start = int(start) - found_ent.sent.start_char
|
||||
gold_end = int(end) - found_ent.sent.start_char
|
||||
gold_entities = [(gold_start, gold_end, wp_title)]
|
||||
gold = GoldParse(doc=sent, links=gold_entities)
|
||||
data.append((sent, gold))
|
||||
total_entities += 1
|
||||
if len(data) % 2500 == 0:
|
||||
print(" -read", total_entities, "entities")
|
||||
|
||||
gold_entities = {}
|
||||
found_useful = False
|
||||
for ent in sent.ents:
|
||||
entry = (ent.start_char, ent.end_char)
|
||||
gold_entry = (gold_start, gold_end)
|
||||
if entry == gold_entry:
|
||||
# add both pos and neg examples (in random order)
|
||||
# this will exclude examples not in the KB
|
||||
if kb:
|
||||
value_by_id = {}
|
||||
candidates = kb.get_candidates(alias)
|
||||
candidate_ids = [
|
||||
c.entity_ for c in candidates
|
||||
]
|
||||
random.shuffle(candidate_ids)
|
||||
for kb_id in candidate_ids:
|
||||
found_useful = True
|
||||
if kb_id != wd_id:
|
||||
value_by_id[kb_id] = 0.0
|
||||
else:
|
||||
value_by_id[kb_id] = 1.0
|
||||
gold_entities[entry] = value_by_id
|
||||
# if no KB, keep all positive examples
|
||||
else:
|
||||
found_useful = True
|
||||
value_by_id = {wd_id: 1.0}
|
||||
|
||||
gold_entities[entry] = value_by_id
|
||||
# currently feeding the gold data one entity per sentence at a time
|
||||
# setting all other entities to empty gold dictionary
|
||||
else:
|
||||
gold_entities[entry] = {}
|
||||
if found_useful:
|
||||
gold = GoldParse(doc=sent, links=gold_entities)
|
||||
data.append((sent, gold))
|
||||
total_entities += 1
|
||||
if len(data) % 2500 == 0:
|
||||
print(" -read", total_entities, "entities")
|
||||
|
||||
print(" -read", total_entities, "entities")
|
||||
return data
|
||||
|
|
|
@ -14,22 +14,97 @@ Write these results to file for downstream KB and training data generation.
|
|||
map_alias_to_link = dict()
|
||||
|
||||
# these will/should be matched ignoring case
|
||||
wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons",
|
||||
"d", "dbdump", "download", "Draft", "Education", "Foundation",
|
||||
"Gadget", "Gadget definition", "gerrit", "File", "Help", "Image", "Incubator",
|
||||
"m", "mail", "mailarchive", "media", "MediaWiki", "MediaWiki talk", "Mediawikiwiki",
|
||||
"MediaZilla", "Meta", "Metawikipedia", "Module",
|
||||
"mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki",
|
||||
"Portal", "phab", "Phabricator", "Project", "q", "quality", "rev",
|
||||
"s", "spcom", "Special", "species", "Strategy", "sulutil", "svn",
|
||||
"Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools",
|
||||
"tswiki", "User", "User talk", "v", "voy",
|
||||
"w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews",
|
||||
"Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech",
|
||||
"Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"]
|
||||
wiki_namespaces = [
|
||||
"b",
|
||||
"betawikiversity",
|
||||
"Book",
|
||||
"c",
|
||||
"Category",
|
||||
"Commons",
|
||||
"d",
|
||||
"dbdump",
|
||||
"download",
|
||||
"Draft",
|
||||
"Education",
|
||||
"Foundation",
|
||||
"Gadget",
|
||||
"Gadget definition",
|
||||
"gerrit",
|
||||
"File",
|
||||
"Help",
|
||||
"Image",
|
||||
"Incubator",
|
||||
"m",
|
||||
"mail",
|
||||
"mailarchive",
|
||||
"media",
|
||||
"MediaWiki",
|
||||
"MediaWiki talk",
|
||||
"Mediawikiwiki",
|
||||
"MediaZilla",
|
||||
"Meta",
|
||||
"Metawikipedia",
|
||||
"Module",
|
||||
"mw",
|
||||
"n",
|
||||
"nost",
|
||||
"oldwikisource",
|
||||
"outreach",
|
||||
"outreachwiki",
|
||||
"otrs",
|
||||
"OTRSwiki",
|
||||
"Portal",
|
||||
"phab",
|
||||
"Phabricator",
|
||||
"Project",
|
||||
"q",
|
||||
"quality",
|
||||
"rev",
|
||||
"s",
|
||||
"spcom",
|
||||
"Special",
|
||||
"species",
|
||||
"Strategy",
|
||||
"sulutil",
|
||||
"svn",
|
||||
"Talk",
|
||||
"Template",
|
||||
"Template talk",
|
||||
"Testwiki",
|
||||
"ticket",
|
||||
"TimedText",
|
||||
"Toollabs",
|
||||
"tools",
|
||||
"tswiki",
|
||||
"User",
|
||||
"User talk",
|
||||
"v",
|
||||
"voy",
|
||||
"w",
|
||||
"Wikibooks",
|
||||
"Wikidata",
|
||||
"wikiHow",
|
||||
"Wikinvest",
|
||||
"wikilivres",
|
||||
"Wikimedia",
|
||||
"Wikinews",
|
||||
"Wikipedia",
|
||||
"Wikipedia talk",
|
||||
"Wikiquote",
|
||||
"Wikisource",
|
||||
"Wikispecies",
|
||||
"Wikitech",
|
||||
"Wikiversity",
|
||||
"Wikivoyage",
|
||||
"wikt",
|
||||
"wiktionary",
|
||||
"wmf",
|
||||
"wmania",
|
||||
"WP",
|
||||
]
|
||||
|
||||
# find the links
|
||||
link_regex = re.compile(r'\[\[[^\[\]]*\]\]')
|
||||
link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
|
||||
|
||||
# match on interwiki links, e.g. `en:` or `:fr:`
|
||||
ns_regex = r":?" + "[a-z][a-z]" + ":"
|
||||
|
@ -41,18 +116,22 @@ for ns in wiki_namespaces:
|
|||
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
||||
|
||||
|
||||
def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
|
||||
def now():
|
||||
return datetime.datetime.now()
|
||||
|
||||
|
||||
def read_prior_probs(wikipedia_input, prior_prob_output):
|
||||
"""
|
||||
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
||||
The full file takes about 2h to parse 1100M lines.
|
||||
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
|
||||
"""
|
||||
with bz2.open(wikipedia_input, mode='rb') as file:
|
||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||
line = file.readline()
|
||||
cnt = 0
|
||||
while line:
|
||||
if cnt % 5000000 == 0:
|
||||
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
|
||||
print(now(), "processed", cnt, "lines of Wikipedia dump")
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
|
||||
aliases, entities, normalizations = get_wp_links(clean_line)
|
||||
|
@ -64,10 +143,11 @@ def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
|
|||
cnt += 1
|
||||
|
||||
# write all aliases and their entities and count occurrences to file
|
||||
with open(prior_prob_output, mode='w', encoding='utf8') as outputfile:
|
||||
with prior_prob_output.open("w", encoding="utf8") as outputfile:
|
||||
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
|
||||
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
|
||||
for entity, count in sorted(alias_dict.items(), key=lambda x: x[1], reverse=True):
|
||||
s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
|
||||
for entity, count in s_dict:
|
||||
outputfile.write(alias + "|" + str(count) + "|" + entity + "\n")
|
||||
|
||||
|
||||
|
@ -140,13 +220,13 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
|||
entity_to_count = dict()
|
||||
total_count = 0
|
||||
|
||||
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
|
||||
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||
# skip header
|
||||
prior_file.readline()
|
||||
line = prior_file.readline()
|
||||
|
||||
while line:
|
||||
splits = line.replace('\n', "").split(sep='|')
|
||||
splits = line.replace("\n", "").split(sep="|")
|
||||
# alias = splits[0]
|
||||
count = int(splits[1])
|
||||
entity = splits[2]
|
||||
|
@ -158,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
|||
|
||||
line = prior_file.readline()
|
||||
|
||||
with open(count_output, mode='w', encoding='utf8') as entity_file:
|
||||
with count_output.open("w", encoding="utf8") as entity_file:
|
||||
entity_file.write("entity" + "|" + "count" + "\n")
|
||||
for entity, count in entity_to_count.items():
|
||||
entity_file.write(entity + "|" + str(count) + "\n")
|
||||
|
@ -171,12 +251,11 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
|||
|
||||
def get_all_frequencies(count_input):
|
||||
entity_to_count = dict()
|
||||
with open(count_input, 'r', encoding='utf8') as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter='|')
|
||||
with count_input.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
entity_to_count[row[0]] = int(row[1])
|
||||
|
||||
return entity_to_count
|
||||
|
||||
|
|
|
@ -14,15 +14,15 @@ def create_kb(vocab):
|
|||
# adding entities
|
||||
entity_0 = "Q1004791_Douglas"
|
||||
print("adding entity", entity_0)
|
||||
kb.add_entity(entity=entity_0, prob=0.5, entity_vector=[0])
|
||||
kb.add_entity(entity=entity_0, freq=0.5, entity_vector=[0])
|
||||
|
||||
entity_1 = "Q42_Douglas_Adams"
|
||||
print("adding entity", entity_1)
|
||||
kb.add_entity(entity=entity_1, prob=0.5, entity_vector=[1])
|
||||
kb.add_entity(entity=entity_1, freq=0.5, entity_vector=[1])
|
||||
|
||||
entity_2 = "Q5301561_Douglas_Haig"
|
||||
print("adding entity", entity_2)
|
||||
kb.add_entity(entity=entity_2, prob=0.5, entity_vector=[2])
|
||||
kb.add_entity(entity=entity_2, freq=0.5, entity_vector=[2])
|
||||
|
||||
# adding aliases
|
||||
print()
|
||||
|
|
|
@ -1,11 +1,14 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
from os import path
|
||||
import random
|
||||
import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp
|
||||
from bin.wiki_entity_linking import wikipedia_processor as wp
|
||||
from bin.wiki_entity_linking import training_set_creator, kb_creator
|
||||
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
|
||||
|
||||
import spacy
|
||||
|
@ -17,23 +20,26 @@ Demonstrate how to build a knowledge base from WikiData and run an Entity Linkin
|
|||
"""
|
||||
|
||||
ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
|
||||
OUTPUT_DIR = ROOT_DIR / 'wikipedia'
|
||||
TRAINING_DIR = OUTPUT_DIR / 'training_data_nel'
|
||||
OUTPUT_DIR = ROOT_DIR / "wikipedia"
|
||||
TRAINING_DIR = OUTPUT_DIR / "training_data_nel"
|
||||
|
||||
PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv'
|
||||
ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv'
|
||||
ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv'
|
||||
ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv'
|
||||
PRIOR_PROB = OUTPUT_DIR / "prior_prob.csv"
|
||||
ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv"
|
||||
ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv"
|
||||
ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv"
|
||||
|
||||
KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb'
|
||||
NLP_1_DIR = OUTPUT_DIR / 'nlp_1'
|
||||
NLP_2_DIR = OUTPUT_DIR / 'nlp_2'
|
||||
KB_DIR = OUTPUT_DIR / "kb_1"
|
||||
KB_FILE = "kb"
|
||||
NLP_1_DIR = OUTPUT_DIR / "nlp_1"
|
||||
NLP_2_DIR = OUTPUT_DIR / "nlp_2"
|
||||
|
||||
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||
WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2'
|
||||
WIKIDATA_JSON = ROOT_DIR / "wikidata" / "wikidata-20190304-all.json.bz2"
|
||||
|
||||
# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
|
||||
ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2'
|
||||
ENWIKI_DUMP = (
|
||||
ROOT_DIR / "wikipedia" / "enwiki-20190320-pages-articles-multistream.xml.bz2"
|
||||
)
|
||||
|
||||
# KB construction parameters
|
||||
MAX_CANDIDATES = 10
|
||||
|
@ -48,11 +54,15 @@ L2 = 1e-6
|
|||
CONTEXT_WIDTH = 128
|
||||
|
||||
|
||||
def now():
|
||||
return datetime.datetime.now()
|
||||
|
||||
|
||||
def run_pipeline():
|
||||
# set the appropriate booleans to define which parts of the pipeline should be re(run)
|
||||
print("START", datetime.datetime.now())
|
||||
print("START", now())
|
||||
print()
|
||||
nlp_1 = spacy.load('en_core_web_lg')
|
||||
nlp_1 = spacy.load("en_core_web_lg")
|
||||
nlp_2 = None
|
||||
kb_2 = None
|
||||
|
||||
|
@ -82,43 +92,48 @@ def run_pipeline():
|
|||
|
||||
# STEP 1 : create prior probabilities from WP (run only once)
|
||||
if to_create_prior_probs:
|
||||
print("STEP 1: to_create_prior_probs", datetime.datetime.now())
|
||||
wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB)
|
||||
print("STEP 1: to_create_prior_probs", now())
|
||||
wp.read_prior_probs(ENWIKI_DUMP, PRIOR_PROB)
|
||||
print()
|
||||
|
||||
# STEP 2 : deduce entity frequencies from WP (run only once)
|
||||
if to_create_entity_counts:
|
||||
print("STEP 2: to_create_entity_counts", datetime.datetime.now())
|
||||
wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False)
|
||||
print("STEP 2: to_create_entity_counts", now())
|
||||
wp.write_entity_counts(PRIOR_PROB, ENTITY_COUNTS, to_print=False)
|
||||
print()
|
||||
|
||||
# STEP 3 : create KB and write to file (run only once)
|
||||
if to_create_kb:
|
||||
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
||||
kb_1 = kb_creator.create_kb(nlp_1,
|
||||
max_entities_per_alias=MAX_CANDIDATES,
|
||||
min_entity_freq=MIN_ENTITY_FREQ,
|
||||
min_occ=MIN_PAIR_OCC,
|
||||
entity_def_output=ENTITY_DEFS,
|
||||
entity_descr_output=ENTITY_DESCR,
|
||||
count_input=ENTITY_COUNTS,
|
||||
prior_prob_input=PRIOR_PROB,
|
||||
wikidata_input=WIKIDATA_JSON)
|
||||
print("STEP 3a: to_create_kb", now())
|
||||
kb_1 = kb_creator.create_kb(
|
||||
nlp=nlp_1,
|
||||
max_entities_per_alias=MAX_CANDIDATES,
|
||||
min_entity_freq=MIN_ENTITY_FREQ,
|
||||
min_occ=MIN_PAIR_OCC,
|
||||
entity_def_output=ENTITY_DEFS,
|
||||
entity_descr_output=ENTITY_DESCR,
|
||||
count_input=ENTITY_COUNTS,
|
||||
prior_prob_input=PRIOR_PROB,
|
||||
wikidata_input=WIKIDATA_JSON,
|
||||
)
|
||||
print("kb entities:", kb_1.get_size_entities())
|
||||
print("kb aliases:", kb_1.get_size_aliases())
|
||||
print()
|
||||
|
||||
print("STEP 3b: write KB and NLP", datetime.datetime.now())
|
||||
kb_1.dump(KB_FILE)
|
||||
print("STEP 3b: write KB and NLP", now())
|
||||
|
||||
if not path.exists(KB_DIR):
|
||||
os.makedirs(KB_DIR)
|
||||
kb_1.dump(KB_DIR / KB_FILE)
|
||||
nlp_1.to_disk(NLP_1_DIR)
|
||||
print()
|
||||
|
||||
# STEP 4 : read KB back in from file
|
||||
if to_read_kb:
|
||||
print("STEP 4: to_read_kb", datetime.datetime.now())
|
||||
print("STEP 4: to_read_kb", now())
|
||||
nlp_2 = spacy.load(NLP_1_DIR)
|
||||
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
|
||||
kb_2.load_bulk(KB_FILE)
|
||||
kb_2.load_bulk(KB_DIR / KB_FILE)
|
||||
print("kb entities:", kb_2.get_size_entities())
|
||||
print("kb aliases:", kb_2.get_size_aliases())
|
||||
print()
|
||||
|
@ -130,20 +145,26 @@ def run_pipeline():
|
|||
|
||||
# STEP 5: create a training dataset from WP
|
||||
if create_wp_training:
|
||||
print("STEP 5: create training dataset", datetime.datetime.now())
|
||||
training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP,
|
||||
entity_def_input=ENTITY_DEFS,
|
||||
training_output=TRAINING_DIR)
|
||||
print("STEP 5: create training dataset", now())
|
||||
training_set_creator.create_training(
|
||||
wikipedia_input=ENWIKI_DUMP,
|
||||
entity_def_input=ENTITY_DEFS,
|
||||
training_output=TRAINING_DIR,
|
||||
)
|
||||
|
||||
# STEP 6: create and train the entity linking pipe
|
||||
if train_pipe:
|
||||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||
print("STEP 6: training Entity Linking pipe", now())
|
||||
type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)}
|
||||
print(" -analysing", len(type_to_int), "different entity types")
|
||||
el_pipe = nlp_2.create_pipe(name='entity_linker',
|
||||
config={"context_width": CONTEXT_WIDTH,
|
||||
"pretrained_vectors": nlp_2.vocab.vectors.name,
|
||||
"type_to_int": type_to_int})
|
||||
el_pipe = nlp_2.create_pipe(
|
||||
name="entity_linker",
|
||||
config={
|
||||
"context_width": CONTEXT_WIDTH,
|
||||
"pretrained_vectors": nlp_2.vocab.vectors.name,
|
||||
"type_to_int": type_to_int,
|
||||
},
|
||||
)
|
||||
el_pipe.set_kb(kb_2)
|
||||
nlp_2.add_pipe(el_pipe, last=True)
|
||||
|
||||
|
@ -157,18 +178,22 @@ def run_pipeline():
|
|||
train_limit = 5000
|
||||
dev_limit = 5000
|
||||
|
||||
train_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=False,
|
||||
limit=train_limit)
|
||||
# for training, get pos & neg instances that correspond to entries in the kb
|
||||
train_data = training_set_creator.read_training(
|
||||
nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=False,
|
||||
limit=train_limit,
|
||||
kb=el_pipe.kb,
|
||||
)
|
||||
|
||||
print("Training on", len(train_data), "articles")
|
||||
print()
|
||||
|
||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=True,
|
||||
limit=dev_limit)
|
||||
# for testing, get all pos instances, whether or not they are in the kb
|
||||
dev_data = training_set_creator.read_training(
|
||||
nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
|
||||
)
|
||||
|
||||
print("Dev testing on", len(dev_data), "articles")
|
||||
print()
|
||||
|
@ -187,8 +212,8 @@ def run_pipeline():
|
|||
try:
|
||||
docs, golds = zip(*batch)
|
||||
nlp_2.update(
|
||||
docs,
|
||||
golds,
|
||||
docs=docs,
|
||||
golds=golds,
|
||||
sgd=optimizer,
|
||||
drop=DROPOUT,
|
||||
losses=losses,
|
||||
|
@ -200,48 +225,61 @@ def run_pipeline():
|
|||
if batchnr > 0:
|
||||
el_pipe.cfg["context_weight"] = 1
|
||||
el_pipe.cfg["prior_weight"] = 1
|
||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
|
||||
" / dev acc avg", round(dev_acc_context, 3))
|
||||
dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
|
||||
losses["entity_linker"] = losses["entity_linker"] / batchnr
|
||||
print(
|
||||
"Epoch, train loss",
|
||||
itn,
|
||||
round(losses["entity_linker"], 2),
|
||||
" / dev acc avg",
|
||||
round(dev_acc_context, 3),
|
||||
)
|
||||
|
||||
# STEP 7: measure the performance of our trained pipe on an independent dev set
|
||||
if len(dev_data) and measure_performance:
|
||||
print()
|
||||
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
||||
print("STEP 7: performance measurement of Entity Linking pipe", now())
|
||||
print()
|
||||
|
||||
counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2)
|
||||
counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
|
||||
dev_data, kb_2
|
||||
)
|
||||
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
|
||||
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
|
||||
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
|
||||
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
|
||||
|
||||
oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
|
||||
print("dev acc oracle:", round(acc_o, 3), oracle_by_label)
|
||||
|
||||
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
|
||||
print("dev acc random:", round(acc_r, 3), random_by_label)
|
||||
|
||||
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
|
||||
print("dev acc prior:", round(acc_p, 3), prior_by_label)
|
||||
|
||||
# using only context
|
||||
el_pipe.cfg["context_weight"] = 1
|
||||
el_pipe.cfg["prior_weight"] = 0
|
||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
||||
print("dev acc context avg:", round(dev_acc_context, 3),
|
||||
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
|
||||
dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
|
||||
context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
|
||||
print("dev acc context avg:", round(dev_acc_context, 3), context_by_label)
|
||||
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.cfg["context_weight"] = 1
|
||||
el_pipe.cfg["prior_weight"] = 1
|
||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False)
|
||||
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
||||
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||
dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
|
||||
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
|
||||
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
|
||||
|
||||
# STEP 8: apply the EL pipe on a toy example
|
||||
if to_test_pipeline:
|
||||
print()
|
||||
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
|
||||
print("STEP 8: applying Entity Linking to toy example", now())
|
||||
print()
|
||||
run_el_toy_example(nlp=nlp_2)
|
||||
|
||||
# STEP 9: write the NLP pipeline (including entity linker) to file
|
||||
if to_write_nlp:
|
||||
print()
|
||||
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
||||
print("STEP 9: testing NLP IO", now())
|
||||
print()
|
||||
print("writing to", NLP_2_DIR)
|
||||
nlp_2.to_disk(NLP_2_DIR)
|
||||
|
@ -262,23 +300,22 @@ def run_pipeline():
|
|||
el_pipe = nlp_3.get_pipe("entity_linker")
|
||||
|
||||
dev_limit = 5000
|
||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||
training_dir=TRAINING_DIR,
|
||||
dev=True,
|
||||
limit=dev_limit)
|
||||
dev_data = training_set_creator.read_training(
|
||||
nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
|
||||
)
|
||||
|
||||
print("Dev testing from file on", len(dev_data), "articles")
|
||||
print()
|
||||
|
||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe, error_analysis=False)
|
||||
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
||||
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||
dev_acc_combo, dev_acc_combo_dict = _measure_acc(dev_data, el_pipe)
|
||||
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]
|
||||
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
|
||||
|
||||
print()
|
||||
print("STOP", datetime.datetime.now())
|
||||
print("STOP", now())
|
||||
|
||||
|
||||
def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
||||
def _measure_acc(data, el_pipe=None, error_analysis=False):
|
||||
# If the docs in the data require further processing with an entity linker, set el_pipe
|
||||
correct_by_label = dict()
|
||||
incorrect_by_label = dict()
|
||||
|
@ -291,16 +328,21 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
|||
for doc, gold in zip(docs, golds):
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for entity in gold.links:
|
||||
start, end, gold_kb = entity
|
||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
# only evaluating on positive examples
|
||||
for gold_kb, value in kb_dict.items():
|
||||
if value:
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
pred_entity = ent.kb_id_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
||||
offset = _offset(start, end)
|
||||
gold_entity = correct_entries_per_article.get(offset, None)
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
if gold_entity == pred_entity:
|
||||
|
@ -311,28 +353,33 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
|||
incorrect_by_label[ent_label] = incorrect + 1
|
||||
if error_analysis:
|
||||
print(ent.text, "in", doc)
|
||||
print("Predicted", pred_entity, "should have been", gold_entity)
|
||||
print(
|
||||
"Predicted",
|
||||
pred_entity,
|
||||
"should have been",
|
||||
gold_entity,
|
||||
)
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print("Error assessing accuracy", e)
|
||||
|
||||
acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
|
||||
acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
|
||||
return acc, acc_by_label
|
||||
|
||||
|
||||
def _measure_baselines(data, kb):
|
||||
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
||||
counts_by_label = dict()
|
||||
counts_d = dict()
|
||||
|
||||
random_correct_by_label = dict()
|
||||
random_incorrect_by_label = dict()
|
||||
random_correct_d = dict()
|
||||
random_incorrect_d = dict()
|
||||
|
||||
oracle_correct_by_label = dict()
|
||||
oracle_incorrect_by_label = dict()
|
||||
oracle_correct_d = dict()
|
||||
oracle_incorrect_d = dict()
|
||||
|
||||
prior_correct_by_label = dict()
|
||||
prior_incorrect_by_label = dict()
|
||||
prior_correct_d = dict()
|
||||
prior_incorrect_d = dict()
|
||||
|
||||
docs = [d for d, g in data if len(d) > 0]
|
||||
golds = [g for d, g in data if len(d) > 0]
|
||||
|
@ -340,19 +387,24 @@ def _measure_baselines(data, kb):
|
|||
for doc, gold in zip(docs, golds):
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for entity in gold.links:
|
||||
start, end, gold_kb = entity
|
||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
for gold_kb, value in kb_dict.items():
|
||||
# only evaluating on positive examples
|
||||
if value:
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
label = ent.label_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
||||
offset = _offset(start, end)
|
||||
gold_entity = correct_entries_per_article.get(offset, None)
|
||||
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1
|
||||
counts_d[label] = counts_d.get(label, 0) + 1
|
||||
candidates = kb.get_candidates(ent.text)
|
||||
oracle_candidate = ""
|
||||
best_candidate = ""
|
||||
|
@ -370,28 +422,40 @@ def _measure_baselines(data, kb):
|
|||
random_candidate = random.choice(candidates).entity_
|
||||
|
||||
if gold_entity == best_candidate:
|
||||
prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1
|
||||
prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
|
||||
else:
|
||||
prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1
|
||||
prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
|
||||
|
||||
if gold_entity == random_candidate:
|
||||
random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1
|
||||
random_correct_d[label] = random_correct_d.get(label, 0) + 1
|
||||
else:
|
||||
random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1
|
||||
random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
|
||||
|
||||
if gold_entity == oracle_candidate:
|
||||
oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1
|
||||
oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
|
||||
else:
|
||||
oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1
|
||||
oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
|
||||
|
||||
except Exception as e:
|
||||
print("Error assessing accuracy", e)
|
||||
|
||||
acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label)
|
||||
acc_rand, acc_rand_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
|
||||
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
|
||||
acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
|
||||
acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
|
||||
acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
|
||||
|
||||
return counts_by_label, acc_rand, acc_rand_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
|
||||
return (
|
||||
counts_d,
|
||||
acc_rand,
|
||||
acc_rand_d,
|
||||
acc_prior,
|
||||
acc_prior_d,
|
||||
acc_oracle,
|
||||
acc_oracle_d,
|
||||
)
|
||||
|
||||
|
||||
def _offset(start, end):
|
||||
return "{}_{}".format(start, end)
|
||||
|
||||
|
||||
def calculate_acc(correct_by_label, incorrect_by_label):
|
||||
|
@ -422,15 +486,23 @@ def check_kb(kb):
|
|||
|
||||
print("generating candidates for " + mention + " :")
|
||||
for c in candidates:
|
||||
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
|
||||
print(
|
||||
" ",
|
||||
c.prior_prob,
|
||||
c.alias_,
|
||||
"-->",
|
||||
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
|
||||
)
|
||||
print()
|
||||
|
||||
|
||||
def run_el_toy_example(nlp):
|
||||
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
||||
"Douglas reminds us to always bring our towel, even in China or Brazil. " \
|
||||
"The main character in Doug's novel is the man Arthur Dent, " \
|
||||
"but Douglas doesn't write about George Washington or Homer Simpson."
|
||||
text = (
|
||||
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
|
||||
"Douglas reminds us to always bring our towel, even in China or Brazil. "
|
||||
"The main character in Doug's novel is the man Arthur Dent, "
|
||||
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
|
||||
)
|
||||
doc = nlp(text)
|
||||
print(text)
|
||||
for ent in doc.ents:
|
||||
|
|
|
@ -663,15 +663,14 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False,
|
|||
|
||||
|
||||
def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg):
|
||||
# TODO proper error
|
||||
if "entity_width" not in cfg:
|
||||
raise ValueError("entity_width not found")
|
||||
raise ValueError(Errors.E144.format(param="entity_width"))
|
||||
if "context_width" not in cfg:
|
||||
raise ValueError("context_width not found")
|
||||
raise ValueError(Errors.E144.format(param="context_width"))
|
||||
|
||||
conv_depth = cfg.get("conv_depth", 2)
|
||||
cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3)
|
||||
pretrained_vectors = cfg.get("pretrained_vectors") # self.nlp.vocab.vectors.name
|
||||
pretrained_vectors = cfg.get("pretrained_vectors", None)
|
||||
context_width = cfg.get("context_width")
|
||||
entity_width = cfg.get("entity_width")
|
||||
|
||||
|
|
|
@ -406,6 +406,13 @@ class Errors(object):
|
|||
E141 = ("Entity vectors should be of length {required} instead of the provided {found}.")
|
||||
E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'")
|
||||
E143 = ("Labels for component '{name}' not initialized. Did you forget to call add_label()?")
|
||||
E144 = ("Could not find parameter `{param}` when building the entity linker model.")
|
||||
E145 = ("Error reading `{param}` from input file.")
|
||||
E146 = ("Could not access `{path}`.")
|
||||
E147 = ("Unexpected error in the {method} functionality of the EntityLinker: {msg}. "
|
||||
"This is likely a bug in spaCy, so feel free to open an issue.")
|
||||
E148 = ("Expected {ents} KB identifiers but got {ids}. Make sure that each entity in `doc.ents` "
|
||||
"is assigned to a KB identifier.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -31,7 +31,7 @@ cdef class GoldParse:
|
|||
cdef public list ents
|
||||
cdef public dict brackets
|
||||
cdef public object cats
|
||||
cdef public list links
|
||||
cdef public dict links
|
||||
|
||||
cdef readonly list cand_to_gold
|
||||
cdef readonly list gold_to_cand
|
||||
|
|
|
@ -468,8 +468,11 @@ cdef class GoldParse:
|
|||
examples of a label to have the value 0.0. Labels not in the
|
||||
dictionary are treated as missing - the gradient for those labels
|
||||
will be zero.
|
||||
links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples,
|
||||
representing the external ID of an entity in a knowledge base.
|
||||
links (dict): A dict with `(start_char, end_char)` keys,
|
||||
and the values being dicts with kb_id:value entries,
|
||||
representing the external IDs in a knowledge base (KB)
|
||||
mapped to either 1.0 or 0.0, indicating positive and
|
||||
negative examples respectively.
|
||||
RETURNS (GoldParse): The newly constructed object.
|
||||
"""
|
||||
if words is None:
|
||||
|
|
12
spacy/kb.pxd
12
spacy/kb.pxd
|
@ -79,7 +79,7 @@ cdef class KnowledgeBase:
|
|||
return new_index
|
||||
|
||||
|
||||
cdef inline int64_t c_add_entity(self, hash_t entity_hash, float prob,
|
||||
cdef inline int64_t c_add_entity(self, hash_t entity_hash, float freq,
|
||||
int32_t vector_index, int feats_row) nogil:
|
||||
"""Add an entry to the vector of entries.
|
||||
After calling this method, make sure to update also the _entry_index using the return value"""
|
||||
|
@ -92,7 +92,7 @@ cdef class KnowledgeBase:
|
|||
entry.entity_hash = entity_hash
|
||||
entry.vector_index = vector_index
|
||||
entry.feats_row = feats_row
|
||||
entry.prob = prob
|
||||
entry.freq = freq
|
||||
|
||||
self._entries.push_back(entry)
|
||||
return new_index
|
||||
|
@ -125,7 +125,7 @@ cdef class KnowledgeBase:
|
|||
entry.entity_hash = dummy_hash
|
||||
entry.vector_index = dummy_value
|
||||
entry.feats_row = dummy_value
|
||||
entry.prob = dummy_value
|
||||
entry.freq = dummy_value
|
||||
|
||||
# Avoid struct initializer to enable nogil
|
||||
cdef vector[int64_t] dummy_entry_indices
|
||||
|
@ -141,7 +141,7 @@ cdef class KnowledgeBase:
|
|||
self._aliases_table.push_back(alias)
|
||||
|
||||
cpdef load_bulk(self, loc)
|
||||
cpdef set_entities(self, entity_list, prob_list, vector_list)
|
||||
cpdef set_entities(self, entity_list, freq_list, vector_list)
|
||||
|
||||
|
||||
cdef class Writer:
|
||||
|
@ -149,7 +149,7 @@ cdef class Writer:
|
|||
|
||||
cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1
|
||||
cdef int write_vector_element(self, float element) except -1
|
||||
cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1
|
||||
cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1
|
||||
|
||||
cdef int write_alias_length(self, int64_t alias_length) except -1
|
||||
cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1
|
||||
|
@ -162,7 +162,7 @@ cdef class Reader:
|
|||
|
||||
cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1
|
||||
cdef int read_vector_element(self, float* element) except -1
|
||||
cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1
|
||||
cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1
|
||||
|
||||
cdef int read_alias_length(self, int64_t* alias_length) except -1
|
||||
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1
|
||||
|
|
86
spacy/kb.pyx
86
spacy/kb.pyx
|
@ -94,7 +94,7 @@ cdef class KnowledgeBase:
|
|||
def get_alias_strings(self):
|
||||
return [self.vocab.strings[x] for x in self._alias_index]
|
||||
|
||||
def add_entity(self, unicode entity, float prob, vector[float] entity_vector):
|
||||
def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
|
||||
"""
|
||||
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
|
||||
Return the hash of the entity ID/name at the end.
|
||||
|
@ -113,15 +113,15 @@ cdef class KnowledgeBase:
|
|||
vector_index = self.c_add_vector(entity_vector=entity_vector)
|
||||
|
||||
new_index = self.c_add_entity(entity_hash=entity_hash,
|
||||
prob=prob,
|
||||
freq=freq,
|
||||
vector_index=vector_index,
|
||||
feats_row=-1) # Features table currently not implemented
|
||||
self._entry_index[entity_hash] = new_index
|
||||
|
||||
return entity_hash
|
||||
|
||||
cpdef set_entities(self, entity_list, prob_list, vector_list):
|
||||
if len(entity_list) != len(prob_list) or len(entity_list) != len(vector_list):
|
||||
cpdef set_entities(self, entity_list, freq_list, vector_list):
|
||||
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
|
||||
raise ValueError(Errors.E140)
|
||||
|
||||
nr_entities = len(entity_list)
|
||||
|
@ -137,7 +137,7 @@ cdef class KnowledgeBase:
|
|||
|
||||
entity_hash = self.vocab.strings.add(entity_list[i])
|
||||
entry.entity_hash = entity_hash
|
||||
entry.prob = prob_list[i]
|
||||
entry.freq = freq_list[i]
|
||||
|
||||
vector_index = self.c_add_vector(entity_vector=vector_list[i])
|
||||
entry.vector_index = vector_index
|
||||
|
@ -196,13 +196,42 @@ cdef class KnowledgeBase:
|
|||
|
||||
return [Candidate(kb=self,
|
||||
entity_hash=self._entries[entry_index].entity_hash,
|
||||
entity_freq=self._entries[entry_index].prob,
|
||||
entity_freq=self._entries[entry_index].freq,
|
||||
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
|
||||
alias_hash=alias_hash,
|
||||
prior_prob=prob)
|
||||
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
||||
prior_prob=prior_prob)
|
||||
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
||||
if entry_index != 0]
|
||||
|
||||
def get_vector(self, unicode entity):
|
||||
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||
|
||||
# Return an empty list if this entity is unknown in this KB
|
||||
if entity_hash not in self._entry_index:
|
||||
return [0] * self.entity_vector_length
|
||||
entry_index = self._entry_index[entity_hash]
|
||||
|
||||
return self._vectors_table[self._entries[entry_index].vector_index]
|
||||
|
||||
def get_prior_prob(self, unicode entity, unicode alias):
|
||||
""" Return the prior probability of a given alias being linked to a given entity,
|
||||
or return 0.0 when this combination is not known in the knowledge base"""
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||
|
||||
if entity_hash not in self._entry_index or alias_hash not in self._alias_index:
|
||||
return 0.0
|
||||
|
||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||
entry_index = self._entry_index[entity_hash]
|
||||
|
||||
alias_entry = self._aliases_table[alias_index]
|
||||
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs):
|
||||
if self._entries[entry_index].entity_hash == entity_hash:
|
||||
return prior_prob
|
||||
|
||||
return 0.0
|
||||
|
||||
|
||||
def dump(self, loc):
|
||||
cdef Writer writer = Writer(loc)
|
||||
|
@ -222,7 +251,7 @@ cdef class KnowledgeBase:
|
|||
entry = self._entries[entry_index]
|
||||
assert entry.entity_hash == entry_hash
|
||||
assert entry_index == i
|
||||
writer.write_entry(entry.entity_hash, entry.prob, entry.vector_index)
|
||||
writer.write_entry(entry.entity_hash, entry.freq, entry.vector_index)
|
||||
i = i+1
|
||||
|
||||
writer.write_alias_length(self.get_size_aliases())
|
||||
|
@ -248,7 +277,7 @@ cdef class KnowledgeBase:
|
|||
cdef hash_t entity_hash
|
||||
cdef hash_t alias_hash
|
||||
cdef int64_t entry_index
|
||||
cdef float prob
|
||||
cdef float freq, prob
|
||||
cdef int32_t vector_index
|
||||
cdef KBEntryC entry
|
||||
cdef AliasC alias
|
||||
|
@ -284,10 +313,10 @@ cdef class KnowledgeBase:
|
|||
# index 0 is a dummy object not stored in the _entry_index and can be ignored.
|
||||
i = 1
|
||||
while i <= nr_entities:
|
||||
reader.read_entry(&entity_hash, &prob, &vector_index)
|
||||
reader.read_entry(&entity_hash, &freq, &vector_index)
|
||||
|
||||
entry.entity_hash = entity_hash
|
||||
entry.prob = prob
|
||||
entry.freq = freq
|
||||
entry.vector_index = vector_index
|
||||
entry.feats_row = -1 # Features table currently not implemented
|
||||
|
||||
|
@ -343,7 +372,8 @@ cdef class Writer:
|
|||
loc = bytes(loc)
|
||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||
self._fp = fopen(<char*>bytes_loc, 'wb')
|
||||
assert self._fp != NULL
|
||||
if not self._fp:
|
||||
raise IOError(Errors.E146.format(path=loc))
|
||||
fseek(self._fp, 0, 0)
|
||||
|
||||
def close(self):
|
||||
|
@ -357,9 +387,9 @@ cdef class Writer:
|
|||
cdef int write_vector_element(self, float element) except -1:
|
||||
self._write(&element, sizeof(element))
|
||||
|
||||
cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1:
|
||||
cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1:
|
||||
self._write(&entry_hash, sizeof(entry_hash))
|
||||
self._write(&entry_prob, sizeof(entry_prob))
|
||||
self._write(&entry_freq, sizeof(entry_freq))
|
||||
self._write(&vector_index, sizeof(vector_index))
|
||||
# Features table currently not implemented and not written to file
|
||||
|
||||
|
@ -399,39 +429,39 @@ cdef class Reader:
|
|||
if status < 1:
|
||||
if feof(self._fp):
|
||||
return 0 # end of file
|
||||
raise IOError("error reading header from input file")
|
||||
raise IOError(Errors.E145.format(param="header"))
|
||||
|
||||
status = self._read(entity_vector_length, sizeof(int64_t))
|
||||
if status < 1:
|
||||
if feof(self._fp):
|
||||
return 0 # end of file
|
||||
raise IOError("error reading header from input file")
|
||||
raise IOError(Errors.E145.format(param="vector length"))
|
||||
|
||||
cdef int read_vector_element(self, float* element) except -1:
|
||||
status = self._read(element, sizeof(float))
|
||||
if status < 1:
|
||||
if feof(self._fp):
|
||||
return 0 # end of file
|
||||
raise IOError("error reading entity vector from input file")
|
||||
raise IOError(Errors.E145.format(param="vector element"))
|
||||
|
||||
cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1:
|
||||
cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1:
|
||||
status = self._read(entity_hash, sizeof(hash_t))
|
||||
if status < 1:
|
||||
if feof(self._fp):
|
||||
return 0 # end of file
|
||||
raise IOError("error reading entity hash from input file")
|
||||
raise IOError(Errors.E145.format(param="entity hash"))
|
||||
|
||||
status = self._read(prob, sizeof(float))
|
||||
status = self._read(freq, sizeof(float))
|
||||
if status < 1:
|
||||
if feof(self._fp):
|
||||
return 0 # end of file
|
||||
raise IOError("error reading entity prob from input file")
|
||||
raise IOError(Errors.E145.format(param="entity freq"))
|
||||
|
||||
status = self._read(vector_index, sizeof(int32_t))
|
||||
if status < 1:
|
||||
if feof(self._fp):
|
||||
return 0 # end of file
|
||||
raise IOError("error reading entity vector from input file")
|
||||
raise IOError(Errors.E145.format(param="vector index"))
|
||||
|
||||
if feof(self._fp):
|
||||
return 0
|
||||
|
@ -443,33 +473,33 @@ cdef class Reader:
|
|||
if status < 1:
|
||||
if feof(self._fp):
|
||||
return 0 # end of file
|
||||
raise IOError("error reading alias length from input file")
|
||||
raise IOError(Errors.E145.format(param="alias length"))
|
||||
|
||||
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1:
|
||||
status = self._read(alias_hash, sizeof(hash_t))
|
||||
if status < 1:
|
||||
if feof(self._fp):
|
||||
return 0 # end of file
|
||||
raise IOError("error reading alias hash from input file")
|
||||
raise IOError(Errors.E145.format(param="alias hash"))
|
||||
|
||||
status = self._read(candidate_length, sizeof(int64_t))
|
||||
if status < 1:
|
||||
if feof(self._fp):
|
||||
return 0 # end of file
|
||||
raise IOError("error reading candidate length from input file")
|
||||
raise IOError(Errors.E145.format(param="candidate length"))
|
||||
|
||||
cdef int read_alias(self, int64_t* entry_index, float* prob) except -1:
|
||||
status = self._read(entry_index, sizeof(int64_t))
|
||||
if status < 1:
|
||||
if feof(self._fp):
|
||||
return 0 # end of file
|
||||
raise IOError("error reading entry index for alias from input file")
|
||||
raise IOError(Errors.E145.format(param="entry index"))
|
||||
|
||||
status = self._read(prob, sizeof(float))
|
||||
if status < 1:
|
||||
if feof(self._fp):
|
||||
return 0 # end of file
|
||||
raise IOError("error reading prob for entity/alias from input file")
|
||||
raise IOError(Errors.E145.format(param="prior probability"))
|
||||
|
||||
cdef int _read(self, void* value, size_t size) except -1:
|
||||
status = fread(value, size, 1, self._fp)
|
||||
|
|
|
@ -14,7 +14,6 @@ from thinc.neural.util import to_categorical
|
|||
from thinc.neural.util import get_array_module
|
||||
|
||||
from spacy.kb import KnowledgeBase
|
||||
from ..cli.pretrain import get_cossim_loss
|
||||
from .functions import merge_subtokens
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..syntax.nn_parser cimport Parser
|
||||
|
@ -1077,6 +1076,7 @@ class EntityLinker(Pipe):
|
|||
DOCS: TODO
|
||||
"""
|
||||
name = 'entity_linker'
|
||||
NIL = "NIL" # string used to refer to a non-existing link
|
||||
|
||||
@classmethod
|
||||
def Model(cls, **cfg):
|
||||
|
@ -1093,6 +1093,8 @@ class EntityLinker(Pipe):
|
|||
self.kb = None
|
||||
self.cfg = dict(cfg)
|
||||
self.sgd_context = None
|
||||
if not self.cfg.get("context_width"):
|
||||
self.cfg["context_width"] = 128
|
||||
|
||||
def set_kb(self, kb):
|
||||
self.kb = kb
|
||||
|
@ -1140,7 +1142,7 @@ class EntityLinker(Pipe):
|
|||
|
||||
context_docs = []
|
||||
entity_encodings = []
|
||||
cats = []
|
||||
|
||||
priors = []
|
||||
type_vectors = []
|
||||
|
||||
|
@ -1149,50 +1151,44 @@ class EntityLinker(Pipe):
|
|||
for doc, gold in zip(docs, golds):
|
||||
ents_by_offset = dict()
|
||||
for ent in doc.ents:
|
||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
||||
for entity in gold.links:
|
||||
start, end, gold_kb = entity
|
||||
ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
mention = doc.text[start:end]
|
||||
for kb_id, value in kb_dict.items():
|
||||
entity_encoding = self.kb.get_vector(kb_id)
|
||||
prior_prob = self.kb.get_prior_prob(kb_id, mention)
|
||||
|
||||
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
|
||||
assert gold_ent is not None
|
||||
type_vector = [0 for i in range(len(type_to_int))]
|
||||
if len(type_to_int) > 0:
|
||||
type_vector[type_to_int[gold_ent.label_]] = 1
|
||||
gold_ent = ents_by_offset["{}_{}".format(start, end)]
|
||||
if gold_ent is None:
|
||||
raise RuntimeError(Errors.E147.format(method="update", msg="gold entity not found"))
|
||||
|
||||
candidates = self.kb.get_candidates(mention)
|
||||
random.shuffle(candidates)
|
||||
nr_neg = 0
|
||||
for c in candidates:
|
||||
kb_id = c.entity_
|
||||
entity_encoding = c.entity_vector
|
||||
type_vector = [0 for i in range(len(type_to_int))]
|
||||
if len(type_to_int) > 0:
|
||||
type_vector[type_to_int[gold_ent.label_]] = 1
|
||||
|
||||
# store data
|
||||
entity_encodings.append(entity_encoding)
|
||||
context_docs.append(doc)
|
||||
type_vectors.append(type_vector)
|
||||
|
||||
if self.cfg.get("prior_weight", 1) > 0:
|
||||
priors.append([c.prior_prob])
|
||||
priors.append([prior_prob])
|
||||
else:
|
||||
priors.append([0])
|
||||
|
||||
if kb_id == gold_kb:
|
||||
cats.append([1])
|
||||
else:
|
||||
nr_neg += 1
|
||||
cats.append([0])
|
||||
|
||||
if len(entity_encodings) > 0:
|
||||
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
|
||||
if not (len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)):
|
||||
raise RuntimeError(Errors.E147.format(method="update", msg="vector lengths not equal"))
|
||||
|
||||
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
|
||||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||
|
||||
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
|
||||
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i]
|
||||
for i in range(len(entity_encodings))]
|
||||
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
|
||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||
|
||||
loss, d_scores = self.get_loss(prediction=pred, golds=cats, docs=None)
|
||||
loss, d_scores = self.get_loss(scores=pred, golds=golds, docs=docs)
|
||||
mention_gradient = bp_mention(d_scores, sgd=sgd)
|
||||
|
||||
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
|
||||
|
@ -1203,39 +1199,45 @@ class EntityLinker(Pipe):
|
|||
return loss
|
||||
return 0
|
||||
|
||||
def get_loss(self, docs, golds, prediction):
|
||||
d_scores = (prediction - golds)
|
||||
def get_loss(self, docs, golds, scores):
|
||||
cats = []
|
||||
for gold in golds:
|
||||
for entity, kb_dict in gold.links.items():
|
||||
for kb_id, value in kb_dict.items():
|
||||
cats.append([value])
|
||||
|
||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||
if len(scores) != len(cats):
|
||||
raise RuntimeError(Errors.E147.format(method="get_loss", msg="gold entities do not match up"))
|
||||
|
||||
d_scores = (scores - cats)
|
||||
loss = (d_scores ** 2).sum()
|
||||
loss = loss / len(golds)
|
||||
loss = loss / len(cats)
|
||||
return loss, d_scores
|
||||
|
||||
def get_loss_old(self, docs, golds, scores):
|
||||
# this loss function assumes we're only using positive examples
|
||||
loss, gradients = get_cossim_loss(yh=scores, y=golds)
|
||||
loss = loss / len(golds)
|
||||
return loss, gradients
|
||||
|
||||
def __call__(self, doc):
|
||||
entities, kb_ids = self.predict([doc])
|
||||
self.set_annotations([doc], entities, kb_ids)
|
||||
kb_ids, tensors = self.predict([doc])
|
||||
self.set_annotations([doc], kb_ids, tensors=tensors)
|
||||
return doc
|
||||
|
||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||
for docs in util.minibatch(stream, size=batch_size):
|
||||
docs = list(docs)
|
||||
entities, kb_ids = self.predict(docs)
|
||||
self.set_annotations(docs, entities, kb_ids)
|
||||
kb_ids, tensors = self.predict(docs)
|
||||
self.set_annotations(docs, kb_ids, tensors=tensors)
|
||||
yield from docs
|
||||
|
||||
def predict(self, docs):
|
||||
""" Return the KB IDs for each entity in each doc, including NIL if there is no prediction """
|
||||
self.require_model()
|
||||
self.require_kb()
|
||||
|
||||
final_entities = []
|
||||
entity_count = 0
|
||||
final_kb_ids = []
|
||||
final_tensors = []
|
||||
|
||||
if not docs:
|
||||
return final_entities, final_kb_ids
|
||||
return final_kb_ids, final_tensors
|
||||
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
|
@ -1247,14 +1249,19 @@ class EntityLinker(Pipe):
|
|||
|
||||
for i, doc in enumerate(docs):
|
||||
if len(doc) > 0:
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
context_encoding = context_encodings[i]
|
||||
for ent in doc.ents:
|
||||
entity_count += 1
|
||||
type_vector = [0 for i in range(len(type_to_int))]
|
||||
if len(type_to_int) > 0:
|
||||
type_vector[type_to_int[ent.label_]] = 1
|
||||
|
||||
candidates = self.kb.get_candidates(ent.text)
|
||||
if candidates:
|
||||
if not candidates:
|
||||
final_kb_ids.append(self.NIL) # no prediction possible for this entity
|
||||
final_tensors.append(context_encoding)
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
|
||||
# this will set the prior probabilities to 0 (just like in training) if their weight is 0
|
||||
|
@ -1264,7 +1271,9 @@ class EntityLinker(Pipe):
|
|||
|
||||
if self.cfg.get("context_weight", 1) > 0:
|
||||
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
||||
assert len(entity_encodings) == len(prior_probs)
|
||||
if len(entity_encodings) != len(prior_probs):
|
||||
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
|
||||
|
||||
mention_encodings = [list(context_encoding) + list(entity_encodings[i])
|
||||
+ list(prior_probs[i]) + type_vector
|
||||
for i in range(len(entity_encodings))]
|
||||
|
@ -1273,15 +1282,26 @@ class EntityLinker(Pipe):
|
|||
# TODO: thresholding
|
||||
best_index = scores.argmax()
|
||||
best_candidate = candidates[best_index]
|
||||
final_entities.append(ent)
|
||||
final_kb_ids.append(best_candidate.entity_)
|
||||
final_tensors.append(context_encoding)
|
||||
|
||||
return final_entities, final_kb_ids
|
||||
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
|
||||
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
|
||||
|
||||
def set_annotations(self, docs, entities, kb_ids=None):
|
||||
for entity, kb_id in zip(entities, kb_ids):
|
||||
for token in entity:
|
||||
token.ent_kb_id_ = kb_id
|
||||
return final_kb_ids, final_tensors
|
||||
|
||||
def set_annotations(self, docs, kb_ids, tensors=None):
|
||||
count_ents = len([ent for doc in docs for ent in doc.ents])
|
||||
if count_ents != len(kb_ids):
|
||||
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
||||
|
||||
i=0
|
||||
for doc in docs:
|
||||
for ent in doc.ents:
|
||||
kb_id = kb_ids[i]
|
||||
i += 1
|
||||
for token in ent:
|
||||
token.ent_kb_id_ = kb_id
|
||||
|
||||
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||
serialize = OrderedDict()
|
||||
|
|
|
@ -93,7 +93,7 @@ cdef struct KBEntryC:
|
|||
int32_t feats_row
|
||||
|
||||
# log probability of entity, based on corpus frequency
|
||||
float prob
|
||||
float freq
|
||||
|
||||
|
||||
# Each alias struct stores a list of Entry pointers with their prior probabilities
|
||||
|
|
|
@ -13,22 +13,38 @@ def nlp():
|
|||
return English()
|
||||
|
||||
|
||||
def assert_almost_equal(a, b):
|
||||
delta = 0.0001
|
||||
assert a - delta <= b <= a + delta
|
||||
|
||||
|
||||
def test_kb_valid_entities(nlp):
|
||||
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity='Q2', prob=0.5, entity_vector=[2])
|
||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[8, 4, 3])
|
||||
mykb.add_entity(entity="Q2", freq=0.5, entity_vector=[2, 1, 0])
|
||||
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[-1, -6, 5])
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
|
||||
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
|
||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
|
||||
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||
|
||||
# test the size of the corresponding KB
|
||||
assert(mykb.get_size_entities() == 3)
|
||||
assert(mykb.get_size_aliases() == 2)
|
||||
assert mykb.get_size_entities() == 3
|
||||
assert mykb.get_size_aliases() == 2
|
||||
|
||||
# test retrieval of the entity vectors
|
||||
assert mykb.get_vector("Q1") == [8, 4, 3]
|
||||
assert mykb.get_vector("Q2") == [2, 1, 0]
|
||||
assert mykb.get_vector("Q3") == [-1, -6, 5]
|
||||
|
||||
# test retrieval of prior probabilities
|
||||
assert_almost_equal(mykb.get_prior_prob(entity="Q2", alias="douglas"), 0.8)
|
||||
assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglas"), 0.2)
|
||||
assert_almost_equal(mykb.get_prior_prob(entity="Q342", alias="douglas"), 0.0)
|
||||
assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglassssss"), 0.0)
|
||||
|
||||
|
||||
def test_kb_invalid_entities(nlp):
|
||||
|
@ -36,13 +52,15 @@ def test_kb_invalid_entities(nlp):
|
|||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
|
||||
|
||||
# adding aliases - should fail because one of the given IDs is not valid
|
||||
with pytest.raises(ValueError):
|
||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q342'], probabilities=[0.8, 0.2])
|
||||
mykb.add_alias(
|
||||
alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]
|
||||
)
|
||||
|
||||
|
||||
def test_kb_invalid_probabilities(nlp):
|
||||
|
@ -50,13 +68,13 @@ def test_kb_invalid_probabilities(nlp):
|
|||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
|
||||
|
||||
# adding aliases - should fail because the sum of the probabilities exceeds 1
|
||||
with pytest.raises(ValueError):
|
||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.4])
|
||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
|
||||
|
||||
|
||||
def test_kb_invalid_combination(nlp):
|
||||
|
@ -64,13 +82,15 @@ def test_kb_invalid_combination(nlp):
|
|||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
|
||||
|
||||
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
|
||||
with pytest.raises(ValueError):
|
||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.3, 0.4, 0.1])
|
||||
mykb.add_alias(
|
||||
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1]
|
||||
)
|
||||
|
||||
|
||||
def test_kb_invalid_entity_vector(nlp):
|
||||
|
@ -78,11 +98,11 @@ def test_kb_invalid_entity_vector(nlp):
|
|||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1, 2, 3])
|
||||
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1, 2, 3])
|
||||
|
||||
# this should fail because the kb's expected entity vector length is 3
|
||||
with pytest.raises(ValueError):
|
||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
|
||||
|
||||
|
||||
def test_candidate_generation(nlp):
|
||||
|
@ -90,18 +110,24 @@ def test_candidate_generation(nlp):
|
|||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||
mykb.add_entity(entity="Q1", freq=0.7, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
|
||||
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
|
||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
|
||||
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||
|
||||
# test the size of the relevant candidates
|
||||
assert(len(mykb.get_candidates('douglas')) == 2)
|
||||
assert(len(mykb.get_candidates('adam')) == 1)
|
||||
assert(len(mykb.get_candidates('shrubbery')) == 0)
|
||||
assert len(mykb.get_candidates("douglas")) == 2
|
||||
assert len(mykb.get_candidates("adam")) == 1
|
||||
assert len(mykb.get_candidates("shrubbery")) == 0
|
||||
|
||||
# test the content of the candidates
|
||||
assert mykb.get_candidates("adam")[0].entity_ == "Q2"
|
||||
assert mykb.get_candidates("adam")[0].alias_ == "adam"
|
||||
assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 0.2)
|
||||
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
|
||||
|
||||
|
||||
def test_preserving_links_asdoc(nlp):
|
||||
|
@ -109,24 +135,26 @@ def test_preserving_links_asdoc(nlp):
|
|||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", freq=0.8, entity_vector=[1])
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7])
|
||||
mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6])
|
||||
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
|
||||
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
|
||||
|
||||
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
|
||||
sentencizer = nlp.create_pipe("sentencizer")
|
||||
nlp.add_pipe(sentencizer)
|
||||
|
||||
ruler = EntityRuler(nlp)
|
||||
patterns = [{"label": "GPE", "pattern": "Boston"},
|
||||
{"label": "GPE", "pattern": "Denver"}]
|
||||
patterns = [
|
||||
{"label": "GPE", "pattern": "Boston"},
|
||||
{"label": "GPE", "pattern": "Denver"},
|
||||
]
|
||||
ruler.add_patterns(patterns)
|
||||
nlp.add_pipe(ruler)
|
||||
|
||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64})
|
||||
el_pipe = nlp.create_pipe(name="entity_linker", config={"context_width": 64})
|
||||
el_pipe.set_kb(mykb)
|
||||
el_pipe.begin_training()
|
||||
el_pipe.context_weight = 0
|
||||
|
|
|
@ -30,10 +30,10 @@ def test_serialize_kb_disk(en_vocab):
|
|||
def _get_dummy_kb(vocab):
|
||||
kb = KnowledgeBase(vocab=vocab, entity_vector_length=3)
|
||||
|
||||
kb.add_entity(entity='Q53', prob=0.33, entity_vector=[0, 5, 3])
|
||||
kb.add_entity(entity='Q17', prob=0.2, entity_vector=[7, 1, 0])
|
||||
kb.add_entity(entity='Q007', prob=0.7, entity_vector=[0, 0, 7])
|
||||
kb.add_entity(entity='Q44', prob=0.4, entity_vector=[4, 4, 4])
|
||||
kb.add_entity(entity='Q53', freq=0.33, entity_vector=[0, 5, 3])
|
||||
kb.add_entity(entity='Q17', freq=0.2, entity_vector=[7, 1, 0])
|
||||
kb.add_entity(entity='Q007', freq=0.7, entity_vector=[0, 0, 7])
|
||||
kb.add_entity(entity='Q44', freq=0.4, entity_vector=[4, 4, 4])
|
||||
|
||||
kb.add_alias(alias='double07', entities=['Q17', 'Q007'], probabilities=[0.1, 0.9])
|
||||
kb.add_alias(alias='guy', entities=['Q53', 'Q007', 'Q17', 'Q44'], probabilities=[0.3, 0.3, 0.2, 0.1])
|
||||
|
|
|
@ -348,7 +348,7 @@ cdef class Tokenizer:
|
|||
"""Add a special-case tokenization rule.
|
||||
|
||||
string (unicode): The string to specially tokenize.
|
||||
token_attrs (iterable): A sequence of dicts, where each dict describes
|
||||
substrings (iterable): A sequence of dicts, where each dict describes
|
||||
a token and its attributes. The `ORTH` fields of the attributes
|
||||
must exactly match the string when they are concatenated.
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user