get vector functionality + unit test

This commit is contained in:
svlandeg 2019-07-17 12:17:02 +02:00
parent a63d15a142
commit 4086c6ff60
3 changed files with 201 additions and 104 deletions

View File

@ -2,6 +2,7 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import os import os
import random
import re import re
import bz2 import bz2
import datetime import datetime
@ -27,21 +28,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
Read the XML wikipedia data to parse out training data: Read the XML wikipedia data to parse out training data:
raw text data + positive instances raw text data + positive instances
""" """
title_regex = re.compile(r'(?<=<title>).*(?=</title>)') title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
id_regex = re.compile(r'(?<=<id>)\d*(?=</id>)') id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
read_ids = set() read_ids = set()
entityfile_loc = training_output / ENTITY_FILE entityfile_loc = training_output / ENTITY_FILE
with open(entityfile_loc, mode="w", encoding='utf8') as entityfile: with open(entityfile_loc, mode="w", encoding="utf8") as entityfile:
# write entity training header file # write entity training header file
_write_training_entity(outputfile=entityfile, _write_training_entity(
article_id="article_id", outputfile=entityfile,
alias="alias", article_id="article_id",
entity="WD_id", alias="alias",
start="start", entity="WD_id",
end="end") start="start",
end="end",
)
with bz2.open(wikipedia_input, mode='rb') as file: with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline() line = file.readline()
cnt = 0 cnt = 0
article_text = "" article_text = ""
@ -51,7 +54,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
reading_revision = False reading_revision = False
while line and (not limit or cnt < limit): while line and (not limit or cnt < limit):
if cnt % 1000000 == 0: if cnt % 1000000 == 0:
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump") print(
datetime.datetime.now(),
"processed",
cnt,
"lines of Wikipedia dump",
)
clean_line = line.strip().decode("utf-8") clean_line = line.strip().decode("utf-8")
if clean_line == "<revision>": if clean_line == "<revision>":
@ -69,12 +77,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
elif clean_line == "</page>": elif clean_line == "</page>":
if article_id: if article_id:
try: try:
_process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(), _process_wp_text(
training_output) wp_to_id,
entityfile,
article_id,
article_title,
article_text.strip(),
training_output,
)
except Exception as e: except Exception as e:
print("Error processing article", article_id, article_title, e) print(
"Error processing article", article_id, article_title, e
)
else: else:
print("Done processing a page, but couldn't find an article_id ?", article_title) print(
"Done processing a page, but couldn't find an article_id ?",
article_title,
)
article_text = "" article_text = ""
article_title = None article_title = None
article_id = None article_id = None
@ -98,7 +117,9 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
if ids: if ids:
article_id = ids[0] article_id = ids[0]
if article_id in read_ids: if article_id in read_ids:
print("Found duplicate article ID", article_id, clean_line) # This should never happen ... print(
"Found duplicate article ID", article_id, clean_line
) # This should never happen ...
read_ids.add(article_id) read_ids.add(article_id)
# read the title of this article (outside the revision portion of the document) # read the title of this article (outside the revision portion of the document)
@ -111,10 +132,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
cnt += 1 cnt += 1
text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)') text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output): def _process_wp_text(
wp_to_id, entityfile, article_id, article_title, article_text, training_output
):
found_entities = False found_entities = False
# ignore meta Wikipedia pages # ignore meta Wikipedia pages
@ -141,11 +164,11 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
entity_buffer = "" entity_buffer = ""
mention_buffer = "" mention_buffer = ""
for index, letter in enumerate(clean_text): for index, letter in enumerate(clean_text):
if letter == '[': if letter == "[":
open_read += 1 open_read += 1
elif letter == ']': elif letter == "]":
open_read -= 1 open_read -= 1
elif letter == '|': elif letter == "|":
if reading_text: if reading_text:
final_text += letter final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern # switch from reading entity to mention in the [[entity|mention]] pattern
@ -163,7 +186,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
elif reading_text: elif reading_text:
final_text += letter final_text += letter
else: else:
raise ValueError("Not sure at point", clean_text[index-2:index+2]) raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
if open_read > 2: if open_read > 2:
reading_special_case = True reading_special_case = True
@ -175,7 +198,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
# we just finished reading an entity # we just finished reading an entity
if open_read == 0 and not reading_text: if open_read == 0 and not reading_text:
if '#' in entity_buffer or entity_buffer.startswith(':'): if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True reading_special_case = True
# Ignore cases with nested structures like File: handles etc # Ignore cases with nested structures like File: handles etc
if not reading_special_case: if not reading_special_case:
@ -185,12 +208,14 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
end = start + len(mention_buffer) end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None) qid = wp_to_id.get(entity_buffer, None)
if qid: if qid:
_write_training_entity(outputfile=entityfile, _write_training_entity(
article_id=article_id, outputfile=entityfile,
alias=mention_buffer, article_id=article_id,
entity=qid, alias=mention_buffer,
start=start, entity=qid,
end=end) start=start,
end=end,
)
found_entities = True found_entities = True
final_text += mention_buffer final_text += mention_buffer
@ -203,29 +228,35 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
reading_special_case = False reading_special_case = False
if found_entities: if found_entities:
_write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output) _write_training_article(
article_id=article_id,
clean_text=final_text,
training_output=training_output,
)
info_regex = re.compile(r'{[^{]*?}') info_regex = re.compile(r"{[^{]*?}")
htlm_regex = re.compile(r'&lt;!--[^-]*--&gt;') htlm_regex = re.compile(r"&lt;!--[^-]*--&gt;")
category_regex = re.compile(r'\[\[Category:[^\[]*]]') category_regex = re.compile(r"\[\[Category:[^\[]*]]")
file_regex = re.compile(r'\[\[File:[^[\]]+]]') file_regex = re.compile(r"\[\[File:[^[\]]+]]")
ref_regex = re.compile(r'&lt;ref.*?&gt;') # non-greedy ref_regex = re.compile(r"&lt;ref.*?&gt;") # non-greedy
ref_2_regex = re.compile(r'&lt;/ref.*?&gt;') # non-greedy ref_2_regex = re.compile(r"&lt;/ref.*?&gt;") # non-greedy
def _get_clean_wp_text(article_text): def _get_clean_wp_text(article_text):
clean_text = article_text.strip() clean_text = article_text.strip()
# remove bolding & italic markup # remove bolding & italic markup
clean_text = clean_text.replace('\'\'\'', '') clean_text = clean_text.replace("'''", "")
clean_text = clean_text.replace('\'\'', '') clean_text = clean_text.replace("''", "")
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating # remove nested {{info}} statements by removing the inner/smallest ones first and iterating
try_again = True try_again = True
previous_length = len(clean_text) previous_length = len(clean_text)
while try_again: while try_again:
clean_text = info_regex.sub('', clean_text) # non-greedy match excluding a nested { clean_text = info_regex.sub(
"", clean_text
) # non-greedy match excluding a nested {
if len(clean_text) < previous_length: if len(clean_text) < previous_length:
try_again = True try_again = True
else: else:
@ -233,14 +264,14 @@ def _get_clean_wp_text(article_text):
previous_length = len(clean_text) previous_length = len(clean_text)
# remove HTML comments # remove HTML comments
clean_text = htlm_regex.sub('', clean_text) clean_text = htlm_regex.sub("", clean_text)
# remove Category and File statements # remove Category and File statements
clean_text = category_regex.sub('', clean_text) clean_text = category_regex.sub("", clean_text)
clean_text = file_regex.sub('', clean_text) clean_text = file_regex.sub("", clean_text)
# remove multiple = # remove multiple =
while '==' in clean_text: while "==" in clean_text:
clean_text = clean_text.replace("==", "=") clean_text = clean_text.replace("==", "=")
clean_text = clean_text.replace(". =", ".") clean_text = clean_text.replace(". =", ".")
@ -249,43 +280,56 @@ def _get_clean_wp_text(article_text):
clean_text = clean_text.replace(" =", "") clean_text = clean_text.replace(" =", "")
# remove refs (non-greedy match) # remove refs (non-greedy match)
clean_text = ref_regex.sub('', clean_text) clean_text = ref_regex.sub("", clean_text)
clean_text = ref_2_regex.sub('', clean_text) clean_text = ref_2_regex.sub("", clean_text)
# remove additional wikiformatting # remove additional wikiformatting
clean_text = re.sub(r'&lt;blockquote&gt;', '', clean_text) clean_text = re.sub(r"&lt;blockquote&gt;", "", clean_text)
clean_text = re.sub(r'&lt;/blockquote&gt;', '', clean_text) clean_text = re.sub(r"&lt;/blockquote&gt;", "", clean_text)
# change special characters back to normal ones # change special characters back to normal ones
clean_text = clean_text.replace(r'&lt;', '<') clean_text = clean_text.replace(r"&lt;", "<")
clean_text = clean_text.replace(r'&gt;', '>') clean_text = clean_text.replace(r"&gt;", ">")
clean_text = clean_text.replace(r'&quot;', '"') clean_text = clean_text.replace(r"&quot;", '"')
clean_text = clean_text.replace(r'&amp;nbsp;', ' ') clean_text = clean_text.replace(r"&amp;nbsp;", " ")
clean_text = clean_text.replace(r'&amp;', '&') clean_text = clean_text.replace(r"&amp;", "&")
# remove multiple spaces # remove multiple spaces
while ' ' in clean_text: while " " in clean_text:
clean_text = clean_text.replace(' ', ' ') clean_text = clean_text.replace(" ", " ")
return clean_text.strip() return clean_text.strip()
def _write_training_article(article_id, clean_text, training_output): def _write_training_article(article_id, clean_text, training_output):
file_loc = training_output / str(article_id) + ".txt" file_loc = training_output / str(article_id) + ".txt"
with open(file_loc, mode='w', encoding='utf8') as outputfile: with open(file_loc, mode="w", encoding="utf8") as outputfile:
outputfile.write(clean_text) outputfile.write(clean_text)
def _write_training_entity(outputfile, article_id, alias, entity, start, end): def _write_training_entity(outputfile, article_id, alias, entity, start, end):
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n") outputfile.write(
article_id
+ "|"
+ alias
+ "|"
+ entity
+ "|"
+ str(start)
+ "|"
+ str(end)
+ "\n"
)
def is_dev(article_id): def is_dev(article_id):
return article_id.endswith("3") return article_id.endswith("3")
def read_training(nlp, training_dir, dev, limit): def read_training(nlp, training_dir, dev, limit, kb=None):
# This method provides training examples that correspond to the entity annotations found by the nlp object """ This method provides training examples that correspond to the entity annotations found by the nlp object.
When kb is provided, it will include also negative training examples by using the candidate generator.
When kb=None, it will only include positive training examples."""
entityfile_loc = training_dir / ENTITY_FILE entityfile_loc = training_dir / ENTITY_FILE
data = [] data = []
@ -296,24 +340,34 @@ def read_training(nlp, training_dir, dev, limit):
skip_articles = set() skip_articles = set()
total_entities = 0 total_entities = 0
with open(entityfile_loc, mode='r', encoding='utf8') as file: with open(entityfile_loc, mode="r", encoding="utf8") as file:
for line in file: for line in file:
if not limit or len(data) < limit: if not limit or len(data) < limit:
fields = line.replace('\n', "").split(sep='|') fields = line.replace("\n", "").split(sep="|")
article_id = fields[0] article_id = fields[0]
alias = fields[1] alias = fields[1]
wp_title = fields[2] wd_id = fields[2]
start = fields[3] start = fields[3]
end = fields[4] end = fields[4]
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles: if (
dev == is_dev(article_id)
and article_id != "article_id"
and article_id not in skip_articles
):
if not current_doc or (current_article_id != article_id): if not current_doc or (current_article_id != article_id):
# parse the new article text # parse the new article text
file_name = article_id + ".txt" file_name = article_id + ".txt"
try: try:
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f: with open(
os.path.join(training_dir, file_name),
mode="r",
encoding="utf8",
) as f:
text = f.read() text = f.read()
if len(text) < 30000: # threshold for convenience / speed of processing if (
len(text) < 30000
): # threshold for convenience / speed of processing
current_doc = nlp(text) current_doc = nlp(text)
current_article_id = article_id current_article_id = article_id
ents_by_offset = dict() ents_by_offset = dict()
@ -321,7 +375,11 @@ def read_training(nlp, training_dir, dev, limit):
sent_length = len(ent.sent) sent_length = len(ent.sent)
# custom filtering to avoid too long or too short sentences # custom filtering to avoid too long or too short sentences
if 5 < sent_length < 100: if 5 < sent_length < 100:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent ents_by_offset[
str(ent.start_char)
+ "_"
+ str(ent.end_char)
] = ent
else: else:
skip_articles.add(article_id) skip_articles.add(article_id)
current_doc = None current_doc = None
@ -332,7 +390,7 @@ def read_training(nlp, training_dir, dev, limit):
# repeat checking this condition in case an exception was thrown # repeat checking this condition in case an exception was thrown
if current_doc and (current_article_id == article_id): if current_doc and (current_article_id == article_id):
found_ent = ents_by_offset.get(start + "_" + end, None) found_ent = ents_by_offset.get(start + "_" + end, None)
if found_ent: if found_ent:
if found_ent.text != alias: if found_ent.text != alias:
skip_articles.add(article_id) skip_articles.add(article_id)
@ -342,7 +400,26 @@ def read_training(nlp, training_dir, dev, limit):
# currently feeding the gold data one entity per sentence at a time # currently feeding the gold data one entity per sentence at a time
gold_start = int(start) - found_ent.sent.start_char gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char gold_end = int(end) - found_ent.sent.start_char
gold_entities = [(gold_start, gold_end, wp_title)]
# add both positive and negative examples (in random order just to be sure)
if kb:
gold_entities = {}
candidate_ids = [
c.entity_ for c in kb.get_candidates(alias)
]
candidate_ids.append(
wd_id
) # in case the KB doesn't have it
random.shuffle(candidate_ids)
for kb_id in candidate_ids:
entry = (gold_start, gold_end, kb_id)
if kb_id != wd_id:
gold_entities[entry] = 0.0
else:
gold_entities[entry] = 1.0
else:
gold_entities = {}
gold = GoldParse(doc=sent, links=gold_entities) gold = GoldParse(doc=sent, links=gold_entities)
data.append((sent, gold)) data.append((sent, gold))
total_entities += 1 total_entities += 1

View File

@ -203,6 +203,15 @@ cdef class KnowledgeBase:
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs) for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
if entry_index != 0] if entry_index != 0]
def get_vector(self, unicode entity):
cdef hash_t entity_hash = self.vocab.strings.add(entity)
# Return an empty list if this entity is unknown in this KB
if entity_hash not in self._entry_index:
return []
entry_index = self._entry_index[entity_hash]
return self._vectors_table[self._entries[entry_index].vector_index]
def dump(self, loc): def dump(self, loc):
cdef Writer writer = Writer(loc) cdef Writer writer = Writer(loc)

View File

@ -15,20 +15,25 @@ def nlp():
def test_kb_valid_entities(nlp): def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases""" """Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[8, 4, 3])
mykb.add_entity(entity='Q2', prob=0.5, entity_vector=[2]) mykb.add_entity(entity="Q2", prob=0.5, entity_vector=[2, 1, 0])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[-1, -6, 5])
# adding aliases # adding aliases
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2]) mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9]) mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the corresponding KB # test the size of the corresponding KB
assert(mykb.get_size_entities() == 3) assert mykb.get_size_entities() == 3
assert(mykb.get_size_aliases() == 2) assert mykb.get_size_aliases() == 2
# test retrieval of the entity vectors
assert mykb.get_vector("Q1") == [8, 4, 3]
assert mykb.get_vector("Q2") == [2, 1, 0]
assert mykb.get_vector("Q3") == [-1, -6, 5]
def test_kb_invalid_entities(nlp): def test_kb_invalid_entities(nlp):
@ -36,13 +41,15 @@ def test_kb_invalid_entities(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
# adding aliases - should fail because one of the given IDs is not valid # adding aliases - should fail because one of the given IDs is not valid
with pytest.raises(ValueError): with pytest.raises(ValueError):
mykb.add_alias(alias='douglas', entities=['Q2', 'Q342'], probabilities=[0.8, 0.2]) mykb.add_alias(
alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]
)
def test_kb_invalid_probabilities(nlp): def test_kb_invalid_probabilities(nlp):
@ -50,13 +57,13 @@ def test_kb_invalid_probabilities(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
# adding aliases - should fail because the sum of the probabilities exceeds 1 # adding aliases - should fail because the sum of the probabilities exceeds 1
with pytest.raises(ValueError): with pytest.raises(ValueError):
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.4]) mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
def test_kb_invalid_combination(nlp): def test_kb_invalid_combination(nlp):
@ -64,13 +71,15 @@ def test_kb_invalid_combination(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
# adding aliases - should fail because the entities and probabilities vectors are not of equal length # adding aliases - should fail because the entities and probabilities vectors are not of equal length
with pytest.raises(ValueError): with pytest.raises(ValueError):
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.3, 0.4, 0.1]) mykb.add_alias(
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1]
)
def test_kb_invalid_entity_vector(nlp): def test_kb_invalid_entity_vector(nlp):
@ -78,11 +87,11 @@ def test_kb_invalid_entity_vector(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1, 2, 3]) mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1, 2, 3])
# this should fail because the kb's expected entity vector length is 3 # this should fail because the kb's expected entity vector length is 3
with pytest.raises(ValueError): with pytest.raises(ValueError):
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
def test_candidate_generation(nlp): def test_candidate_generation(nlp):
@ -90,18 +99,18 @@ def test_candidate_generation(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
# adding aliases # adding aliases
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2]) mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9]) mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates # test the size of the relevant candidates
assert(len(mykb.get_candidates('douglas')) == 2) assert len(mykb.get_candidates("douglas")) == 2
assert(len(mykb.get_candidates('adam')) == 1) assert len(mykb.get_candidates("adam")) == 1
assert(len(mykb.get_candidates('shrubbery')) == 0) assert len(mykb.get_candidates("shrubbery")) == 0
def test_preserving_links_asdoc(nlp): def test_preserving_links_asdoc(nlp):
@ -109,24 +118,26 @@ def test_preserving_links_asdoc(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities # adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1]) mykb.add_entity(entity="Q2", prob=0.8, entity_vector=[1])
# adding aliases # adding aliases
mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7]) mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6]) mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained) # set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
sentencizer = nlp.create_pipe("sentencizer") sentencizer = nlp.create_pipe("sentencizer")
nlp.add_pipe(sentencizer) nlp.add_pipe(sentencizer)
ruler = EntityRuler(nlp) ruler = EntityRuler(nlp)
patterns = [{"label": "GPE", "pattern": "Boston"}, patterns = [
{"label": "GPE", "pattern": "Denver"}] {"label": "GPE", "pattern": "Boston"},
{"label": "GPE", "pattern": "Denver"},
]
ruler.add_patterns(patterns) ruler.add_patterns(patterns)
nlp.add_pipe(ruler) nlp.add_pipe(ruler)
el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64}) el_pipe = nlp.create_pipe(name="entity_linker", config={"context_width": 64})
el_pipe.set_kb(mykb) el_pipe.set_kb(mykb)
el_pipe.begin_training() el_pipe.begin_training()
el_pipe.context_weight = 0 el_pipe.context_weight = 0