mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-03 10:55:52 +03:00
get vector functionality + unit test
This commit is contained in:
parent
a63d15a142
commit
4086c6ff60
|
@ -2,6 +2,7 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import bz2
|
||||
import datetime
|
||||
|
@ -27,21 +28,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
Read the XML wikipedia data to parse out training data:
|
||||
raw text data + positive instances
|
||||
"""
|
||||
title_regex = re.compile(r'(?<=<title>).*(?=</title>)')
|
||||
id_regex = re.compile(r'(?<=<id>)\d*(?=</id>)')
|
||||
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
|
||||
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
||||
|
||||
read_ids = set()
|
||||
entityfile_loc = training_output / ENTITY_FILE
|
||||
with open(entityfile_loc, mode="w", encoding='utf8') as entityfile:
|
||||
with open(entityfile_loc, mode="w", encoding="utf8") as entityfile:
|
||||
# write entity training header file
|
||||
_write_training_entity(outputfile=entityfile,
|
||||
article_id="article_id",
|
||||
alias="alias",
|
||||
entity="WD_id",
|
||||
start="start",
|
||||
end="end")
|
||||
_write_training_entity(
|
||||
outputfile=entityfile,
|
||||
article_id="article_id",
|
||||
alias="alias",
|
||||
entity="WD_id",
|
||||
start="start",
|
||||
end="end",
|
||||
)
|
||||
|
||||
with bz2.open(wikipedia_input, mode='rb') as file:
|
||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||
line = file.readline()
|
||||
cnt = 0
|
||||
article_text = ""
|
||||
|
@ -51,7 +54,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
reading_revision = False
|
||||
while line and (not limit or cnt < limit):
|
||||
if cnt % 1000000 == 0:
|
||||
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
|
||||
print(
|
||||
datetime.datetime.now(),
|
||||
"processed",
|
||||
cnt,
|
||||
"lines of Wikipedia dump",
|
||||
)
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
|
||||
if clean_line == "<revision>":
|
||||
|
@ -69,12 +77,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
elif clean_line == "</page>":
|
||||
if article_id:
|
||||
try:
|
||||
_process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(),
|
||||
training_output)
|
||||
_process_wp_text(
|
||||
wp_to_id,
|
||||
entityfile,
|
||||
article_id,
|
||||
article_title,
|
||||
article_text.strip(),
|
||||
training_output,
|
||||
)
|
||||
except Exception as e:
|
||||
print("Error processing article", article_id, article_title, e)
|
||||
print(
|
||||
"Error processing article", article_id, article_title, e
|
||||
)
|
||||
else:
|
||||
print("Done processing a page, but couldn't find an article_id ?", article_title)
|
||||
print(
|
||||
"Done processing a page, but couldn't find an article_id ?",
|
||||
article_title,
|
||||
)
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
|
@ -98,7 +117,9 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
if ids:
|
||||
article_id = ids[0]
|
||||
if article_id in read_ids:
|
||||
print("Found duplicate article ID", article_id, clean_line) # This should never happen ...
|
||||
print(
|
||||
"Found duplicate article ID", article_id, clean_line
|
||||
) # This should never happen ...
|
||||
read_ids.add(article_id)
|
||||
|
||||
# read the title of this article (outside the revision portion of the document)
|
||||
|
@ -111,10 +132,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
cnt += 1
|
||||
|
||||
|
||||
text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)')
|
||||
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
||||
|
||||
|
||||
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output):
|
||||
def _process_wp_text(
|
||||
wp_to_id, entityfile, article_id, article_title, article_text, training_output
|
||||
):
|
||||
found_entities = False
|
||||
|
||||
# ignore meta Wikipedia pages
|
||||
|
@ -141,11 +164,11 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
|||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
for index, letter in enumerate(clean_text):
|
||||
if letter == '[':
|
||||
if letter == "[":
|
||||
open_read += 1
|
||||
elif letter == ']':
|
||||
elif letter == "]":
|
||||
open_read -= 1
|
||||
elif letter == '|':
|
||||
elif letter == "|":
|
||||
if reading_text:
|
||||
final_text += letter
|
||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||
|
@ -163,7 +186,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
|||
elif reading_text:
|
||||
final_text += letter
|
||||
else:
|
||||
raise ValueError("Not sure at point", clean_text[index-2:index+2])
|
||||
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
|
||||
|
||||
if open_read > 2:
|
||||
reading_special_case = True
|
||||
|
@ -175,7 +198,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
|||
|
||||
# we just finished reading an entity
|
||||
if open_read == 0 and not reading_text:
|
||||
if '#' in entity_buffer or entity_buffer.startswith(':'):
|
||||
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
||||
reading_special_case = True
|
||||
# Ignore cases with nested structures like File: handles etc
|
||||
if not reading_special_case:
|
||||
|
@ -185,12 +208,14 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
|||
end = start + len(mention_buffer)
|
||||
qid = wp_to_id.get(entity_buffer, None)
|
||||
if qid:
|
||||
_write_training_entity(outputfile=entityfile,
|
||||
article_id=article_id,
|
||||
alias=mention_buffer,
|
||||
entity=qid,
|
||||
start=start,
|
||||
end=end)
|
||||
_write_training_entity(
|
||||
outputfile=entityfile,
|
||||
article_id=article_id,
|
||||
alias=mention_buffer,
|
||||
entity=qid,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
found_entities = True
|
||||
final_text += mention_buffer
|
||||
|
||||
|
@ -203,29 +228,35 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
|
|||
reading_special_case = False
|
||||
|
||||
if found_entities:
|
||||
_write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output)
|
||||
_write_training_article(
|
||||
article_id=article_id,
|
||||
clean_text=final_text,
|
||||
training_output=training_output,
|
||||
)
|
||||
|
||||
|
||||
info_regex = re.compile(r'{[^{]*?}')
|
||||
htlm_regex = re.compile(r'<!--[^-]*-->')
|
||||
category_regex = re.compile(r'\[\[Category:[^\[]*]]')
|
||||
file_regex = re.compile(r'\[\[File:[^[\]]+]]')
|
||||
ref_regex = re.compile(r'<ref.*?>') # non-greedy
|
||||
ref_2_regex = re.compile(r'</ref.*?>') # non-greedy
|
||||
info_regex = re.compile(r"{[^{]*?}")
|
||||
htlm_regex = re.compile(r"<!--[^-]*-->")
|
||||
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
|
||||
file_regex = re.compile(r"\[\[File:[^[\]]+]]")
|
||||
ref_regex = re.compile(r"<ref.*?>") # non-greedy
|
||||
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
||||
|
||||
|
||||
def _get_clean_wp_text(article_text):
|
||||
clean_text = article_text.strip()
|
||||
|
||||
# remove bolding & italic markup
|
||||
clean_text = clean_text.replace('\'\'\'', '')
|
||||
clean_text = clean_text.replace('\'\'', '')
|
||||
clean_text = clean_text.replace("'''", "")
|
||||
clean_text = clean_text.replace("''", "")
|
||||
|
||||
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
||||
try_again = True
|
||||
previous_length = len(clean_text)
|
||||
while try_again:
|
||||
clean_text = info_regex.sub('', clean_text) # non-greedy match excluding a nested {
|
||||
clean_text = info_regex.sub(
|
||||
"", clean_text
|
||||
) # non-greedy match excluding a nested {
|
||||
if len(clean_text) < previous_length:
|
||||
try_again = True
|
||||
else:
|
||||
|
@ -233,14 +264,14 @@ def _get_clean_wp_text(article_text):
|
|||
previous_length = len(clean_text)
|
||||
|
||||
# remove HTML comments
|
||||
clean_text = htlm_regex.sub('', clean_text)
|
||||
clean_text = htlm_regex.sub("", clean_text)
|
||||
|
||||
# remove Category and File statements
|
||||
clean_text = category_regex.sub('', clean_text)
|
||||
clean_text = file_regex.sub('', clean_text)
|
||||
clean_text = category_regex.sub("", clean_text)
|
||||
clean_text = file_regex.sub("", clean_text)
|
||||
|
||||
# remove multiple =
|
||||
while '==' in clean_text:
|
||||
while "==" in clean_text:
|
||||
clean_text = clean_text.replace("==", "=")
|
||||
|
||||
clean_text = clean_text.replace(". =", ".")
|
||||
|
@ -249,43 +280,56 @@ def _get_clean_wp_text(article_text):
|
|||
clean_text = clean_text.replace(" =", "")
|
||||
|
||||
# remove refs (non-greedy match)
|
||||
clean_text = ref_regex.sub('', clean_text)
|
||||
clean_text = ref_2_regex.sub('', clean_text)
|
||||
clean_text = ref_regex.sub("", clean_text)
|
||||
clean_text = ref_2_regex.sub("", clean_text)
|
||||
|
||||
# remove additional wikiformatting
|
||||
clean_text = re.sub(r'<blockquote>', '', clean_text)
|
||||
clean_text = re.sub(r'</blockquote>', '', clean_text)
|
||||
clean_text = re.sub(r"<blockquote>", "", clean_text)
|
||||
clean_text = re.sub(r"</blockquote>", "", clean_text)
|
||||
|
||||
# change special characters back to normal ones
|
||||
clean_text = clean_text.replace(r'<', '<')
|
||||
clean_text = clean_text.replace(r'>', '>')
|
||||
clean_text = clean_text.replace(r'"', '"')
|
||||
clean_text = clean_text.replace(r'&nbsp;', ' ')
|
||||
clean_text = clean_text.replace(r'&', '&')
|
||||
clean_text = clean_text.replace(r"<", "<")
|
||||
clean_text = clean_text.replace(r">", ">")
|
||||
clean_text = clean_text.replace(r""", '"')
|
||||
clean_text = clean_text.replace(r"&nbsp;", " ")
|
||||
clean_text = clean_text.replace(r"&", "&")
|
||||
|
||||
# remove multiple spaces
|
||||
while ' ' in clean_text:
|
||||
clean_text = clean_text.replace(' ', ' ')
|
||||
while " " in clean_text:
|
||||
clean_text = clean_text.replace(" ", " ")
|
||||
|
||||
return clean_text.strip()
|
||||
|
||||
|
||||
def _write_training_article(article_id, clean_text, training_output):
|
||||
file_loc = training_output / str(article_id) + ".txt"
|
||||
with open(file_loc, mode='w', encoding='utf8') as outputfile:
|
||||
with open(file_loc, mode="w", encoding="utf8") as outputfile:
|
||||
outputfile.write(clean_text)
|
||||
|
||||
|
||||
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
|
||||
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n")
|
||||
outputfile.write(
|
||||
article_id
|
||||
+ "|"
|
||||
+ alias
|
||||
+ "|"
|
||||
+ entity
|
||||
+ "|"
|
||||
+ str(start)
|
||||
+ "|"
|
||||
+ str(end)
|
||||
+ "\n"
|
||||
)
|
||||
|
||||
|
||||
def is_dev(article_id):
|
||||
return article_id.endswith("3")
|
||||
|
||||
|
||||
def read_training(nlp, training_dir, dev, limit):
|
||||
# This method provides training examples that correspond to the entity annotations found by the nlp object
|
||||
def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||
When kb is provided, it will include also negative training examples by using the candidate generator.
|
||||
When kb=None, it will only include positive training examples."""
|
||||
entityfile_loc = training_dir / ENTITY_FILE
|
||||
data = []
|
||||
|
||||
|
@ -296,24 +340,34 @@ def read_training(nlp, training_dir, dev, limit):
|
|||
skip_articles = set()
|
||||
total_entities = 0
|
||||
|
||||
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
||||
with open(entityfile_loc, mode="r", encoding="utf8") as file:
|
||||
for line in file:
|
||||
if not limit or len(data) < limit:
|
||||
fields = line.replace('\n', "").split(sep='|')
|
||||
fields = line.replace("\n", "").split(sep="|")
|
||||
article_id = fields[0]
|
||||
alias = fields[1]
|
||||
wp_title = fields[2]
|
||||
wd_id = fields[2]
|
||||
start = fields[3]
|
||||
end = fields[4]
|
||||
|
||||
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
|
||||
if (
|
||||
dev == is_dev(article_id)
|
||||
and article_id != "article_id"
|
||||
and article_id not in skip_articles
|
||||
):
|
||||
if not current_doc or (current_article_id != article_id):
|
||||
# parse the new article text
|
||||
file_name = article_id + ".txt"
|
||||
try:
|
||||
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
|
||||
with open(
|
||||
os.path.join(training_dir, file_name),
|
||||
mode="r",
|
||||
encoding="utf8",
|
||||
) as f:
|
||||
text = f.read()
|
||||
if len(text) < 30000: # threshold for convenience / speed of processing
|
||||
if (
|
||||
len(text) < 30000
|
||||
): # threshold for convenience / speed of processing
|
||||
current_doc = nlp(text)
|
||||
current_article_id = article_id
|
||||
ents_by_offset = dict()
|
||||
|
@ -321,7 +375,11 @@ def read_training(nlp, training_dir, dev, limit):
|
|||
sent_length = len(ent.sent)
|
||||
# custom filtering to avoid too long or too short sentences
|
||||
if 5 < sent_length < 100:
|
||||
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
||||
ents_by_offset[
|
||||
str(ent.start_char)
|
||||
+ "_"
|
||||
+ str(ent.end_char)
|
||||
] = ent
|
||||
else:
|
||||
skip_articles.add(article_id)
|
||||
current_doc = None
|
||||
|
@ -332,7 +390,7 @@ def read_training(nlp, training_dir, dev, limit):
|
|||
|
||||
# repeat checking this condition in case an exception was thrown
|
||||
if current_doc and (current_article_id == article_id):
|
||||
found_ent = ents_by_offset.get(start + "_" + end, None)
|
||||
found_ent = ents_by_offset.get(start + "_" + end, None)
|
||||
if found_ent:
|
||||
if found_ent.text != alias:
|
||||
skip_articles.add(article_id)
|
||||
|
@ -342,7 +400,26 @@ def read_training(nlp, training_dir, dev, limit):
|
|||
# currently feeding the gold data one entity per sentence at a time
|
||||
gold_start = int(start) - found_ent.sent.start_char
|
||||
gold_end = int(end) - found_ent.sent.start_char
|
||||
gold_entities = [(gold_start, gold_end, wp_title)]
|
||||
|
||||
# add both positive and negative examples (in random order just to be sure)
|
||||
if kb:
|
||||
gold_entities = {}
|
||||
candidate_ids = [
|
||||
c.entity_ for c in kb.get_candidates(alias)
|
||||
]
|
||||
candidate_ids.append(
|
||||
wd_id
|
||||
) # in case the KB doesn't have it
|
||||
random.shuffle(candidate_ids)
|
||||
for kb_id in candidate_ids:
|
||||
entry = (gold_start, gold_end, kb_id)
|
||||
if kb_id != wd_id:
|
||||
gold_entities[entry] = 0.0
|
||||
else:
|
||||
gold_entities[entry] = 1.0
|
||||
else:
|
||||
gold_entities = {}
|
||||
|
||||
gold = GoldParse(doc=sent, links=gold_entities)
|
||||
data.append((sent, gold))
|
||||
total_entities += 1
|
||||
|
|
|
@ -203,6 +203,15 @@ cdef class KnowledgeBase:
|
|||
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
||||
if entry_index != 0]
|
||||
|
||||
def get_vector(self, unicode entity):
|
||||
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||
|
||||
# Return an empty list if this entity is unknown in this KB
|
||||
if entity_hash not in self._entry_index:
|
||||
return []
|
||||
entry_index = self._entry_index[entity_hash]
|
||||
|
||||
return self._vectors_table[self._entries[entry_index].vector_index]
|
||||
|
||||
def dump(self, loc):
|
||||
cdef Writer writer = Writer(loc)
|
||||
|
|
|
@ -15,20 +15,25 @@ def nlp():
|
|||
|
||||
def test_kb_valid_entities(nlp):
|
||||
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity='Q2', prob=0.5, entity_vector=[2])
|
||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[8, 4, 3])
|
||||
mykb.add_entity(entity="Q2", prob=0.5, entity_vector=[2, 1, 0])
|
||||
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[-1, -6, 5])
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
|
||||
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
|
||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
|
||||
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||
|
||||
# test the size of the corresponding KB
|
||||
assert(mykb.get_size_entities() == 3)
|
||||
assert(mykb.get_size_aliases() == 2)
|
||||
assert mykb.get_size_entities() == 3
|
||||
assert mykb.get_size_aliases() == 2
|
||||
|
||||
# test retrieval of the entity vectors
|
||||
assert mykb.get_vector("Q1") == [8, 4, 3]
|
||||
assert mykb.get_vector("Q2") == [2, 1, 0]
|
||||
assert mykb.get_vector("Q3") == [-1, -6, 5]
|
||||
|
||||
|
||||
def test_kb_invalid_entities(nlp):
|
||||
|
@ -36,13 +41,15 @@ def test_kb_invalid_entities(nlp):
|
|||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
|
||||
|
||||
# adding aliases - should fail because one of the given IDs is not valid
|
||||
with pytest.raises(ValueError):
|
||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q342'], probabilities=[0.8, 0.2])
|
||||
mykb.add_alias(
|
||||
alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]
|
||||
)
|
||||
|
||||
|
||||
def test_kb_invalid_probabilities(nlp):
|
||||
|
@ -50,13 +57,13 @@ def test_kb_invalid_probabilities(nlp):
|
|||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
|
||||
|
||||
# adding aliases - should fail because the sum of the probabilities exceeds 1
|
||||
with pytest.raises(ValueError):
|
||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.4])
|
||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
|
||||
|
||||
|
||||
def test_kb_invalid_combination(nlp):
|
||||
|
@ -64,13 +71,15 @@ def test_kb_invalid_combination(nlp):
|
|||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
|
||||
|
||||
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
|
||||
with pytest.raises(ValueError):
|
||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.3, 0.4, 0.1])
|
||||
mykb.add_alias(
|
||||
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1]
|
||||
)
|
||||
|
||||
|
||||
def test_kb_invalid_entity_vector(nlp):
|
||||
|
@ -78,11 +87,11 @@ def test_kb_invalid_entity_vector(nlp):
|
|||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1, 2, 3])
|
||||
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1, 2, 3])
|
||||
|
||||
# this should fail because the kb's expected entity vector length is 3
|
||||
with pytest.raises(ValueError):
|
||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
|
||||
|
||||
|
||||
def test_candidate_generation(nlp):
|
||||
|
@ -90,18 +99,18 @@ def test_candidate_generation(nlp):
|
|||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
|
||||
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
|
||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
|
||||
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||
|
||||
# test the size of the relevant candidates
|
||||
assert(len(mykb.get_candidates('douglas')) == 2)
|
||||
assert(len(mykb.get_candidates('adam')) == 1)
|
||||
assert(len(mykb.get_candidates('shrubbery')) == 0)
|
||||
assert len(mykb.get_candidates("douglas")) == 2
|
||||
assert len(mykb.get_candidates("adam")) == 1
|
||||
assert len(mykb.get_candidates("shrubbery")) == 0
|
||||
|
||||
|
||||
def test_preserving_links_asdoc(nlp):
|
||||
|
@ -109,24 +118,26 @@ def test_preserving_links_asdoc(nlp):
|
|||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", prob=0.8, entity_vector=[1])
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7])
|
||||
mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6])
|
||||
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
|
||||
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
|
||||
|
||||
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
|
||||
sentencizer = nlp.create_pipe("sentencizer")
|
||||
nlp.add_pipe(sentencizer)
|
||||
|
||||
ruler = EntityRuler(nlp)
|
||||
patterns = [{"label": "GPE", "pattern": "Boston"},
|
||||
{"label": "GPE", "pattern": "Denver"}]
|
||||
patterns = [
|
||||
{"label": "GPE", "pattern": "Boston"},
|
||||
{"label": "GPE", "pattern": "Denver"},
|
||||
]
|
||||
ruler.add_patterns(patterns)
|
||||
nlp.add_pipe(ruler)
|
||||
|
||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64})
|
||||
el_pipe = nlp.create_pipe(name="entity_linker", config={"context_width": 64})
|
||||
el_pipe.set_kb(mykb)
|
||||
el_pipe.begin_training()
|
||||
el_pipe.context_weight = 0
|
||||
|
|
Loading…
Reference in New Issue
Block a user