mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-03 19:08:06 +03:00
get vector functionality + unit test
This commit is contained in:
parent
a63d15a142
commit
4086c6ff60
|
@ -2,6 +2,7 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
import random
|
||||||
import re
|
import re
|
||||||
import bz2
|
import bz2
|
||||||
import datetime
|
import datetime
|
||||||
|
@ -27,21 +28,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
Read the XML wikipedia data to parse out training data:
|
Read the XML wikipedia data to parse out training data:
|
||||||
raw text data + positive instances
|
raw text data + positive instances
|
||||||
"""
|
"""
|
||||||
title_regex = re.compile(r'(?<=<title>).*(?=</title>)')
|
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
|
||||||
id_regex = re.compile(r'(?<=<id>)\d*(?=</id>)')
|
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
||||||
|
|
||||||
read_ids = set()
|
read_ids = set()
|
||||||
entityfile_loc = training_output / ENTITY_FILE
|
entityfile_loc = training_output / ENTITY_FILE
|
||||||
with open(entityfile_loc, mode="w", encoding='utf8') as entityfile:
|
with open(entityfile_loc, mode="w", encoding="utf8") as entityfile:
|
||||||
# write entity training header file
|
# write entity training header file
|
||||||
_write_training_entity(outputfile=entityfile,
|
_write_training_entity(
|
||||||
article_id="article_id",
|
outputfile=entityfile,
|
||||||
alias="alias",
|
article_id="article_id",
|
||||||
entity="WD_id",
|
alias="alias",
|
||||||
start="start",
|
entity="WD_id",
|
||||||
end="end")
|
start="start",
|
||||||
|
end="end",
|
||||||
|
)
|
||||||
|
|
||||||
with bz2.open(wikipedia_input, mode='rb') as file:
|
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||||
line = file.readline()
|
line = file.readline()
|
||||||
cnt = 0
|
cnt = 0
|
||||||
article_text = ""
|
article_text = ""
|
||||||
|
@ -51,7 +54,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
reading_revision = False
|
reading_revision = False
|
||||||
while line and (not limit or cnt < limit):
|
while line and (not limit or cnt < limit):
|
||||||
if cnt % 1000000 == 0:
|
if cnt % 1000000 == 0:
|
||||||
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
|
print(
|
||||||
|
datetime.datetime.now(),
|
||||||
|
"processed",
|
||||||
|
cnt,
|
||||||
|
"lines of Wikipedia dump",
|
||||||
|
)
|
||||||
clean_line = line.strip().decode("utf-8")
|
clean_line = line.strip().decode("utf-8")
|
||||||
|
|
||||||
if clean_line == "<revision>":
|
if clean_line == "<revision>":
|
||||||
|
@ -69,12 +77,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
elif clean_line == "</page>":
|
elif clean_line == "</page>":
|
||||||
if article_id:
|
if article_id:
|
||||||
try:
|
try:
|
||||||
_process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(),
|
_process_wp_text(
|
||||||
training_output)
|
wp_to_id,
|
||||||
|
entityfile,
|
||||||
|
article_id,
|
||||||
|
article_title,
|
||||||
|
article_text.strip(),
|
||||||
|
training_output,
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print("Error processing article", article_id, article_title, e)
|
print(
|
||||||
|
"Error processing article", article_id, article_title, e
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
print("Done processing a page, but couldn't find an article_id ?", article_title)
|
print(
|
||||||
|
"Done processing a page, but couldn't find an article_id ?",
|
||||||
|
article_title,
|
||||||
|
)
|
||||||
article_text = ""
|
article_text = ""
|
||||||
article_title = None
|
article_title = None
|
||||||
article_id = None
|
article_id = None
|
||||||
|
@ -98,7 +117,9 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
if ids:
|
if ids:
|
||||||
article_id = ids[0]
|
article_id = ids[0]
|
||||||
if article_id in read_ids:
|
if article_id in read_ids:
|
||||||
print("Found duplicate article ID", article_id, clean_line) # This should never happen ...
|
print(
|
||||||
|
"Found duplicate article ID", article_id, clean_line
|
||||||
|
) # This should never happen ...
|
||||||
read_ids.add(article_id)
|
read_ids.add(article_id)
|
||||||
|
|
||||||
# read the title of this article (outside the revision portion of the document)
|
# read the title of this article (outside the revision portion of the document)
|
||||||
|
@ -111,10 +132,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
cnt += 1
|
cnt += 1
|
||||||
|
|
||||||
|
|
||||||
text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)')
|
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
||||||
|
|
||||||
|
|
||||||
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output):
|
def _process_wp_text(
|
||||||
|
wp_to_id, entityfile, article_id, article_title, article_text, training_output
|
||||||
|
):
|
||||||
found_entities = False
|
found_entities = False
|
||||||
|
|
||||||
# ignore meta Wikipedia pages
|
# ignore meta Wikipedia pages
|
||||||
|
@ -141,11 +164,11 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
||||||
entity_buffer = ""
|
entity_buffer = ""
|
||||||
mention_buffer = ""
|
mention_buffer = ""
|
||||||
for index, letter in enumerate(clean_text):
|
for index, letter in enumerate(clean_text):
|
||||||
if letter == '[':
|
if letter == "[":
|
||||||
open_read += 1
|
open_read += 1
|
||||||
elif letter == ']':
|
elif letter == "]":
|
||||||
open_read -= 1
|
open_read -= 1
|
||||||
elif letter == '|':
|
elif letter == "|":
|
||||||
if reading_text:
|
if reading_text:
|
||||||
final_text += letter
|
final_text += letter
|
||||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||||
|
@ -163,7 +186,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
||||||
elif reading_text:
|
elif reading_text:
|
||||||
final_text += letter
|
final_text += letter
|
||||||
else:
|
else:
|
||||||
raise ValueError("Not sure at point", clean_text[index-2:index+2])
|
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
|
||||||
|
|
||||||
if open_read > 2:
|
if open_read > 2:
|
||||||
reading_special_case = True
|
reading_special_case = True
|
||||||
|
@ -175,7 +198,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
||||||
|
|
||||||
# we just finished reading an entity
|
# we just finished reading an entity
|
||||||
if open_read == 0 and not reading_text:
|
if open_read == 0 and not reading_text:
|
||||||
if '#' in entity_buffer or entity_buffer.startswith(':'):
|
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
||||||
reading_special_case = True
|
reading_special_case = True
|
||||||
# Ignore cases with nested structures like File: handles etc
|
# Ignore cases with nested structures like File: handles etc
|
||||||
if not reading_special_case:
|
if not reading_special_case:
|
||||||
|
@ -185,12 +208,14 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
||||||
end = start + len(mention_buffer)
|
end = start + len(mention_buffer)
|
||||||
qid = wp_to_id.get(entity_buffer, None)
|
qid = wp_to_id.get(entity_buffer, None)
|
||||||
if qid:
|
if qid:
|
||||||
_write_training_entity(outputfile=entityfile,
|
_write_training_entity(
|
||||||
article_id=article_id,
|
outputfile=entityfile,
|
||||||
alias=mention_buffer,
|
article_id=article_id,
|
||||||
entity=qid,
|
alias=mention_buffer,
|
||||||
start=start,
|
entity=qid,
|
||||||
end=end)
|
start=start,
|
||||||
|
end=end,
|
||||||
|
)
|
||||||
found_entities = True
|
found_entities = True
|
||||||
final_text += mention_buffer
|
final_text += mention_buffer
|
||||||
|
|
||||||
|
@ -203,29 +228,35 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
||||||
reading_special_case = False
|
reading_special_case = False
|
||||||
|
|
||||||
if found_entities:
|
if found_entities:
|
||||||
_write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output)
|
_write_training_article(
|
||||||
|
article_id=article_id,
|
||||||
|
clean_text=final_text,
|
||||||
|
training_output=training_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
info_regex = re.compile(r'{[^{]*?}')
|
info_regex = re.compile(r"{[^{]*?}")
|
||||||
htlm_regex = re.compile(r'<!--[^-]*-->')
|
htlm_regex = re.compile(r"<!--[^-]*-->")
|
||||||
category_regex = re.compile(r'\[\[Category:[^\[]*]]')
|
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
|
||||||
file_regex = re.compile(r'\[\[File:[^[\]]+]]')
|
file_regex = re.compile(r"\[\[File:[^[\]]+]]")
|
||||||
ref_regex = re.compile(r'<ref.*?>') # non-greedy
|
ref_regex = re.compile(r"<ref.*?>") # non-greedy
|
||||||
ref_2_regex = re.compile(r'</ref.*?>') # non-greedy
|
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
||||||
|
|
||||||
|
|
||||||
def _get_clean_wp_text(article_text):
|
def _get_clean_wp_text(article_text):
|
||||||
clean_text = article_text.strip()
|
clean_text = article_text.strip()
|
||||||
|
|
||||||
# remove bolding & italic markup
|
# remove bolding & italic markup
|
||||||
clean_text = clean_text.replace('\'\'\'', '')
|
clean_text = clean_text.replace("'''", "")
|
||||||
clean_text = clean_text.replace('\'\'', '')
|
clean_text = clean_text.replace("''", "")
|
||||||
|
|
||||||
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
||||||
try_again = True
|
try_again = True
|
||||||
previous_length = len(clean_text)
|
previous_length = len(clean_text)
|
||||||
while try_again:
|
while try_again:
|
||||||
clean_text = info_regex.sub('', clean_text) # non-greedy match excluding a nested {
|
clean_text = info_regex.sub(
|
||||||
|
"", clean_text
|
||||||
|
) # non-greedy match excluding a nested {
|
||||||
if len(clean_text) < previous_length:
|
if len(clean_text) < previous_length:
|
||||||
try_again = True
|
try_again = True
|
||||||
else:
|
else:
|
||||||
|
@ -233,14 +264,14 @@ def _get_clean_wp_text(article_text):
|
||||||
previous_length = len(clean_text)
|
previous_length = len(clean_text)
|
||||||
|
|
||||||
# remove HTML comments
|
# remove HTML comments
|
||||||
clean_text = htlm_regex.sub('', clean_text)
|
clean_text = htlm_regex.sub("", clean_text)
|
||||||
|
|
||||||
# remove Category and File statements
|
# remove Category and File statements
|
||||||
clean_text = category_regex.sub('', clean_text)
|
clean_text = category_regex.sub("", clean_text)
|
||||||
clean_text = file_regex.sub('', clean_text)
|
clean_text = file_regex.sub("", clean_text)
|
||||||
|
|
||||||
# remove multiple =
|
# remove multiple =
|
||||||
while '==' in clean_text:
|
while "==" in clean_text:
|
||||||
clean_text = clean_text.replace("==", "=")
|
clean_text = clean_text.replace("==", "=")
|
||||||
|
|
||||||
clean_text = clean_text.replace(". =", ".")
|
clean_text = clean_text.replace(". =", ".")
|
||||||
|
@ -249,43 +280,56 @@ def _get_clean_wp_text(article_text):
|
||||||
clean_text = clean_text.replace(" =", "")
|
clean_text = clean_text.replace(" =", "")
|
||||||
|
|
||||||
# remove refs (non-greedy match)
|
# remove refs (non-greedy match)
|
||||||
clean_text = ref_regex.sub('', clean_text)
|
clean_text = ref_regex.sub("", clean_text)
|
||||||
clean_text = ref_2_regex.sub('', clean_text)
|
clean_text = ref_2_regex.sub("", clean_text)
|
||||||
|
|
||||||
# remove additional wikiformatting
|
# remove additional wikiformatting
|
||||||
clean_text = re.sub(r'<blockquote>', '', clean_text)
|
clean_text = re.sub(r"<blockquote>", "", clean_text)
|
||||||
clean_text = re.sub(r'</blockquote>', '', clean_text)
|
clean_text = re.sub(r"</blockquote>", "", clean_text)
|
||||||
|
|
||||||
# change special characters back to normal ones
|
# change special characters back to normal ones
|
||||||
clean_text = clean_text.replace(r'<', '<')
|
clean_text = clean_text.replace(r"<", "<")
|
||||||
clean_text = clean_text.replace(r'>', '>')
|
clean_text = clean_text.replace(r">", ">")
|
||||||
clean_text = clean_text.replace(r'"', '"')
|
clean_text = clean_text.replace(r""", '"')
|
||||||
clean_text = clean_text.replace(r'&nbsp;', ' ')
|
clean_text = clean_text.replace(r"&nbsp;", " ")
|
||||||
clean_text = clean_text.replace(r'&', '&')
|
clean_text = clean_text.replace(r"&", "&")
|
||||||
|
|
||||||
# remove multiple spaces
|
# remove multiple spaces
|
||||||
while ' ' in clean_text:
|
while " " in clean_text:
|
||||||
clean_text = clean_text.replace(' ', ' ')
|
clean_text = clean_text.replace(" ", " ")
|
||||||
|
|
||||||
return clean_text.strip()
|
return clean_text.strip()
|
||||||
|
|
||||||
|
|
||||||
def _write_training_article(article_id, clean_text, training_output):
|
def _write_training_article(article_id, clean_text, training_output):
|
||||||
file_loc = training_output / str(article_id) + ".txt"
|
file_loc = training_output / str(article_id) + ".txt"
|
||||||
with open(file_loc, mode='w', encoding='utf8') as outputfile:
|
with open(file_loc, mode="w", encoding="utf8") as outputfile:
|
||||||
outputfile.write(clean_text)
|
outputfile.write(clean_text)
|
||||||
|
|
||||||
|
|
||||||
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
|
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
|
||||||
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n")
|
outputfile.write(
|
||||||
|
article_id
|
||||||
|
+ "|"
|
||||||
|
+ alias
|
||||||
|
+ "|"
|
||||||
|
+ entity
|
||||||
|
+ "|"
|
||||||
|
+ str(start)
|
||||||
|
+ "|"
|
||||||
|
+ str(end)
|
||||||
|
+ "\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def is_dev(article_id):
|
def is_dev(article_id):
|
||||||
return article_id.endswith("3")
|
return article_id.endswith("3")
|
||||||
|
|
||||||
|
|
||||||
def read_training(nlp, training_dir, dev, limit):
|
def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||||
# This method provides training examples that correspond to the entity annotations found by the nlp object
|
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||||
|
When kb is provided, it will include also negative training examples by using the candidate generator.
|
||||||
|
When kb=None, it will only include positive training examples."""
|
||||||
entityfile_loc = training_dir / ENTITY_FILE
|
entityfile_loc = training_dir / ENTITY_FILE
|
||||||
data = []
|
data = []
|
||||||
|
|
||||||
|
@ -296,24 +340,34 @@ def read_training(nlp, training_dir, dev, limit):
|
||||||
skip_articles = set()
|
skip_articles = set()
|
||||||
total_entities = 0
|
total_entities = 0
|
||||||
|
|
||||||
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
with open(entityfile_loc, mode="r", encoding="utf8") as file:
|
||||||
for line in file:
|
for line in file:
|
||||||
if not limit or len(data) < limit:
|
if not limit or len(data) < limit:
|
||||||
fields = line.replace('\n', "").split(sep='|')
|
fields = line.replace("\n", "").split(sep="|")
|
||||||
article_id = fields[0]
|
article_id = fields[0]
|
||||||
alias = fields[1]
|
alias = fields[1]
|
||||||
wp_title = fields[2]
|
wd_id = fields[2]
|
||||||
start = fields[3]
|
start = fields[3]
|
||||||
end = fields[4]
|
end = fields[4]
|
||||||
|
|
||||||
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
|
if (
|
||||||
|
dev == is_dev(article_id)
|
||||||
|
and article_id != "article_id"
|
||||||
|
and article_id not in skip_articles
|
||||||
|
):
|
||||||
if not current_doc or (current_article_id != article_id):
|
if not current_doc or (current_article_id != article_id):
|
||||||
# parse the new article text
|
# parse the new article text
|
||||||
file_name = article_id + ".txt"
|
file_name = article_id + ".txt"
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
|
with open(
|
||||||
|
os.path.join(training_dir, file_name),
|
||||||
|
mode="r",
|
||||||
|
encoding="utf8",
|
||||||
|
) as f:
|
||||||
text = f.read()
|
text = f.read()
|
||||||
if len(text) < 30000: # threshold for convenience / speed of processing
|
if (
|
||||||
|
len(text) < 30000
|
||||||
|
): # threshold for convenience / speed of processing
|
||||||
current_doc = nlp(text)
|
current_doc = nlp(text)
|
||||||
current_article_id = article_id
|
current_article_id = article_id
|
||||||
ents_by_offset = dict()
|
ents_by_offset = dict()
|
||||||
|
@ -321,7 +375,11 @@ def read_training(nlp, training_dir, dev, limit):
|
||||||
sent_length = len(ent.sent)
|
sent_length = len(ent.sent)
|
||||||
# custom filtering to avoid too long or too short sentences
|
# custom filtering to avoid too long or too short sentences
|
||||||
if 5 < sent_length < 100:
|
if 5 < sent_length < 100:
|
||||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
ents_by_offset[
|
||||||
|
str(ent.start_char)
|
||||||
|
+ "_"
|
||||||
|
+ str(ent.end_char)
|
||||||
|
] = ent
|
||||||
else:
|
else:
|
||||||
skip_articles.add(article_id)
|
skip_articles.add(article_id)
|
||||||
current_doc = None
|
current_doc = None
|
||||||
|
@ -332,7 +390,7 @@ def read_training(nlp, training_dir, dev, limit):
|
||||||
|
|
||||||
# repeat checking this condition in case an exception was thrown
|
# repeat checking this condition in case an exception was thrown
|
||||||
if current_doc and (current_article_id == article_id):
|
if current_doc and (current_article_id == article_id):
|
||||||
found_ent = ents_by_offset.get(start + "_" + end, None)
|
found_ent = ents_by_offset.get(start + "_" + end, None)
|
||||||
if found_ent:
|
if found_ent:
|
||||||
if found_ent.text != alias:
|
if found_ent.text != alias:
|
||||||
skip_articles.add(article_id)
|
skip_articles.add(article_id)
|
||||||
|
@ -342,7 +400,26 @@ def read_training(nlp, training_dir, dev, limit):
|
||||||
# currently feeding the gold data one entity per sentence at a time
|
# currently feeding the gold data one entity per sentence at a time
|
||||||
gold_start = int(start) - found_ent.sent.start_char
|
gold_start = int(start) - found_ent.sent.start_char
|
||||||
gold_end = int(end) - found_ent.sent.start_char
|
gold_end = int(end) - found_ent.sent.start_char
|
||||||
gold_entities = [(gold_start, gold_end, wp_title)]
|
|
||||||
|
# add both positive and negative examples (in random order just to be sure)
|
||||||
|
if kb:
|
||||||
|
gold_entities = {}
|
||||||
|
candidate_ids = [
|
||||||
|
c.entity_ for c in kb.get_candidates(alias)
|
||||||
|
]
|
||||||
|
candidate_ids.append(
|
||||||
|
wd_id
|
||||||
|
) # in case the KB doesn't have it
|
||||||
|
random.shuffle(candidate_ids)
|
||||||
|
for kb_id in candidate_ids:
|
||||||
|
entry = (gold_start, gold_end, kb_id)
|
||||||
|
if kb_id != wd_id:
|
||||||
|
gold_entities[entry] = 0.0
|
||||||
|
else:
|
||||||
|
gold_entities[entry] = 1.0
|
||||||
|
else:
|
||||||
|
gold_entities = {}
|
||||||
|
|
||||||
gold = GoldParse(doc=sent, links=gold_entities)
|
gold = GoldParse(doc=sent, links=gold_entities)
|
||||||
data.append((sent, gold))
|
data.append((sent, gold))
|
||||||
total_entities += 1
|
total_entities += 1
|
||||||
|
|
|
@ -203,6 +203,15 @@ cdef class KnowledgeBase:
|
||||||
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
||||||
if entry_index != 0]
|
if entry_index != 0]
|
||||||
|
|
||||||
|
def get_vector(self, unicode entity):
|
||||||
|
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||||
|
|
||||||
|
# Return an empty list if this entity is unknown in this KB
|
||||||
|
if entity_hash not in self._entry_index:
|
||||||
|
return []
|
||||||
|
entry_index = self._entry_index[entity_hash]
|
||||||
|
|
||||||
|
return self._vectors_table[self._entries[entry_index].vector_index]
|
||||||
|
|
||||||
def dump(self, loc):
|
def dump(self, loc):
|
||||||
cdef Writer writer = Writer(loc)
|
cdef Writer writer = Writer(loc)
|
||||||
|
|
|
@ -15,20 +15,25 @@ def nlp():
|
||||||
|
|
||||||
def test_kb_valid_entities(nlp):
|
def test_kb_valid_entities(nlp):
|
||||||
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[8, 4, 3])
|
||||||
mykb.add_entity(entity='Q2', prob=0.5, entity_vector=[2])
|
mykb.add_entity(entity="Q2", prob=0.5, entity_vector=[2, 1, 0])
|
||||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[-1, -6, 5])
|
||||||
|
|
||||||
# adding aliases
|
# adding aliases
|
||||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
|
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
|
||||||
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
|
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||||
|
|
||||||
# test the size of the corresponding KB
|
# test the size of the corresponding KB
|
||||||
assert(mykb.get_size_entities() == 3)
|
assert mykb.get_size_entities() == 3
|
||||||
assert(mykb.get_size_aliases() == 2)
|
assert mykb.get_size_aliases() == 2
|
||||||
|
|
||||||
|
# test retrieval of the entity vectors
|
||||||
|
assert mykb.get_vector("Q1") == [8, 4, 3]
|
||||||
|
assert mykb.get_vector("Q2") == [2, 1, 0]
|
||||||
|
assert mykb.get_vector("Q3") == [-1, -6, 5]
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_entities(nlp):
|
def test_kb_invalid_entities(nlp):
|
||||||
|
@ -36,13 +41,15 @@ def test_kb_invalid_entities(nlp):
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
|
||||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
|
||||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
|
||||||
|
|
||||||
# adding aliases - should fail because one of the given IDs is not valid
|
# adding aliases - should fail because one of the given IDs is not valid
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q342'], probabilities=[0.8, 0.2])
|
mykb.add_alias(
|
||||||
|
alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_probabilities(nlp):
|
def test_kb_invalid_probabilities(nlp):
|
||||||
|
@ -50,13 +57,13 @@ def test_kb_invalid_probabilities(nlp):
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
|
||||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
|
||||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
|
||||||
|
|
||||||
# adding aliases - should fail because the sum of the probabilities exceeds 1
|
# adding aliases - should fail because the sum of the probabilities exceeds 1
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.4])
|
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_combination(nlp):
|
def test_kb_invalid_combination(nlp):
|
||||||
|
@ -64,13 +71,15 @@ def test_kb_invalid_combination(nlp):
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
|
||||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
|
||||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
|
||||||
|
|
||||||
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
|
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.3, 0.4, 0.1])
|
mykb.add_alias(
|
||||||
|
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_entity_vector(nlp):
|
def test_kb_invalid_entity_vector(nlp):
|
||||||
|
@ -78,11 +87,11 @@ def test_kb_invalid_entity_vector(nlp):
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1, 2, 3])
|
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1, 2, 3])
|
||||||
|
|
||||||
# this should fail because the kb's expected entity vector length is 3
|
# this should fail because the kb's expected entity vector length is 3
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
|
||||||
|
|
||||||
|
|
||||||
def test_candidate_generation(nlp):
|
def test_candidate_generation(nlp):
|
||||||
|
@ -90,18 +99,18 @@ def test_candidate_generation(nlp):
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
|
||||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
|
||||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
|
||||||
|
|
||||||
# adding aliases
|
# adding aliases
|
||||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
|
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
|
||||||
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
|
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||||
|
|
||||||
# test the size of the relevant candidates
|
# test the size of the relevant candidates
|
||||||
assert(len(mykb.get_candidates('douglas')) == 2)
|
assert len(mykb.get_candidates("douglas")) == 2
|
||||||
assert(len(mykb.get_candidates('adam')) == 1)
|
assert len(mykb.get_candidates("adam")) == 1
|
||||||
assert(len(mykb.get_candidates('shrubbery')) == 0)
|
assert len(mykb.get_candidates("shrubbery")) == 0
|
||||||
|
|
||||||
|
|
||||||
def test_preserving_links_asdoc(nlp):
|
def test_preserving_links_asdoc(nlp):
|
||||||
|
@ -109,24 +118,26 @@ def test_preserving_links_asdoc(nlp):
|
||||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
|
||||||
mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1])
|
mykb.add_entity(entity="Q2", prob=0.8, entity_vector=[1])
|
||||||
|
|
||||||
# adding aliases
|
# adding aliases
|
||||||
mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7])
|
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
|
||||||
mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6])
|
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
|
||||||
|
|
||||||
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
|
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
|
||||||
sentencizer = nlp.create_pipe("sentencizer")
|
sentencizer = nlp.create_pipe("sentencizer")
|
||||||
nlp.add_pipe(sentencizer)
|
nlp.add_pipe(sentencizer)
|
||||||
|
|
||||||
ruler = EntityRuler(nlp)
|
ruler = EntityRuler(nlp)
|
||||||
patterns = [{"label": "GPE", "pattern": "Boston"},
|
patterns = [
|
||||||
{"label": "GPE", "pattern": "Denver"}]
|
{"label": "GPE", "pattern": "Boston"},
|
||||||
|
{"label": "GPE", "pattern": "Denver"},
|
||||||
|
]
|
||||||
ruler.add_patterns(patterns)
|
ruler.add_patterns(patterns)
|
||||||
nlp.add_pipe(ruler)
|
nlp.add_pipe(ruler)
|
||||||
|
|
||||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64})
|
el_pipe = nlp.create_pipe(name="entity_linker", config={"context_width": 64})
|
||||||
el_pipe.set_kb(mykb)
|
el_pipe.set_kb(mykb)
|
||||||
el_pipe.begin_training()
|
el_pipe.begin_training()
|
||||||
el_pipe.context_weight = 0
|
el_pipe.context_weight = 0
|
||||||
|
|
Loading…
Reference in New Issue
Block a user