get vector functionality + unit test

This commit is contained in:
svlandeg 2019-07-17 12:17:02 +02:00
parent a63d15a142
commit 4086c6ff60
3 changed files with 201 additions and 104 deletions

View File

@ -2,6 +2,7 @@
from __future__ import unicode_literals
import os
import random
import re
import bz2
import datetime
@ -27,21 +28,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
Read the XML wikipedia data to parse out training data:
raw text data + positive instances
"""
title_regex = re.compile(r'(?<=<title>).*(?=</title>)')
id_regex = re.compile(r'(?<=<id>)\d*(?=</id>)')
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
read_ids = set()
entityfile_loc = training_output / ENTITY_FILE
with open(entityfile_loc, mode="w", encoding='utf8') as entityfile:
with open(entityfile_loc, mode="w", encoding="utf8") as entityfile:
# write entity training header file
_write_training_entity(outputfile=entityfile,
article_id="article_id",
alias="alias",
entity="WD_id",
start="start",
end="end")
_write_training_entity(
outputfile=entityfile,
article_id="article_id",
alias="alias",
entity="WD_id",
start="start",
end="end",
)
with bz2.open(wikipedia_input, mode='rb') as file:
with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline()
cnt = 0
article_text = ""
@ -51,7 +54,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
reading_revision = False
while line and (not limit or cnt < limit):
if cnt % 1000000 == 0:
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
print(
datetime.datetime.now(),
"processed",
cnt,
"lines of Wikipedia dump",
)
clean_line = line.strip().decode("utf-8")
if clean_line == "<revision>":
@ -69,12 +77,23 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
elif clean_line == "</page>":
if article_id:
try:
_process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(),
training_output)
_process_wp_text(
wp_to_id,
entityfile,
article_id,
article_title,
article_text.strip(),
training_output,
)
except Exception as e:
print("Error processing article", article_id, article_title, e)
print(
"Error processing article", article_id, article_title, e
)
else:
print("Done processing a page, but couldn't find an article_id ?", article_title)
print(
"Done processing a page, but couldn't find an article_id ?",
article_title,
)
article_text = ""
article_title = None
article_id = None
@ -98,7 +117,9 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
if ids:
article_id = ids[0]
if article_id in read_ids:
print("Found duplicate article ID", article_id, clean_line) # This should never happen ...
print(
"Found duplicate article ID", article_id, clean_line
) # This should never happen ...
read_ids.add(article_id)
# read the title of this article (outside the revision portion of the document)
@ -111,10 +132,12 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
cnt += 1
text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)')
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output):
def _process_wp_text(
wp_to_id, entityfile, article_id, article_title, article_text, training_output
):
found_entities = False
# ignore meta Wikipedia pages
@ -141,11 +164,11 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
entity_buffer = ""
mention_buffer = ""
for index, letter in enumerate(clean_text):
if letter == '[':
if letter == "[":
open_read += 1
elif letter == ']':
elif letter == "]":
open_read -= 1
elif letter == '|':
elif letter == "|":
if reading_text:
final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern
@ -163,7 +186,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
elif reading_text:
final_text += letter
else:
raise ValueError("Not sure at point", clean_text[index-2:index+2])
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
if open_read > 2:
reading_special_case = True
@ -175,7 +198,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
# we just finished reading an entity
if open_read == 0 and not reading_text:
if '#' in entity_buffer or entity_buffer.startswith(':'):
if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
@ -185,12 +208,14 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
_write_training_entity(outputfile=entityfile,
article_id=article_id,
alias=mention_buffer,
entity=qid,
start=start,
end=end)
_write_training_entity(
outputfile=entityfile,
article_id=article_id,
alias=mention_buffer,
entity=qid,
start=start,
end=end,
)
found_entities = True
final_text += mention_buffer
@ -203,29 +228,35 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te
reading_special_case = False
if found_entities:
_write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output)
_write_training_article(
article_id=article_id,
clean_text=final_text,
training_output=training_output,
)
info_regex = re.compile(r'{[^{]*?}')
htlm_regex = re.compile(r'&lt;!--[^-]*--&gt;')
category_regex = re.compile(r'\[\[Category:[^\[]*]]')
file_regex = re.compile(r'\[\[File:[^[\]]+]]')
ref_regex = re.compile(r'&lt;ref.*?&gt;') # non-greedy
ref_2_regex = re.compile(r'&lt;/ref.*?&gt;') # non-greedy
info_regex = re.compile(r"{[^{]*?}")
htlm_regex = re.compile(r"&lt;!--[^-]*--&gt;")
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
file_regex = re.compile(r"\[\[File:[^[\]]+]]")
ref_regex = re.compile(r"&lt;ref.*?&gt;") # non-greedy
ref_2_regex = re.compile(r"&lt;/ref.*?&gt;") # non-greedy
def _get_clean_wp_text(article_text):
clean_text = article_text.strip()
# remove bolding & italic markup
clean_text = clean_text.replace('\'\'\'', '')
clean_text = clean_text.replace('\'\'', '')
clean_text = clean_text.replace("'''", "")
clean_text = clean_text.replace("''", "")
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
try_again = True
previous_length = len(clean_text)
while try_again:
clean_text = info_regex.sub('', clean_text) # non-greedy match excluding a nested {
clean_text = info_regex.sub(
"", clean_text
) # non-greedy match excluding a nested {
if len(clean_text) < previous_length:
try_again = True
else:
@ -233,14 +264,14 @@ def _get_clean_wp_text(article_text):
previous_length = len(clean_text)
# remove HTML comments
clean_text = htlm_regex.sub('', clean_text)
clean_text = htlm_regex.sub("", clean_text)
# remove Category and File statements
clean_text = category_regex.sub('', clean_text)
clean_text = file_regex.sub('', clean_text)
clean_text = category_regex.sub("", clean_text)
clean_text = file_regex.sub("", clean_text)
# remove multiple =
while '==' in clean_text:
while "==" in clean_text:
clean_text = clean_text.replace("==", "=")
clean_text = clean_text.replace(". =", ".")
@ -249,43 +280,56 @@ def _get_clean_wp_text(article_text):
clean_text = clean_text.replace(" =", "")
# remove refs (non-greedy match)
clean_text = ref_regex.sub('', clean_text)
clean_text = ref_2_regex.sub('', clean_text)
clean_text = ref_regex.sub("", clean_text)
clean_text = ref_2_regex.sub("", clean_text)
# remove additional wikiformatting
clean_text = re.sub(r'&lt;blockquote&gt;', '', clean_text)
clean_text = re.sub(r'&lt;/blockquote&gt;', '', clean_text)
clean_text = re.sub(r"&lt;blockquote&gt;", "", clean_text)
clean_text = re.sub(r"&lt;/blockquote&gt;", "", clean_text)
# change special characters back to normal ones
clean_text = clean_text.replace(r'&lt;', '<')
clean_text = clean_text.replace(r'&gt;', '>')
clean_text = clean_text.replace(r'&quot;', '"')
clean_text = clean_text.replace(r'&amp;nbsp;', ' ')
clean_text = clean_text.replace(r'&amp;', '&')
clean_text = clean_text.replace(r"&lt;", "<")
clean_text = clean_text.replace(r"&gt;", ">")
clean_text = clean_text.replace(r"&quot;", '"')
clean_text = clean_text.replace(r"&amp;nbsp;", " ")
clean_text = clean_text.replace(r"&amp;", "&")
# remove multiple spaces
while ' ' in clean_text:
clean_text = clean_text.replace(' ', ' ')
while " " in clean_text:
clean_text = clean_text.replace(" ", " ")
return clean_text.strip()
def _write_training_article(article_id, clean_text, training_output):
file_loc = training_output / str(article_id) + ".txt"
with open(file_loc, mode='w', encoding='utf8') as outputfile:
with open(file_loc, mode="w", encoding="utf8") as outputfile:
outputfile.write(clean_text)
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n")
outputfile.write(
article_id
+ "|"
+ alias
+ "|"
+ entity
+ "|"
+ str(start)
+ "|"
+ str(end)
+ "\n"
)
def is_dev(article_id):
return article_id.endswith("3")
def read_training(nlp, training_dir, dev, limit):
# This method provides training examples that correspond to the entity annotations found by the nlp object
def read_training(nlp, training_dir, dev, limit, kb=None):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
When kb is provided, it will include also negative training examples by using the candidate generator.
When kb=None, it will only include positive training examples."""
entityfile_loc = training_dir / ENTITY_FILE
data = []
@ -296,24 +340,34 @@ def read_training(nlp, training_dir, dev, limit):
skip_articles = set()
total_entities = 0
with open(entityfile_loc, mode='r', encoding='utf8') as file:
with open(entityfile_loc, mode="r", encoding="utf8") as file:
for line in file:
if not limit or len(data) < limit:
fields = line.replace('\n', "").split(sep='|')
fields = line.replace("\n", "").split(sep="|")
article_id = fields[0]
alias = fields[1]
wp_title = fields[2]
wd_id = fields[2]
start = fields[3]
end = fields[4]
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
if (
dev == is_dev(article_id)
and article_id != "article_id"
and article_id not in skip_articles
):
if not current_doc or (current_article_id != article_id):
# parse the new article text
file_name = article_id + ".txt"
try:
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
with open(
os.path.join(training_dir, file_name),
mode="r",
encoding="utf8",
) as f:
text = f.read()
if len(text) < 30000: # threshold for convenience / speed of processing
if (
len(text) < 30000
): # threshold for convenience / speed of processing
current_doc = nlp(text)
current_article_id = article_id
ents_by_offset = dict()
@ -321,7 +375,11 @@ def read_training(nlp, training_dir, dev, limit):
sent_length = len(ent.sent)
# custom filtering to avoid too long or too short sentences
if 5 < sent_length < 100:
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
ents_by_offset[
str(ent.start_char)
+ "_"
+ str(ent.end_char)
] = ent
else:
skip_articles.add(article_id)
current_doc = None
@ -332,7 +390,7 @@ def read_training(nlp, training_dir, dev, limit):
# repeat checking this condition in case an exception was thrown
if current_doc and (current_article_id == article_id):
found_ent = ents_by_offset.get(start + "_" + end, None)
found_ent = ents_by_offset.get(start + "_" + end, None)
if found_ent:
if found_ent.text != alias:
skip_articles.add(article_id)
@ -342,7 +400,26 @@ def read_training(nlp, training_dir, dev, limit):
# currently feeding the gold data one entity per sentence at a time
gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char
gold_entities = [(gold_start, gold_end, wp_title)]
# add both positive and negative examples (in random order just to be sure)
if kb:
gold_entities = {}
candidate_ids = [
c.entity_ for c in kb.get_candidates(alias)
]
candidate_ids.append(
wd_id
) # in case the KB doesn't have it
random.shuffle(candidate_ids)
for kb_id in candidate_ids:
entry = (gold_start, gold_end, kb_id)
if kb_id != wd_id:
gold_entities[entry] = 0.0
else:
gold_entities[entry] = 1.0
else:
gold_entities = {}
gold = GoldParse(doc=sent, links=gold_entities)
data.append((sent, gold))
total_entities += 1

