mirror of
https://github.com/explosion/spaCy.git
synced 2024-09-23 04:19:11 +03:00
Merge branch 'master' into spacy.io
This commit is contained in:
commit
4361da2bba
|
@ -13,9 +13,17 @@ INPUT_DIM = 300 # dimension of pre-trained input vectors
|
||||||
DESC_WIDTH = 64 # dimension of output entity vectors
|
DESC_WIDTH = 64 # dimension of output entity vectors
|
||||||
|
|
||||||
|
|
||||||
def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
def create_kb(
|
||||||
entity_def_output, entity_descr_output,
|
nlp,
|
||||||
count_input, prior_prob_input, wikidata_input):
|
max_entities_per_alias,
|
||||||
|
min_entity_freq,
|
||||||
|
min_occ,
|
||||||
|
entity_def_output,
|
||||||
|
entity_descr_output,
|
||||||
|
count_input,
|
||||||
|
prior_prob_input,
|
||||||
|
wikidata_input,
|
||||||
|
):
|
||||||
# Create the knowledge base from Wikidata entries
|
# Create the knowledge base from Wikidata entries
|
||||||
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
|
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
|
||||||
|
|
||||||
|
@ -28,7 +36,9 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
||||||
title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
|
title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
|
||||||
|
|
||||||
# write the title-ID and ID-description mappings to file
|
# write the title-ID and ID-description mappings to file
|
||||||
_write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr)
|
_write_entity_files(
|
||||||
|
entity_def_output, entity_descr_output, title_to_id, id_to_descr
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# read the mappings from file
|
# read the mappings from file
|
||||||
|
@ -54,8 +64,8 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
||||||
frequency_list.append(freq)
|
frequency_list.append(freq)
|
||||||
filtered_title_to_id[title] = entity
|
filtered_title_to_id[title] = entity
|
||||||
|
|
||||||
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()),
|
print(len(title_to_id.keys()), "original titles")
|
||||||
"titles with filter frequency", min_entity_freq)
|
print("kept", len(filtered_title_to_id.keys()), " with frequency", min_entity_freq)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(" * train entity encoder", datetime.datetime.now())
|
print(" * train entity encoder", datetime.datetime.now())
|
||||||
|
@ -70,14 +80,20 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(" * adding", len(entity_list), "entities", datetime.datetime.now())
|
print(" * adding", len(entity_list), "entities", datetime.datetime.now())
|
||||||
kb.set_entities(entity_list=entity_list, prob_list=frequency_list, vector_list=embeddings)
|
kb.set_entities(
|
||||||
|
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(" * adding aliases", datetime.datetime.now())
|
print(" * adding aliases", datetime.datetime.now())
|
||||||
print()
|
print()
|
||||||
_add_aliases(kb, title_to_id=filtered_title_to_id,
|
_add_aliases(
|
||||||
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
|
kb,
|
||||||
prior_prob_input=prior_prob_input)
|
title_to_id=filtered_title_to_id,
|
||||||
|
max_entities_per_alias=max_entities_per_alias,
|
||||||
|
min_occ=min_occ,
|
||||||
|
prior_prob_input=prior_prob_input,
|
||||||
|
)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
|
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
|
||||||
|
@ -86,13 +102,15 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
||||||
return kb
|
return kb
|
||||||
|
|
||||||
|
|
||||||
def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr):
|
def _write_entity_files(
|
||||||
with open(entity_def_output, mode='w', encoding='utf8') as id_file:
|
entity_def_output, entity_descr_output, title_to_id, id_to_descr
|
||||||
|
):
|
||||||
|
with entity_def_output.open("w", encoding="utf8") as id_file:
|
||||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||||
for title, qid in title_to_id.items():
|
for title, qid in title_to_id.items():
|
||||||
id_file.write(title + "|" + str(qid) + "\n")
|
id_file.write(title + "|" + str(qid) + "\n")
|
||||||
|
|
||||||
with open(entity_descr_output, mode='w', encoding='utf8') as descr_file:
|
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
||||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||||
for qid, descr in id_to_descr.items():
|
for qid, descr in id_to_descr.items():
|
||||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||||
|
@ -100,8 +118,8 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_
|
||||||
|
|
||||||
def get_entity_to_id(entity_def_output):
|
def get_entity_to_id(entity_def_output):
|
||||||
entity_to_id = dict()
|
entity_to_id = dict()
|
||||||
with open(entity_def_output, 'r', encoding='utf8') as csvfile:
|
with entity_def_output.open("r", encoding="utf8") as csvfile:
|
||||||
csvreader = csv.reader(csvfile, delimiter='|')
|
csvreader = csv.reader(csvfile, delimiter="|")
|
||||||
# skip header
|
# skip header
|
||||||
next(csvreader)
|
next(csvreader)
|
||||||
for row in csvreader:
|
for row in csvreader:
|
||||||
|
@ -111,8 +129,8 @@ def get_entity_to_id(entity_def_output):
|
||||||
|
|
||||||
def get_id_to_description(entity_descr_output):
|
def get_id_to_description(entity_descr_output):
|
||||||
id_to_desc = dict()
|
id_to_desc = dict()
|
||||||
with open(entity_descr_output, 'r', encoding='utf8') as csvfile:
|
with entity_descr_output.open("r", encoding="utf8") as csvfile:
|
||||||
csvreader = csv.reader(csvfile, delimiter='|')
|
csvreader = csv.reader(csvfile, delimiter="|")
|
||||||
# skip header
|
# skip header
|
||||||
next(csvreader)
|
next(csvreader)
|
||||||
for row in csvreader:
|
for row in csvreader:
|
||||||
|
@ -125,7 +143,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
||||||
|
|
||||||
# adding aliases with prior probabilities
|
# adding aliases with prior probabilities
|
||||||
# we can read this file sequentially, it's sorted by alias, and then by count
|
# we can read this file sequentially, it's sorted by alias, and then by count
|
||||||
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
|
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||||
# skip header
|
# skip header
|
||||||
prior_file.readline()
|
prior_file.readline()
|
||||||
line = prior_file.readline()
|
line = prior_file.readline()
|
||||||
|
@ -134,7 +152,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
||||||
counts = []
|
counts = []
|
||||||
entities = []
|
entities = []
|
||||||
while line:
|
while line:
|
||||||
splits = line.replace('\n', "").split(sep='|')
|
splits = line.replace("\n", "").split(sep="|")
|
||||||
new_alias = splits[0]
|
new_alias = splits[0]
|
||||||
count = int(splits[1])
|
count = int(splits[1])
|
||||||
entity = splits[2]
|
entity = splits[2]
|
||||||
|
@ -153,7 +171,11 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
||||||
|
|
||||||
if selected_entities:
|
if selected_entities:
|
||||||
try:
|
try:
|
||||||
kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs)
|
kb.add_alias(
|
||||||
|
alias=previous_alias,
|
||||||
|
entities=selected_entities,
|
||||||
|
probabilities=prior_probs,
|
||||||
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(e)
|
print(e)
|
||||||
total_count = 0
|
total_count = 0
|
||||||
|
@ -168,4 +190,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
||||||
previous_alias = new_alias
|
previous_alias = new_alias
|
||||||
|
|
||||||
line = prior_file.readline()
|
line = prior_file.readline()
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import os
|
import random
|
||||||
import re
|
import re
|
||||||
import bz2
|
import bz2
|
||||||
import datetime
|
import datetime
|
||||||
|
@ -17,6 +17,10 @@ Gold-standard entities are stored in one file in standoff format (by character o
|
||||||
ENTITY_FILE = "gold_entities.csv"
|
ENTITY_FILE = "gold_entities.csv"
|
||||||
|
|
||||||
|
|
||||||
|
def now():
|
||||||
|
return datetime.datetime.now()
|
||||||
|
|
||||||
|
|
||||||
def create_training(wikipedia_input, entity_def_input, training_output):
|
def create_training(wikipedia_input, entity_def_input, training_output):
|
||||||
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
||||||
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
|
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
|
||||||
|
@ -27,21 +31,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
Read the XML wikipedia data to parse out training data:
|
Read the XML wikipedia data to parse out training data:
|
||||||
raw text data + positive instances
|
raw text data + positive instances
|
||||||
"""
|
"""
|
||||||
title_regex = re.compile(r'(?<=<title>).*(?=</title>)')
|
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
|
||||||
id_regex = re.compile(r'(?<=<id>)\d*(?=</id>)')
|
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
||||||
|
|
||||||
read_ids = set()
|
read_ids = set()
|
||||||
entityfile_loc = training_output / ENTITY_FILE
|
entityfile_loc = training_output / ENTITY_FILE
|
||||||
with open(entityfile_loc, mode="w", encoding='utf8') as entityfile:
|
with entityfile_loc.open("w", encoding="utf8") as entityfile:
|
||||||
# write entity training header file
|
# write entity training header file
|
||||||
_write_training_entity(outputfile=entityfile,
|
_write_training_entity(
|
||||||
|
outputfile=entityfile,
|
||||||
article_id="article_id",
|
article_id="article_id",
|
||||||
alias="alias",
|
alias="alias",
|
||||||
entity="WD_id",
|
entity="WD_id",
|
||||||
start="start",
|
start="start",
|
||||||
end="end")
|
end="end",
|
||||||
|
)
|
||||||
|
|
||||||
with bz2.open(wikipedia_input, mode='rb') as file:
|
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||||
line = file.readline()
|
line = file.readline()
|
||||||
cnt = 0
|
cnt = 0
|
||||||
article_text = ""
|
article_text = ""
|
||||||
|
@ -51,7 +57,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
reading_revision = False
|
reading_revision = False
|
||||||
while line and (not limit or cnt < limit):
|
while line and (not limit or cnt < limit):
|
||||||
if cnt % 1000000 == 0:
|
if cnt % 1000000 == 0:
|
||||||
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
|
print(now(), "processed", cnt, "lines of Wikipedia dump")
|
||||||
clean_line = line.strip().decode("utf-8")
|
clean_line = line.strip().decode("utf-8")
|
||||||
|
|
||||||
if clean_line == "<revision>":
|
if clean_line == "<revision>":
|
||||||
|
@ -69,12 +75,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
elif clean_line == "</page>":
|
elif clean_line == "</page>":
|
||||||
if article_id:
|
if article_id:
|
||||||
try:
|
try:
|
||||||
_process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(),
|
_process_wp_text(
|
||||||
training_output)
|
wp_to_id,
|
||||||
|
entityfile,
|
||||||
|
article_id,
|
||||||
|
article_title,
|
||||||
|
article_text.strip(),
|
||||||
|
training_output,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error processing article", article_id, article_title, e)
|
print(
|
||||||
|
"Error processing article", article_id, article_title, e
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("Done processing a page, but couldn't find an article_id ?", article_title)
|
print(
|
||||||
|
"Done processing a page, but couldn't find an article_id ?",
|
||||||
|
article_title,
|
||||||
|
)
|
||||||
article_text = ""
|
article_text = ""
|
||||||
article_title = None
|
article_title = None
|
||||||
article_id = None
|
article_id = None
|
||||||
|
@ -98,7 +115,9 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
if ids:
|
if ids:
|
||||||
article_id = ids[0]
|
article_id = ids[0]
|
||||||
if article_id in read_ids:
|
if article_id in read_ids:
|
||||||
print("Found duplicate article ID", article_id, clean_line) # This should never happen ...
|
print(
|
||||||
|
"Found duplicate article ID", article_id, clean_line
|
||||||
|
) # This should never happen ...
|
||||||
read_ids.add(article_id)
|
read_ids.add(article_id)
|
||||||
|
|
||||||
# read the title of this article (outside the revision portion of the document)
|
# read the title of this article (outside the revision portion of the document)
|
||||||
|
@ -111,10 +130,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
cnt += 1
|
cnt += 1
|
||||||
|
|
||||||
|
|
||||||
text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)')
|
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
||||||
|
|
||||||
|
|
||||||
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output):
|
def _process_wp_text(
|
||||||
|
wp_to_id, entityfile, article_id, article_title, article_text, training_output
|
||||||
|
):
|
||||||
found_entities = False
|
found_entities = False
|
||||||
|
|
||||||
# ignore meta Wikipedia pages
|
# ignore meta Wikipedia pages
|
||||||
|
@ -141,11 +162,11 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
||||||
entity_buffer = ""
|
entity_buffer = ""
|
||||||
mention_buffer = ""
|
mention_buffer = ""
|
||||||
for index, letter in enumerate(clean_text):
|
for index, letter in enumerate(clean_text):
|
||||||
if letter == '[':
|
if letter == "[":
|
||||||
open_read += 1
|
open_read += 1
|
||||||
elif letter == ']':
|
elif letter == "]":
|
||||||
open_read -= 1
|
open_read -= 1
|
||||||
elif letter == '|':
|
elif letter == "|":
|
||||||
if reading_text:
|
if reading_text:
|
||||||
final_text += letter
|
final_text += letter
|
||||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||||
|
@ -163,7 +184,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
||||||
elif reading_text:
|
elif reading_text:
|
||||||
final_text += letter
|
final_text += letter
|
||||||
else:
|
else:
|
||||||
raise ValueError("Not sure at point", clean_text[index-2:index+2])
|
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
|
||||||
|
|
||||||
if open_read > 2:
|
if open_read > 2:
|
||||||
reading_special_case = True
|
reading_special_case = True
|
||||||
|
@ -175,7 +196,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
||||||
|
|
||||||
# we just finished reading an entity
|
# we just finished reading an entity
|
||||||
if open_read == 0 and not reading_text:
|
if open_read == 0 and not reading_text:
|
||||||
if '#' in entity_buffer or entity_buffer.startswith(':'):
|
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
||||||
reading_special_case = True
|
reading_special_case = True
|
||||||
# Ignore cases with nested structures like File: handles etc
|
# Ignore cases with nested structures like File: handles etc
|
||||||
if not reading_special_case:
|
if not reading_special_case:
|
||||||
|
@ -185,12 +206,14 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
||||||
end = start + len(mention_buffer)
|
end = start + len(mention_buffer)
|
||||||
qid = wp_to_id.get(entity_buffer, None)
|
qid = wp_to_id.get(entity_buffer, None)
|
||||||
if qid:
|
if qid:
|
||||||
_write_training_entity(outputfile=entityfile,
|
_write_training_entity(
|
||||||
|
outputfile=entityfile,
|
||||||
article_id=article_id,
|
article_id=article_id,
|
||||||
alias=mention_buffer,
|
alias=mention_buffer,
|
||||||
entity=qid,
|
entity=qid,
|
||||||
start=start,
|
start=start,
|
||||||
end=end)
|
end=end,
|
||||||
|
)
|
||||||
found_entities = True
|
found_entities = True
|
||||||
final_text += mention_buffer
|
final_text += mention_buffer
|
||||||
|
|
||||||
|
@ -203,29 +226,35 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
||||||
reading_special_case = False
|
reading_special_case = False
|
||||||
|
|
||||||
if found_entities:
|
if found_entities:
|
||||||
_write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output)
|
_write_training_article(
|
||||||
|
article_id=article_id,
|
||||||
|
clean_text=final_text,
|
||||||
|
training_output=training_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
info_regex = re.compile(r'{[^{]*?}')
|
info_regex = re.compile(r"{[^{]*?}")
|
||||||
htlm_regex = re.compile(r'<!--[^-]*-->')
|
htlm_regex = re.compile(r"<!--[^-]*-->")
|
||||||
category_regex = re.compile(r'\[\[Category:[^\[]*]]')
|
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
|
||||||
file_regex = re.compile(r'\[\[File:[^[\]]+]]')
|
file_regex = re.compile(r"\[\[File:[^[\]]+]]")
|
||||||
ref_regex = re.compile(r'<ref.*?>') # non-greedy
|
ref_regex = re.compile(r"<ref.*?>") # non-greedy
|
||||||
ref_2_regex = re.compile(r'</ref.*?>') # non-greedy
|
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
||||||
|
|
||||||
|
|
||||||
def _get_clean_wp_text(article_text):
|
def _get_clean_wp_text(article_text):
|
||||||
clean_text = article_text.strip()
|
clean_text = article_text.strip()
|
||||||
|
|
||||||
# remove bolding & italic markup
|
# remove bolding & italic markup
|
||||||
clean_text = clean_text.replace('\'\'\'', '')
|
clean_text = clean_text.replace("'''", "")
|
||||||
clean_text = clean_text.replace('\'\'', '')
|
clean_text = clean_text.replace("''", "")
|
||||||
|
|
||||||
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
||||||
try_again = True
|
try_again = True
|
||||||
previous_length = len(clean_text)
|
previous_length = len(clean_text)
|
||||||
while try_again:
|
while try_again:
|
||||||
clean_text = info_regex.sub('', clean_text) # non-greedy match excluding a nested {
|
clean_text = info_regex.sub(
|
||||||
|
"", clean_text
|
||||||
|
) # non-greedy match excluding a nested {
|
||||||
if len(clean_text) < previous_length:
|
if len(clean_text) < previous_length:
|
||||||
try_again = True
|
try_again = True
|
||||||
else:
|
else:
|
||||||
|
@ -233,14 +262,14 @@ def _get_clean_wp_text(article_text):
|
||||||
previous_length = len(clean_text)
|
previous_length = len(clean_text)
|
||||||
|
|
||||||
# remove HTML comments
|
# remove HTML comments
|
||||||
clean_text = htlm_regex.sub('', clean_text)
|
clean_text = htlm_regex.sub("", clean_text)
|
||||||
|
|
||||||
# remove Category and File statements
|
# remove Category and File statements
|
||||||
clean_text = category_regex.sub('', clean_text)
|
clean_text = category_regex.sub("", clean_text)
|
||||||
clean_text = file_regex.sub('', clean_text)
|
clean_text = file_regex.sub("", clean_text)
|
||||||
|
|
||||||
# remove multiple =
|
# remove multiple =
|
||||||
while '==' in clean_text:
|
while "==" in clean_text:
|
||||||
clean_text = clean_text.replace("==", "=")
|
clean_text = clean_text.replace("==", "=")
|
||||||
|
|
||||||
clean_text = clean_text.replace(". =", ".")
|
clean_text = clean_text.replace(". =", ".")
|
||||||
|
@ -249,43 +278,47 @@ def _get_clean_wp_text(article_text):
|
||||||
clean_text = clean_text.replace(" =", "")
|
clean_text = clean_text.replace(" =", "")
|
||||||
|
|
||||||
# remove refs (non-greedy match)
|
# remove refs (non-greedy match)
|
||||||
clean_text = ref_regex.sub('', clean_text)
|
clean_text = ref_regex.sub("", clean_text)
|
||||||
clean_text = ref_2_regex.sub('', clean_text)
|
clean_text = ref_2_regex.sub("", clean_text)
|
||||||
|
|
||||||
# remove additional wikiformatting
|
# remove additional wikiformatting
|
||||||
clean_text = re.sub(r'<blockquote>', '', clean_text)
|
clean_text = re.sub(r"<blockquote>", "", clean_text)
|
||||||
clean_text = re.sub(r'</blockquote>', '', clean_text)
|
clean_text = re.sub(r"</blockquote>", "", clean_text)
|
||||||
|
|
||||||
# change special characters back to normal ones
|
# change special characters back to normal ones
|
||||||
clean_text = clean_text.replace(r'<', '<')
|
clean_text = clean_text.replace(r"<", "<")
|
||||||
clean_text = clean_text.replace(r'>', '>')
|
clean_text = clean_text.replace(r">", ">")
|
||||||
clean_text = clean_text.replace(r'"', '"')
|
clean_text = clean_text.replace(r""", '"')
|
||||||
clean_text = clean_text.replace(r'&nbsp;', ' ')
|
clean_text = clean_text.replace(r"&nbsp;", " ")
|
||||||
clean_text = clean_text.replace(r'&', '&')
|
clean_text = clean_text.replace(r"&", "&")
|
||||||
|
|
||||||
# remove multiple spaces
|
# remove multiple spaces
|
||||||
while ' ' in clean_text:
|
while " " in clean_text:
|
||||||
clean_text = clean_text.replace(' ', ' ')
|
clean_text = clean_text.replace(" ", " ")
|
||||||
|
|
||||||
return clean_text.strip()
|
return clean_text.strip()
|
||||||
|
|
||||||
|
|
||||||
def _write_training_article(article_id, clean_text, training_output):
|
def _write_training_article(article_id, clean_text, training_output):
|
||||||
file_loc = training_output / str(article_id) + ".txt"
|
file_loc = training_output / "{}.txt".format(article_id)
|
||||||
with open(file_loc, mode='w', encoding='utf8') as outputfile:
|
with file_loc.open("w", encoding="utf8") as outputfile:
|
||||||
outputfile.write(clean_text)
|
outputfile.write(clean_text)
|
||||||
|
|
||||||
|
|
||||||
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
|
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
|
||||||
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n")
|
line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end)
|
||||||
|
outputfile.write(line)
|
||||||
|
|
||||||
|
|
||||||
def is_dev(article_id):
|
def is_dev(article_id):
|
||||||
return article_id.endswith("3")
|
return article_id.endswith("3")
|
||||||
|
|
||||||
|
|
||||||
def read_training(nlp, training_dir, dev, limit):
|
def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||||
# This method provides training examples that correspond to the entity annotations found by the nlp object
|
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||||
|
When kb is provided (for training), it will include negative training examples by using the candidate generator,
|
||||||
|
and it will only keep positive training examples that can be found in the KB.
|
||||||
|
When kb=None (for testing), it will include all positive examples only."""
|
||||||
entityfile_loc = training_dir / ENTITY_FILE
|
entityfile_loc = training_dir / ENTITY_FILE
|
||||||
data = []
|
data = []
|
||||||
|
|
||||||
|
@ -296,24 +329,30 @@ def read_training(nlp, training_dir, dev, limit):
|
||||||
skip_articles = set()
|
skip_articles = set()
|
||||||
total_entities = 0
|
total_entities = 0
|
||||||
|
|
||||||
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
with entityfile_loc.open("r", encoding="utf8") as file:
|
||||||
for line in file:
|
for line in file:
|
||||||
if not limit or len(data) < limit:
|
if not limit or len(data) < limit:
|
||||||
fields = line.replace('\n', "").split(sep='|')
|
fields = line.replace("\n", "").split(sep="|")
|
||||||
article_id = fields[0]
|
article_id = fields[0]
|
||||||
alias = fields[1]
|
alias = fields[1]
|
||||||
wp_title = fields[2]
|
wd_id = fields[2]
|
||||||
start = fields[3]
|
start = fields[3]
|
||||||
end = fields[4]
|
end = fields[4]
|
||||||
|
|
||||||
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
|
if (
|
||||||
|
dev == is_dev(article_id)
|
||||||
|
and article_id != "article_id"
|
||||||
|
and article_id not in skip_articles
|
||||||
|
):
|
||||||
if not current_doc or (current_article_id != article_id):
|
if not current_doc or (current_article_id != article_id):
|
||||||
# parse the new article text
|
# parse the new article text
|
||||||
file_name = article_id + ".txt"
|
file_name = article_id + ".txt"
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
|
training_file = training_dir / file_name
|
||||||
|
with training_file.open("r", encoding="utf8") as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
if len(text) < 30000: # threshold for convenience / speed of processing
|
# threshold for convenience / speed of processing
|
||||||
|
if len(text) < 30000:
|
||||||
current_doc = nlp(text)
|
current_doc = nlp(text)
|
||||||
current_article_id = article_id
|
current_article_id = article_id
|
||||||
ents_by_offset = dict()
|
ents_by_offset = dict()
|
||||||
|
@ -321,28 +360,64 @@ def read_training(nlp, training_dir, dev, limit):
|
||||||
sent_length = len(ent.sent)
|
sent_length = len(ent.sent)
|
||||||
# custom filtering to avoid too long or too short sentences
|
# custom filtering to avoid too long or too short sentences
|
||||||
if 5 < sent_length < 100:
|
if 5 < sent_length < 100:
|
||||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
offset = "{}_{}".format(
|
||||||
|
ent.start_char, ent.end_char
|
||||||
|
)
|
||||||
|
ents_by_offset[offset] = ent
|
||||||
else:
|
else:
|
||||||
skip_articles.add(article_id)
|
skip_articles.add(article_id)
|
||||||
current_doc = None
|
current_doc = None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Problem parsing article", article_id, e)
|
print("Problem parsing article", article_id, e)
|
||||||
skip_articles.add(article_id)
|
skip_articles.add(article_id)
|
||||||
raise e
|
|
||||||
|
|
||||||
# repeat checking this condition in case an exception was thrown
|
# repeat checking this condition in case an exception was thrown
|
||||||
if current_doc and (current_article_id == article_id):
|
if current_doc and (current_article_id == article_id):
|
||||||
found_ent = ents_by_offset.get(start + "_" + end, None)
|
offset = "{}_{}".format(start, end)
|
||||||
|
found_ent = ents_by_offset.get(offset, None)
|
||||||
if found_ent:
|
if found_ent:
|
||||||
if found_ent.text != alias:
|
if found_ent.text != alias:
|
||||||
skip_articles.add(article_id)
|
skip_articles.add(article_id)
|
||||||
current_doc = None
|
current_doc = None
|
||||||
else:
|
else:
|
||||||
sent = found_ent.sent.as_doc()
|
sent = found_ent.sent.as_doc()
|
||||||
# currently feeding the gold data one entity per sentence at a time
|
|
||||||
gold_start = int(start) - found_ent.sent.start_char
|
gold_start = int(start) - found_ent.sent.start_char
|
||||||
gold_end = int(end) - found_ent.sent.start_char
|
gold_end = int(end) - found_ent.sent.start_char
|
||||||
gold_entities = [(gold_start, gold_end, wp_title)]
|
|
||||||
|
gold_entities = {}
|
||||||
|
found_useful = False
|
||||||
|
for ent in sent.ents:
|
||||||
|
entry = (ent.start_char, ent.end_char)
|
||||||
|
gold_entry = (gold_start, gold_end)
|
||||||
|
if entry == gold_entry:
|
||||||
|
# add both pos and neg examples (in random order)
|
||||||
|
# this will exclude examples not in the KB
|
||||||
|
if kb:
|
||||||
|
value_by_id = {}
|
||||||
|
candidates = kb.get_candidates(alias)
|
||||||
|
candidate_ids = [
|
||||||
|
c.entity_ for c in candidates
|
||||||
|
]
|
||||||
|
random.shuffle(candidate_ids)
|
||||||
|
for kb_id in candidate_ids:
|
||||||
|
found_useful = True
|
||||||
|
if kb_id != wd_id:
|
||||||
|
value_by_id[kb_id] = 0.0
|
||||||
|
else:
|
||||||
|
value_by_id[kb_id] = 1.0
|
||||||
|
gold_entities[entry] = value_by_id
|
||||||
|
# if no KB, keep all positive examples
|
||||||
|
else:
|
||||||
|
found_useful = True
|
||||||
|
value_by_id = {wd_id: 1.0}
|
||||||
|
|
||||||
|
gold_entities[entry] = value_by_id
|
||||||
|
# currently feeding the gold data one entity per sentence at a time
|
||||||
|
# setting all other entities to empty gold dictionary
|
||||||
|
else:
|
||||||
|
gold_entities[entry] = {}
|
||||||
|
if found_useful:
|
||||||
gold = GoldParse(doc=sent, links=gold_entities)
|
gold = GoldParse(doc=sent, links=gold_entities)
|
||||||
data.append((sent, gold))
|
data.append((sent, gold))
|
||||||
total_entities += 1
|
total_entities += 1
|
||||||
|
|
|
@ -14,22 +14,97 @@ Write these results to file for downstream KB and training data generation.
|
||||||
map_alias_to_link = dict()
|
map_alias_to_link = dict()
|
||||||
|
|
||||||
# these will/should be matched ignoring case
|
# these will/should be matched ignoring case
|
||||||
wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons",
|
wiki_namespaces = [
|
||||||
"d", "dbdump", "download", "Draft", "Education", "Foundation",
|
"b",
|
||||||
"Gadget", "Gadget definition", "gerrit", "File", "Help", "Image", "Incubator",
|
"betawikiversity",
|
||||||
"m", "mail", "mailarchive", "media", "MediaWiki", "MediaWiki talk", "Mediawikiwiki",
|
"Book",
|
||||||
"MediaZilla", "Meta", "Metawikipedia", "Module",
|
"c",
|
||||||
"mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki",
|
"Category",
|
||||||
"Portal", "phab", "Phabricator", "Project", "q", "quality", "rev",
|
"Commons",
|
||||||
"s", "spcom", "Special", "species", "Strategy", "sulutil", "svn",
|
"d",
|
||||||
"Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools",
|
"dbdump",
|
||||||
"tswiki", "User", "User talk", "v", "voy",
|
"download",
|
||||||
"w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews",
|
"Draft",
|
||||||
"Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech",
|
"Education",
|
||||||
"Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"]
|
"Foundation",
|
||||||
|
"Gadget",
|
||||||
|
"Gadget definition",
|
||||||
|
"gerrit",
|
||||||
|
"File",
|
||||||
|
"Help",
|
||||||
|
"Image",
|
||||||
|
"Incubator",
|
||||||
|
"m",
|
||||||
|
"mail",
|
||||||
|
"mailarchive",
|
||||||
|
"media",
|
||||||
|
"MediaWiki",
|
||||||
|
"MediaWiki talk",
|
||||||
|
"Mediawikiwiki",
|
||||||
|
"MediaZilla",
|
||||||
|
"Meta",
|
||||||
|
"Metawikipedia",
|
||||||
|
"Module",
|
||||||
|
"mw",
|
||||||
|
"n",
|
||||||
|
"nost",
|
||||||
|
"oldwikisource",
|
||||||
|
"outreach",
|
||||||
|
"outreachwiki",
|
||||||
|
"otrs",
|
||||||
|
"OTRSwiki",
|
||||||
|
"Portal",
|
||||||
|
"phab",
|
||||||
|
"Phabricator",
|
||||||
|
"Project",
|
||||||
|
"q",
|
||||||
|
"quality",
|
||||||
|
"rev",
|
||||||
|
"s",
|
||||||
|
"spcom",
|
||||||
|
"Special",
|
||||||
|
"species",
|
||||||
|
"Strategy",
|
||||||
|
"sulutil",
|
||||||
|
"svn",
|
||||||
|
"Talk",
|
||||||
|
"Template",
|
||||||
|
"Template talk",
|
||||||
|
"Testwiki",
|
||||||
|
"ticket",
|
||||||
|
"TimedText",
|
||||||
|
"Toollabs",
|
||||||
|
"tools",
|
||||||
|
"tswiki",
|
||||||
|
"User",
|
||||||
|
"User talk",
|
||||||
|
"v",
|
||||||
|
"voy",
|
||||||
|
"w",
|
||||||
|
"Wikibooks",
|
||||||
|
"Wikidata",
|
||||||
|
"wikiHow",
|
||||||
|
"Wikinvest",
|
||||||
|
"wikilivres",
|
||||||
|
"Wikimedia",
|
||||||
|
"Wikinews",
|
||||||
|
"Wikipedia",
|
||||||
|
"Wikipedia talk",
|
||||||
|
"Wikiquote",
|
||||||
|
"Wikisource",
|
||||||
|
"Wikispecies",
|
||||||
|
"Wikitech",
|
||||||
|
"Wikiversity",
|
||||||
|
"Wikivoyage",
|
||||||
|
"wikt",
|
||||||
|
"wiktionary",
|
||||||
|
"wmf",
|
||||||
|
"wmania",
|
||||||
|
"WP",
|
||||||
|
]
|
||||||
|
|
||||||
# find the links
|
# find the links
|
||||||
link_regex = re.compile(r'\[\[[^\[\]]*\]\]')
|
link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
|
||||||
|
|
||||||
# match on interwiki links, e.g. `en:` or `:fr:`
|
# match on interwiki links, e.g. `en:` or `:fr:`
|
||||||
ns_regex = r":?" + "[a-z][a-z]" + ":"
|
ns_regex = r":?" + "[a-z][a-z]" + ":"
|
||||||
|
@ -41,18 +116,22 @@ for ns in wiki_namespaces:
|
||||||
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
|
def now():
|
||||||
|
return datetime.datetime.now()
|
||||||
|
|
||||||
|
|
||||||
|
def read_prior_probs(wikipedia_input, prior_prob_output):
|
||||||
"""
|
"""
|
||||||
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
||||||
The full file takes about 2h to parse 1100M lines.
|
The full file takes about 2h to parse 1100M lines.
|
||||||
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
|
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
|
||||||
"""
|
"""
|
||||||
with bz2.open(wikipedia_input, mode='rb') as file:
|
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||||
line = file.readline()
|
line = file.readline()
|
||||||
cnt = 0
|
cnt = 0
|
||||||
while line:
|
while line:
|
||||||
if cnt % 5000000 == 0:
|
if cnt % 5000000 == 0:
|
||||||
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
|
print(now(), "processed", cnt, "lines of Wikipedia dump")
|
||||||
clean_line = line.strip().decode("utf-8")
|
clean_line = line.strip().decode("utf-8")
|
||||||
|
|
||||||
aliases, entities, normalizations = get_wp_links(clean_line)
|
aliases, entities, normalizations = get_wp_links(clean_line)
|
||||||
|
@ -64,10 +143,11 @@ def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
|
||||||
cnt += 1
|
cnt += 1
|
||||||
|
|
||||||
# write all aliases and their entities and count occurrences to file
|
# write all aliases and their entities and count occurrences to file
|
||||||
with open(prior_prob_output, mode='w', encoding='utf8') as outputfile:
|
with prior_prob_output.open("w", encoding="utf8") as outputfile:
|
||||||
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
|
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
|
||||||
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
|
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
|
||||||
for entity, count in sorted(alias_dict.items(), key=lambda x: x[1], reverse=True):
|
s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
for entity, count in s_dict:
|
||||||
outputfile.write(alias + "|" + str(count) + "|" + entity + "\n")
|
outputfile.write(alias + "|" + str(count) + "|" + entity + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
@ -140,13 +220,13 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||||
entity_to_count = dict()
|
entity_to_count = dict()
|
||||||
total_count = 0
|
total_count = 0
|
||||||
|
|
||||||
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
|
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||||
# skip header
|
# skip header
|
||||||
prior_file.readline()
|
prior_file.readline()
|
||||||
line = prior_file.readline()
|
line = prior_file.readline()
|
||||||
|
|
||||||
while line:
|
while line:
|
||||||
splits = line.replace('\n', "").split(sep='|')
|
splits = line.replace("\n", "").split(sep="|")
|
||||||
# alias = splits[0]
|
# alias = splits[0]
|
||||||
count = int(splits[1])
|
count = int(splits[1])
|
||||||
entity = splits[2]
|
entity = splits[2]
|
||||||
|
@ -158,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||||
|
|
||||||
line = prior_file.readline()
|
line = prior_file.readline()
|
||||||
|
|
||||||
with open(count_output, mode='w', encoding='utf8') as entity_file:
|
with count_output.open("w", encoding="utf8") as entity_file:
|
||||||
entity_file.write("entity" + "|" + "count" + "\n")
|
entity_file.write("entity" + "|" + "count" + "\n")
|
||||||
for entity, count in entity_to_count.items():
|
for entity, count in entity_to_count.items():
|
||||||
entity_file.write(entity + "|" + str(count) + "\n")
|
entity_file.write(entity + "|" + str(count) + "\n")
|
||||||
|
@ -171,12 +251,11 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||||
|
|
||||||
def get_all_frequencies(count_input):
|
def get_all_frequencies(count_input):
|
||||||
entity_to_count = dict()
|
entity_to_count = dict()
|
||||||
with open(count_input, 'r', encoding='utf8') as csvfile:
|
with count_input.open("r", encoding="utf8") as csvfile:
|
||||||
csvreader = csv.reader(csvfile, delimiter='|')
|
csvreader = csv.reader(csvfile, delimiter="|")
|
||||||
# skip header
|
# skip header
|
||||||
next(csvreader)
|
next(csvreader)
|
||||||
for row in csvreader:
|
for row in csvreader:
|
||||||
entity_to_count[row[0]] = int(row[1])
|
entity_to_count[row[0]] = int(row[1])
|
||||||
|
|
||||||
return entity_to_count
|
return entity_to_count
|
||||||
|
|
||||||
|
|
|
@ -14,15 +14,15 @@ def create_kb(vocab):
|
||||||
# adding entities
|
# adding entities
|
||||||
entity_0 = "Q1004791_Douglas"
|
entity_0 = "Q1004791_Douglas"
|
||||||
print("adding entity", entity_0)
|
print("adding entity", entity_0)
|
||||||
kb.add_entity(entity=entity_0, prob=0.5, entity_vector=[0])
|
kb.add_entity(entity=entity_0, freq=0.5, entity_vector=[0])
|
||||||
|
|
||||||
entity_1 = "Q42_Douglas_Adams"
|
entity_1 = "Q42_Douglas_Adams"
|
||||||
print("adding entity", entity_1)
|
print("adding entity", entity_1)
|
||||||
kb.add_entity(entity=entity_1, prob=0.5, entity_vector=[1])
|
kb.add_entity(entity=entity_1, freq=0.5, entity_vector=[1])
|
||||||
|
|
||||||
entity_2 = "Q5301561_Douglas_Haig"
|
entity_2 = "Q5301561_Douglas_Haig"
|
||||||
print("adding entity", entity_2)
|
print("adding entity", entity_2)
|
||||||
kb.add_entity(entity=entity_2, prob=0.5, entity_vector=[2])
|
kb.add_entity(entity=entity_2, freq=0.5, entity_vector=[2])
|
||||||
|
|
||||||
# adding aliases
|
# adding aliases
|
||||||
print()
|
print()
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import os
|
||||||
|
from os import path
|
||||||
import random
|
import random
|
||||||
import datetime
|
import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp
|
from bin.wiki_entity_linking import wikipedia_processor as wp
|
||||||
|
from bin.wiki_entity_linking import training_set_creator, kb_creator
|
||||||
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
|
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
|
||||||
|
|
||||||
import spacy
|
import spacy
|
||||||
|
@ -17,23 +20,26 @@ Demonstrate how to build a knowledge base from WikiData and run an Entity Linkin
|
||||||
"""
|
"""
|
||||||
|
|
||||||
ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
|
ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
|
||||||
OUTPUT_DIR = ROOT_DIR / 'wikipedia'
|
OUTPUT_DIR = ROOT_DIR / "wikipedia"
|
||||||
TRAINING_DIR = OUTPUT_DIR / 'training_data_nel'
|
TRAINING_DIR = OUTPUT_DIR / "training_data_nel"
|
||||||
|
|
||||||
PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv'
|
PRIOR_PROB = OUTPUT_DIR / "prior_prob.csv"
|
||||||
ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv'
|
ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv"
|
||||||
ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv'
|
ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv"
|
||||||
ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv'
|
ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv"
|
||||||
|
|
||||||
KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb'
|
KB_DIR = OUTPUT_DIR / "kb_1"
|
||||||
NLP_1_DIR = OUTPUT_DIR / 'nlp_1'
|
KB_FILE = "kb"
|
||||||
NLP_2_DIR = OUTPUT_DIR / 'nlp_2'
|
NLP_1_DIR = OUTPUT_DIR / "nlp_1"
|
||||||
|
NLP_2_DIR = OUTPUT_DIR / "nlp_2"
|
||||||
|
|
||||||
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||||
WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2'
|
WIKIDATA_JSON = ROOT_DIR / "wikidata" / "wikidata-20190304-all.json.bz2"
|
||||||
|
|
||||||
# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
|
# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
|
||||||
ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2'
|
ENWIKI_DUMP = (
|
||||||
|
ROOT_DIR / "wikipedia" / "enwiki-20190320-pages-articles-multistream.xml.bz2"
|
||||||
|
)
|
||||||
|
|
||||||
# KB construction parameters
|
# KB construction parameters
|
||||||
MAX_CANDIDATES = 10
|
MAX_CANDIDATES = 10
|
||||||
|
@ -48,11 +54,15 @@ L2 = 1e-6
|
||||||
CONTEXT_WIDTH = 128
|
CONTEXT_WIDTH = 128
|
||||||
|
|
||||||
|
|
||||||
|
def now():
|
||||||
|
return datetime.datetime.now()
|
||||||
|
|
||||||
|
|
||||||
def run_pipeline():
|
def run_pipeline():
|
||||||
# set the appropriate booleans to define which parts of the pipeline should be re(run)
|
# set the appropriate booleans to define which parts of the pipeline should be re(run)
|
||||||
print("START", datetime.datetime.now())
|
print("START", now())
|
||||||
print()
|
print()
|
||||||
nlp_1 = spacy.load('en_core_web_lg')
|
nlp_1 = spacy.load("en_core_web_lg")
|
||||||
nlp_2 = None
|
nlp_2 = None
|
||||||
kb_2 = None
|
kb_2 = None
|
||||||
|
|
||||||
|
@ -82,20 +92,21 @@ def run_pipeline():
|
||||||
|
|
||||||
# STEP 1 : create prior probabilities from WP (run only once)
|
# STEP 1 : create prior probabilities from WP (run only once)
|
||||||
if to_create_prior_probs:
|
if to_create_prior_probs:
|
||||||
print("STEP 1: to_create_prior_probs", datetime.datetime.now())
|
print("STEP 1: to_create_prior_probs", now())
|
||||||
wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB)
|
wp.read_prior_probs(ENWIKI_DUMP, PRIOR_PROB)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# STEP 2 : deduce entity frequencies from WP (run only once)
|
# STEP 2 : deduce entity frequencies from WP (run only once)
|
||||||
if to_create_entity_counts:
|
if to_create_entity_counts:
|
||||||
print("STEP 2: to_create_entity_counts", datetime.datetime.now())
|
print("STEP 2: to_create_entity_counts", now())
|
||||||
wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False)
|
wp.write_entity_counts(PRIOR_PROB, ENTITY_COUNTS, to_print=False)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# STEP 3 : create KB and write to file (run only once)
|
# STEP 3 : create KB and write to file (run only once)
|
||||||
if to_create_kb:
|
if to_create_kb:
|
||||||
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
print("STEP 3a: to_create_kb", now())
|
||||||
kb_1 = kb_creator.create_kb(nlp_1,
|
kb_1 = kb_creator.create_kb(
|
||||||
|
nlp=nlp_1,
|
||||||
max_entities_per_alias=MAX_CANDIDATES,
|
max_entities_per_alias=MAX_CANDIDATES,
|
||||||
min_entity_freq=MIN_ENTITY_FREQ,
|
min_entity_freq=MIN_ENTITY_FREQ,
|
||||||
min_occ=MIN_PAIR_OCC,
|
min_occ=MIN_PAIR_OCC,
|
||||||
|
@ -103,22 +114,26 @@ def run_pipeline():
|
||||||
entity_descr_output=ENTITY_DESCR,
|
entity_descr_output=ENTITY_DESCR,
|
||||||
count_input=ENTITY_COUNTS,
|
count_input=ENTITY_COUNTS,
|
||||||
prior_prob_input=PRIOR_PROB,
|
prior_prob_input=PRIOR_PROB,
|
||||||
wikidata_input=WIKIDATA_JSON)
|
wikidata_input=WIKIDATA_JSON,
|
||||||
|
)
|
||||||
print("kb entities:", kb_1.get_size_entities())
|
print("kb entities:", kb_1.get_size_entities())
|
||||||
print("kb aliases:", kb_1.get_size_aliases())
|
print("kb aliases:", kb_1.get_size_aliases())
|
||||||
print()
|
print()
|
||||||
|
|
||||||
print("STEP 3b: write KB and NLP", datetime.datetime.now())
|
print("STEP 3b: write KB and NLP", now())
|
||||||
kb_1.dump(KB_FILE)
|
|
||||||
|
if not path.exists(KB_DIR):
|
||||||
|
os.makedirs(KB_DIR)
|
||||||
|
kb_1.dump(KB_DIR / KB_FILE)
|
||||||
nlp_1.to_disk(NLP_1_DIR)
|
nlp_1.to_disk(NLP_1_DIR)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
# STEP 4 : read KB back in from file
|
# STEP 4 : read KB back in from file
|
||||||
if to_read_kb:
|
if to_read_kb:
|
||||||
print("STEP 4: to_read_kb", datetime.datetime.now())
|
print("STEP 4: to_read_kb", now())
|
||||||
nlp_2 = spacy.load(NLP_1_DIR)
|
nlp_2 = spacy.load(NLP_1_DIR)
|
||||||
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
|
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
|
||||||
kb_2.load_bulk(KB_FILE)
|
kb_2.load_bulk(KB_DIR / KB_FILE)
|
||||||
print("kb entities:", kb_2.get_size_entities())
|
print("kb entities:", kb_2.get_size_entities())
|
||||||
print("kb aliases:", kb_2.get_size_aliases())
|
print("kb aliases:", kb_2.get_size_aliases())
|
||||||
print()
|
print()
|
||||||
|
@ -130,20 +145,26 @@ def run_pipeline():
|
||||||
|
|
||||||
# STEP 5: create a training dataset from WP
|
# STEP 5: create a training dataset from WP
|
||||||
if create_wp_training:
|
if create_wp_training:
|
||||||
print("STEP 5: create training dataset", datetime.datetime.now())
|
print("STEP 5: create training dataset", now())
|
||||||
training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP,
|
training_set_creator.create_training(
|
||||||
|
wikipedia_input=ENWIKI_DUMP,
|
||||||
entity_def_input=ENTITY_DEFS,
|
entity_def_input=ENTITY_DEFS,
|
||||||
training_output=TRAINING_DIR)
|
training_output=TRAINING_DIR,
|
||||||
|
)
|
||||||
|
|
||||||
# STEP 6: create and train the entity linking pipe
|
# STEP 6: create and train the entity linking pipe
|
||||||
if train_pipe:
|
if train_pipe:
|
||||||
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
print("STEP 6: training Entity Linking pipe", now())
|
||||||
type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)}
|
type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)}
|
||||||
print(" -analysing", len(type_to_int), "different entity types")
|
print(" -analysing", len(type_to_int), "different entity types")
|
||||||
el_pipe = nlp_2.create_pipe(name='entity_linker',
|
el_pipe = nlp_2.create_pipe(
|
||||||
config={"context_width": CONTEXT_WIDTH,
|
name="entity_linker",
|
||||||
|
config={
|
||||||
|
"context_width": CONTEXT_WIDTH,
|
||||||
"pretrained_vectors": nlp_2.vocab.vectors.name,
|
"pretrained_vectors": nlp_2.vocab.vectors.name,
|
||||||
"type_to_int": type_to_int})
|
"type_to_int": type_to_int,
|
||||||
|
},
|
||||||
|
)
|
||||||
el_pipe.set_kb(kb_2)
|
el_pipe.set_kb(kb_2)
|
||||||
nlp_2.add_pipe(el_pipe, last=True)
|
nlp_2.add_pipe(el_pipe, last=True)
|
||||||
|
|
||||||
|
@ -157,18 +178,22 @@ def run_pipeline():
|
||||||
train_limit = 5000
|
train_limit = 5000
|
||||||
dev_limit = 5000
|
dev_limit = 5000
|
||||||
|
|
||||||
train_data = training_set_creator.read_training(nlp=nlp_2,
|
# for training, get pos & neg instances that correspond to entries in the kb
|
||||||
|
train_data = training_set_creator.read_training(
|
||||||
|
nlp=nlp_2,
|
||||||
training_dir=TRAINING_DIR,
|
training_dir=TRAINING_DIR,
|
||||||
dev=False,
|
dev=False,
|
||||||
limit=train_limit)
|
limit=train_limit,
|
||||||
|
kb=el_pipe.kb,
|
||||||
|
)
|
||||||
|
|
||||||
print("Training on", len(train_data), "articles")
|
print("Training on", len(train_data), "articles")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
# for testing, get all pos instances, whether or not they are in the kb
|
||||||
training_dir=TRAINING_DIR,
|
dev_data = training_set_creator.read_training(
|
||||||
dev=True,
|
nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
|
||||||
limit=dev_limit)
|
)
|
||||||
|
|
||||||
print("Dev testing on", len(dev_data), "articles")
|
print("Dev testing on", len(dev_data), "articles")
|
||||||
print()
|
print()
|
||||||
|
@ -187,8 +212,8 @@ def run_pipeline():
|
||||||
try:
|
try:
|
||||||
docs, golds = zip(*batch)
|
docs, golds = zip(*batch)
|
||||||
nlp_2.update(
|
nlp_2.update(
|
||||||
docs,
|
docs=docs,
|
||||||
golds,
|
golds=golds,
|
||||||
sgd=optimizer,
|
sgd=optimizer,
|
||||||
drop=DROPOUT,
|
drop=DROPOUT,
|
||||||
losses=losses,
|
losses=losses,
|
||||||
|
@ -200,48 +225,61 @@ def run_pipeline():
|
||||||
if batchnr > 0:
|
if batchnr > 0:
|
||||||
el_pipe.cfg["context_weight"] = 1
|
el_pipe.cfg["context_weight"] = 1
|
||||||
el_pipe.cfg["prior_weight"] = 1
|
el_pipe.cfg["prior_weight"] = 1
|
||||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
|
||||||
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
losses["entity_linker"] = losses["entity_linker"] / batchnr
|
||||||
print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
|
print(
|
||||||
" / dev acc avg", round(dev_acc_context, 3))
|
"Epoch, train loss",
|
||||||
|
itn,
|
||||||
|
round(losses["entity_linker"], 2),
|
||||||
|
" / dev acc avg",
|
||||||
|
round(dev_acc_context, 3),
|
||||||
|
)
|
||||||
|
|
||||||
# STEP 7: measure the performance of our trained pipe on an independent dev set
|
# STEP 7: measure the performance of our trained pipe on an independent dev set
|
||||||
if len(dev_data) and measure_performance:
|
if len(dev_data) and measure_performance:
|
||||||
print()
|
print()
|
||||||
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
print("STEP 7: performance measurement of Entity Linking pipe", now())
|
||||||
print()
|
print()
|
||||||
|
|
||||||
counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2)
|
counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
|
||||||
|
dev_data, kb_2
|
||||||
|
)
|
||||||
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
|
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
|
||||||
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
|
|
||||||
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
|
oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
|
||||||
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
|
print("dev acc oracle:", round(acc_o, 3), oracle_by_label)
|
||||||
|
|
||||||
|
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
|
||||||
|
print("dev acc random:", round(acc_r, 3), random_by_label)
|
||||||
|
|
||||||
|
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
|
||||||
|
print("dev acc prior:", round(acc_p, 3), prior_by_label)
|
||||||
|
|
||||||
# using only context
|
# using only context
|
||||||
el_pipe.cfg["context_weight"] = 1
|
el_pipe.cfg["context_weight"] = 1
|
||||||
el_pipe.cfg["prior_weight"] = 0
|
el_pipe.cfg["prior_weight"] = 0
|
||||||
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
|
||||||
print("dev acc context avg:", round(dev_acc_context, 3),
|
context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
|
||||||
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
|
print("dev acc context avg:", round(dev_acc_context, 3), context_by_label)
|
||||||
|
|
||||||
# measuring combined accuracy (prior + context)
|
# measuring combined accuracy (prior + context)
|
||||||
el_pipe.cfg["context_weight"] = 1
|
el_pipe.cfg["context_weight"] = 1
|
||||||
el_pipe.cfg["prior_weight"] = 1
|
el_pipe.cfg["prior_weight"] = 1
|
||||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False)
|
dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
|
||||||
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
|
||||||
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
|
||||||
|
|
||||||
# STEP 8: apply the EL pipe on a toy example
|
# STEP 8: apply the EL pipe on a toy example
|
||||||
if to_test_pipeline:
|
if to_test_pipeline:
|
||||||
print()
|
print()
|
||||||
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
|
print("STEP 8: applying Entity Linking to toy example", now())
|
||||||
print()
|
print()
|
||||||
run_el_toy_example(nlp=nlp_2)
|
run_el_toy_example(nlp=nlp_2)
|
||||||
|
|
||||||
# STEP 9: write the NLP pipeline (including entity linker) to file
|
# STEP 9: write the NLP pipeline (including entity linker) to file
|
||||||
if to_write_nlp:
|
if to_write_nlp:
|
||||||
print()
|
print()
|
||||||
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
print("STEP 9: testing NLP IO", now())
|
||||||
print()
|
print()
|
||||||
print("writing to", NLP_2_DIR)
|
print("writing to", NLP_2_DIR)
|
||||||
nlp_2.to_disk(NLP_2_DIR)
|
nlp_2.to_disk(NLP_2_DIR)
|
||||||
|
@ -262,23 +300,22 @@ def run_pipeline():
|
||||||
el_pipe = nlp_3.get_pipe("entity_linker")
|
el_pipe = nlp_3.get_pipe("entity_linker")
|
||||||
|
|
||||||
dev_limit = 5000
|
dev_limit = 5000
|
||||||
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
dev_data = training_set_creator.read_training(
|
||||||
training_dir=TRAINING_DIR,
|
nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None
|
||||||
dev=True,
|
)
|
||||||
limit=dev_limit)
|
|
||||||
|
|
||||||
print("Dev testing from file on", len(dev_data), "articles")
|
print("Dev testing from file on", len(dev_data), "articles")
|
||||||
print()
|
print()
|
||||||
|
|
||||||
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe, error_analysis=False)
|
dev_acc_combo, dev_acc_combo_dict = _measure_acc(dev_data, el_pipe)
|
||||||
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]
|
||||||
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print("STOP", datetime.datetime.now())
|
print("STOP", now())
|
||||||
|
|
||||||
|
|
||||||
def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
def _measure_acc(data, el_pipe=None, error_analysis=False):
|
||||||
# If the docs in the data require further processing with an entity linker, set el_pipe
|
# If the docs in the data require further processing with an entity linker, set el_pipe
|
||||||
correct_by_label = dict()
|
correct_by_label = dict()
|
||||||
incorrect_by_label = dict()
|
incorrect_by_label = dict()
|
||||||
|
@ -291,16 +328,21 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
try:
|
try:
|
||||||
correct_entries_per_article = dict()
|
correct_entries_per_article = dict()
|
||||||
for entity in gold.links:
|
for entity, kb_dict in gold.links.items():
|
||||||
start, end, gold_kb = entity
|
start, end = entity
|
||||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
# only evaluating on positive examples
|
||||||
|
for gold_kb, value in kb_dict.items():
|
||||||
|
if value:
|
||||||
|
offset = _offset(start, end)
|
||||||
|
correct_entries_per_article[offset] = gold_kb
|
||||||
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
ent_label = ent.label_
|
ent_label = ent.label_
|
||||||
pred_entity = ent.kb_id_
|
pred_entity = ent.kb_id_
|
||||||
start = ent.start_char
|
start = ent.start_char
|
||||||
end = ent.end_char
|
end = ent.end_char
|
||||||
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
offset = _offset(start, end)
|
||||||
|
gold_entity = correct_entries_per_article.get(offset, None)
|
||||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
if gold_entity is not None:
|
if gold_entity is not None:
|
||||||
if gold_entity == pred_entity:
|
if gold_entity == pred_entity:
|
||||||
|
@ -311,7 +353,12 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
||||||
incorrect_by_label[ent_label] = incorrect + 1
|
incorrect_by_label[ent_label] = incorrect + 1
|
||||||
if error_analysis:
|
if error_analysis:
|
||||||
print(ent.text, "in", doc)
|
print(ent.text, "in", doc)
|
||||||
print("Predicted", pred_entity, "should have been", gold_entity)
|
print(
|
||||||
|
"Predicted",
|
||||||
|
pred_entity,
|
||||||
|
"should have been",
|
||||||
|
gold_entity,
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -323,16 +370,16 @@ def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
||||||
|
|
||||||
def _measure_baselines(data, kb):
|
def _measure_baselines(data, kb):
|
||||||
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
||||||
counts_by_label = dict()
|
counts_d = dict()
|
||||||
|
|
||||||
random_correct_by_label = dict()
|
random_correct_d = dict()
|
||||||
random_incorrect_by_label = dict()
|
random_incorrect_d = dict()
|
||||||
|
|
||||||
oracle_correct_by_label = dict()
|
oracle_correct_d = dict()
|
||||||
oracle_incorrect_by_label = dict()
|
oracle_incorrect_d = dict()
|
||||||
|
|
||||||
prior_correct_by_label = dict()
|
prior_correct_d = dict()
|
||||||
prior_incorrect_by_label = dict()
|
prior_incorrect_d = dict()
|
||||||
|
|
||||||
docs = [d for d, g in data if len(d) > 0]
|
docs = [d for d, g in data if len(d) > 0]
|
||||||
golds = [g for d, g in data if len(d) > 0]
|
golds = [g for d, g in data if len(d) > 0]
|
||||||
|
@ -340,19 +387,24 @@ def _measure_baselines(data, kb):
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
try:
|
try:
|
||||||
correct_entries_per_article = dict()
|
correct_entries_per_article = dict()
|
||||||
for entity in gold.links:
|
for entity, kb_dict in gold.links.items():
|
||||||
start, end, gold_kb = entity
|
start, end = entity
|
||||||
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
for gold_kb, value in kb_dict.items():
|
||||||
|
# only evaluating on positive examples
|
||||||
|
if value:
|
||||||
|
offset = _offset(start, end)
|
||||||
|
correct_entries_per_article[offset] = gold_kb
|
||||||
|
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
ent_label = ent.label_
|
label = ent.label_
|
||||||
start = ent.start_char
|
start = ent.start_char
|
||||||
end = ent.end_char
|
end = ent.end_char
|
||||||
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
offset = _offset(start, end)
|
||||||
|
gold_entity = correct_entries_per_article.get(offset, None)
|
||||||
|
|
||||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
if gold_entity is not None:
|
if gold_entity is not None:
|
||||||
counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1
|
counts_d[label] = counts_d.get(label, 0) + 1
|
||||||
candidates = kb.get_candidates(ent.text)
|
candidates = kb.get_candidates(ent.text)
|
||||||
oracle_candidate = ""
|
oracle_candidate = ""
|
||||||
best_candidate = ""
|
best_candidate = ""
|
||||||
|
@ -370,28 +422,40 @@ def _measure_baselines(data, kb):
|
||||||
random_candidate = random.choice(candidates).entity_
|
random_candidate = random.choice(candidates).entity_
|
||||||
|
|
||||||
if gold_entity == best_candidate:
|
if gold_entity == best_candidate:
|
||||||
prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1
|
prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
|
||||||
else:
|
else:
|
||||||
prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1
|
prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
|
||||||
|
|
||||||
if gold_entity == random_candidate:
|
if gold_entity == random_candidate:
|
||||||
random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1
|
random_correct_d[label] = random_correct_d.get(label, 0) + 1
|
||||||
else:
|
else:
|
||||||
random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1
|
random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
|
||||||
|
|
||||||
if gold_entity == oracle_candidate:
|
if gold_entity == oracle_candidate:
|
||||||
oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1
|
oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
|
||||||
else:
|
else:
|
||||||
oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1
|
oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error assessing accuracy", e)
|
print("Error assessing accuracy", e)
|
||||||
|
|
||||||
acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label)
|
acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
|
||||||
acc_rand, acc_rand_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
|
acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
|
||||||
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
|
acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
|
||||||
|
|
||||||
return counts_by_label, acc_rand, acc_rand_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
|
return (
|
||||||
|
counts_d,
|
||||||
|
acc_rand,
|
||||||
|
acc_rand_d,
|
||||||
|
acc_prior,
|
||||||
|
acc_prior_d,
|
||||||
|
acc_oracle,
|
||||||
|
acc_oracle_d,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _offset(start, end):
|
||||||
|
return "{}_{}".format(start, end)
|
||||||
|
|
||||||
|
|
||||||
def calculate_acc(correct_by_label, incorrect_by_label):
|
def calculate_acc(correct_by_label, incorrect_by_label):
|
||||||
|
@ -422,15 +486,23 @@ def check_kb(kb):
|
||||||
|
|
||||||
print("generating candidates for " + mention + " :")
|
print("generating candidates for " + mention + " :")
|
||||||
for c in candidates:
|
for c in candidates:
|
||||||
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
|
print(
|
||||||
|
" ",
|
||||||
|
c.prior_prob,
|
||||||
|
c.alias_,
|
||||||
|
"-->",
|
||||||
|
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
|
||||||
|
)
|
||||||
print()
|
print()
|
||||||
|
|
||||||
|
|
||||||
def run_el_toy_example(nlp):
|
def run_el_toy_example(nlp):
|
||||||
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
text = (
|
||||||
"Douglas reminds us to always bring our towel, even in China or Brazil. " \
|
"In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, "
|
||||||
"The main character in Doug's novel is the man Arthur Dent, " \
|
"Douglas reminds us to always bring our towel, even in China or Brazil. "
|
||||||
"but Douglas doesn't write about George Washington or Homer Simpson."
|
"The main character in Doug's novel is the man Arthur Dent, "
|
||||||
|
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
|
||||||
|
)
|
||||||
doc = nlp(text)
|
doc = nlp(text)
|
||||||
print(text)
|
print(text)
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
|
|
|
@ -663,15 +663,14 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False,
|
||||||
|
|
||||||
|
|
||||||
def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg):
|
def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg):
|
||||||
# TODO proper error
|
|
||||||
if "entity_width" not in cfg:
|
if "entity_width" not in cfg:
|
||||||
raise ValueError("entity_width not found")
|
raise ValueError(Errors.E144.format(param="entity_width"))
|
||||||
if "context_width" not in cfg:
|
if "context_width" not in cfg:
|
||||||
raise ValueError("context_width not found")
|
raise ValueError(Errors.E144.format(param="context_width"))
|
||||||
|
|
||||||
conv_depth = cfg.get("conv_depth", 2)
|
conv_depth = cfg.get("conv_depth", 2)
|
||||||
cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3)
|
cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3)
|
||||||
pretrained_vectors = cfg.get("pretrained_vectors") # self.nlp.vocab.vectors.name
|
pretrained_vectors = cfg.get("pretrained_vectors", None)
|
||||||
context_width = cfg.get("context_width")
|
context_width = cfg.get("context_width")
|
||||||
entity_width = cfg.get("entity_width")
|
entity_width = cfg.get("entity_width")
|
||||||
|
|
||||||
|
|
|
@ -406,7 +406,15 @@ class Errors(object):
|
||||||
E141 = ("Entity vectors should be of length {required} instead of the provided {found}.")
|
E141 = ("Entity vectors should be of length {required} instead of the provided {found}.")
|
||||||
E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'")
|
E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'")
|
||||||
E143 = ("Labels for component '{name}' not initialized. Did you forget to call add_label()?")
|
E143 = ("Labels for component '{name}' not initialized. Did you forget to call add_label()?")
|
||||||
|
E144 = ("Could not find parameter `{param}` when building the entity linker model.")
|
||||||
|
E145 = ("Error reading `{param}` from input file.")
|
||||||
|
E146 = ("Could not access `{path}`.")
|
||||||
|
E147 = ("Unexpected error in the {method} functionality of the EntityLinker: {msg}. "
|
||||||
|
"This is likely a bug in spaCy, so feel free to open an issue.")
|
||||||
|
E148 = ("Expected {ents} KB identifiers but got {ids}. Make sure that each entity in `doc.ents` "
|
||||||
|
"is assigned to a KB identifier.")
|
||||||
|
E149 = ("Error deserializing model. Check that the config used to create the "
|
||||||
|
"component matches the model being loaded.")
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
class TempErrors(object):
|
class TempErrors(object):
|
||||||
|
|
|
@ -31,7 +31,7 @@ cdef class GoldParse:
|
||||||
cdef public list ents
|
cdef public list ents
|
||||||
cdef public dict brackets
|
cdef public dict brackets
|
||||||
cdef public object cats
|
cdef public object cats
|
||||||
cdef public list links
|
cdef public dict links
|
||||||
|
|
||||||
cdef readonly list cand_to_gold
|
cdef readonly list cand_to_gold
|
||||||
cdef readonly list gold_to_cand
|
cdef readonly list gold_to_cand
|
||||||
|
|
|
@ -468,8 +468,11 @@ cdef class GoldParse:
|
||||||
examples of a label to have the value 0.0. Labels not in the
|
examples of a label to have the value 0.0. Labels not in the
|
||||||
dictionary are treated as missing - the gradient for those labels
|
dictionary are treated as missing - the gradient for those labels
|
||||||
will be zero.
|
will be zero.
|
||||||
links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples,
|
links (dict): A dict with `(start_char, end_char)` keys,
|
||||||
representing the external ID of an entity in a knowledge base.
|
and the values being dicts with kb_id:value entries,
|
||||||
|
representing the external IDs in a knowledge base (KB)
|
||||||
|
mapped to either 1.0 or 0.0, indicating positive and
|
||||||
|
negative examples respectively.
|
||||||
RETURNS (GoldParse): The newly constructed object.
|
RETURNS (GoldParse): The newly constructed object.
|
||||||
"""
|
"""
|
||||||
if words is None:
|
if words is None:
|
||||||
|
|
12
spacy/kb.pxd
12
spacy/kb.pxd
|
@ -79,7 +79,7 @@ cdef class KnowledgeBase:
|
||||||
return new_index
|
return new_index
|
||||||
|
|
||||||
|
|
||||||
cdef inline int64_t c_add_entity(self, hash_t entity_hash, float prob,
|
cdef inline int64_t c_add_entity(self, hash_t entity_hash, float freq,
|
||||||
int32_t vector_index, int feats_row) nogil:
|
int32_t vector_index, int feats_row) nogil:
|
||||||
"""Add an entry to the vector of entries.
|
"""Add an entry to the vector of entries.
|
||||||
After calling this method, make sure to update also the _entry_index using the return value"""
|
After calling this method, make sure to update also the _entry_index using the return value"""
|
||||||
|
@ -92,7 +92,7 @@ cdef class KnowledgeBase:
|
||||||
entry.entity_hash = entity_hash
|
entry.entity_hash = entity_hash
|
||||||
entry.vector_index = vector_index
|
entry.vector_index = vector_index
|
||||||
entry.feats_row = feats_row
|
entry.feats_row = feats_row
|
||||||
entry.prob = prob
|
entry.freq = freq
|
||||||
|
|
||||||
self._entries.push_back(entry)
|
self._entries.push_back(entry)
|
||||||
return new_index
|
return new_index
|
||||||
|
@ -125,7 +125,7 @@ cdef class KnowledgeBase:
|
||||||
entry.entity_hash = dummy_hash
|
entry.entity_hash = dummy_hash
|
||||||
entry.vector_index = dummy_value
|
entry.vector_index = dummy_value
|
||||||
entry.feats_row = dummy_value
|
entry.feats_row = dummy_value
|
||||||
entry.prob = dummy_value
|
entry.freq = dummy_value
|
||||||
|
|
||||||
# Avoid struct initializer to enable nogil
|
# Avoid struct initializer to enable nogil
|
||||||
cdef vector[int64_t] dummy_entry_indices
|
cdef vector[int64_t] dummy_entry_indices
|
||||||
|
@ -141,7 +141,7 @@ cdef class KnowledgeBase:
|
||||||
self._aliases_table.push_back(alias)
|
self._aliases_table.push_back(alias)
|
||||||
|
|
||||||
cpdef load_bulk(self, loc)
|
cpdef load_bulk(self, loc)
|
||||||
cpdef set_entities(self, entity_list, prob_list, vector_list)
|
cpdef set_entities(self, entity_list, freq_list, vector_list)
|
||||||
|
|
||||||
|
|
||||||
cdef class Writer:
|
cdef class Writer:
|
||||||
|
@ -149,7 +149,7 @@ cdef class Writer:
|
||||||
|
|
||||||
cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1
|
cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1
|
||||||
cdef int write_vector_element(self, float element) except -1
|
cdef int write_vector_element(self, float element) except -1
|
||||||
cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1
|
cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1
|
||||||
|
|
||||||
cdef int write_alias_length(self, int64_t alias_length) except -1
|
cdef int write_alias_length(self, int64_t alias_length) except -1
|
||||||
cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1
|
cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1
|
||||||
|
@ -162,7 +162,7 @@ cdef class Reader:
|
||||||
|
|
||||||
cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1
|
cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1
|
||||||
cdef int read_vector_element(self, float* element) except -1
|
cdef int read_vector_element(self, float* element) except -1
|
||||||
cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1
|
cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1
|
||||||
|
|
||||||
cdef int read_alias_length(self, int64_t* alias_length) except -1
|
cdef int read_alias_length(self, int64_t* alias_length) except -1
|
||||||
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1
|
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1
|
||||||
|
|
86
spacy/kb.pyx
86
spacy/kb.pyx
|
@ -94,7 +94,7 @@ cdef class KnowledgeBase:
|
||||||
def get_alias_strings(self):
|
def get_alias_strings(self):
|
||||||
return [self.vocab.strings[x] for x in self._alias_index]
|
return [self.vocab.strings[x] for x in self._alias_index]
|
||||||
|
|
||||||
def add_entity(self, unicode entity, float prob, vector[float] entity_vector):
|
def add_entity(self, unicode entity, float freq, vector[float] entity_vector):
|
||||||
"""
|
"""
|
||||||
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
|
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
|
||||||
Return the hash of the entity ID/name at the end.
|
Return the hash of the entity ID/name at the end.
|
||||||
|
@ -113,15 +113,15 @@ cdef class KnowledgeBase:
|
||||||
vector_index = self.c_add_vector(entity_vector=entity_vector)
|
vector_index = self.c_add_vector(entity_vector=entity_vector)
|
||||||
|
|
||||||
new_index = self.c_add_entity(entity_hash=entity_hash,
|
new_index = self.c_add_entity(entity_hash=entity_hash,
|
||||||
prob=prob,
|
freq=freq,
|
||||||
vector_index=vector_index,
|
vector_index=vector_index,
|
||||||
feats_row=-1) # Features table currently not implemented
|
feats_row=-1) # Features table currently not implemented
|
||||||
self._entry_index[entity_hash] = new_index
|
self._entry_index[entity_hash] = new_index
|
||||||
|
|
||||||
return entity_hash
|
return entity_hash
|
||||||
|
|
||||||
cpdef set_entities(self, entity_list, prob_list, vector_list):
|
cpdef set_entities(self, entity_list, freq_list, vector_list):
|
||||||
if len(entity_list) != len(prob_list) or len(entity_list) != len(vector_list):
|
if len(entity_list) != len(freq_list) or len(entity_list) != len(vector_list):
|
||||||
raise ValueError(Errors.E140)
|
raise ValueError(Errors.E140)
|
||||||
|
|
||||||
nr_entities = len(entity_list)
|
nr_entities = len(entity_list)
|
||||||
|
@ -137,7 +137,7 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
entity_hash = self.vocab.strings.add(entity_list[i])
|
entity_hash = self.vocab.strings.add(entity_list[i])
|
||||||
entry.entity_hash = entity_hash
|
entry.entity_hash = entity_hash
|
||||||
entry.prob = prob_list[i]
|
entry.freq = freq_list[i]
|
||||||
|
|
||||||
vector_index = self.c_add_vector(entity_vector=vector_list[i])
|
vector_index = self.c_add_vector(entity_vector=vector_list[i])
|
||||||
entry.vector_index = vector_index
|
entry.vector_index = vector_index
|
||||||
|
@ -196,13 +196,42 @@ cdef class KnowledgeBase:
|
||||||
|
|
||||||
return [Candidate(kb=self,
|
return [Candidate(kb=self,
|
||||||
entity_hash=self._entries[entry_index].entity_hash,
|
entity_hash=self._entries[entry_index].entity_hash,
|
||||||
entity_freq=self._entries[entry_index].prob,
|
entity_freq=self._entries[entry_index].freq,
|
||||||
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
|
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
|
||||||
alias_hash=alias_hash,
|
alias_hash=alias_hash,
|
||||||
prior_prob=prob)
|
prior_prob=prior_prob)
|
||||||
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
||||||
if entry_index != 0]
|
if entry_index != 0]
|
||||||
|
|
||||||
|
def get_vector(self, unicode entity):
|
||||||
|
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||||
|
|
||||||
|
# Return an empty list if this entity is unknown in this KB
|
||||||
|
if entity_hash not in self._entry_index:
|
||||||
|
return [0] * self.entity_vector_length
|
||||||
|
entry_index = self._entry_index[entity_hash]
|
||||||
|
|
||||||
|
return self._vectors_table[self._entries[entry_index].vector_index]
|
||||||
|
|
||||||
|
def get_prior_prob(self, unicode entity, unicode alias):
|
||||||
|
""" Return the prior probability of a given alias being linked to a given entity,
|
||||||
|
or return 0.0 when this combination is not known in the knowledge base"""
|
||||||
|
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||||
|
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||||
|
|
||||||
|
if entity_hash not in self._entry_index or alias_hash not in self._alias_index:
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||||
|
entry_index = self._entry_index[entity_hash]
|
||||||
|
|
||||||
|
alias_entry = self._aliases_table[alias_index]
|
||||||
|
for (entry_index, prior_prob) in zip(alias_entry.entry_indices, alias_entry.probs):
|
||||||
|
if self._entries[entry_index].entity_hash == entity_hash:
|
||||||
|
return prior_prob
|
||||||
|
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
|
||||||
def dump(self, loc):
|
def dump(self, loc):
|
||||||
cdef Writer writer = Writer(loc)
|
cdef Writer writer = Writer(loc)
|
||||||
|
@ -222,7 +251,7 @@ cdef class KnowledgeBase:
|
||||||
entry = self._entries[entry_index]
|
entry = self._entries[entry_index]
|
||||||
assert entry.entity_hash == entry_hash
|
assert entry.entity_hash == entry_hash
|
||||||
assert entry_index == i
|
assert entry_index == i
|
||||||
writer.write_entry(entry.entity_hash, entry.prob, entry.vector_index)
|
writer.write_entry(entry.entity_hash, entry.freq, entry.vector_index)
|
||||||
i = i+1
|
i = i+1
|
||||||
|
|
||||||
writer.write_alias_length(self.get_size_aliases())
|
writer.write_alias_length(self.get_size_aliases())
|
||||||
|
@ -248,7 +277,7 @@ cdef class KnowledgeBase:
|
||||||
cdef hash_t entity_hash
|
cdef hash_t entity_hash
|
||||||
cdef hash_t alias_hash
|
cdef hash_t alias_hash
|
||||||
cdef int64_t entry_index
|
cdef int64_t entry_index
|
||||||
cdef float prob
|
cdef float freq, prob
|
||||||
cdef int32_t vector_index
|
cdef int32_t vector_index
|
||||||
cdef KBEntryC entry
|
cdef KBEntryC entry
|
||||||
cdef AliasC alias
|
cdef AliasC alias
|
||||||
|
@ -284,10 +313,10 @@ cdef class KnowledgeBase:
|
||||||
# index 0 is a dummy object not stored in the _entry_index and can be ignored.
|
# index 0 is a dummy object not stored in the _entry_index and can be ignored.
|
||||||
i = 1
|
i = 1
|
||||||
while i <= nr_entities:
|
while i <= nr_entities:
|
||||||
reader.read_entry(&entity_hash, &prob, &vector_index)
|
reader.read_entry(&entity_hash, &freq, &vector_index)
|
||||||
|
|
||||||
entry.entity_hash = entity_hash
|
entry.entity_hash = entity_hash
|
||||||
entry.prob = prob
|
entry.freq = freq
|
||||||
entry.vector_index = vector_index
|
entry.vector_index = vector_index
|
||||||
entry.feats_row = -1 # Features table currently not implemented
|
entry.feats_row = -1 # Features table currently not implemented
|
||||||
|
|
||||||
|
@ -343,7 +372,8 @@ cdef class Writer:
|
||||||
loc = bytes(loc)
|
loc = bytes(loc)
|
||||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||||
self._fp = fopen(<char*>bytes_loc, 'wb')
|
self._fp = fopen(<char*>bytes_loc, 'wb')
|
||||||
assert self._fp != NULL
|
if not self._fp:
|
||||||
|
raise IOError(Errors.E146.format(path=loc))
|
||||||
fseek(self._fp, 0, 0)
|
fseek(self._fp, 0, 0)
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
|
@ -357,9 +387,9 @@ cdef class Writer:
|
||||||
cdef int write_vector_element(self, float element) except -1:
|
cdef int write_vector_element(self, float element) except -1:
|
||||||
self._write(&element, sizeof(element))
|
self._write(&element, sizeof(element))
|
||||||
|
|
||||||
cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1:
|
cdef int write_entry(self, hash_t entry_hash, float entry_freq, int32_t vector_index) except -1:
|
||||||
self._write(&entry_hash, sizeof(entry_hash))
|
self._write(&entry_hash, sizeof(entry_hash))
|
||||||
self._write(&entry_prob, sizeof(entry_prob))
|
self._write(&entry_freq, sizeof(entry_freq))
|
||||||
self._write(&vector_index, sizeof(vector_index))
|
self._write(&vector_index, sizeof(vector_index))
|
||||||
# Features table currently not implemented and not written to file
|
# Features table currently not implemented and not written to file
|
||||||
|
|
||||||
|
@ -399,39 +429,39 @@ cdef class Reader:
|
||||||
if status < 1:
|
if status < 1:
|
||||||
if feof(self._fp):
|
if feof(self._fp):
|
||||||
return 0 # end of file
|
return 0 # end of file
|
||||||
raise IOError("error reading header from input file")
|
raise IOError(Errors.E145.format(param="header"))
|
||||||
|
|
||||||
status = self._read(entity_vector_length, sizeof(int64_t))
|
status = self._read(entity_vector_length, sizeof(int64_t))
|
||||||
if status < 1:
|
if status < 1:
|
||||||
if feof(self._fp):
|
if feof(self._fp):
|
||||||
return 0 # end of file
|
return 0 # end of file
|
||||||
raise IOError("error reading header from input file")
|
raise IOError(Errors.E145.format(param="vector length"))
|
||||||
|
|
||||||
cdef int read_vector_element(self, float* element) except -1:
|
cdef int read_vector_element(self, float* element) except -1:
|
||||||
status = self._read(element, sizeof(float))
|
status = self._read(element, sizeof(float))
|
||||||
if status < 1:
|
if status < 1:
|
||||||
if feof(self._fp):
|
if feof(self._fp):
|
||||||
return 0 # end of file
|
return 0 # end of file
|
||||||
raise IOError("error reading entity vector from input file")
|
raise IOError(Errors.E145.format(param="vector element"))
|
||||||
|
|
||||||
cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1:
|
cdef int read_entry(self, hash_t* entity_hash, float* freq, int32_t* vector_index) except -1:
|
||||||
status = self._read(entity_hash, sizeof(hash_t))
|
status = self._read(entity_hash, sizeof(hash_t))
|
||||||
if status < 1:
|
if status < 1:
|
||||||
if feof(self._fp):
|
if feof(self._fp):
|
||||||
return 0 # end of file
|
return 0 # end of file
|
||||||
raise IOError("error reading entity hash from input file")
|
raise IOError(Errors.E145.format(param="entity hash"))
|
||||||
|
|
||||||
status = self._read(prob, sizeof(float))
|
status = self._read(freq, sizeof(float))
|
||||||
if status < 1:
|
if status < 1:
|
||||||
if feof(self._fp):
|
if feof(self._fp):
|
||||||
return 0 # end of file
|
return 0 # end of file
|
||||||
raise IOError("error reading entity prob from input file")
|
raise IOError(Errors.E145.format(param="entity freq"))
|
||||||
|
|
||||||
status = self._read(vector_index, sizeof(int32_t))
|
status = self._read(vector_index, sizeof(int32_t))
|
||||||
if status < 1:
|
if status < 1:
|
||||||
if feof(self._fp):
|
if feof(self._fp):
|
||||||
return 0 # end of file
|
return 0 # end of file
|
||||||
raise IOError("error reading entity vector from input file")
|
raise IOError(Errors.E145.format(param="vector index"))
|
||||||
|
|
||||||
if feof(self._fp):
|
if feof(self._fp):
|
||||||
return 0
|
return 0
|
||||||
|
@ -443,33 +473,33 @@ cdef class Reader:
|
||||||
if status < 1:
|
if status < 1:
|
||||||
if feof(self._fp):
|
if feof(self._fp):
|
||||||
return 0 # end of file
|
return 0 # end of file
|
||||||
raise IOError("error reading alias length from input file")
|
raise IOError(Errors.E145.format(param="alias length"))
|
||||||
|
|
||||||
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1:
|
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1:
|
||||||
status = self._read(alias_hash, sizeof(hash_t))
|
status = self._read(alias_hash, sizeof(hash_t))
|
||||||
if status < 1:
|
if status < 1:
|
||||||
if feof(self._fp):
|
if feof(self._fp):
|
||||||
return 0 # end of file
|
return 0 # end of file
|
||||||
raise IOError("error reading alias hash from input file")
|
raise IOError(Errors.E145.format(param="alias hash"))
|
||||||
|
|
||||||
status = self._read(candidate_length, sizeof(int64_t))
|
status = self._read(candidate_length, sizeof(int64_t))
|
||||||
if status < 1:
|
if status < 1:
|
||||||
if feof(self._fp):
|
if feof(self._fp):
|
||||||
return 0 # end of file
|
return 0 # end of file
|
||||||
raise IOError("error reading candidate length from input file")
|
raise IOError(Errors.E145.format(param="candidate length"))
|
||||||
|
|
||||||
cdef int read_alias(self, int64_t* entry_index, float* prob) except -1:
|
cdef int read_alias(self, int64_t* entry_index, float* prob) except -1:
|
||||||
status = self._read(entry_index, sizeof(int64_t))
|
status = self._read(entry_index, sizeof(int64_t))
|
||||||
if status < 1:
|
if status < 1:
|
||||||
if feof(self._fp):
|
if feof(self._fp):
|
||||||
return 0 # end of file
|
return 0 # end of file
|
||||||
raise IOError("error reading entry index for alias from input file")
|
raise IOError(Errors.E145.format(param="entry index"))
|
||||||
|
|
||||||
status = self._read(prob, sizeof(float))
|
status = self._read(prob, sizeof(float))
|
||||||
if status < 1:
|
if status < 1:
|
||||||
if feof(self._fp):
|
if feof(self._fp):
|
||||||
return 0 # end of file
|
return 0 # end of file
|
||||||
raise IOError("error reading prob for entity/alias from input file")
|
raise IOError(Errors.E145.format(param="prior probability"))
|
||||||
|
|
||||||
cdef int _read(self, void* value, size_t size) except -1:
|
cdef int _read(self, void* value, size_t size) except -1:
|
||||||
status = fread(value, size, 1, self._fp)
|
status = fread(value, size, 1, self._fp)
|
||||||
|
|
|
@ -12,10 +12,6 @@ from ...language import Language
|
||||||
from ...attrs import LANG, NORM
|
from ...attrs import LANG, NORM
|
||||||
from ...util import update_exc, add_lookups
|
from ...util import update_exc, add_lookups
|
||||||
|
|
||||||
# Borrowing french syntax parser because both languages use
|
|
||||||
# universal dependencies for tagging/parsing.
|
|
||||||
# Read here for more:
|
|
||||||
# https://github.com/explosion/spaCy/pull/1882#issuecomment-361409573
|
|
||||||
from .syntax_iterators import SYNTAX_ITERATORS
|
from .syntax_iterators import SYNTAX_ITERATORS
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,6 @@ from thinc.neural.util import to_categorical
|
||||||
from thinc.neural.util import get_array_module
|
from thinc.neural.util import get_array_module
|
||||||
|
|
||||||
from spacy.kb import KnowledgeBase
|
from spacy.kb import KnowledgeBase
|
||||||
from ..cli.pretrain import get_cossim_loss
|
|
||||||
from .functions import merge_subtokens
|
from .functions import merge_subtokens
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..syntax.nn_parser cimport Parser
|
from ..syntax.nn_parser cimport Parser
|
||||||
|
@ -168,7 +167,10 @@ class Pipe(object):
|
||||||
self.cfg["pretrained_vectors"] = self.vocab.vectors.name
|
self.cfg["pretrained_vectors"] = self.vocab.vectors.name
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model(**self.cfg)
|
self.model = self.Model(**self.cfg)
|
||||||
|
try:
|
||||||
self.model.from_bytes(b)
|
self.model.from_bytes(b)
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149)
|
||||||
|
|
||||||
deserialize = OrderedDict()
|
deserialize = OrderedDict()
|
||||||
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
||||||
|
@ -197,7 +199,10 @@ class Pipe(object):
|
||||||
self.cfg["pretrained_vectors"] = self.vocab.vectors.name
|
self.cfg["pretrained_vectors"] = self.vocab.vectors.name
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model(**self.cfg)
|
self.model = self.Model(**self.cfg)
|
||||||
|
try:
|
||||||
self.model.from_bytes(p.open("rb").read())
|
self.model.from_bytes(p.open("rb").read())
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149)
|
||||||
|
|
||||||
deserialize = OrderedDict()
|
deserialize = OrderedDict()
|
||||||
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
|
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
|
||||||
|
@ -563,7 +568,10 @@ class Tagger(Pipe):
|
||||||
"token_vector_width",
|
"token_vector_width",
|
||||||
self.cfg.get("token_vector_width", 96))
|
self.cfg.get("token_vector_width", 96))
|
||||||
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
||||||
|
try:
|
||||||
self.model.from_bytes(b)
|
self.model.from_bytes(b)
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149)
|
||||||
|
|
||||||
def load_tag_map(b):
|
def load_tag_map(b):
|
||||||
tag_map = srsly.msgpack_loads(b)
|
tag_map = srsly.msgpack_loads(b)
|
||||||
|
@ -601,7 +609,10 @@ class Tagger(Pipe):
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
|
||||||
with p.open("rb") as file_:
|
with p.open("rb") as file_:
|
||||||
|
try:
|
||||||
self.model.from_bytes(file_.read())
|
self.model.from_bytes(file_.read())
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149)
|
||||||
|
|
||||||
def load_tag_map(p):
|
def load_tag_map(p):
|
||||||
tag_map = srsly.read_msgpack(p)
|
tag_map = srsly.read_msgpack(p)
|
||||||
|
@ -1077,6 +1088,7 @@ class EntityLinker(Pipe):
|
||||||
DOCS: TODO
|
DOCS: TODO
|
||||||
"""
|
"""
|
||||||
name = 'entity_linker'
|
name = 'entity_linker'
|
||||||
|
NIL = "NIL" # string used to refer to a non-existing link
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def Model(cls, **cfg):
|
def Model(cls, **cfg):
|
||||||
|
@ -1093,6 +1105,8 @@ class EntityLinker(Pipe):
|
||||||
self.kb = None
|
self.kb = None
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.sgd_context = None
|
self.sgd_context = None
|
||||||
|
if not self.cfg.get("context_width"):
|
||||||
|
self.cfg["context_width"] = 128
|
||||||
|
|
||||||
def set_kb(self, kb):
|
def set_kb(self, kb):
|
||||||
self.kb = kb
|
self.kb = kb
|
||||||
|
@ -1140,7 +1154,7 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
context_docs = []
|
context_docs = []
|
||||||
entity_encodings = []
|
entity_encodings = []
|
||||||
cats = []
|
|
||||||
priors = []
|
priors = []
|
||||||
type_vectors = []
|
type_vectors = []
|
||||||
|
|
||||||
|
@ -1149,50 +1163,44 @@ class EntityLinker(Pipe):
|
||||||
for doc, gold in zip(docs, golds):
|
for doc, gold in zip(docs, golds):
|
||||||
ents_by_offset = dict()
|
ents_by_offset = dict()
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent
|
||||||
for entity in gold.links:
|
for entity, kb_dict in gold.links.items():
|
||||||
start, end, gold_kb = entity
|
start, end = entity
|
||||||
mention = doc.text[start:end]
|
mention = doc.text[start:end]
|
||||||
|
for kb_id, value in kb_dict.items():
|
||||||
|
entity_encoding = self.kb.get_vector(kb_id)
|
||||||
|
prior_prob = self.kb.get_prior_prob(kb_id, mention)
|
||||||
|
|
||||||
|
gold_ent = ents_by_offset["{}_{}".format(start, end)]
|
||||||
|
if gold_ent is None:
|
||||||
|
raise RuntimeError(Errors.E147.format(method="update", msg="gold entity not found"))
|
||||||
|
|
||||||
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
|
|
||||||
assert gold_ent is not None
|
|
||||||
type_vector = [0 for i in range(len(type_to_int))]
|
type_vector = [0 for i in range(len(type_to_int))]
|
||||||
if len(type_to_int) > 0:
|
if len(type_to_int) > 0:
|
||||||
type_vector[type_to_int[gold_ent.label_]] = 1
|
type_vector[type_to_int[gold_ent.label_]] = 1
|
||||||
|
|
||||||
candidates = self.kb.get_candidates(mention)
|
# store data
|
||||||
random.shuffle(candidates)
|
|
||||||
nr_neg = 0
|
|
||||||
for c in candidates:
|
|
||||||
kb_id = c.entity_
|
|
||||||
entity_encoding = c.entity_vector
|
|
||||||
entity_encodings.append(entity_encoding)
|
entity_encodings.append(entity_encoding)
|
||||||
context_docs.append(doc)
|
context_docs.append(doc)
|
||||||
type_vectors.append(type_vector)
|
type_vectors.append(type_vector)
|
||||||
|
|
||||||
if self.cfg.get("prior_weight", 1) > 0:
|
if self.cfg.get("prior_weight", 1) > 0:
|
||||||
priors.append([c.prior_prob])
|
priors.append([prior_prob])
|
||||||
else:
|
else:
|
||||||
priors.append([0])
|
priors.append([0])
|
||||||
|
|
||||||
if kb_id == gold_kb:
|
|
||||||
cats.append([1])
|
|
||||||
else:
|
|
||||||
nr_neg += 1
|
|
||||||
cats.append([0])
|
|
||||||
|
|
||||||
if len(entity_encodings) > 0:
|
if len(entity_encodings) > 0:
|
||||||
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
|
if not (len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)):
|
||||||
|
raise RuntimeError(Errors.E147.format(method="update", msg="vector lengths not equal"))
|
||||||
|
|
||||||
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
|
|
||||||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||||
|
|
||||||
|
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
|
||||||
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i]
|
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i]
|
||||||
for i in range(len(entity_encodings))]
|
for i in range(len(entity_encodings))]
|
||||||
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
|
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
|
||||||
cats = self.model.ops.asarray(cats, dtype="float32")
|
|
||||||
|
|
||||||
loss, d_scores = self.get_loss(prediction=pred, golds=cats, docs=None)
|
loss, d_scores = self.get_loss(scores=pred, golds=golds, docs=docs)
|
||||||
mention_gradient = bp_mention(d_scores, sgd=sgd)
|
mention_gradient = bp_mention(d_scores, sgd=sgd)
|
||||||
|
|
||||||
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
|
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
|
||||||
|
@ -1203,39 +1211,45 @@ class EntityLinker(Pipe):
|
||||||
return loss
|
return loss
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def get_loss(self, docs, golds, prediction):
|
def get_loss(self, docs, golds, scores):
|
||||||
d_scores = (prediction - golds)
|
cats = []
|
||||||
|
for gold in golds:
|
||||||
|
for entity, kb_dict in gold.links.items():
|
||||||
|
for kb_id, value in kb_dict.items():
|
||||||
|
cats.append([value])
|
||||||
|
|
||||||
|
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||||
|
if len(scores) != len(cats):
|
||||||
|
raise RuntimeError(Errors.E147.format(method="get_loss", msg="gold entities do not match up"))
|
||||||
|
|
||||||
|
d_scores = (scores - cats)
|
||||||
loss = (d_scores ** 2).sum()
|
loss = (d_scores ** 2).sum()
|
||||||
loss = loss / len(golds)
|
loss = loss / len(cats)
|
||||||
return loss, d_scores
|
return loss, d_scores
|
||||||
|
|
||||||
def get_loss_old(self, docs, golds, scores):
|
|
||||||
# this loss function assumes we're only using positive examples
|
|
||||||
loss, gradients = get_cossim_loss(yh=scores, y=golds)
|
|
||||||
loss = loss / len(golds)
|
|
||||||
return loss, gradients
|
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
entities, kb_ids = self.predict([doc])
|
kb_ids, tensors = self.predict([doc])
|
||||||
self.set_annotations([doc], entities, kb_ids)
|
self.set_annotations([doc], kb_ids, tensors=tensors)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||||
for docs in util.minibatch(stream, size=batch_size):
|
for docs in util.minibatch(stream, size=batch_size):
|
||||||
docs = list(docs)
|
docs = list(docs)
|
||||||
entities, kb_ids = self.predict(docs)
|
kb_ids, tensors = self.predict(docs)
|
||||||
self.set_annotations(docs, entities, kb_ids)
|
self.set_annotations(docs, kb_ids, tensors=tensors)
|
||||||
yield from docs
|
yield from docs
|
||||||
|
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
|
""" Return the KB IDs for each entity in each doc, including NIL if there is no prediction """
|
||||||
self.require_model()
|
self.require_model()
|
||||||
self.require_kb()
|
self.require_kb()
|
||||||
|
|
||||||
final_entities = []
|
entity_count = 0
|
||||||
final_kb_ids = []
|
final_kb_ids = []
|
||||||
|
final_tensors = []
|
||||||
|
|
||||||
if not docs:
|
if not docs:
|
||||||
return final_entities, final_kb_ids
|
return final_kb_ids, final_tensors
|
||||||
|
|
||||||
if isinstance(docs, Doc):
|
if isinstance(docs, Doc):
|
||||||
docs = [docs]
|
docs = [docs]
|
||||||
|
@ -1247,14 +1261,19 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
for i, doc in enumerate(docs):
|
for i, doc in enumerate(docs):
|
||||||
if len(doc) > 0:
|
if len(doc) > 0:
|
||||||
|
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||||
context_encoding = context_encodings[i]
|
context_encoding = context_encodings[i]
|
||||||
for ent in doc.ents:
|
for ent in doc.ents:
|
||||||
|
entity_count += 1
|
||||||
type_vector = [0 for i in range(len(type_to_int))]
|
type_vector = [0 for i in range(len(type_to_int))]
|
||||||
if len(type_to_int) > 0:
|
if len(type_to_int) > 0:
|
||||||
type_vector[type_to_int[ent.label_]] = 1
|
type_vector[type_to_int[ent.label_]] = 1
|
||||||
|
|
||||||
candidates = self.kb.get_candidates(ent.text)
|
candidates = self.kb.get_candidates(ent.text)
|
||||||
if candidates:
|
if not candidates:
|
||||||
|
final_kb_ids.append(self.NIL) # no prediction possible for this entity
|
||||||
|
final_tensors.append(context_encoding)
|
||||||
|
else:
|
||||||
random.shuffle(candidates)
|
random.shuffle(candidates)
|
||||||
|
|
||||||
# this will set the prior probabilities to 0 (just like in training) if their weight is 0
|
# this will set the prior probabilities to 0 (just like in training) if their weight is 0
|
||||||
|
@ -1264,7 +1283,9 @@ class EntityLinker(Pipe):
|
||||||
|
|
||||||
if self.cfg.get("context_weight", 1) > 0:
|
if self.cfg.get("context_weight", 1) > 0:
|
||||||
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
||||||
assert len(entity_encodings) == len(prior_probs)
|
if len(entity_encodings) != len(prior_probs):
|
||||||
|
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
|
||||||
|
|
||||||
mention_encodings = [list(context_encoding) + list(entity_encodings[i])
|
mention_encodings = [list(context_encoding) + list(entity_encodings[i])
|
||||||
+ list(prior_probs[i]) + type_vector
|
+ list(prior_probs[i]) + type_vector
|
||||||
for i in range(len(entity_encodings))]
|
for i in range(len(entity_encodings))]
|
||||||
|
@ -1273,14 +1294,25 @@ class EntityLinker(Pipe):
|
||||||
# TODO: thresholding
|
# TODO: thresholding
|
||||||
best_index = scores.argmax()
|
best_index = scores.argmax()
|
||||||
best_candidate = candidates[best_index]
|
best_candidate = candidates[best_index]
|
||||||
final_entities.append(ent)
|
|
||||||
final_kb_ids.append(best_candidate.entity_)
|
final_kb_ids.append(best_candidate.entity_)
|
||||||
|
final_tensors.append(context_encoding)
|
||||||
|
|
||||||
return final_entities, final_kb_ids
|
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
|
||||||
|
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
|
||||||
|
|
||||||
def set_annotations(self, docs, entities, kb_ids=None):
|
return final_kb_ids, final_tensors
|
||||||
for entity, kb_id in zip(entities, kb_ids):
|
|
||||||
for token in entity:
|
def set_annotations(self, docs, kb_ids, tensors=None):
|
||||||
|
count_ents = len([ent for doc in docs for ent in doc.ents])
|
||||||
|
if count_ents != len(kb_ids):
|
||||||
|
raise ValueError(Errors.E148.format(ents=count_ents, ids=len(kb_ids)))
|
||||||
|
|
||||||
|
i=0
|
||||||
|
for doc in docs:
|
||||||
|
for ent in doc.ents:
|
||||||
|
kb_id = kb_ids[i]
|
||||||
|
i += 1
|
||||||
|
for token in ent:
|
||||||
token.ent_kb_id_ = kb_id
|
token.ent_kb_id_ = kb_id
|
||||||
|
|
||||||
def to_disk(self, path, exclude=tuple(), **kwargs):
|
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||||
|
@ -1297,7 +1329,10 @@ class EntityLinker(Pipe):
|
||||||
def load_model(p):
|
def load_model(p):
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
self.model = self.Model(**self.cfg)
|
self.model = self.Model(**self.cfg)
|
||||||
|
try:
|
||||||
self.model.from_bytes(p.open("rb").read())
|
self.model.from_bytes(p.open("rb").read())
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149)
|
||||||
|
|
||||||
def load_kb(p):
|
def load_kb(p):
|
||||||
kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"])
|
kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"])
|
||||||
|
|
|
@ -93,7 +93,7 @@ cdef struct KBEntryC:
|
||||||
int32_t feats_row
|
int32_t feats_row
|
||||||
|
|
||||||
# log probability of entity, based on corpus frequency
|
# log probability of entity, based on corpus frequency
|
||||||
float prob
|
float freq
|
||||||
|
|
||||||
|
|
||||||
# Each alias struct stores a list of Entry pointers with their prior probabilities
|
# Each alias struct stores a list of Entry pointers with their prior probabilities
|
||||||
|
|
|
@ -631,7 +631,10 @@ cdef class Parser:
|
||||||
cfg = {}
|
cfg = {}
|
||||||
with (path / 'model').open('rb') as file_:
|
with (path / 'model').open('rb') as file_:
|
||||||
bytes_data = file_.read()
|
bytes_data = file_.read()
|
||||||
|
try:
|
||||||
self.model.from_bytes(bytes_data)
|
self.model.from_bytes(bytes_data)
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149)
|
||||||
self.cfg.update(cfg)
|
self.cfg.update(cfg)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
@ -663,6 +666,9 @@ cdef class Parser:
|
||||||
else:
|
else:
|
||||||
cfg = {}
|
cfg = {}
|
||||||
if 'model' in msg:
|
if 'model' in msg:
|
||||||
|
try:
|
||||||
self.model.from_bytes(msg['model'])
|
self.model.from_bytes(msg['model'])
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149)
|
||||||
self.cfg.update(cfg)
|
self.cfg.update(cfg)
|
||||||
return self
|
return self
|
||||||
|
|
|
@ -13,22 +13,38 @@ def nlp():
|
||||||
return English()
|
return English()
|
||||||
|
|
||||||
|
|
||||||
|
def assert_almost_equal(a, b):
|
||||||
|
delta = 0.0001
|
||||||
|
assert a - delta <= b <= a + delta
|
||||||
|
|
||||||
|
|
||||||
def test_kb_valid_entities(nlp):
|
def test_kb_valid_entities(nlp):
|
||||||
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[8, 4, 3])
|
||||||
mykb.add_entity(entity='Q2', prob=0.5, entity_vector=[2])
|
mykb.add_entity(entity="Q2", freq=0.5, entity_vector=[2, 1, 0])
|
||||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[-1, -6, 5])
|
||||||
|
|
||||||
# adding aliases
|
# adding aliases
|
||||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
|
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
|
||||||
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
|
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||||
|
|
||||||
# test the size of the corresponding KB
|
# test the size of the corresponding KB
|
||||||
assert(mykb.get_size_entities() == 3)
|
assert mykb.get_size_entities() == 3
|
||||||
assert(mykb.get_size_aliases() == 2)
|
assert mykb.get_size_aliases() == 2
|
||||||
|
|
||||||
|
# test retrieval of the entity vectors
|
||||||
|
assert mykb.get_vector("Q1") == [8, 4, 3]
|
||||||
|
assert mykb.get_vector("Q2") == [2, 1, 0]
|
||||||
|
assert mykb.get_vector("Q3") == [-1, -6, 5]
|
||||||
|
|
||||||
|
# test retrieval of prior probabilities
|
||||||
|
assert_almost_equal(mykb.get_prior_prob(entity="Q2", alias="douglas"), 0.8)
|
||||||
|
assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglas"), 0.2)
|
||||||
|
assert_almost_equal(mykb.get_prior_prob(entity="Q342", alias="douglas"), 0.0)
|
||||||
|
assert_almost_equal(mykb.get_prior_prob(entity="Q3", alias="douglassssss"), 0.0)
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_entities(nlp):
|
def test_kb_invalid_entities(nlp):
|
||||||
|
@ -36,13 +52,15 @@ def test_kb_invalid_entities(nlp):
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
|
||||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
|
||||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
|
||||||
|
|
||||||
# adding aliases - should fail because one of the given IDs is not valid
|
# adding aliases - should fail because one of the given IDs is not valid
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q342'], probabilities=[0.8, 0.2])
|
mykb.add_alias(
|
||||||
|
alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_probabilities(nlp):
|
def test_kb_invalid_probabilities(nlp):
|
||||||
|
@ -50,13 +68,13 @@ def test_kb_invalid_probabilities(nlp):
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
|
||||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
|
||||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
|
||||||
|
|
||||||
# adding aliases - should fail because the sum of the probabilities exceeds 1
|
# adding aliases - should fail because the sum of the probabilities exceeds 1
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.4])
|
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_combination(nlp):
|
def test_kb_invalid_combination(nlp):
|
||||||
|
@ -64,13 +82,15 @@ def test_kb_invalid_combination(nlp):
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
|
||||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
|
||||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
|
||||||
|
|
||||||
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
|
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.3, 0.4, 0.1])
|
mykb.add_alias(
|
||||||
|
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_entity_vector(nlp):
|
def test_kb_invalid_entity_vector(nlp):
|
||||||
|
@ -78,11 +98,11 @@ def test_kb_invalid_entity_vector(nlp):
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1, 2, 3])
|
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1, 2, 3])
|
||||||
|
|
||||||
# this should fail because the kb's expected entity vector length is 3
|
# this should fail because the kb's expected entity vector length is 3
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
|
||||||
|
|
||||||
|
|
||||||
def test_candidate_generation(nlp):
|
def test_candidate_generation(nlp):
|
||||||
|
@ -90,18 +110,24 @@ def test_candidate_generation(nlp):
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=0.7, entity_vector=[1])
|
||||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2])
|
||||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3])
|
||||||
|
|
||||||
# adding aliases
|
# adding aliases
|
||||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
|
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
|
||||||
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
|
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||||
|
|
||||||
# test the size of the relevant candidates
|
# test the size of the relevant candidates
|
||||||
assert(len(mykb.get_candidates('douglas')) == 2)
|
assert len(mykb.get_candidates("douglas")) == 2
|
||||||
assert(len(mykb.get_candidates('adam')) == 1)
|
assert len(mykb.get_candidates("adam")) == 1
|
||||||
assert(len(mykb.get_candidates('shrubbery')) == 0)
|
assert len(mykb.get_candidates("shrubbery")) == 0
|
||||||
|
|
||||||
|
# test the content of the candidates
|
||||||
|
assert mykb.get_candidates("adam")[0].entity_ == "Q2"
|
||||||
|
assert mykb.get_candidates("adam")[0].alias_ == "adam"
|
||||||
|
assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 0.2)
|
||||||
|
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
|
||||||
|
|
||||||
|
|
||||||
def test_preserving_links_asdoc(nlp):
|
def test_preserving_links_asdoc(nlp):
|
||||||
|
@ -109,24 +135,26 @@ def test_preserving_links_asdoc(nlp):
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1])
|
||||||
mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1])
|
mykb.add_entity(entity="Q2", freq=0.8, entity_vector=[1])
|
||||||
|
|
||||||
# adding aliases
|
# adding aliases
|
||||||
mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7])
|
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
|
||||||
mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6])
|
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
|
||||||
|
|
||||||
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
|
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
|
||||||
sentencizer = nlp.create_pipe("sentencizer")
|
sentencizer = nlp.create_pipe("sentencizer")
|
||||||
nlp.add_pipe(sentencizer)
|
nlp.add_pipe(sentencizer)
|
||||||
|
|
||||||
ruler = EntityRuler(nlp)
|
ruler = EntityRuler(nlp)
|
||||||
patterns = [{"label": "GPE", "pattern": "Boston"},
|
patterns = [
|
||||||
{"label": "GPE", "pattern": "Denver"}]
|
{"label": "GPE", "pattern": "Boston"},
|
||||||
|
{"label": "GPE", "pattern": "Denver"},
|
||||||
|
]
|
||||||
ruler.add_patterns(patterns)
|
ruler.add_patterns(patterns)
|
||||||
nlp.add_pipe(ruler)
|
nlp.add_pipe(ruler)
|
||||||
|
|
||||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64})
|
el_pipe = nlp.create_pipe(name="entity_linker", config={"context_width": 64})
|
||||||
el_pipe.set_kb(mykb)
|
el_pipe.set_kb(mykb)
|
||||||
el_pipe.begin_training()
|
el_pipe.begin_training()
|
||||||
el_pipe.context_weight = 0
|
el_pipe.context_weight = 0
|
||||||
|
|
112
spacy/tests/regression/test_issue3962.py
Normal file
112
spacy/tests/regression/test_issue3962.py
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ..util import get_doc
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def doc(en_tokenizer):
|
||||||
|
text = "He jests at scars, that never felt a wound."
|
||||||
|
heads = [1, 6, -1, -1, 3, 2, 1, 0, 1, -2, -3]
|
||||||
|
deps = [
|
||||||
|
"nsubj",
|
||||||
|
"ccomp",
|
||||||
|
"prep",
|
||||||
|
"pobj",
|
||||||
|
"punct",
|
||||||
|
"nsubj",
|
||||||
|
"neg",
|
||||||
|
"ROOT",
|
||||||
|
"det",
|
||||||
|
"dobj",
|
||||||
|
"punct",
|
||||||
|
]
|
||||||
|
tokens = en_tokenizer(text)
|
||||||
|
return get_doc(tokens.vocab, words=[t.text for t in tokens], heads=heads, deps=deps)
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue3962(doc):
|
||||||
|
""" Ensure that as_doc does not result in out-of-bound access of tokens.
|
||||||
|
This is achieved by setting the head to itself if it would lie out of the span otherwise."""
|
||||||
|
span2 = doc[1:5] # "jests at scars ,"
|
||||||
|
doc2 = span2.as_doc()
|
||||||
|
doc2_json = doc2.to_json()
|
||||||
|
assert doc2_json
|
||||||
|
|
||||||
|
assert doc2[0].head.text == "jests" # head set to itself, being the new artificial root
|
||||||
|
assert doc2[0].dep_ == "dep"
|
||||||
|
assert doc2[1].head.text == "jests"
|
||||||
|
assert doc2[1].dep_ == "prep"
|
||||||
|
assert doc2[2].head.text == "at"
|
||||||
|
assert doc2[2].dep_ == "pobj"
|
||||||
|
assert doc2[3].head.text == "jests" # head set to the new artificial root
|
||||||
|
assert doc2[3].dep_ == "dep"
|
||||||
|
|
||||||
|
# We should still have 1 sentence
|
||||||
|
assert len(list(doc2.sents)) == 1
|
||||||
|
|
||||||
|
span3 = doc[6:9] # "never felt a"
|
||||||
|
doc3 = span3.as_doc()
|
||||||
|
doc3_json = doc3.to_json()
|
||||||
|
assert doc3_json
|
||||||
|
|
||||||
|
assert doc3[0].head.text == "felt"
|
||||||
|
assert doc3[0].dep_ == "neg"
|
||||||
|
assert doc3[1].head.text == "felt"
|
||||||
|
assert doc3[1].dep_ == "ROOT"
|
||||||
|
assert doc3[2].head.text == "felt" # head set to ancestor
|
||||||
|
assert doc3[2].dep_ == "dep"
|
||||||
|
|
||||||
|
# We should still have 1 sentence as "a" can be attached to "felt" instead of "wound"
|
||||||
|
assert len(list(doc3.sents)) == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def two_sent_doc(en_tokenizer):
|
||||||
|
text = "He jests at scars. They never felt a wound."
|
||||||
|
heads = [1, 0, -1, -1, -3, 2, 1, 0, 1, -2, -3]
|
||||||
|
deps = [
|
||||||
|
"nsubj",
|
||||||
|
"ROOT",
|
||||||
|
"prep",
|
||||||
|
"pobj",
|
||||||
|
"punct",
|
||||||
|
"nsubj",
|
||||||
|
"neg",
|
||||||
|
"ROOT",
|
||||||
|
"det",
|
||||||
|
"dobj",
|
||||||
|
"punct",
|
||||||
|
]
|
||||||
|
tokens = en_tokenizer(text)
|
||||||
|
return get_doc(tokens.vocab, words=[t.text for t in tokens], heads=heads, deps=deps)
|
||||||
|
|
||||||
|
|
||||||
|
def test_issue3962_long(two_sent_doc):
|
||||||
|
""" Ensure that as_doc does not result in out-of-bound access of tokens.
|
||||||
|
This is achieved by setting the head to itself if it would lie out of the span otherwise."""
|
||||||
|
span2 = two_sent_doc[1:7] # "jests at scars. They never"
|
||||||
|
doc2 = span2.as_doc()
|
||||||
|
doc2_json = doc2.to_json()
|
||||||
|
assert doc2_json
|
||||||
|
|
||||||
|
assert doc2[0].head.text == "jests" # head set to itself, being the new artificial root (in sentence 1)
|
||||||
|
assert doc2[0].dep_ == "ROOT"
|
||||||
|
assert doc2[1].head.text == "jests"
|
||||||
|
assert doc2[1].dep_ == "prep"
|
||||||
|
assert doc2[2].head.text == "at"
|
||||||
|
assert doc2[2].dep_ == "pobj"
|
||||||
|
assert doc2[3].head.text == "jests"
|
||||||
|
assert doc2[3].dep_ == "punct"
|
||||||
|
assert doc2[4].head.text == "They" # head set to itself, being the new artificial root (in sentence 2)
|
||||||
|
assert doc2[4].dep_ == "dep"
|
||||||
|
assert doc2[4].head.text == "They" # head set to the new artificial head (in sentence 2)
|
||||||
|
assert doc2[4].dep_ == "dep"
|
||||||
|
|
||||||
|
# We should still have 2 sentences
|
||||||
|
sents = list(doc2.sents)
|
||||||
|
assert len(sents) == 2
|
||||||
|
assert sents[0].text == "jests at scars ."
|
||||||
|
assert sents[1].text == "They never"
|
28
spacy/tests/regression/test_issue4002.py
Normal file
28
spacy/tests/regression/test_issue4002.py
Normal file
|
@ -0,0 +1,28 @@
|
||||||
|
# coding: utf8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from spacy.matcher import PhraseMatcher
|
||||||
|
from spacy.tokens import Doc
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.xfail
|
||||||
|
def test_issue4002(en_vocab):
|
||||||
|
"""Test that the PhraseMatcher can match on overwritten NORM attributes.
|
||||||
|
"""
|
||||||
|
matcher = PhraseMatcher(en_vocab, attr="NORM")
|
||||||
|
pattern1 = Doc(en_vocab, words=["c", "d"])
|
||||||
|
assert [t.norm_ for t in pattern1] == ["c", "d"]
|
||||||
|
matcher.add("TEST", None, pattern1)
|
||||||
|
doc = Doc(en_vocab, words=["a", "b", "c", "d"])
|
||||||
|
assert [t.norm_ for t in doc] == ["a", "b", "c", "d"]
|
||||||
|
matches = matcher(doc)
|
||||||
|
assert len(matches) == 1
|
||||||
|
matcher = PhraseMatcher(en_vocab, attr="NORM")
|
||||||
|
pattern2 = Doc(en_vocab, words=["1", "2"])
|
||||||
|
pattern2[0].norm_ = "c"
|
||||||
|
pattern2[1].norm_ = "d"
|
||||||
|
assert [t.norm_ for t in pattern2] == ["c", "d"]
|
||||||
|
matcher.add("TEST", None, pattern2)
|
||||||
|
matches = matcher(doc)
|
||||||
|
assert len(matches) == 1
|
|
@ -30,10 +30,10 @@ def test_serialize_kb_disk(en_vocab):
|
||||||
def _get_dummy_kb(vocab):
|
def _get_dummy_kb(vocab):
|
||||||
kb = KnowledgeBase(vocab=vocab, entity_vector_length=3)
|
kb = KnowledgeBase(vocab=vocab, entity_vector_length=3)
|
||||||
|
|
||||||
kb.add_entity(entity='Q53', prob=0.33, entity_vector=[0, 5, 3])
|
kb.add_entity(entity='Q53', freq=0.33, entity_vector=[0, 5, 3])
|
||||||
kb.add_entity(entity='Q17', prob=0.2, entity_vector=[7, 1, 0])
|
kb.add_entity(entity='Q17', freq=0.2, entity_vector=[7, 1, 0])
|
||||||
kb.add_entity(entity='Q007', prob=0.7, entity_vector=[0, 0, 7])
|
kb.add_entity(entity='Q007', freq=0.7, entity_vector=[0, 0, 7])
|
||||||
kb.add_entity(entity='Q44', prob=0.4, entity_vector=[4, 4, 4])
|
kb.add_entity(entity='Q44', freq=0.4, entity_vector=[4, 4, 4])
|
||||||
|
|
||||||
kb.add_alias(alias='double07', entities=['Q17', 'Q007'], probabilities=[0.1, 0.9])
|
kb.add_alias(alias='double07', entities=['Q17', 'Q007'], probabilities=[0.1, 0.9])
|
||||||
kb.add_alias(alias='guy', entities=['Q53', 'Q007', 'Q17', 'Q44'], probabilities=[0.3, 0.3, 0.2, 0.1])
|
kb.add_alias(alias='guy', entities=['Q53', 'Q007', 'Q17', 'Q44'], probabilities=[0.3, 0.3, 0.2, 0.1])
|
||||||
|
|
|
@ -348,7 +348,7 @@ cdef class Tokenizer:
|
||||||
"""Add a special-case tokenization rule.
|
"""Add a special-case tokenization rule.
|
||||||
|
|
||||||
string (unicode): The string to specially tokenize.
|
string (unicode): The string to specially tokenize.
|
||||||
token_attrs (iterable): A sequence of dicts, where each dict describes
|
substrings (iterable): A sequence of dicts, where each dict describes
|
||||||
a token and its attributes. The `ORTH` fields of the attributes
|
a token and its attributes. The `ORTH` fields of the attributes
|
||||||
must exactly match the string when they are concatenated.
|
must exactly match the string when they are concatenated.
|
||||||
|
|
||||||
|
|
|
@ -794,7 +794,7 @@ cdef class Doc:
|
||||||
if array[i, col] != 0:
|
if array[i, col] != 0:
|
||||||
self.vocab.morphology.assign_tag(&tokens[i], array[i, col])
|
self.vocab.morphology.assign_tag(&tokens[i], array[i, col])
|
||||||
# Now load the data
|
# Now load the data
|
||||||
for i in range(self.length):
|
for i in range(length):
|
||||||
token = &self.c[i]
|
token = &self.c[i]
|
||||||
for j in range(n_attrs):
|
for j in range(n_attrs):
|
||||||
if attr_ids[j] != TAG:
|
if attr_ids[j] != TAG:
|
||||||
|
@ -804,7 +804,7 @@ cdef class Doc:
|
||||||
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)
|
self.is_tagged = bool(self.is_tagged or TAG in attrs or POS in attrs)
|
||||||
# If document is parsed, set children
|
# If document is parsed, set children
|
||||||
if self.is_parsed:
|
if self.is_parsed:
|
||||||
set_children_from_heads(self.c, self.length)
|
set_children_from_heads(self.c, length)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get_lca_matrix(self):
|
def get_lca_matrix(self):
|
||||||
|
|
|
@ -17,6 +17,7 @@ from ..attrs cimport attr_id_t
|
||||||
from ..parts_of_speech cimport univ_pos_t
|
from ..parts_of_speech cimport univ_pos_t
|
||||||
from ..attrs cimport *
|
from ..attrs cimport *
|
||||||
from ..lexeme cimport Lexeme
|
from ..lexeme cimport Lexeme
|
||||||
|
from ..symbols cimport dep
|
||||||
|
|
||||||
from ..util import normalize_slice
|
from ..util import normalize_slice
|
||||||
from ..compat import is_config, basestring_
|
from ..compat import is_config, basestring_
|
||||||
|
@ -206,7 +207,6 @@ cdef class Span:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/span#as_doc
|
DOCS: https://spacy.io/api/span#as_doc
|
||||||
"""
|
"""
|
||||||
# TODO: Fix!
|
|
||||||
words = [t.text for t in self]
|
words = [t.text for t in self]
|
||||||
spaces = [bool(t.whitespace_) for t in self]
|
spaces = [bool(t.whitespace_) for t in self]
|
||||||
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
|
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
|
||||||
|
@ -220,7 +220,9 @@ cdef class Span:
|
||||||
else:
|
else:
|
||||||
array_head.append(SENT_START)
|
array_head.append(SENT_START)
|
||||||
array = self.doc.to_array(array_head)
|
array = self.doc.to_array(array_head)
|
||||||
doc.from_array(array_head, array[self.start : self.end])
|
array = array[self.start : self.end]
|
||||||
|
self._fix_dep_copy(array_head, array)
|
||||||
|
doc.from_array(array_head, array)
|
||||||
doc.noun_chunks_iterator = self.doc.noun_chunks_iterator
|
doc.noun_chunks_iterator = self.doc.noun_chunks_iterator
|
||||||
doc.user_hooks = self.doc.user_hooks
|
doc.user_hooks = self.doc.user_hooks
|
||||||
doc.user_span_hooks = self.doc.user_span_hooks
|
doc.user_span_hooks = self.doc.user_span_hooks
|
||||||
|
@ -235,6 +237,44 @@ cdef class Span:
|
||||||
doc.cats[cat_label] = value
|
doc.cats[cat_label] = value
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
|
def _fix_dep_copy(self, attrs, array):
|
||||||
|
""" Rewire dependency links to make sure their heads fall into the span
|
||||||
|
while still keeping the correct number of sentences. """
|
||||||
|
cdef int length = len(array)
|
||||||
|
cdef attr_t value
|
||||||
|
cdef int i, head_col, ancestor_i
|
||||||
|
old_to_new_root = dict()
|
||||||
|
if HEAD in attrs:
|
||||||
|
head_col = attrs.index(HEAD)
|
||||||
|
for i in range(length):
|
||||||
|
# if the HEAD refers to a token outside this span, find a more appropriate ancestor
|
||||||
|
token = self[i]
|
||||||
|
ancestor_i = token.head.i - self.start # span offset
|
||||||
|
if ancestor_i not in range(length):
|
||||||
|
if DEP in attrs:
|
||||||
|
array[i, attrs.index(DEP)] = dep
|
||||||
|
|
||||||
|
# try finding an ancestor within this span
|
||||||
|
ancestors = token.ancestors
|
||||||
|
for ancestor in ancestors:
|
||||||
|
ancestor_i = ancestor.i - self.start
|
||||||
|
if ancestor_i in range(length):
|
||||||
|
array[i, head_col] = ancestor_i - i
|
||||||
|
|
||||||
|
# if there is no appropriate ancestor, define a new artificial root
|
||||||
|
value = array[i, head_col]
|
||||||
|
if (i+value) not in range(length):
|
||||||
|
new_root = old_to_new_root.get(ancestor_i, None)
|
||||||
|
if new_root is not None:
|
||||||
|
# take the same artificial root as a previous token from the same sentence
|
||||||
|
array[i, head_col] = new_root - i
|
||||||
|
else:
|
||||||
|
# set this token as the new artificial root
|
||||||
|
array[i, head_col] = 0
|
||||||
|
old_to_new_root[ancestor_i] = i
|
||||||
|
|
||||||
|
return array
|
||||||
|
|
||||||
def merge(self, *args, **attributes):
|
def merge(self, *args, **attributes):
|
||||||
"""Retokenize the document, such that the span is merged into a single
|
"""Retokenize the document, such that the span is merged into a single
|
||||||
token.
|
token.
|
||||||
|
@ -500,7 +540,7 @@ cdef class Span:
|
||||||
if "root" in self.doc.user_span_hooks:
|
if "root" in self.doc.user_span_hooks:
|
||||||
return self.doc.user_span_hooks["root"](self)
|
return self.doc.user_span_hooks["root"](self)
|
||||||
# This should probably be called 'head', and the other one called
|
# This should probably be called 'head', and the other one called
|
||||||
# 'gov'. But we went with 'head' elsehwhere, and now we're stuck =/
|
# 'gov'. But we went with 'head' elsewhere, and now we're stuck =/
|
||||||
cdef int i
|
cdef int i
|
||||||
# First, we scan through the Span, and check whether there's a word
|
# First, we scan through the Span, and check whether there's a word
|
||||||
# with head==0, i.e. a sentence root. If so, we can return it. The
|
# with head==0, i.e. a sentence root. If so, we can return it. The
|
||||||
|
|
|
@ -45,10 +45,11 @@ Whether the provided syntactic annotations form a projective dependency tree.
|
||||||
|
|
||||||
| Name | Type | Description |
|
| Name | Type | Description |
|
||||||
| --------------------------------- | ---- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| --------------------------------- | ---- | -------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `words` | list | The words. |
|
||||||
| `tags` | list | The part-of-speech tag annotations. |
|
| `tags` | list | The part-of-speech tag annotations. |
|
||||||
| `heads` | list | The syntactic head annotations. |
|
| `heads` | list | The syntactic head annotations. |
|
||||||
| `labels` | list | The syntactic relation-type annotations. |
|
| `labels` | list | The syntactic relation-type annotations. |
|
||||||
| `ents` | list | The named entity annotations. |
|
| `ner` | list | The named entity annotations as BILUO tags. |
|
||||||
| `cand_to_gold` | list | The alignment from candidate tokenization to gold tokenization. |
|
| `cand_to_gold` | list | The alignment from candidate tokenization to gold tokenization. |
|
||||||
| `gold_to_cand` | list | The alignment from gold tokenization to candidate tokenization. |
|
| `gold_to_cand` | list | The alignment from gold tokenization to candidate tokenization. |
|
||||||
| `cats` <Tag variant="new">2</Tag> | list | Entries in the list should be either a label, or a `(start, end, label)` triple. The tuple form is used for categories applied to spans of the document. |
|
| `cats` <Tag variant="new">2</Tag> | list | Entries in the list should be either a label, or a `(start, end, label)` triple. The tuple form is used for categories applied to spans of the document. |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user