View File

@ -203,6 +203,15 @@ cdef class KnowledgeBase:
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
if entry_index != 0]
def get_vector(self, unicode entity):
cdef hash_t entity_hash = self.vocab.strings.add(entity)
# Return an empty list if this entity is unknown in this KB
if entity_hash not in self._entry_index:
return []
entry_index = self._entry_index[entity_hash]
return self._vectors_table[self._entries[entry_index].vector_index]
def dump(self, loc):
cdef Writer writer = Writer(loc)

View File

@ -15,20 +15,25 @@ def nlp():
def test_kb_valid_entities(nlp):
"""Test the valid construction of a KB with 3 entities and two aliases"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.5, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[8, 4, 3])
mykb.add_entity(entity="Q2", prob=0.5, entity_vector=[2, 1, 0])
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[-1, -6, 5])
# adding aliases
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the corresponding KB
assert(mykb.get_size_entities() == 3)
assert(mykb.get_size_aliases() == 2)
assert mykb.get_size_entities() == 3
assert mykb.get_size_aliases() == 2
# test retrieval of the entity vectors
assert mykb.get_vector("Q1") == [8, 4, 3]
assert mykb.get_vector("Q2") == [2, 1, 0]
assert mykb.get_vector("Q3") == [-1, -6, 5]
def test_kb_invalid_entities(nlp):
@ -36,13 +41,15 @@ def test_kb_invalid_entities(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
# adding aliases - should fail because one of the given IDs is not valid
with pytest.raises(ValueError):
mykb.add_alias(alias='douglas', entities=['Q2', 'Q342'], probabilities=[0.8, 0.2])
mykb.add_alias(
alias="douglas", entities=["Q2", "Q342"], probabilities=[0.8, 0.2]
)
def test_kb_invalid_probabilities(nlp):
@ -50,13 +57,13 @@ def test_kb_invalid_probabilities(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
# adding aliases - should fail because the sum of the probabilities exceeds 1
with pytest.raises(ValueError):
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.4])
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.4])
def test_kb_invalid_combination(nlp):
@ -64,13 +71,15 @@ def test_kb_invalid_combination(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
with pytest.raises(ValueError):
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.3, 0.4, 0.1])
mykb.add_alias(
alias="douglas", entities=["Q2", "Q3"], probabilities=[0.3, 0.4, 0.1]
)
def test_kb_invalid_entity_vector(nlp):
@ -78,11 +87,11 @@ def test_kb_invalid_entity_vector(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1, 2, 3])
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1, 2, 3])
# this should fail because the kb's expected entity vector length is 3
with pytest.raises(ValueError):
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
def test_candidate_generation(nlp):
@ -90,18 +99,18 @@ def test_candidate_generation(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
mykb.add_entity(entity="Q2", prob=0.2, entity_vector=[2])
mykb.add_entity(entity="Q3", prob=0.5, entity_vector=[3])
# adding aliases
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2])
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
assert(len(mykb.get_candidates('douglas')) == 2)
assert(len(mykb.get_candidates('adam')) == 1)
assert(len(mykb.get_candidates('shrubbery')) == 0)
assert len(mykb.get_candidates("douglas")) == 2
assert len(mykb.get_candidates("adam")) == 1
assert len(mykb.get_candidates("shrubbery")) == 0
def test_preserving_links_asdoc(nlp):
@ -109,24 +118,26 @@ def test_preserving_links_asdoc(nlp):
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1])
mykb.add_entity(entity="Q1", prob=0.9, entity_vector=[1])
mykb.add_entity(entity="Q2", prob=0.8, entity_vector=[1])
# adding aliases
mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7])
mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6])
mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7])
mykb.add_alias(alias="Denver", entities=["Q2"], probabilities=[0.6])
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
sentencizer = nlp.create_pipe("sentencizer")
nlp.add_pipe(sentencizer)
ruler = EntityRuler(nlp)
patterns = [{"label": "GPE", "pattern": "Boston"},
{"label": "GPE", "pattern": "Denver"}]
patterns = [
{"label": "GPE", "pattern": "Boston"},
{"label": "GPE", "pattern": "Denver"},
]
ruler.add_patterns(patterns)
nlp.add_pipe(ruler)
el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64})
el_pipe = nlp.create_pipe(name="entity_linker", config={"context_width": 64})
el_pipe.set_kb(mykb)
el_pipe.begin_training()
el_pipe.context_weight = 0