mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
Merge pull request #3864 from svlandeg/feature/nel-wiki
Entity linking using Wikipedia & Wikidata
This commit is contained in:
commit
6ba5ddbd5f
0
bin/__init__.py
Normal file
0
bin/__init__.py
Normal file
|
@ -292,8 +292,8 @@ def evaluate(gold_ud, system_ud, deprel_weights=None, check_parse=True):
|
||||||
|
|
||||||
def spans_score(gold_spans, system_spans):
|
def spans_score(gold_spans, system_spans):
|
||||||
correct, gi, si = 0, 0, 0
|
correct, gi, si = 0, 0, 0
|
||||||
undersegmented = list()
|
undersegmented = []
|
||||||
oversegmented = list()
|
oversegmented = []
|
||||||
combo = 0
|
combo = 0
|
||||||
previous_end_si_earlier = False
|
previous_end_si_earlier = False
|
||||||
previous_end_gi_earlier = False
|
previous_end_gi_earlier = False
|
||||||
|
|
0
bin/wiki_entity_linking/__init__.py
Normal file
0
bin/wiki_entity_linking/__init__.py
Normal file
171
bin/wiki_entity_linking/kb_creator.py
Normal file
171
bin/wiki_entity_linking/kb_creator.py
Normal file
|
@ -0,0 +1,171 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from .train_descriptions import EntityEncoder
|
||||||
|
from . import wikidata_processor as wd, wikipedia_processor as wp
|
||||||
|
from spacy.kb import KnowledgeBase
|
||||||
|
|
||||||
|
import csv
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
INPUT_DIM = 300 # dimension of pre-trained input vectors
|
||||||
|
DESC_WIDTH = 64 # dimension of output entity vectors
|
||||||
|
|
||||||
|
|
||||||
|
def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
||||||
|
entity_def_output, entity_descr_output,
|
||||||
|
count_input, prior_prob_input, wikidata_input):
|
||||||
|
# Create the knowledge base from Wikidata entries
|
||||||
|
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
|
||||||
|
|
||||||
|
# disable this part of the pipeline when rerunning the KB generation from preprocessed files
|
||||||
|
read_raw_data = True
|
||||||
|
|
||||||
|
if read_raw_data:
|
||||||
|
print()
|
||||||
|
print(" * _read_wikidata_entities", datetime.datetime.now())
|
||||||
|
title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
|
||||||
|
|
||||||
|
# write the title-ID and ID-description mappings to file
|
||||||
|
_write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# read the mappings from file
|
||||||
|
title_to_id = get_entity_to_id(entity_def_output)
|
||||||
|
id_to_descr = get_id_to_description(entity_descr_output)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(" * _get_entity_frequencies", datetime.datetime.now())
|
||||||
|
print()
|
||||||
|
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
|
||||||
|
|
||||||
|
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
|
||||||
|
filtered_title_to_id = dict()
|
||||||
|
entity_list = []
|
||||||
|
description_list = []
|
||||||
|
frequency_list = []
|
||||||
|
for title, entity in title_to_id.items():
|
||||||
|
freq = entity_frequencies.get(title, 0)
|
||||||
|
desc = id_to_descr.get(entity, None)
|
||||||
|
if desc and freq > min_entity_freq:
|
||||||
|
entity_list.append(entity)
|
||||||
|
description_list.append(desc)
|
||||||
|
frequency_list.append(freq)
|
||||||
|
filtered_title_to_id[title] = entity
|
||||||
|
|
||||||
|
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()),
|
||||||
|
"titles with filter frequency", min_entity_freq)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(" * train entity encoder", datetime.datetime.now())
|
||||||
|
print()
|
||||||
|
encoder = EntityEncoder(nlp, INPUT_DIM, DESC_WIDTH)
|
||||||
|
encoder.train(description_list=description_list, to_print=True)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(" * get entity embeddings", datetime.datetime.now())
|
||||||
|
print()
|
||||||
|
embeddings = encoder.apply_encoder(description_list)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(" * adding", len(entity_list), "entities", datetime.datetime.now())
|
||||||
|
kb.set_entities(entity_list=entity_list, prob_list=frequency_list, vector_list=embeddings)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print(" * adding aliases", datetime.datetime.now())
|
||||||
|
print()
|
||||||
|
_add_aliases(kb, title_to_id=filtered_title_to_id,
|
||||||
|
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
|
||||||
|
prior_prob_input=prior_prob_input)
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
|
||||||
|
|
||||||
|
print("done with kb", datetime.datetime.now())
|
||||||
|
return kb
|
||||||
|
|
||||||
|
|
||||||
|
def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr):
|
||||||
|
with open(entity_def_output, mode='w', encoding='utf8') as id_file:
|
||||||
|
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||||
|
for title, qid in title_to_id.items():
|
||||||
|
id_file.write(title + "|" + str(qid) + "\n")
|
||||||
|
|
||||||
|
with open(entity_descr_output, mode='w', encoding='utf8') as descr_file:
|
||||||
|
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||||
|
for qid, descr in id_to_descr.items():
|
||||||
|
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def get_entity_to_id(entity_def_output):
|
||||||
|
entity_to_id = dict()
|
||||||
|
with open(entity_def_output, 'r', encoding='utf8') as csvfile:
|
||||||
|
csvreader = csv.reader(csvfile, delimiter='|')
|
||||||
|
# skip header
|
||||||
|
next(csvreader)
|
||||||
|
for row in csvreader:
|
||||||
|
entity_to_id[row[0]] = row[1]
|
||||||
|
return entity_to_id
|
||||||
|
|
||||||
|
|
||||||
|
def get_id_to_description(entity_descr_output):
|
||||||
|
id_to_desc = dict()
|
||||||
|
with open(entity_descr_output, 'r', encoding='utf8') as csvfile:
|
||||||
|
csvreader = csv.reader(csvfile, delimiter='|')
|
||||||
|
# skip header
|
||||||
|
next(csvreader)
|
||||||
|
for row in csvreader:
|
||||||
|
id_to_desc[row[0]] = row[1]
|
||||||
|
return id_to_desc
|
||||||
|
|
||||||
|
|
||||||
|
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
|
||||||
|
wp_titles = title_to_id.keys()
|
||||||
|
|
||||||
|
# adding aliases with prior probabilities
|
||||||
|
# we can read this file sequentially, it's sorted by alias, and then by count
|
||||||
|
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
|
||||||
|
# skip header
|
||||||
|
prior_file.readline()
|
||||||
|
line = prior_file.readline()
|
||||||
|
previous_alias = None
|
||||||
|
total_count = 0
|
||||||
|
counts = []
|
||||||
|
entities = []
|
||||||
|
while line:
|
||||||
|
splits = line.replace('\n', "").split(sep='|')
|
||||||
|
new_alias = splits[0]
|
||||||
|
count = int(splits[1])
|
||||||
|
entity = splits[2]
|
||||||
|
|
||||||
|
if new_alias != previous_alias and previous_alias:
|
||||||
|
# done reading the previous alias --> output
|
||||||
|
if len(entities) > 0:
|
||||||
|
selected_entities = []
|
||||||
|
prior_probs = []
|
||||||
|
for ent_count, ent_string in zip(counts, entities):
|
||||||
|
if ent_string in wp_titles:
|
||||||
|
wd_id = title_to_id[ent_string]
|
||||||
|
p_entity_givenalias = ent_count / total_count
|
||||||
|
selected_entities.append(wd_id)
|
||||||
|
prior_probs.append(p_entity_givenalias)
|
||||||
|
|
||||||
|
if selected_entities:
|
||||||
|
try:
|
||||||
|
kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs)
|
||||||
|
except ValueError as e:
|
||||||
|
print(e)
|
||||||
|
total_count = 0
|
||||||
|
counts = []
|
||||||
|
entities = []
|
||||||
|
|
||||||
|
total_count += count
|
||||||
|
|
||||||
|
if len(entities) < max_entities_per_alias and count >= min_occ:
|
||||||
|
counts.append(count)
|
||||||
|
entities.append(entity)
|
||||||
|
previous_alias = new_alias
|
||||||
|
|
||||||
|
line = prior_file.readline()
|
||||||
|
|
121
bin/wiki_entity_linking/train_descriptions.py
Normal file
121
bin/wiki_entity_linking/train_descriptions.py
Normal file
|
@ -0,0 +1,121 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from random import shuffle
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from spacy._ml import zero_init, create_default_optimizer
|
||||||
|
from spacy.cli.pretrain import get_cossim_loss
|
||||||
|
|
||||||
|
from thinc.v2v import Model
|
||||||
|
from thinc.api import chain
|
||||||
|
from thinc.neural._classes.affine import Affine
|
||||||
|
|
||||||
|
|
||||||
|
class EntityEncoder:
|
||||||
|
"""
|
||||||
|
Train the embeddings of entity descriptions to fit a fixed-size entity vector (e.g. 64D).
|
||||||
|
This entity vector will be stored in the KB, for further downstream use in the entity model.
|
||||||
|
"""
|
||||||
|
|
||||||
|
DROP = 0
|
||||||
|
EPOCHS = 5
|
||||||
|
STOP_THRESHOLD = 0.04
|
||||||
|
|
||||||
|
BATCH_SIZE = 1000
|
||||||
|
|
||||||
|
def __init__(self, nlp, input_dim, desc_width):
|
||||||
|
self.nlp = nlp
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.desc_width = desc_width
|
||||||
|
|
||||||
|
def apply_encoder(self, description_list):
|
||||||
|
if self.encoder is None:
|
||||||
|
raise ValueError("Can not apply encoder before training it")
|
||||||
|
|
||||||
|
batch_size = 100000
|
||||||
|
|
||||||
|
start = 0
|
||||||
|
stop = min(batch_size, len(description_list))
|
||||||
|
encodings = []
|
||||||
|
|
||||||
|
while start < len(description_list):
|
||||||
|
docs = list(self.nlp.pipe(description_list[start:stop]))
|
||||||
|
doc_embeddings = [self._get_doc_embedding(doc) for doc in docs]
|
||||||
|
enc = self.encoder(np.asarray(doc_embeddings))
|
||||||
|
encodings.extend(enc.tolist())
|
||||||
|
|
||||||
|
start = start + batch_size
|
||||||
|
stop = min(stop + batch_size, len(description_list))
|
||||||
|
|
||||||
|
return encodings
|
||||||
|
|
||||||
|
def train(self, description_list, to_print=False):
|
||||||
|
processed, loss = self._train_model(description_list)
|
||||||
|
if to_print:
|
||||||
|
print("Trained on", processed, "entities across", self.EPOCHS, "epochs")
|
||||||
|
print("Final loss:", loss)
|
||||||
|
|
||||||
|
def _train_model(self, description_list):
|
||||||
|
# TODO: when loss gets too low, a 'mean of empty slice' warning is thrown by numpy
|
||||||
|
|
||||||
|
self._build_network(self.input_dim, self.desc_width)
|
||||||
|
|
||||||
|
processed = 0
|
||||||
|
loss = 1
|
||||||
|
descriptions = description_list.copy() # copy this list so that shuffling does not affect other functions
|
||||||
|
|
||||||
|
for i in range(self.EPOCHS):
|
||||||
|
shuffle(descriptions)
|
||||||
|
|
||||||
|
batch_nr = 0
|
||||||
|
start = 0
|
||||||
|
stop = min(self.BATCH_SIZE, len(descriptions))
|
||||||
|
|
||||||
|
while loss > self.STOP_THRESHOLD and start < len(descriptions):
|
||||||
|
batch = []
|
||||||
|
for descr in descriptions[start:stop]:
|
||||||
|
doc = self.nlp(descr)
|
||||||
|
doc_vector = self._get_doc_embedding(doc)
|
||||||
|
batch.append(doc_vector)
|
||||||
|
|
||||||
|
loss = self._update(batch)
|
||||||
|
print(i, batch_nr, loss)
|
||||||
|
processed += len(batch)
|
||||||
|
|
||||||
|
batch_nr += 1
|
||||||
|
start = start + self.BATCH_SIZE
|
||||||
|
stop = min(stop + self.BATCH_SIZE, len(descriptions))
|
||||||
|
|
||||||
|
return processed, loss
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_doc_embedding(doc):
|
||||||
|
indices = np.zeros((len(doc),), dtype="i")
|
||||||
|
for i, word in enumerate(doc):
|
||||||
|
if word.orth in doc.vocab.vectors.key2row:
|
||||||
|
indices[i] = doc.vocab.vectors.key2row[word.orth]
|
||||||
|
else:
|
||||||
|
indices[i] = 0
|
||||||
|
word_vectors = doc.vocab.vectors.data[indices]
|
||||||
|
doc_vector = np.mean(word_vectors, axis=0)
|
||||||
|
return doc_vector
|
||||||
|
|
||||||
|
def _build_network(self, orig_width, hidden_with):
|
||||||
|
with Model.define_operators({">>": chain}):
|
||||||
|
# very simple encoder-decoder model
|
||||||
|
self.encoder = (
|
||||||
|
Affine(hidden_with, orig_width)
|
||||||
|
)
|
||||||
|
self.model = self.encoder >> zero_init(Affine(orig_width, hidden_with, drop_factor=0.0))
|
||||||
|
self.sgd = create_default_optimizer(self.model.ops)
|
||||||
|
|
||||||
|
def _update(self, vectors):
|
||||||
|
predictions, bp_model = self.model.begin_update(np.asarray(vectors), drop=self.DROP)
|
||||||
|
loss, d_scores = self._get_loss(scores=predictions, golds=np.asarray(vectors))
|
||||||
|
bp_model(d_scores, sgd=self.sgd)
|
||||||
|
return loss / len(vectors)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _get_loss(golds, scores):
|
||||||
|
loss, gradients = get_cossim_loss(scores, golds)
|
||||||
|
return loss, gradients
|
353
bin/wiki_entity_linking/training_set_creator.py
Normal file
353
bin/wiki_entity_linking/training_set_creator.py
Normal file
|
@ -0,0 +1,353 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
import bz2
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
from spacy.gold import GoldParse
|
||||||
|
from bin.wiki_entity_linking import kb_creator
|
||||||
|
|
||||||
|
"""
|
||||||
|
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
|
||||||
|
Gold-standard entities are stored in one file in standoff format (by character offset).
|
||||||
|
"""
|
||||||
|
|
||||||
|
ENTITY_FILE = "gold_entities.csv"
|
||||||
|
|
||||||
|
|
||||||
|
def create_training(wikipedia_input, entity_def_input, training_output):
|
||||||
|
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
||||||
|
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None)
|
||||||
|
|
||||||
|
|
||||||
|
def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None):
|
||||||
|
"""
|
||||||
|
Read the XML wikipedia data to parse out training data:
|
||||||
|
raw text data + positive instances
|
||||||
|
"""
|
||||||
|
title_regex = re.compile(r'(?<=<title>).*(?=</title>)')
|
||||||
|
id_regex = re.compile(r'(?<=<id>)\d*(?=</id>)')
|
||||||
|
|
||||||
|
read_ids = set()
|
||||||
|
entityfile_loc = training_output / ENTITY_FILE
|
||||||
|
with open(entityfile_loc, mode="w", encoding='utf8') as entityfile:
|
||||||
|
# write entity training header file
|
||||||
|
_write_training_entity(outputfile=entityfile,
|
||||||
|
article_id="article_id",
|
||||||
|
alias="alias",
|
||||||
|
entity="WD_id",
|
||||||
|
start="start",
|
||||||
|
end="end")
|
||||||
|
|
||||||
|
with bz2.open(wikipedia_input, mode='rb') as file:
|
||||||
|
line = file.readline()
|
||||||
|
cnt = 0
|
||||||
|
article_text = ""
|
||||||
|
article_title = None
|
||||||
|
article_id = None
|
||||||
|
reading_text = False
|
||||||
|
reading_revision = False
|
||||||
|
while line and (not limit or cnt < limit):
|
||||||
|
if cnt % 1000000 == 0:
|
||||||
|
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
|
||||||
|
clean_line = line.strip().decode("utf-8")
|
||||||
|
|
||||||
|
if clean_line == "<revision>":
|
||||||
|
reading_revision = True
|
||||||
|
elif clean_line == "</revision>":
|
||||||
|
reading_revision = False
|
||||||
|
|
||||||
|
# Start reading new page
|
||||||
|
if clean_line == "<page>":
|
||||||
|
article_text = ""
|
||||||
|
article_title = None
|
||||||
|
article_id = None
|
||||||
|
|
||||||
|
# finished reading this page
|
||||||
|
elif clean_line == "</page>":
|
||||||
|
if article_id:
|
||||||
|
try:
|
||||||
|
_process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(),
|
||||||
|
training_output)
|
||||||
|
except Exception as e:
|
||||||
|
print("Error processing article", article_id, article_title, e)
|
||||||
|
else:
|
||||||
|
print("Done processing a page, but couldn't find an article_id ?", article_title)
|
||||||
|
article_text = ""
|
||||||
|
article_title = None
|
||||||
|
article_id = None
|
||||||
|
reading_text = False
|
||||||
|
reading_revision = False
|
||||||
|
|
||||||
|
# start reading text within a page
|
||||||
|
if "<text" in clean_line:
|
||||||
|
reading_text = True
|
||||||
|
|
||||||
|
if reading_text:
|
||||||
|
article_text += " " + clean_line
|
||||||
|
|
||||||
|
# stop reading text within a page (we assume a new page doesn't start on the same line)
|
||||||
|
if "</text" in clean_line:
|
||||||
|
reading_text = False
|
||||||
|
|
||||||
|
# read the ID of this article (outside the revision portion of the document)
|
||||||
|
if not reading_revision:
|
||||||
|
ids = id_regex.search(clean_line)
|
||||||
|
if ids:
|
||||||
|
article_id = ids[0]
|
||||||
|
if article_id in read_ids:
|
||||||
|
print("Found duplicate article ID", article_id, clean_line) # This should never happen ...
|
||||||
|
read_ids.add(article_id)
|
||||||
|
|
||||||
|
# read the title of this article (outside the revision portion of the document)
|
||||||
|
if not reading_revision:
|
||||||
|
titles = title_regex.search(clean_line)
|
||||||
|
if titles:
|
||||||
|
article_title = titles[0].strip()
|
||||||
|
|
||||||
|
line = file.readline()
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
|
||||||
|
text_regex = re.compile(r'(?<=<text xml:space=\"preserve\">).*(?=</text)')
|
||||||
|
|
||||||
|
|
||||||
|
def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text, training_output):
|
||||||
|
found_entities = False
|
||||||
|
|
||||||
|
# ignore meta Wikipedia pages
|
||||||
|
if article_title.startswith("Wikipedia:"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# remove the text tags
|
||||||
|
text = text_regex.search(article_text).group(0)
|
||||||
|
|
||||||
|
# stop processing if this is a redirect page
|
||||||
|
if text.startswith("#REDIRECT"):
|
||||||
|
return
|
||||||
|
|
||||||
|
# get the raw text without markup etc, keeping only interwiki links
|
||||||
|
clean_text = _get_clean_wp_text(text)
|
||||||
|
|
||||||
|
# read the text char by char to get the right offsets for the interwiki links
|
||||||
|
final_text = ""
|
||||||
|
open_read = 0
|
||||||
|
reading_text = True
|
||||||
|
reading_entity = False
|
||||||
|
reading_mention = False
|
||||||
|
reading_special_case = False
|
||||||
|
entity_buffer = ""
|
||||||
|
mention_buffer = ""
|
||||||
|
for index, letter in enumerate(clean_text):
|
||||||
|
if letter == '[':
|
||||||
|
open_read += 1
|
||||||
|
elif letter == ']':
|
||||||
|
open_read -= 1
|
||||||
|
elif letter == '|':
|
||||||
|
if reading_text:
|
||||||
|
final_text += letter
|
||||||
|
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||||
|
elif reading_entity:
|
||||||
|
reading_text = False
|
||||||
|
reading_entity = False
|
||||||
|
reading_mention = True
|
||||||
|
else:
|
||||||
|
reading_special_case = True
|
||||||
|
else:
|
||||||
|
if reading_entity:
|
||||||
|
entity_buffer += letter
|
||||||
|
elif reading_mention:
|
||||||
|
mention_buffer += letter
|
||||||
|
elif reading_text:
|
||||||
|
final_text += letter
|
||||||
|
else:
|
||||||
|
raise ValueError("Not sure at point", clean_text[index-2:index+2])
|
||||||
|
|
||||||
|
if open_read > 2:
|
||||||
|
reading_special_case = True
|
||||||
|
|
||||||
|
if open_read == 2 and reading_text:
|
||||||
|
reading_text = False
|
||||||
|
reading_entity = True
|
||||||
|
reading_mention = False
|
||||||
|
|
||||||
|
# we just finished reading an entity
|
||||||
|
if open_read == 0 and not reading_text:
|
||||||
|
if '#' in entity_buffer or entity_buffer.startswith(':'):
|
||||||
|
reading_special_case = True
|
||||||
|
# Ignore cases with nested structures like File: handles etc
|
||||||
|
if not reading_special_case:
|
||||||
|
if not mention_buffer:
|
||||||
|
mention_buffer = entity_buffer
|
||||||
|
start = len(final_text)
|
||||||
|
end = start + len(mention_buffer)
|
||||||
|
qid = wp_to_id.get(entity_buffer, None)
|
||||||
|
if qid:
|
||||||
|
_write_training_entity(outputfile=entityfile,
|
||||||
|
article_id=article_id,
|
||||||
|
alias=mention_buffer,
|
||||||
|
entity=qid,
|
||||||
|
start=start,
|
||||||
|
end=end)
|
||||||
|
found_entities = True
|
||||||
|
final_text += mention_buffer
|
||||||
|
|
||||||
|
entity_buffer = ""
|
||||||
|
mention_buffer = ""
|
||||||
|
|
||||||
|
reading_text = True
|
||||||
|
reading_entity = False
|
||||||
|
reading_mention = False
|
||||||
|
reading_special_case = False
|
||||||
|
|
||||||
|
if found_entities:
|
||||||
|
_write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output)
|
||||||
|
|
||||||
|
|
||||||
|
info_regex = re.compile(r'{[^{]*?}')
|
||||||
|
htlm_regex = re.compile(r'<!--[^-]*-->')
|
||||||
|
category_regex = re.compile(r'\[\[Category:[^\[]*]]')
|
||||||
|
file_regex = re.compile(r'\[\[File:[^[\]]+]]')
|
||||||
|
ref_regex = re.compile(r'<ref.*?>') # non-greedy
|
||||||
|
ref_2_regex = re.compile(r'</ref.*?>') # non-greedy
|
||||||
|
|
||||||
|
|
||||||
|
def _get_clean_wp_text(article_text):
|
||||||
|
clean_text = article_text.strip()
|
||||||
|
|
||||||
|
# remove bolding & italic markup
|
||||||
|
clean_text = clean_text.replace('\'\'\'', '')
|
||||||
|
clean_text = clean_text.replace('\'\'', '')
|
||||||
|
|
||||||
|
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
||||||
|
try_again = True
|
||||||
|
previous_length = len(clean_text)
|
||||||
|
while try_again:
|
||||||
|
clean_text = info_regex.sub('', clean_text) # non-greedy match excluding a nested {
|
||||||
|
if len(clean_text) < previous_length:
|
||||||
|
try_again = True
|
||||||
|
else:
|
||||||
|
try_again = False
|
||||||
|
previous_length = len(clean_text)
|
||||||
|
|
||||||
|
# remove HTML comments
|
||||||
|
clean_text = htlm_regex.sub('', clean_text)
|
||||||
|
|
||||||
|
# remove Category and File statements
|
||||||
|
clean_text = category_regex.sub('', clean_text)
|
||||||
|
clean_text = file_regex.sub('', clean_text)
|
||||||
|
|
||||||
|
# remove multiple =
|
||||||
|
while '==' in clean_text:
|
||||||
|
clean_text = clean_text.replace("==", "=")
|
||||||
|
|
||||||
|
clean_text = clean_text.replace(". =", ".")
|
||||||
|
clean_text = clean_text.replace(" = ", ". ")
|
||||||
|
clean_text = clean_text.replace("= ", ".")
|
||||||
|
clean_text = clean_text.replace(" =", "")
|
||||||
|
|
||||||
|
# remove refs (non-greedy match)
|
||||||
|
clean_text = ref_regex.sub('', clean_text)
|
||||||
|
clean_text = ref_2_regex.sub('', clean_text)
|
||||||
|
|
||||||
|
# remove additional wikiformatting
|
||||||
|
clean_text = re.sub(r'<blockquote>', '', clean_text)
|
||||||
|
clean_text = re.sub(r'</blockquote>', '', clean_text)
|
||||||
|
|
||||||
|
# change special characters back to normal ones
|
||||||
|
clean_text = clean_text.replace(r'<', '<')
|
||||||
|
clean_text = clean_text.replace(r'>', '>')
|
||||||
|
clean_text = clean_text.replace(r'"', '"')
|
||||||
|
clean_text = clean_text.replace(r'&nbsp;', ' ')
|
||||||
|
clean_text = clean_text.replace(r'&', '&')
|
||||||
|
|
||||||
|
# remove multiple spaces
|
||||||
|
while ' ' in clean_text:
|
||||||
|
clean_text = clean_text.replace(' ', ' ')
|
||||||
|
|
||||||
|
return clean_text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _write_training_article(article_id, clean_text, training_output):
|
||||||
|
file_loc = training_output / str(article_id) + ".txt"
|
||||||
|
with open(file_loc, mode='w', encoding='utf8') as outputfile:
|
||||||
|
outputfile.write(clean_text)
|
||||||
|
|
||||||
|
|
||||||
|
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
|
||||||
|
outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def is_dev(article_id):
|
||||||
|
return article_id.endswith("3")
|
||||||
|
|
||||||
|
|
||||||
|
def read_training(nlp, training_dir, dev, limit):
|
||||||
|
# This method provides training examples that correspond to the entity annotations found by the nlp object
|
||||||
|
entityfile_loc = training_dir / ENTITY_FILE
|
||||||
|
data = []
|
||||||
|
|
||||||
|
# assume the data is written sequentially, so we can reuse the article docs
|
||||||
|
current_article_id = None
|
||||||
|
current_doc = None
|
||||||
|
ents_by_offset = dict()
|
||||||
|
skip_articles = set()
|
||||||
|
total_entities = 0
|
||||||
|
|
||||||
|
with open(entityfile_loc, mode='r', encoding='utf8') as file:
|
||||||
|
for line in file:
|
||||||
|
if not limit or len(data) < limit:
|
||||||
|
fields = line.replace('\n', "").split(sep='|')
|
||||||
|
article_id = fields[0]
|
||||||
|
alias = fields[1]
|
||||||
|
wp_title = fields[2]
|
||||||
|
start = fields[3]
|
||||||
|
end = fields[4]
|
||||||
|
|
||||||
|
if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles:
|
||||||
|
if not current_doc or (current_article_id != article_id):
|
||||||
|
# parse the new article text
|
||||||
|
file_name = article_id + ".txt"
|
||||||
|
try:
|
||||||
|
with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f:
|
||||||
|
text = f.read()
|
||||||
|
if len(text) < 30000: # threshold for convenience / speed of processing
|
||||||
|
current_doc = nlp(text)
|
||||||
|
current_article_id = article_id
|
||||||
|
ents_by_offset = dict()
|
||||||
|
for ent in current_doc.ents:
|
||||||
|
sent_length = len(ent.sent)
|
||||||
|
# custom filtering to avoid too long or too short sentences
|
||||||
|
if 5 < sent_length < 100:
|
||||||
|
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
||||||
|
else:
|
||||||
|
skip_articles.add(article_id)
|
||||||
|
current_doc = None
|
||||||
|
except Exception as e:
|
||||||
|
print("Problem parsing article", article_id, e)
|
||||||
|
skip_articles.add(article_id)
|
||||||
|
raise e
|
||||||
|
|
||||||
|
# repeat checking this condition in case an exception was thrown
|
||||||
|
if current_doc and (current_article_id == article_id):
|
||||||
|
found_ent = ents_by_offset.get(start + "_" + end, None)
|
||||||
|
if found_ent:
|
||||||
|
if found_ent.text != alias:
|
||||||
|
skip_articles.add(article_id)
|
||||||
|
current_doc = None
|
||||||
|
else:
|
||||||
|
sent = found_ent.sent.as_doc()
|
||||||
|
# currently feeding the gold data one entity per sentence at a time
|
||||||
|
gold_start = int(start) - found_ent.sent.start_char
|
||||||
|
gold_end = int(end) - found_ent.sent.start_char
|
||||||
|
gold_entities = [(gold_start, gold_end, wp_title)]
|
||||||
|
gold = GoldParse(doc=sent, links=gold_entities)
|
||||||
|
data.append((sent, gold))
|
||||||
|
total_entities += 1
|
||||||
|
if len(data) % 2500 == 0:
|
||||||
|
print(" -read", total_entities, "entities")
|
||||||
|
|
||||||
|
print(" -read", total_entities, "entities")
|
||||||
|
return data
|
119
bin/wiki_entity_linking/wikidata_processor.py
Normal file
119
bin/wiki_entity_linking/wikidata_processor.py
Normal file
|
@ -0,0 +1,119 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import bz2
|
||||||
|
import json
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
|
||||||
|
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
|
||||||
|
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
|
||||||
|
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||||
|
|
||||||
|
lang = 'en'
|
||||||
|
site_filter = 'enwiki'
|
||||||
|
|
||||||
|
# properties filter (currently disabled to get ALL data)
|
||||||
|
prop_filter = dict()
|
||||||
|
# prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
|
||||||
|
|
||||||
|
title_to_id = dict()
|
||||||
|
id_to_descr = dict()
|
||||||
|
|
||||||
|
# parse appropriate fields - depending on what we need in the KB
|
||||||
|
parse_properties = False
|
||||||
|
parse_sitelinks = True
|
||||||
|
parse_labels = False
|
||||||
|
parse_descriptions = True
|
||||||
|
parse_aliases = False
|
||||||
|
parse_claims = False
|
||||||
|
|
||||||
|
with bz2.open(wikidata_file, mode='rb') as file:
|
||||||
|
line = file.readline()
|
||||||
|
cnt = 0
|
||||||
|
while line and (not limit or cnt < limit):
|
||||||
|
if cnt % 500000 == 0:
|
||||||
|
print(datetime.datetime.now(), "processed", cnt, "lines of WikiData dump")
|
||||||
|
clean_line = line.strip()
|
||||||
|
if clean_line.endswith(b","):
|
||||||
|
clean_line = clean_line[:-1]
|
||||||
|
if len(clean_line) > 1:
|
||||||
|
obj = json.loads(clean_line)
|
||||||
|
entry_type = obj["type"]
|
||||||
|
|
||||||
|
if entry_type == "item":
|
||||||
|
# filtering records on their properties (currently disabled to get ALL data)
|
||||||
|
# keep = False
|
||||||
|
keep = True
|
||||||
|
|
||||||
|
claims = obj["claims"]
|
||||||
|
if parse_claims:
|
||||||
|
for prop, value_set in prop_filter.items():
|
||||||
|
claim_property = claims.get(prop, None)
|
||||||
|
if claim_property:
|
||||||
|
for cp in claim_property:
|
||||||
|
cp_id = cp['mainsnak'].get('datavalue', {}).get('value', {}).get('id')
|
||||||
|
cp_rank = cp['rank']
|
||||||
|
if cp_rank != "deprecated" and cp_id in value_set:
|
||||||
|
keep = True
|
||||||
|
|
||||||
|
if keep:
|
||||||
|
unique_id = obj["id"]
|
||||||
|
|
||||||
|
if to_print:
|
||||||
|
print("ID:", unique_id)
|
||||||
|
print("type:", entry_type)
|
||||||
|
|
||||||
|
# parsing all properties that refer to other entities
|
||||||
|
if parse_properties:
|
||||||
|
for prop, claim_property in claims.items():
|
||||||
|
cp_dicts = [cp['mainsnak']['datavalue'].get('value') for cp in claim_property
|
||||||
|
if cp['mainsnak'].get('datavalue')]
|
||||||
|
cp_values = [cp_dict.get('id') for cp_dict in cp_dicts if isinstance(cp_dict, dict)
|
||||||
|
if cp_dict.get('id') is not None]
|
||||||
|
if cp_values:
|
||||||
|
if to_print:
|
||||||
|
print("prop:", prop, cp_values)
|
||||||
|
|
||||||
|
found_link = False
|
||||||
|
if parse_sitelinks:
|
||||||
|
site_value = obj["sitelinks"].get(site_filter, None)
|
||||||
|
if site_value:
|
||||||
|
site = site_value['title']
|
||||||
|
if to_print:
|
||||||
|
print(site_filter, ":", site)
|
||||||
|
title_to_id[site] = unique_id
|
||||||
|
found_link = True
|
||||||
|
|
||||||
|
if parse_labels:
|
||||||
|
labels = obj["labels"]
|
||||||
|
if labels:
|
||||||
|
lang_label = labels.get(lang, None)
|
||||||
|
if lang_label:
|
||||||
|
if to_print:
|
||||||
|
print("label (" + lang + "):", lang_label["value"])
|
||||||
|
|
||||||
|
if found_link and parse_descriptions:
|
||||||
|
descriptions = obj["descriptions"]
|
||||||
|
if descriptions:
|
||||||
|
lang_descr = descriptions.get(lang, None)
|
||||||
|
if lang_descr:
|
||||||
|
if to_print:
|
||||||
|
print("description (" + lang + "):", lang_descr["value"])
|
||||||
|
id_to_descr[unique_id] = lang_descr["value"]
|
||||||
|
|
||||||
|
if parse_aliases:
|
||||||
|
aliases = obj["aliases"]
|
||||||
|
if aliases:
|
||||||
|
lang_aliases = aliases.get(lang, None)
|
||||||
|
if lang_aliases:
|
||||||
|
for item in lang_aliases:
|
||||||
|
if to_print:
|
||||||
|
print("alias (" + lang + "):", item["value"])
|
||||||
|
|
||||||
|
if to_print:
|
||||||
|
print()
|
||||||
|
line = file.readline()
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
return title_to_id, id_to_descr
|
182
bin/wiki_entity_linking/wikipedia_processor.py
Normal file
182
bin/wiki_entity_linking/wikipedia_processor.py
Normal file
|
@ -0,0 +1,182 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import re
|
||||||
|
import bz2
|
||||||
|
import csv
|
||||||
|
import datetime
|
||||||
|
|
||||||
|
"""
|
||||||
|
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
|
||||||
|
Write these results to file for downstream KB and training data generation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
map_alias_to_link = dict()
|
||||||
|
|
||||||
|
# these will/should be matched ignoring case
|
||||||
|
wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons",
|
||||||
|
"d", "dbdump", "download", "Draft", "Education", "Foundation",
|
||||||
|
"Gadget", "Gadget definition", "gerrit", "File", "Help", "Image", "Incubator",
|
||||||
|
"m", "mail", "mailarchive", "media", "MediaWiki", "MediaWiki talk", "Mediawikiwiki",
|
||||||
|
"MediaZilla", "Meta", "Metawikipedia", "Module",
|
||||||
|
"mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki",
|
||||||
|
"Portal", "phab", "Phabricator", "Project", "q", "quality", "rev",
|
||||||
|
"s", "spcom", "Special", "species", "Strategy", "sulutil", "svn",
|
||||||
|
"Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools",
|
||||||
|
"tswiki", "User", "User talk", "v", "voy",
|
||||||
|
"w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews",
|
||||||
|
"Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech",
|
||||||
|
"Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"]
|
||||||
|
|
||||||
|
# find the links
|
||||||
|
link_regex = re.compile(r'\[\[[^\[\]]*\]\]')
|
||||||
|
|
||||||
|
# match on interwiki links, e.g. `en:` or `:fr:`
|
||||||
|
ns_regex = r":?" + "[a-z][a-z]" + ":"
|
||||||
|
|
||||||
|
# match on Namespace: optionally preceded by a :
|
||||||
|
for ns in wiki_namespaces:
|
||||||
|
ns_regex += "|" + ":?" + ns + ":"
|
||||||
|
|
||||||
|
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
||||||
|
|
||||||
|
|
||||||
|
def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output):
|
||||||
|
"""
|
||||||
|
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
||||||
|
The full file takes about 2h to parse 1100M lines.
|
||||||
|
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
|
||||||
|
"""
|
||||||
|
with bz2.open(wikipedia_input, mode='rb') as file:
|
||||||
|
line = file.readline()
|
||||||
|
cnt = 0
|
||||||
|
while line:
|
||||||
|
if cnt % 5000000 == 0:
|
||||||
|
print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump")
|
||||||
|
clean_line = line.strip().decode("utf-8")
|
||||||
|
|
||||||
|
aliases, entities, normalizations = get_wp_links(clean_line)
|
||||||
|
for alias, entity, norm in zip(aliases, entities, normalizations):
|
||||||
|
_store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
|
||||||
|
_store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
|
||||||
|
|
||||||
|
line = file.readline()
|
||||||
|
cnt += 1
|
||||||
|
|
||||||
|
# write all aliases and their entities and count occurrences to file
|
||||||
|
with open(prior_prob_output, mode='w', encoding='utf8') as outputfile:
|
||||||
|
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
|
||||||
|
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
|
||||||
|
for entity, count in sorted(alias_dict.items(), key=lambda x: x[1], reverse=True):
|
||||||
|
outputfile.write(alias + "|" + str(count) + "|" + entity + "\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _store_alias(alias, entity, normalize_alias=False, normalize_entity=True):
|
||||||
|
alias = alias.strip()
|
||||||
|
entity = entity.strip()
|
||||||
|
|
||||||
|
# remove everything after # as this is not part of the title but refers to a specific paragraph
|
||||||
|
if normalize_entity:
|
||||||
|
# wikipedia titles are always capitalized
|
||||||
|
entity = _capitalize_first(entity.split("#")[0])
|
||||||
|
if normalize_alias:
|
||||||
|
alias = alias.split("#")[0]
|
||||||
|
|
||||||
|
if alias and entity:
|
||||||
|
alias_dict = map_alias_to_link.get(alias, dict())
|
||||||
|
entity_count = alias_dict.get(entity, 0)
|
||||||
|
alias_dict[entity] = entity_count + 1
|
||||||
|
map_alias_to_link[alias] = alias_dict
|
||||||
|
|
||||||
|
|
||||||
|
def get_wp_links(text):
|
||||||
|
aliases = []
|
||||||
|
entities = []
|
||||||
|
normalizations = []
|
||||||
|
|
||||||
|
matches = link_regex.findall(text)
|
||||||
|
for match in matches:
|
||||||
|
match = match[2:][:-2].replace("_", " ").strip()
|
||||||
|
|
||||||
|
if ns_regex.match(match):
|
||||||
|
pass # ignore namespaces at the beginning of the string
|
||||||
|
|
||||||
|
# this is a simple [[link]], with the alias the same as the mention
|
||||||
|
elif "|" not in match:
|
||||||
|
aliases.append(match)
|
||||||
|
entities.append(match)
|
||||||
|
normalizations.append(True)
|
||||||
|
|
||||||
|
# in wiki format, the link is written as [[entity|alias]]
|
||||||
|
else:
|
||||||
|
splits = match.split("|")
|
||||||
|
entity = splits[0].strip()
|
||||||
|
alias = splits[1].strip()
|
||||||
|
# specific wiki format [[alias (specification)|]]
|
||||||
|
if len(alias) == 0 and "(" in entity:
|
||||||
|
alias = entity.split("(")[0]
|
||||||
|
aliases.append(alias)
|
||||||
|
entities.append(entity)
|
||||||
|
normalizations.append(False)
|
||||||
|
else:
|
||||||
|
aliases.append(alias)
|
||||||
|
entities.append(entity)
|
||||||
|
normalizations.append(False)
|
||||||
|
|
||||||
|
return aliases, entities, normalizations
|
||||||
|
|
||||||
|
|
||||||
|
def _capitalize_first(text):
|
||||||
|
if not text:
|
||||||
|
return None
|
||||||
|
result = text[0].capitalize()
|
||||||
|
if len(result) > 0:
|
||||||
|
result += text[1:]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||||
|
# Write entity counts for quick access later
|
||||||
|
entity_to_count = dict()
|
||||||
|
total_count = 0
|
||||||
|
|
||||||
|
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
|
||||||
|
# skip header
|
||||||
|
prior_file.readline()
|
||||||
|
line = prior_file.readline()
|
||||||
|
|
||||||
|
while line:
|
||||||
|
splits = line.replace('\n', "").split(sep='|')
|
||||||
|
# alias = splits[0]
|
||||||
|
count = int(splits[1])
|
||||||
|
entity = splits[2]
|
||||||
|
|
||||||
|
current_count = entity_to_count.get(entity, 0)
|
||||||
|
entity_to_count[entity] = current_count + count
|
||||||
|
|
||||||
|
total_count += count
|
||||||
|
|
||||||
|
line = prior_file.readline()
|
||||||
|
|
||||||
|
with open(count_output, mode='w', encoding='utf8') as entity_file:
|
||||||
|
entity_file.write("entity" + "|" + "count" + "\n")
|
||||||
|
for entity, count in entity_to_count.items():
|
||||||
|
entity_file.write(entity + "|" + str(count) + "\n")
|
||||||
|
|
||||||
|
if to_print:
|
||||||
|
for entity, count in entity_to_count.items():
|
||||||
|
print("Entity count:", entity, count)
|
||||||
|
print("Total count:", total_count)
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_frequencies(count_input):
|
||||||
|
entity_to_count = dict()
|
||||||
|
with open(count_input, 'r', encoding='utf8') as csvfile:
|
||||||
|
csvreader = csv.reader(csvfile, delimiter='|')
|
||||||
|
# skip header
|
||||||
|
next(csvreader)
|
||||||
|
for row in csvreader:
|
||||||
|
entity_to_count[row[0]] = int(row[1])
|
||||||
|
|
||||||
|
return entity_to_count
|
||||||
|
|
|
@ -9,26 +9,26 @@ from spacy.kb import KnowledgeBase
|
||||||
|
|
||||||
|
|
||||||
def create_kb(vocab):
|
def create_kb(vocab):
|
||||||
kb = KnowledgeBase(vocab=vocab)
|
kb = KnowledgeBase(vocab=vocab, entity_vector_length=1)
|
||||||
|
|
||||||
# adding entities
|
# adding entities
|
||||||
entity_0 = "Q1004791_Douglas"
|
entity_0 = "Q1004791_Douglas"
|
||||||
print("adding entity", entity_0)
|
print("adding entity", entity_0)
|
||||||
kb.add_entity(entity=entity_0, prob=0.5)
|
kb.add_entity(entity=entity_0, prob=0.5, entity_vector=[0])
|
||||||
|
|
||||||
entity_1 = "Q42_Douglas_Adams"
|
entity_1 = "Q42_Douglas_Adams"
|
||||||
print("adding entity", entity_1)
|
print("adding entity", entity_1)
|
||||||
kb.add_entity(entity=entity_1, prob=0.5)
|
kb.add_entity(entity=entity_1, prob=0.5, entity_vector=[1])
|
||||||
|
|
||||||
entity_2 = "Q5301561_Douglas_Haig"
|
entity_2 = "Q5301561_Douglas_Haig"
|
||||||
print("adding entity", entity_2)
|
print("adding entity", entity_2)
|
||||||
kb.add_entity(entity=entity_2, prob=0.5)
|
kb.add_entity(entity=entity_2, prob=0.5, entity_vector=[2])
|
||||||
|
|
||||||
# adding aliases
|
# adding aliases
|
||||||
print()
|
print()
|
||||||
alias_0 = "Douglas"
|
alias_0 = "Douglas"
|
||||||
print("adding alias", alias_0)
|
print("adding alias", alias_0)
|
||||||
kb.add_alias(alias=alias_0, entities=[entity_0, entity_1, entity_2], probabilities=[0.1, 0.6, 0.2])
|
kb.add_alias(alias=alias_0, entities=[entity_0, entity_1, entity_2], probabilities=[0.6, 0.1, 0.2])
|
||||||
|
|
||||||
alias_1 = "Douglas Adams"
|
alias_1 = "Douglas Adams"
|
||||||
print("adding alias", alias_1)
|
print("adding alias", alias_1)
|
||||||
|
@ -41,8 +41,12 @@ def create_kb(vocab):
|
||||||
|
|
||||||
|
|
||||||
def add_el(kb, nlp):
|
def add_el(kb, nlp):
|
||||||
el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": kb})
|
el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64})
|
||||||
|
el_pipe.set_kb(kb)
|
||||||
nlp.add_pipe(el_pipe, last=True)
|
nlp.add_pipe(el_pipe, last=True)
|
||||||
|
nlp.begin_training()
|
||||||
|
el_pipe.context_weight = 0
|
||||||
|
el_pipe.prior_weight = 1
|
||||||
|
|
||||||
for alias in ["Douglas Adams", "Douglas"]:
|
for alias in ["Douglas Adams", "Douglas"]:
|
||||||
candidates = nlp.linker.kb.get_candidates(alias)
|
candidates = nlp.linker.kb.get_candidates(alias)
|
||||||
|
@ -66,6 +70,6 @@ def add_el(kb, nlp):
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
nlp = spacy.load('en_core_web_sm')
|
my_nlp = spacy.load('en_core_web_sm')
|
||||||
my_kb = create_kb(nlp.vocab)
|
my_kb = create_kb(my_nlp.vocab)
|
||||||
add_el(my_kb, nlp)
|
add_el(my_kb, my_nlp)
|
||||||
|
|
442
examples/pipeline/wikidata_entity_linking.py
Normal file
442
examples/pipeline/wikidata_entity_linking.py
Normal file
|
@ -0,0 +1,442 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import random
|
||||||
|
import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp
|
||||||
|
from bin.wiki_entity_linking.kb_creator import DESC_WIDTH
|
||||||
|
|
||||||
|
import spacy
|
||||||
|
from spacy.kb import KnowledgeBase
|
||||||
|
from spacy.util import minibatch, compounding
|
||||||
|
|
||||||
|
"""
|
||||||
|
Demonstrate how to build a knowledge base from WikiData and run an Entity Linking algorithm.
|
||||||
|
"""
|
||||||
|
|
||||||
|
ROOT_DIR = Path("C:/Users/Sofie/Documents/data/")
|
||||||
|
OUTPUT_DIR = ROOT_DIR / 'wikipedia'
|
||||||
|
TRAINING_DIR = OUTPUT_DIR / 'training_data_nel'
|
||||||
|
|
||||||
|
PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv'
|
||||||
|
ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv'
|
||||||
|
ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv'
|
||||||
|
ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv'
|
||||||
|
|
||||||
|
KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb'
|
||||||
|
NLP_1_DIR = OUTPUT_DIR / 'nlp_1'
|
||||||
|
NLP_2_DIR = OUTPUT_DIR / 'nlp_2'
|
||||||
|
|
||||||
|
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||||
|
WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2'
|
||||||
|
|
||||||
|
# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/
|
||||||
|
ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2'
|
||||||
|
|
||||||
|
# KB construction parameters
|
||||||
|
MAX_CANDIDATES = 10
|
||||||
|
MIN_ENTITY_FREQ = 20
|
||||||
|
MIN_PAIR_OCC = 5
|
||||||
|
|
||||||
|
# model training parameters
|
||||||
|
EPOCHS = 10
|
||||||
|
DROPOUT = 0.5
|
||||||
|
LEARN_RATE = 0.005
|
||||||
|
L2 = 1e-6
|
||||||
|
CONTEXT_WIDTH = 128
|
||||||
|
|
||||||
|
|
||||||
|
def run_pipeline():
|
||||||
|
# set the appropriate booleans to define which parts of the pipeline should be re(run)
|
||||||
|
print("START", datetime.datetime.now())
|
||||||
|
print()
|
||||||
|
nlp_1 = spacy.load('en_core_web_lg')
|
||||||
|
nlp_2 = None
|
||||||
|
kb_2 = None
|
||||||
|
|
||||||
|
# one-time methods to create KB and write to file
|
||||||
|
to_create_prior_probs = False
|
||||||
|
to_create_entity_counts = False
|
||||||
|
to_create_kb = False
|
||||||
|
|
||||||
|
# read KB back in from file
|
||||||
|
to_read_kb = True
|
||||||
|
to_test_kb = False
|
||||||
|
|
||||||
|
# create training dataset
|
||||||
|
create_wp_training = False
|
||||||
|
|
||||||
|
# train the EL pipe
|
||||||
|
train_pipe = True
|
||||||
|
measure_performance = True
|
||||||
|
|
||||||
|
# test the EL pipe on a simple example
|
||||||
|
to_test_pipeline = True
|
||||||
|
|
||||||
|
# write the NLP object, read back in and test again
|
||||||
|
to_write_nlp = True
|
||||||
|
to_read_nlp = True
|
||||||
|
test_from_file = False
|
||||||
|
|
||||||
|
# STEP 1 : create prior probabilities from WP (run only once)
|
||||||
|
if to_create_prior_probs:
|
||||||
|
print("STEP 1: to_create_prior_probs", datetime.datetime.now())
|
||||||
|
wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# STEP 2 : deduce entity frequencies from WP (run only once)
|
||||||
|
if to_create_entity_counts:
|
||||||
|
print("STEP 2: to_create_entity_counts", datetime.datetime.now())
|
||||||
|
wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# STEP 3 : create KB and write to file (run only once)
|
||||||
|
if to_create_kb:
|
||||||
|
print("STEP 3a: to_create_kb", datetime.datetime.now())
|
||||||
|
kb_1 = kb_creator.create_kb(nlp_1,
|
||||||
|
max_entities_per_alias=MAX_CANDIDATES,
|
||||||
|
min_entity_freq=MIN_ENTITY_FREQ,
|
||||||
|
min_occ=MIN_PAIR_OCC,
|
||||||
|
entity_def_output=ENTITY_DEFS,
|
||||||
|
entity_descr_output=ENTITY_DESCR,
|
||||||
|
count_input=ENTITY_COUNTS,
|
||||||
|
prior_prob_input=PRIOR_PROB,
|
||||||
|
wikidata_input=WIKIDATA_JSON)
|
||||||
|
print("kb entities:", kb_1.get_size_entities())
|
||||||
|
print("kb aliases:", kb_1.get_size_aliases())
|
||||||
|
print()
|
||||||
|
|
||||||
|
print("STEP 3b: write KB and NLP", datetime.datetime.now())
|
||||||
|
kb_1.dump(KB_FILE)
|
||||||
|
nlp_1.to_disk(NLP_1_DIR)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# STEP 4 : read KB back in from file
|
||||||
|
if to_read_kb:
|
||||||
|
print("STEP 4: to_read_kb", datetime.datetime.now())
|
||||||
|
nlp_2 = spacy.load(NLP_1_DIR)
|
||||||
|
kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH)
|
||||||
|
kb_2.load_bulk(KB_FILE)
|
||||||
|
print("kb entities:", kb_2.get_size_entities())
|
||||||
|
print("kb aliases:", kb_2.get_size_aliases())
|
||||||
|
print()
|
||||||
|
|
||||||
|
# test KB
|
||||||
|
if to_test_kb:
|
||||||
|
check_kb(kb_2)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# STEP 5: create a training dataset from WP
|
||||||
|
if create_wp_training:
|
||||||
|
print("STEP 5: create training dataset", datetime.datetime.now())
|
||||||
|
training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP,
|
||||||
|
entity_def_input=ENTITY_DEFS,
|
||||||
|
training_output=TRAINING_DIR)
|
||||||
|
|
||||||
|
# STEP 6: create and train the entity linking pipe
|
||||||
|
if train_pipe:
|
||||||
|
print("STEP 6: training Entity Linking pipe", datetime.datetime.now())
|
||||||
|
type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)}
|
||||||
|
print(" -analysing", len(type_to_int), "different entity types")
|
||||||
|
el_pipe = nlp_2.create_pipe(name='entity_linker',
|
||||||
|
config={"context_width": CONTEXT_WIDTH,
|
||||||
|
"pretrained_vectors": nlp_2.vocab.vectors.name,
|
||||||
|
"type_to_int": type_to_int})
|
||||||
|
el_pipe.set_kb(kb_2)
|
||||||
|
nlp_2.add_pipe(el_pipe, last=True)
|
||||||
|
|
||||||
|
other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"]
|
||||||
|
with nlp_2.disable_pipes(*other_pipes): # only train Entity Linking
|
||||||
|
optimizer = nlp_2.begin_training()
|
||||||
|
optimizer.learn_rate = LEARN_RATE
|
||||||
|
optimizer.L2 = L2
|
||||||
|
|
||||||
|
# define the size (nr of entities) of training and dev set
|
||||||
|
train_limit = 5000
|
||||||
|
dev_limit = 5000
|
||||||
|
|
||||||
|
train_data = training_set_creator.read_training(nlp=nlp_2,
|
||||||
|
training_dir=TRAINING_DIR,
|
||||||
|
dev=False,
|
||||||
|
limit=train_limit)
|
||||||
|
|
||||||
|
print("Training on", len(train_data), "articles")
|
||||||
|
print()
|
||||||
|
|
||||||
|
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||||
|
training_dir=TRAINING_DIR,
|
||||||
|
dev=True,
|
||||||
|
limit=dev_limit)
|
||||||
|
|
||||||
|
print("Dev testing on", len(dev_data), "articles")
|
||||||
|
print()
|
||||||
|
|
||||||
|
if not train_data:
|
||||||
|
print("Did not find any training data")
|
||||||
|
else:
|
||||||
|
for itn in range(EPOCHS):
|
||||||
|
random.shuffle(train_data)
|
||||||
|
losses = {}
|
||||||
|
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
||||||
|
batchnr = 0
|
||||||
|
|
||||||
|
with nlp_2.disable_pipes(*other_pipes):
|
||||||
|
for batch in batches:
|
||||||
|
try:
|
||||||
|
docs, golds = zip(*batch)
|
||||||
|
nlp_2.update(
|
||||||
|
docs,
|
||||||
|
golds,
|
||||||
|
sgd=optimizer,
|
||||||
|
drop=DROPOUT,
|
||||||
|
losses=losses,
|
||||||
|
)
|
||||||
|
batchnr += 1
|
||||||
|
except Exception as e:
|
||||||
|
print("Error updating batch:", e)
|
||||||
|
|
||||||
|
if batchnr > 0:
|
||||||
|
el_pipe.cfg["context_weight"] = 1
|
||||||
|
el_pipe.cfg["prior_weight"] = 1
|
||||||
|
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
||||||
|
losses['entity_linker'] = losses['entity_linker'] / batchnr
|
||||||
|
print("Epoch, train loss", itn, round(losses['entity_linker'], 2),
|
||||||
|
" / dev acc avg", round(dev_acc_context, 3))
|
||||||
|
|
||||||
|
# STEP 7: measure the performance of our trained pipe on an independent dev set
|
||||||
|
if len(dev_data) and measure_performance:
|
||||||
|
print()
|
||||||
|
print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now())
|
||||||
|
print()
|
||||||
|
|
||||||
|
counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2)
|
||||||
|
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
|
||||||
|
print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()])
|
||||||
|
print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()])
|
||||||
|
print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()])
|
||||||
|
|
||||||
|
# using only context
|
||||||
|
el_pipe.cfg["context_weight"] = 1
|
||||||
|
el_pipe.cfg["prior_weight"] = 0
|
||||||
|
dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe)
|
||||||
|
print("dev acc context avg:", round(dev_acc_context, 3),
|
||||||
|
[(x, round(y, 3)) for x, y in dev_acc_context_dict.items()])
|
||||||
|
|
||||||
|
# measuring combined accuracy (prior + context)
|
||||||
|
el_pipe.cfg["context_weight"] = 1
|
||||||
|
el_pipe.cfg["prior_weight"] = 1
|
||||||
|
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False)
|
||||||
|
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
||||||
|
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||||
|
|
||||||
|
# STEP 8: apply the EL pipe on a toy example
|
||||||
|
if to_test_pipeline:
|
||||||
|
print()
|
||||||
|
print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now())
|
||||||
|
print()
|
||||||
|
run_el_toy_example(nlp=nlp_2)
|
||||||
|
|
||||||
|
# STEP 9: write the NLP pipeline (including entity linker) to file
|
||||||
|
if to_write_nlp:
|
||||||
|
print()
|
||||||
|
print("STEP 9: testing NLP IO", datetime.datetime.now())
|
||||||
|
print()
|
||||||
|
print("writing to", NLP_2_DIR)
|
||||||
|
nlp_2.to_disk(NLP_2_DIR)
|
||||||
|
print()
|
||||||
|
|
||||||
|
# verify that the IO has gone correctly
|
||||||
|
if to_read_nlp:
|
||||||
|
print("reading from", NLP_2_DIR)
|
||||||
|
nlp_3 = spacy.load(NLP_2_DIR)
|
||||||
|
|
||||||
|
print("running toy example with NLP 3")
|
||||||
|
run_el_toy_example(nlp=nlp_3)
|
||||||
|
|
||||||
|
# testing performance with an NLP model from file
|
||||||
|
if test_from_file:
|
||||||
|
nlp_2 = spacy.load(NLP_1_DIR)
|
||||||
|
nlp_3 = spacy.load(NLP_2_DIR)
|
||||||
|
el_pipe = nlp_3.get_pipe("entity_linker")
|
||||||
|
|
||||||
|
dev_limit = 5000
|
||||||
|
dev_data = training_set_creator.read_training(nlp=nlp_2,
|
||||||
|
training_dir=TRAINING_DIR,
|
||||||
|
dev=True,
|
||||||
|
limit=dev_limit)
|
||||||
|
|
||||||
|
print("Dev testing from file on", len(dev_data), "articles")
|
||||||
|
print()
|
||||||
|
|
||||||
|
dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe, error_analysis=False)
|
||||||
|
print("dev acc combo avg:", round(dev_acc_combo, 3),
|
||||||
|
[(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()])
|
||||||
|
|
||||||
|
print()
|
||||||
|
print("STOP", datetime.datetime.now())
|
||||||
|
|
||||||
|
|
||||||
|
def _measure_accuracy(data, el_pipe=None, error_analysis=False):
|
||||||
|
# If the docs in the data require further processing with an entity linker, set el_pipe
|
||||||
|
correct_by_label = dict()
|
||||||
|
incorrect_by_label = dict()
|
||||||
|
|
||||||
|
docs = [d for d, g in data if len(d) > 0]
|
||||||
|
if el_pipe is not None:
|
||||||
|
docs = list(el_pipe.pipe(docs))
|
||||||
|
golds = [g for d, g in data if len(d) > 0]
|
||||||
|
|
||||||
|
for doc, gold in zip(docs, golds):
|
||||||
|
try:
|
||||||
|
correct_entries_per_article = dict()
|
||||||
|
for entity in gold.links:
|
||||||
|
start, end, gold_kb = entity
|
||||||
|
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||||
|
|
||||||
|
for ent in doc.ents:
|
||||||
|
ent_label = ent.label_
|
||||||
|
pred_entity = ent.kb_id_
|
||||||
|
start = ent.start_char
|
||||||
|
end = ent.end_char
|
||||||
|
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
||||||
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
|
if gold_entity is not None:
|
||||||
|
if gold_entity == pred_entity:
|
||||||
|
correct = correct_by_label.get(ent_label, 0)
|
||||||
|
correct_by_label[ent_label] = correct + 1
|
||||||
|
else:
|
||||||
|
incorrect = incorrect_by_label.get(ent_label, 0)
|
||||||
|
incorrect_by_label[ent_label] = incorrect + 1
|
||||||
|
if error_analysis:
|
||||||
|
print(ent.text, "in", doc)
|
||||||
|
print("Predicted", pred_entity, "should have been", gold_entity)
|
||||||
|
print()
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("Error assessing accuracy", e)
|
||||||
|
|
||||||
|
acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
|
||||||
|
return acc, acc_by_label
|
||||||
|
|
||||||
|
|
||||||
|
def _measure_baselines(data, kb):
|
||||||
|
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
||||||
|
counts_by_label = dict()
|
||||||
|
|
||||||
|
random_correct_by_label = dict()
|
||||||
|
random_incorrect_by_label = dict()
|
||||||
|
|
||||||
|
oracle_correct_by_label = dict()
|
||||||
|
oracle_incorrect_by_label = dict()
|
||||||
|
|
||||||
|
prior_correct_by_label = dict()
|
||||||
|
prior_incorrect_by_label = dict()
|
||||||
|
|
||||||
|
docs = [d for d, g in data if len(d) > 0]
|
||||||
|
golds = [g for d, g in data if len(d) > 0]
|
||||||
|
|
||||||
|
for doc, gold in zip(docs, golds):
|
||||||
|
try:
|
||||||
|
correct_entries_per_article = dict()
|
||||||
|
for entity in gold.links:
|
||||||
|
start, end, gold_kb = entity
|
||||||
|
correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb
|
||||||
|
|
||||||
|
for ent in doc.ents:
|
||||||
|
ent_label = ent.label_
|
||||||
|
start = ent.start_char
|
||||||
|
end = ent.end_char
|
||||||
|
gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None)
|
||||||
|
|
||||||
|
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||||
|
if gold_entity is not None:
|
||||||
|
counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1
|
||||||
|
candidates = kb.get_candidates(ent.text)
|
||||||
|
oracle_candidate = ""
|
||||||
|
best_candidate = ""
|
||||||
|
random_candidate = ""
|
||||||
|
if candidates:
|
||||||
|
scores = []
|
||||||
|
|
||||||
|
for c in candidates:
|
||||||
|
scores.append(c.prior_prob)
|
||||||
|
if c.entity_ == gold_entity:
|
||||||
|
oracle_candidate = c.entity_
|
||||||
|
|
||||||
|
best_index = scores.index(max(scores))
|
||||||
|
best_candidate = candidates[best_index].entity_
|
||||||
|
random_candidate = random.choice(candidates).entity_
|
||||||
|
|
||||||
|
if gold_entity == best_candidate:
|
||||||
|
prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1
|
||||||
|
else:
|
||||||
|
prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1
|
||||||
|
|
||||||
|
if gold_entity == random_candidate:
|
||||||
|
random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1
|
||||||
|
else:
|
||||||
|
random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1
|
||||||
|
|
||||||
|
if gold_entity == oracle_candidate:
|
||||||
|
oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1
|
||||||
|
else:
|
||||||
|
oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print("Error assessing accuracy", e)
|
||||||
|
|
||||||
|
acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label)
|
||||||
|
acc_rand, acc_rand_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label)
|
||||||
|
acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label)
|
||||||
|
|
||||||
|
return counts_by_label, acc_rand, acc_rand_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_acc(correct_by_label, incorrect_by_label):
|
||||||
|
acc_by_label = dict()
|
||||||
|
total_correct = 0
|
||||||
|
total_incorrect = 0
|
||||||
|
all_keys = set()
|
||||||
|
all_keys.update(correct_by_label.keys())
|
||||||
|
all_keys.update(incorrect_by_label.keys())
|
||||||
|
for label in sorted(all_keys):
|
||||||
|
correct = correct_by_label.get(label, 0)
|
||||||
|
incorrect = incorrect_by_label.get(label, 0)
|
||||||
|
total_correct += correct
|
||||||
|
total_incorrect += incorrect
|
||||||
|
if correct == incorrect == 0:
|
||||||
|
acc_by_label[label] = 0
|
||||||
|
else:
|
||||||
|
acc_by_label[label] = correct / (correct + incorrect)
|
||||||
|
acc = 0
|
||||||
|
if not (total_correct == total_incorrect == 0):
|
||||||
|
acc = total_correct / (total_correct + total_incorrect)
|
||||||
|
return acc, acc_by_label
|
||||||
|
|
||||||
|
|
||||||
|
def check_kb(kb):
|
||||||
|
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
|
||||||
|
candidates = kb.get_candidates(mention)
|
||||||
|
|
||||||
|
print("generating candidates for " + mention + " :")
|
||||||
|
for c in candidates:
|
||||||
|
print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")")
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def run_el_toy_example(nlp):
|
||||||
|
text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \
|
||||||
|
"Douglas reminds us to always bring our towel, even in China or Brazil. " \
|
||||||
|
"The main character in Doug's novel is the man Arthur Dent, " \
|
||||||
|
"but Douglas doesn't write about George Washington or Homer Simpson."
|
||||||
|
doc = nlp(text)
|
||||||
|
print(text)
|
||||||
|
for ent in doc.ents:
|
||||||
|
print(" ent", ent.text, ent.label_, ent.kb_id_)
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
run_pipeline()
|
32
spacy/_ml.py
32
spacy/_ml.py
|
@ -652,6 +652,38 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False,
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg):
|
||||||
|
# TODO proper error
|
||||||
|
if "entity_width" not in cfg:
|
||||||
|
raise ValueError("entity_width not found")
|
||||||
|
if "context_width" not in cfg:
|
||||||
|
raise ValueError("context_width not found")
|
||||||
|
|
||||||
|
conv_depth = cfg.get("conv_depth", 2)
|
||||||
|
cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3)
|
||||||
|
pretrained_vectors = cfg.get("pretrained_vectors") # self.nlp.vocab.vectors.name
|
||||||
|
context_width = cfg.get("context_width")
|
||||||
|
entity_width = cfg.get("entity_width")
|
||||||
|
|
||||||
|
with Model.define_operators({">>": chain, "**": clone}):
|
||||||
|
model = Affine(entity_width, entity_width+context_width+1+ner_types)\
|
||||||
|
>> Affine(1, entity_width, drop_factor=0.0)\
|
||||||
|
>> logistic
|
||||||
|
|
||||||
|
# context encoder
|
||||||
|
tok2vec = Tok2Vec(width=hidden_width, embed_size=embed_width, pretrained_vectors=pretrained_vectors,
|
||||||
|
cnn_maxout_pieces=cnn_maxout_pieces, subword_features=True, conv_depth=conv_depth,
|
||||||
|
bilstm_depth=0) >> flatten_add_lengths >> Pooling(mean_pool)\
|
||||||
|
>> Residual(zero_init(Maxout(hidden_width, hidden_width))) \
|
||||||
|
>> zero_init(Affine(context_width, hidden_width))
|
||||||
|
|
||||||
|
model.tok2vec = tok2vec
|
||||||
|
|
||||||
|
model.tok2vec = tok2vec
|
||||||
|
model.tok2vec.nO = context_width
|
||||||
|
model.nO = 1
|
||||||
|
return model
|
||||||
|
|
||||||
@layerize
|
@layerize
|
||||||
def flatten(seqs, drop=0.0):
|
def flatten(seqs, drop=0.0):
|
||||||
ops = Model.ops
|
ops = Model.ops
|
||||||
|
|
|
@ -82,6 +82,7 @@ cdef enum attr_id_t:
|
||||||
DEP
|
DEP
|
||||||
ENT_IOB
|
ENT_IOB
|
||||||
ENT_TYPE
|
ENT_TYPE
|
||||||
|
ENT_KB_ID
|
||||||
HEAD
|
HEAD
|
||||||
SENT_START
|
SENT_START
|
||||||
SPACY
|
SPACY
|
||||||
|
|
|
@ -84,6 +84,7 @@ IDS = {
|
||||||
"DEP": DEP,
|
"DEP": DEP,
|
||||||
"ENT_IOB": ENT_IOB,
|
"ENT_IOB": ENT_IOB,
|
||||||
"ENT_TYPE": ENT_TYPE,
|
"ENT_TYPE": ENT_TYPE,
|
||||||
|
"ENT_KB_ID": ENT_KB_ID,
|
||||||
"HEAD": HEAD,
|
"HEAD": HEAD,
|
||||||
"SENT_START": SENT_START,
|
"SENT_START": SENT_START,
|
||||||
"SPACY": SPACY,
|
"SPACY": SPACY,
|
||||||
|
|
|
@ -301,7 +301,7 @@ def get_vectors_loss(ops, docs, prediction, objective="L2"):
|
||||||
elif objective == "cosine":
|
elif objective == "cosine":
|
||||||
loss, d_target = get_cossim_loss(prediction, target)
|
loss, d_target = get_cossim_loss(prediction, target)
|
||||||
else:
|
else:
|
||||||
raise ValueError(Errors.E139.format(loss_func=objective))
|
raise ValueError(Errors.E142.format(loss_func=objective))
|
||||||
return loss, d_target
|
return loss, d_target
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -399,7 +399,10 @@ class Errors(object):
|
||||||
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input includes either the "
|
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input includes either the "
|
||||||
"`text` or `tokens` key. For more info, see the docs:\n"
|
"`text` or `tokens` key. For more info, see the docs:\n"
|
||||||
"https://spacy.io/api/cli#pretrain-jsonl")
|
"https://spacy.io/api/cli#pretrain-jsonl")
|
||||||
E139 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'")
|
E139 = ("Knowledge base for component '{name}' not initialized. Did you forget to call set_kb()?")
|
||||||
|
E140 = ("The list of entities, prior probabilities and entity vectors should be of equal length.")
|
||||||
|
E141 = ("Entity vectors should be of length {required} instead of the provided {found}.")
|
||||||
|
E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'")
|
||||||
|
|
||||||
|
|
||||||
@add_codes
|
@add_codes
|
||||||
|
|
|
@ -31,6 +31,7 @@ cdef class GoldParse:
|
||||||
cdef public list ents
|
cdef public list ents
|
||||||
cdef public dict brackets
|
cdef public dict brackets
|
||||||
cdef public object cats
|
cdef public object cats
|
||||||
|
cdef public list links
|
||||||
|
|
||||||
cdef readonly list cand_to_gold
|
cdef readonly list cand_to_gold
|
||||||
cdef readonly list gold_to_cand
|
cdef readonly list gold_to_cand
|
||||||
|
|
|
@ -427,7 +427,7 @@ cdef class GoldParse:
|
||||||
|
|
||||||
def __init__(self, doc, annot_tuples=None, words=None, tags=None,
|
def __init__(self, doc, annot_tuples=None, words=None, tags=None,
|
||||||
heads=None, deps=None, entities=None, make_projective=False,
|
heads=None, deps=None, entities=None, make_projective=False,
|
||||||
cats=None, **_):
|
cats=None, links=None, **_):
|
||||||
"""Create a GoldParse.
|
"""Create a GoldParse.
|
||||||
|
|
||||||
doc (Doc): The document the annotations refer to.
|
doc (Doc): The document the annotations refer to.
|
||||||
|
@ -450,6 +450,8 @@ cdef class GoldParse:
|
||||||
examples of a label to have the value 0.0. Labels not in the
|
examples of a label to have the value 0.0. Labels not in the
|
||||||
dictionary are treated as missing - the gradient for those labels
|
dictionary are treated as missing - the gradient for those labels
|
||||||
will be zero.
|
will be zero.
|
||||||
|
links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples,
|
||||||
|
representing the external ID of an entity in a knowledge base.
|
||||||
RETURNS (GoldParse): The newly constructed object.
|
RETURNS (GoldParse): The newly constructed object.
|
||||||
"""
|
"""
|
||||||
if words is None:
|
if words is None:
|
||||||
|
@ -485,6 +487,7 @@ cdef class GoldParse:
|
||||||
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
|
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
|
||||||
|
|
||||||
self.cats = {} if cats is None else dict(cats)
|
self.cats = {} if cats is None else dict(cats)
|
||||||
|
self.links = links
|
||||||
self.words = [None] * len(doc)
|
self.words = [None] * len(doc)
|
||||||
self.tags = [None] * len(doc)
|
self.tags = [None] * len(doc)
|
||||||
self.heads = [None] * len(doc)
|
self.heads = [None] * len(doc)
|
||||||
|
|
170
spacy/kb.pxd
170
spacy/kb.pxd
|
@ -1,53 +1,27 @@
|
||||||
"""Knowledge-base for entity or concept linking."""
|
"""Knowledge-base for entity or concept linking."""
|
||||||
from cymem.cymem cimport Pool
|
from cymem.cymem cimport Pool
|
||||||
from preshed.maps cimport PreshMap
|
from preshed.maps cimport PreshMap
|
||||||
|
|
||||||
from libcpp.vector cimport vector
|
from libcpp.vector cimport vector
|
||||||
from libc.stdint cimport int32_t, int64_t
|
from libc.stdint cimport int32_t, int64_t
|
||||||
|
from libc.stdio cimport FILE
|
||||||
|
|
||||||
from spacy.vocab cimport Vocab
|
from spacy.vocab cimport Vocab
|
||||||
from .typedefs cimport hash_t
|
from .typedefs cimport hash_t
|
||||||
|
|
||||||
|
from .structs cimport KBEntryC, AliasC
|
||||||
# Internal struct, for storage and disambiguation. This isn't what we return
|
ctypedef vector[KBEntryC] entry_vec
|
||||||
# to the user as the answer to "here's your entity". It's the minimum number
|
ctypedef vector[AliasC] alias_vec
|
||||||
# of bits we need to keep track of the answers.
|
ctypedef vector[float] float_vec
|
||||||
cdef struct _EntryC:
|
ctypedef vector[float_vec] float_matrix
|
||||||
|
|
||||||
# The hash of this entry's unique ID and name in the kB
|
|
||||||
hash_t entity_hash
|
|
||||||
|
|
||||||
# Allows retrieval of one or more vectors.
|
|
||||||
# Each element of vector_rows should be an index into a vectors table.
|
|
||||||
# Every entry should have the same number of vectors, so we can avoid storing
|
|
||||||
# the number of vectors in each knowledge-base struct
|
|
||||||
int32_t* vector_rows
|
|
||||||
|
|
||||||
# Allows retrieval of a struct of non-vector features. We could make this a
|
|
||||||
# pointer, but we have 32 bits left over in the struct after prob, so we'd
|
|
||||||
# like this to only be 32 bits. We can also set this to -1, for the common
|
|
||||||
# case where there are no features.
|
|
||||||
int32_t feats_row
|
|
||||||
|
|
||||||
# log probability of entity, based on corpus frequency
|
|
||||||
float prob
|
|
||||||
|
|
||||||
|
|
||||||
# Each alias struct stores a list of Entry pointers with their prior probabilities
|
|
||||||
# for this specific mention/alias.
|
|
||||||
cdef struct _AliasC:
|
|
||||||
|
|
||||||
# All entry candidates for this alias
|
|
||||||
vector[int64_t] entry_indices
|
|
||||||
|
|
||||||
# Prior probability P(entity|alias) - should sum up to (at most) 1.
|
|
||||||
vector[float] probs
|
|
||||||
|
|
||||||
|
|
||||||
# Object used by the Entity Linker that summarizes one entity-alias candidate combination.
|
# Object used by the Entity Linker that summarizes one entity-alias candidate combination.
|
||||||
cdef class Candidate:
|
cdef class Candidate:
|
||||||
|
|
||||||
cdef readonly KnowledgeBase kb
|
cdef readonly KnowledgeBase kb
|
||||||
cdef hash_t entity_hash
|
cdef hash_t entity_hash
|
||||||
|
cdef float entity_freq
|
||||||
|
cdef vector[float] entity_vector
|
||||||
cdef hash_t alias_hash
|
cdef hash_t alias_hash
|
||||||
cdef float prior_prob
|
cdef float prior_prob
|
||||||
|
|
||||||
|
@ -55,9 +29,10 @@ cdef class Candidate:
|
||||||
cdef class KnowledgeBase:
|
cdef class KnowledgeBase:
|
||||||
cdef Pool mem
|
cdef Pool mem
|
||||||
cpdef readonly Vocab vocab
|
cpdef readonly Vocab vocab
|
||||||
|
cdef int64_t entity_vector_length
|
||||||
|
|
||||||
# This maps 64bit keys (hash of unique entity string)
|
# This maps 64bit keys (hash of unique entity string)
|
||||||
# to 64bit values (position of the _EntryC struct in the _entries vector).
|
# to 64bit values (position of the _KBEntryC struct in the _entries vector).
|
||||||
# The PreshMap is pretty space efficient, as it uses open addressing. So
|
# The PreshMap is pretty space efficient, as it uses open addressing. So
|
||||||
# the only overhead is the vacancy rate, which is approximately 30%.
|
# the only overhead is the vacancy rate, which is approximately 30%.
|
||||||
cdef PreshMap _entry_index
|
cdef PreshMap _entry_index
|
||||||
|
@ -66,7 +41,7 @@ cdef class KnowledgeBase:
|
||||||
# over allocation.
|
# over allocation.
|
||||||
# In total we end up with (N*128*1.3)+(N*128*1.3) bits for N entries.
|
# In total we end up with (N*128*1.3)+(N*128*1.3) bits for N entries.
|
||||||
# Storing 1m entries would take 41.6mb under this scheme.
|
# Storing 1m entries would take 41.6mb under this scheme.
|
||||||
cdef vector[_EntryC] _entries
|
cdef entry_vec _entries
|
||||||
|
|
||||||
# This maps 64bit keys (hash of unique alias string)
|
# This maps 64bit keys (hash of unique alias string)
|
||||||
# to 64bit values (position of the _AliasC struct in the _aliases_table vector).
|
# to 64bit values (position of the _AliasC struct in the _aliases_table vector).
|
||||||
|
@ -76,7 +51,7 @@ cdef class KnowledgeBase:
|
||||||
# should be P(entity | mention), which is pretty important to know.
|
# should be P(entity | mention), which is pretty important to know.
|
||||||
# We can pack both pieces of information into a 64-bit value, to keep things
|
# We can pack both pieces of information into a 64-bit value, to keep things
|
||||||
# efficient.
|
# efficient.
|
||||||
cdef vector[_AliasC] _aliases_table
|
cdef alias_vec _aliases_table
|
||||||
|
|
||||||
# This is the part which might take more space: storing various
|
# This is the part which might take more space: storing various
|
||||||
# categorical features for the entries, and storing vectors for disambiguation
|
# categorical features for the entries, and storing vectors for disambiguation
|
||||||
|
@ -87,7 +62,7 @@ cdef class KnowledgeBase:
|
||||||
# model, that embeds different features of the entities into vectors. We'll
|
# model, that embeds different features of the entities into vectors. We'll
|
||||||
# still want some per-entity features, like the Wikipedia text or entity
|
# still want some per-entity features, like the Wikipedia text or entity
|
||||||
# co-occurrence. Hopefully those vectors can be narrow, e.g. 64 dimensions.
|
# co-occurrence. Hopefully those vectors can be narrow, e.g. 64 dimensions.
|
||||||
cdef object _vectors_table
|
cdef float_matrix _vectors_table
|
||||||
|
|
||||||
# It's very useful to track categorical features, at least for output, even
|
# It's very useful to track categorical features, at least for output, even
|
||||||
# if they're not useful in the model itself. For instance, we should be
|
# if they're not useful in the model itself. For instance, we should be
|
||||||
|
@ -96,53 +71,102 @@ cdef class KnowledgeBase:
|
||||||
# optional data, we can let users configure a DB as the backend for this.
|
# optional data, we can let users configure a DB as the backend for this.
|
||||||
cdef object _features_table
|
cdef object _features_table
|
||||||
|
|
||||||
|
|
||||||
|
cdef inline int64_t c_add_vector(self, vector[float] entity_vector) nogil:
|
||||||
|
"""Add an entity vector to the vectors table."""
|
||||||
|
cdef int64_t new_index = self._vectors_table.size()
|
||||||
|
self._vectors_table.push_back(entity_vector)
|
||||||
|
return new_index
|
||||||
|
|
||||||
|
|
||||||
cdef inline int64_t c_add_entity(self, hash_t entity_hash, float prob,
|
cdef inline int64_t c_add_entity(self, hash_t entity_hash, float prob,
|
||||||
int32_t* vector_rows, int feats_row):
|
int32_t vector_index, int feats_row) nogil:
|
||||||
"""Add an entry to the knowledge base."""
|
"""Add an entry to the vector of entries.
|
||||||
# This is what we'll map the hash key to. It's where the entry will sit
|
After calling this method, make sure to update also the _entry_index using the return value"""
|
||||||
|
# This is what we'll map the entity hash key to. It's where the entry will sit
|
||||||
# in the vector of entries, so we can get it later.
|
# in the vector of entries, so we can get it later.
|
||||||
cdef int64_t new_index = self._entries.size()
|
cdef int64_t new_index = self._entries.size()
|
||||||
self._entries.push_back(
|
|
||||||
_EntryC(
|
# Avoid struct initializer to enable nogil, cf https://github.com/cython/cython/issues/1642
|
||||||
entity_hash=entity_hash,
|
cdef KBEntryC entry
|
||||||
vector_rows=vector_rows,
|
entry.entity_hash = entity_hash
|
||||||
feats_row=feats_row,
|
entry.vector_index = vector_index
|
||||||
prob=prob
|
entry.feats_row = feats_row
|
||||||
))
|
entry.prob = prob
|
||||||
self._entry_index[entity_hash] = new_index
|
|
||||||
|
self._entries.push_back(entry)
|
||||||
return new_index
|
return new_index
|
||||||
|
|
||||||
cdef inline int64_t c_add_aliases(self, hash_t alias_hash, vector[int64_t] entry_indices, vector[float] probs):
|
cdef inline int64_t c_add_aliases(self, hash_t alias_hash, vector[int64_t] entry_indices, vector[float] probs) nogil:
|
||||||
"""Connect a mention to a list of potential entities with their prior probabilities ."""
|
"""Connect a mention to a list of potential entities with their prior probabilities .
|
||||||
|
After calling this method, make sure to update also the _alias_index using the return value"""
|
||||||
|
# This is what we'll map the alias hash key to. It's where the alias will be defined
|
||||||
|
# in the vector of aliases.
|
||||||
cdef int64_t new_index = self._aliases_table.size()
|
cdef int64_t new_index = self._aliases_table.size()
|
||||||
|
|
||||||
self._aliases_table.push_back(
|
# Avoid struct initializer to enable nogil
|
||||||
_AliasC(
|
cdef AliasC alias
|
||||||
entry_indices=entry_indices,
|
alias.entry_indices = entry_indices
|
||||||
probs=probs
|
alias.probs = probs
|
||||||
))
|
|
||||||
self._alias_index[alias_hash] = new_index
|
self._aliases_table.push_back(alias)
|
||||||
return new_index
|
return new_index
|
||||||
|
|
||||||
cdef inline _create_empty_vectors(self):
|
cdef inline void _create_empty_vectors(self, hash_t dummy_hash) nogil:
|
||||||
"""
|
"""
|
||||||
Making sure the first element of each vector is a dummy,
|
Initializing the vectors and making sure the first element of each vector is a dummy,
|
||||||
because the PreshMap maps pointing to indices in these vectors can not contain 0 as value
|
because the PreshMap maps pointing to indices in these vectors can not contain 0 as value
|
||||||
cf. https://github.com/explosion/preshed/issues/17
|
cf. https://github.com/explosion/preshed/issues/17
|
||||||
"""
|
"""
|
||||||
cdef int32_t dummy_value = 0
|
cdef int32_t dummy_value = 0
|
||||||
self.vocab.strings.add("")
|
|
||||||
self._entries.push_back(
|
# Avoid struct initializer to enable nogil
|
||||||
_EntryC(
|
cdef KBEntryC entry
|
||||||
entity_hash=self.vocab.strings[""],
|
entry.entity_hash = dummy_hash
|
||||||
vector_rows=&dummy_value,
|
entry.vector_index = dummy_value
|
||||||
feats_row=dummy_value,
|
entry.feats_row = dummy_value
|
||||||
prob=dummy_value
|
entry.prob = dummy_value
|
||||||
))
|
|
||||||
self._aliases_table.push_back(
|
# Avoid struct initializer to enable nogil
|
||||||
_AliasC(
|
cdef vector[int64_t] dummy_entry_indices
|
||||||
entry_indices=[dummy_value],
|
dummy_entry_indices.push_back(0)
|
||||||
probs=[dummy_value]
|
cdef vector[float] dummy_probs
|
||||||
))
|
dummy_probs.push_back(0)
|
||||||
|
|
||||||
|
cdef AliasC alias
|
||||||
|
alias.entry_indices = dummy_entry_indices
|
||||||
|
alias.probs = dummy_probs
|
||||||
|
|
||||||
|
self._entries.push_back(entry)
|
||||||
|
self._aliases_table.push_back(alias)
|
||||||
|
|
||||||
|
cpdef load_bulk(self, loc)
|
||||||
|
cpdef set_entities(self, entity_list, prob_list, vector_list)
|
||||||
|
|
||||||
|
|
||||||
|
cdef class Writer:
|
||||||
|
cdef FILE* _fp
|
||||||
|
|
||||||
|
cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1
|
||||||
|
cdef int write_vector_element(self, float element) except -1
|
||||||
|
cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1
|
||||||
|
|
||||||
|
cdef int write_alias_length(self, int64_t alias_length) except -1
|
||||||
|
cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1
|
||||||
|
cdef int write_alias(self, int64_t entry_index, float prob) except -1
|
||||||
|
|
||||||
|
cdef int _write(self, void* value, size_t size) except -1
|
||||||
|
|
||||||
|
cdef class Reader:
|
||||||
|
cdef FILE* _fp
|
||||||
|
|
||||||
|
cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1
|
||||||
|
cdef int read_vector_element(self, float* element) except -1
|
||||||
|
cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1
|
||||||
|
|
||||||
|
cdef int read_alias_length(self, int64_t* alias_length) except -1
|
||||||
|
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1
|
||||||
|
cdef int read_alias(self, int64_t* entry_index, float* prob) except -1
|
||||||
|
|
||||||
|
cdef int _read(self, void* value, size_t size) except -1
|
||||||
|
|
||||||
|
|
397
spacy/kb.pyx
397
spacy/kb.pyx
|
@ -1,13 +1,30 @@
|
||||||
|
# cython: infer_types=True
|
||||||
# cython: profile=True
|
# cython: profile=True
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from spacy.errors import Errors, Warnings, user_warning
|
from spacy.errors import Errors, Warnings, user_warning
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
from cymem.cymem cimport Pool
|
||||||
|
from preshed.maps cimport PreshMap
|
||||||
|
|
||||||
|
from cpython.exc cimport PyErr_SetFromErrno
|
||||||
|
|
||||||
|
from libc.stdio cimport fopen, fclose, fread, fwrite, feof, fseek
|
||||||
|
from libc.stdint cimport int32_t, int64_t
|
||||||
|
|
||||||
|
from .typedefs cimport hash_t
|
||||||
|
|
||||||
|
from os import path
|
||||||
|
from libcpp.vector cimport vector
|
||||||
|
|
||||||
|
|
||||||
cdef class Candidate:
|
cdef class Candidate:
|
||||||
|
|
||||||
def __init__(self, KnowledgeBase kb, entity_hash, alias_hash, prior_prob):
|
def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob):
|
||||||
self.kb = kb
|
self.kb = kb
|
||||||
self.entity_hash = entity_hash
|
self.entity_hash = entity_hash
|
||||||
|
self.entity_freq = entity_freq
|
||||||
|
self.entity_vector = entity_vector
|
||||||
self.alias_hash = alias_hash
|
self.alias_hash = alias_hash
|
||||||
self.prior_prob = prior_prob
|
self.prior_prob = prior_prob
|
||||||
|
|
||||||
|
@ -19,7 +36,7 @@ cdef class Candidate:
|
||||||
@property
|
@property
|
||||||
def entity_(self):
|
def entity_(self):
|
||||||
"""RETURNS (unicode): ID/name of this entity in the KB"""
|
"""RETURNS (unicode): ID/name of this entity in the KB"""
|
||||||
return self.kb.vocab.strings[self.entity]
|
return self.kb.vocab.strings[self.entity_hash]
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def alias(self):
|
def alias(self):
|
||||||
|
@ -29,7 +46,15 @@ cdef class Candidate:
|
||||||
@property
|
@property
|
||||||
def alias_(self):
|
def alias_(self):
|
||||||
"""RETURNS (unicode): ID of the original alias"""
|
"""RETURNS (unicode): ID of the original alias"""
|
||||||
return self.kb.vocab.strings[self.alias]
|
return self.kb.vocab.strings[self.alias_hash]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entity_freq(self):
|
||||||
|
return self.entity_freq
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entity_vector(self):
|
||||||
|
return self.entity_vector
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def prior_prob(self):
|
def prior_prob(self):
|
||||||
|
@ -38,26 +63,41 @@ cdef class Candidate:
|
||||||
|
|
||||||
cdef class KnowledgeBase:
|
cdef class KnowledgeBase:
|
||||||
|
|
||||||
def __init__(self, Vocab vocab):
|
def __init__(self, Vocab vocab, entity_vector_length):
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
|
self.mem = Pool()
|
||||||
|
self.entity_vector_length = entity_vector_length
|
||||||
|
|
||||||
self._entry_index = PreshMap()
|
self._entry_index = PreshMap()
|
||||||
self._alias_index = PreshMap()
|
self._alias_index = PreshMap()
|
||||||
self.mem = Pool()
|
|
||||||
self._create_empty_vectors()
|
self.vocab.strings.add("")
|
||||||
|
self._create_empty_vectors(dummy_hash=self.vocab.strings[""])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def entity_vector_length(self):
|
||||||
|
"""RETURNS (uint64): length of the entity vectors"""
|
||||||
|
return self.entity_vector_length
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return self.get_size_entities()
|
return self.get_size_entities()
|
||||||
|
|
||||||
def get_size_entities(self):
|
def get_size_entities(self):
|
||||||
return self._entries.size() - 1 # not counting dummy element on index 0
|
return len(self._entry_index)
|
||||||
|
|
||||||
|
def get_entity_strings(self):
|
||||||
|
return [self.vocab.strings[x] for x in self._entry_index]
|
||||||
|
|
||||||
def get_size_aliases(self):
|
def get_size_aliases(self):
|
||||||
return self._aliases_table.size() - 1 # not counting dummy element on index 0
|
return len(self._alias_index)
|
||||||
|
|
||||||
def add_entity(self, unicode entity, float prob=0.5, vectors=None, features=None):
|
def get_alias_strings(self):
|
||||||
|
return [self.vocab.strings[x] for x in self._alias_index]
|
||||||
|
|
||||||
|
def add_entity(self, unicode entity, float prob, vector[float] entity_vector):
|
||||||
"""
|
"""
|
||||||
Add an entity to the KB.
|
Add an entity to the KB, optionally specifying its log probability based on corpus frequency
|
||||||
Return the hash of the entity ID at the end
|
Return the hash of the entity ID/name at the end.
|
||||||
"""
|
"""
|
||||||
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||||
|
|
||||||
|
@ -66,40 +106,72 @@ cdef class KnowledgeBase:
|
||||||
user_warning(Warnings.W018.format(entity=entity))
|
user_warning(Warnings.W018.format(entity=entity))
|
||||||
return
|
return
|
||||||
|
|
||||||
cdef int32_t dummy_value = 342
|
# Raise an error if the provided entity vector is not of the correct length
|
||||||
self.c_add_entity(entity_hash=entity_hash, prob=prob,
|
if len(entity_vector) != self.entity_vector_length:
|
||||||
vector_rows=&dummy_value, feats_row=dummy_value)
|
raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length))
|
||||||
# TODO self._vectors_table.get_pointer(vectors),
|
|
||||||
# self._features_table.get(features))
|
vector_index = self.c_add_vector(entity_vector=entity_vector)
|
||||||
|
|
||||||
|
new_index = self.c_add_entity(entity_hash=entity_hash,
|
||||||
|
prob=prob,
|
||||||
|
vector_index=vector_index,
|
||||||
|
feats_row=-1) # Features table currently not implemented
|
||||||
|
self._entry_index[entity_hash] = new_index
|
||||||
|
|
||||||
return entity_hash
|
return entity_hash
|
||||||
|
|
||||||
|
cpdef set_entities(self, entity_list, prob_list, vector_list):
|
||||||
|
if len(entity_list) != len(prob_list) or len(entity_list) != len(vector_list):
|
||||||
|
raise ValueError(Errors.E140)
|
||||||
|
|
||||||
|
nr_entities = len(entity_list)
|
||||||
|
self._entry_index = PreshMap(nr_entities+1)
|
||||||
|
self._entries = entry_vec(nr_entities+1)
|
||||||
|
|
||||||
|
i = 0
|
||||||
|
cdef KBEntryC entry
|
||||||
|
while i < nr_entities:
|
||||||
|
entity_vector = vector_list[i]
|
||||||
|
if len(entity_vector) != self.entity_vector_length:
|
||||||
|
raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length))
|
||||||
|
|
||||||
|
entity_hash = self.vocab.strings.add(entity_list[i])
|
||||||
|
entry.entity_hash = entity_hash
|
||||||
|
entry.prob = prob_list[i]
|
||||||
|
|
||||||
|
vector_index = self.c_add_vector(entity_vector=vector_list[i])
|
||||||
|
entry.vector_index = vector_index
|
||||||
|
|
||||||
|
entry.feats_row = -1 # Features table currently not implemented
|
||||||
|
|
||||||
|
self._entries[i+1] = entry
|
||||||
|
self._entry_index[entity_hash] = i+1
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
def add_alias(self, unicode alias, entities, probabilities):
|
def add_alias(self, unicode alias, entities, probabilities):
|
||||||
"""
|
"""
|
||||||
For a given alias, add its potential entities and prior probabilies to the KB.
|
For a given alias, add its potential entities and prior probabilies to the KB.
|
||||||
Return the alias_hash at the end
|
Return the alias_hash at the end
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Throw an error if the length of entities and probabilities are not the same
|
# Throw an error if the length of entities and probabilities are not the same
|
||||||
if not len(entities) == len(probabilities):
|
if not len(entities) == len(probabilities):
|
||||||
raise ValueError(Errors.E132.format(alias=alias,
|
raise ValueError(Errors.E132.format(alias=alias,
|
||||||
entities_length=len(entities),
|
entities_length=len(entities),
|
||||||
probabilities_length=len(probabilities)))
|
probabilities_length=len(probabilities)))
|
||||||
|
|
||||||
# Throw an error if the probabilities sum up to more than 1
|
# Throw an error if the probabilities sum up to more than 1 (allow for some rounding errors)
|
||||||
prob_sum = sum(probabilities)
|
prob_sum = sum(probabilities)
|
||||||
if prob_sum > 1:
|
if prob_sum > 1.00001:
|
||||||
raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum))
|
raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum))
|
||||||
|
|
||||||
cdef hash_t alias_hash = self.vocab.strings.add(alias)
|
cdef hash_t alias_hash = self.vocab.strings.add(alias)
|
||||||
|
|
||||||
# Return if this alias was added before
|
# Check whether this alias was added before
|
||||||
if alias_hash in self._alias_index:
|
if alias_hash in self._alias_index:
|
||||||
user_warning(Warnings.W017.format(alias=alias))
|
user_warning(Warnings.W017.format(alias=alias))
|
||||||
return
|
return
|
||||||
|
|
||||||
cdef hash_t entity_hash
|
|
||||||
|
|
||||||
cdef vector[int64_t] entry_indices
|
cdef vector[int64_t] entry_indices
|
||||||
cdef vector[float] probs
|
cdef vector[float] probs
|
||||||
|
|
||||||
|
@ -112,20 +184,295 @@ cdef class KnowledgeBase:
|
||||||
entry_indices.push_back(int(entry_index))
|
entry_indices.push_back(int(entry_index))
|
||||||
probs.push_back(float(prob))
|
probs.push_back(float(prob))
|
||||||
|
|
||||||
self.c_add_aliases(alias_hash=alias_hash, entry_indices=entry_indices, probs=probs)
|
new_index = self.c_add_aliases(alias_hash=alias_hash, entry_indices=entry_indices, probs=probs)
|
||||||
|
self._alias_index[alias_hash] = new_index
|
||||||
|
|
||||||
return alias_hash
|
return alias_hash
|
||||||
|
|
||||||
|
|
||||||
def get_candidates(self, unicode alias):
|
def get_candidates(self, unicode alias):
|
||||||
""" TODO: where to put this functionality ?"""
|
|
||||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||||
alias_entry = self._aliases_table[alias_index]
|
alias_entry = self._aliases_table[alias_index]
|
||||||
|
|
||||||
return [Candidate(kb=self,
|
return [Candidate(kb=self,
|
||||||
entity_hash=self._entries[entry_index].entity_hash,
|
entity_hash=self._entries[entry_index].entity_hash,
|
||||||
|
entity_freq=self._entries[entry_index].prob,
|
||||||
|
entity_vector=self._vectors_table[self._entries[entry_index].vector_index],
|
||||||
alias_hash=alias_hash,
|
alias_hash=alias_hash,
|
||||||
prior_prob=prob)
|
prior_prob=prob)
|
||||||
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs)
|
||||||
if entry_index != 0]
|
if entry_index != 0]
|
||||||
|
|
||||||
|
|
||||||
|
def dump(self, loc):
|
||||||
|
cdef Writer writer = Writer(loc)
|
||||||
|
writer.write_header(self.get_size_entities(), self.entity_vector_length)
|
||||||
|
|
||||||
|
# dumping the entity vectors in their original order
|
||||||
|
i = 0
|
||||||
|
for entity_vector in self._vectors_table:
|
||||||
|
for element in entity_vector:
|
||||||
|
writer.write_vector_element(element)
|
||||||
|
i = i+1
|
||||||
|
|
||||||
|
# dumping the entry records in the order in which they are in the _entries vector.
|
||||||
|
# index 0 is a dummy object not stored in the _entry_index and can be ignored.
|
||||||
|
i = 1
|
||||||
|
for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]):
|
||||||
|
entry = self._entries[entry_index]
|
||||||
|
assert entry.entity_hash == entry_hash
|
||||||
|
assert entry_index == i
|
||||||
|
writer.write_entry(entry.entity_hash, entry.prob, entry.vector_index)
|
||||||
|
i = i+1
|
||||||
|
|
||||||
|
writer.write_alias_length(self.get_size_aliases())
|
||||||
|
|
||||||
|
# dumping the aliases in the order in which they are in the _alias_index vector.
|
||||||
|
# index 0 is a dummy object not stored in the _aliases_table and can be ignored.
|
||||||
|
i = 1
|
||||||
|
for alias_hash, alias_index in sorted(self._alias_index.items(), key=lambda x: x[1]):
|
||||||
|
alias = self._aliases_table[alias_index]
|
||||||
|
assert alias_index == i
|
||||||
|
|
||||||
|
candidate_length = len(alias.entry_indices)
|
||||||
|
writer.write_alias_header(alias_hash, candidate_length)
|
||||||
|
|
||||||
|
for j in range(0, candidate_length):
|
||||||
|
writer.write_alias(alias.entry_indices[j], alias.probs[j])
|
||||||
|
|
||||||
|
i = i+1
|
||||||
|
|
||||||
|
writer.close()
|
||||||
|
|
||||||
|
cpdef load_bulk(self, loc):
|
||||||
|
cdef hash_t entity_hash
|
||||||
|
cdef hash_t alias_hash
|
||||||
|
cdef int64_t entry_index
|
||||||
|
cdef float prob
|
||||||
|
cdef int32_t vector_index
|
||||||
|
cdef KBEntryC entry
|
||||||
|
cdef AliasC alias
|
||||||
|
cdef float vector_element
|
||||||
|
|
||||||
|
cdef Reader reader = Reader(loc)
|
||||||
|
|
||||||
|
# STEP 0: load header and initialize KB
|
||||||
|
cdef int64_t nr_entities
|
||||||
|
cdef int64_t entity_vector_length
|
||||||
|
reader.read_header(&nr_entities, &entity_vector_length)
|
||||||
|
|
||||||
|
self.entity_vector_length = entity_vector_length
|
||||||
|
self._entry_index = PreshMap(nr_entities+1)
|
||||||
|
self._entries = entry_vec(nr_entities+1)
|
||||||
|
self._vectors_table = float_matrix(nr_entities+1)
|
||||||
|
|
||||||
|
# STEP 1: load entity vectors
|
||||||
|
cdef int i = 0
|
||||||
|
cdef int j = 0
|
||||||
|
while i < nr_entities:
|
||||||
|
entity_vector = float_vec(entity_vector_length)
|
||||||
|
j = 0
|
||||||
|
while j < entity_vector_length:
|
||||||
|
reader.read_vector_element(&vector_element)
|
||||||
|
entity_vector[j] = vector_element
|
||||||
|
j = j+1
|
||||||
|
self._vectors_table[i] = entity_vector
|
||||||
|
i = i+1
|
||||||
|
|
||||||
|
# STEP 2: load entities
|
||||||
|
# we assume that the entity data was written in sequence
|
||||||
|
# index 0 is a dummy object not stored in the _entry_index and can be ignored.
|
||||||
|
i = 1
|
||||||
|
while i <= nr_entities:
|
||||||
|
reader.read_entry(&entity_hash, &prob, &vector_index)
|
||||||
|
|
||||||
|
entry.entity_hash = entity_hash
|
||||||
|
entry.prob = prob
|
||||||
|
entry.vector_index = vector_index
|
||||||
|
entry.feats_row = -1 # Features table currently not implemented
|
||||||
|
|
||||||
|
self._entries[i] = entry
|
||||||
|
self._entry_index[entity_hash] = i
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# check that all entities were read in properly
|
||||||
|
assert nr_entities == self.get_size_entities()
|
||||||
|
|
||||||
|
# STEP 3: load aliases
|
||||||
|
|
||||||
|
cdef int64_t nr_aliases
|
||||||
|
reader.read_alias_length(&nr_aliases)
|
||||||
|
self._alias_index = PreshMap(nr_aliases+1)
|
||||||
|
self._aliases_table = alias_vec(nr_aliases+1)
|
||||||
|
|
||||||
|
cdef int64_t nr_candidates
|
||||||
|
cdef vector[int64_t] entry_indices
|
||||||
|
cdef vector[float] probs
|
||||||
|
|
||||||
|
i = 1
|
||||||
|
# we assume the alias data was written in sequence
|
||||||
|
# index 0 is a dummy object not stored in the _entry_index and can be ignored.
|
||||||
|
while i <= nr_aliases:
|
||||||
|
reader.read_alias_header(&alias_hash, &nr_candidates)
|
||||||
|
entry_indices = vector[int64_t](nr_candidates)
|
||||||
|
probs = vector[float](nr_candidates)
|
||||||
|
|
||||||
|
for j in range(0, nr_candidates):
|
||||||
|
reader.read_alias(&entry_index, &prob)
|
||||||
|
entry_indices[j] = entry_index
|
||||||
|
probs[j] = prob
|
||||||
|
|
||||||
|
alias.entry_indices = entry_indices
|
||||||
|
alias.probs = probs
|
||||||
|
|
||||||
|
self._aliases_table[i] = alias
|
||||||
|
self._alias_index[alias_hash] = i
|
||||||
|
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
# check that all aliases were read in properly
|
||||||
|
assert nr_aliases == self.get_size_aliases()
|
||||||
|
|
||||||
|
|
||||||
|
cdef class Writer:
|
||||||
|
def __init__(self, object loc):
|
||||||
|
if path.exists(loc):
|
||||||
|
assert not path.isdir(loc), "%s is directory." % loc
|
||||||
|
if isinstance(loc, Path):
|
||||||
|
loc = bytes(loc)
|
||||||
|
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||||
|
self._fp = fopen(<char*>bytes_loc, 'wb')
|
||||||
|
assert self._fp != NULL
|
||||||
|
fseek(self._fp, 0, 0)
|
||||||
|
|
||||||
|
def close(self):
|
||||||
|
cdef size_t status = fclose(self._fp)
|
||||||
|
assert status == 0
|
||||||
|
|
||||||
|
cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1:
|
||||||
|
self._write(&nr_entries, sizeof(nr_entries))
|
||||||
|
self._write(&entity_vector_length, sizeof(entity_vector_length))
|
||||||
|
|
||||||
|
cdef int write_vector_element(self, float element) except -1:
|
||||||
|
self._write(&element, sizeof(element))
|
||||||
|
|
||||||
|
cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1:
|
||||||
|
self._write(&entry_hash, sizeof(entry_hash))
|
||||||
|
self._write(&entry_prob, sizeof(entry_prob))
|
||||||
|
self._write(&vector_index, sizeof(vector_index))
|
||||||
|
# Features table currently not implemented and not written to file
|
||||||
|
|
||||||
|
cdef int write_alias_length(self, int64_t alias_length) except -1:
|
||||||
|
self._write(&alias_length, sizeof(alias_length))
|
||||||
|
|
||||||
|
cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1:
|
||||||
|
self._write(&alias_hash, sizeof(alias_hash))
|
||||||
|
self._write(&candidate_length, sizeof(candidate_length))
|
||||||
|
|
||||||
|
cdef int write_alias(self, int64_t entry_index, float prob) except -1:
|
||||||
|
self._write(&entry_index, sizeof(entry_index))
|
||||||
|
self._write(&prob, sizeof(prob))
|
||||||
|
|
||||||
|
cdef int _write(self, void* value, size_t size) except -1:
|
||||||
|
status = fwrite(value, size, 1, self._fp)
|
||||||
|
assert status == 1, status
|
||||||
|
|
||||||
|
|
||||||
|
cdef class Reader:
|
||||||
|
def __init__(self, object loc):
|
||||||
|
assert path.exists(loc)
|
||||||
|
assert not path.isdir(loc)
|
||||||
|
if isinstance(loc, Path):
|
||||||
|
loc = bytes(loc)
|
||||||
|
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||||
|
self._fp = fopen(<char*>bytes_loc, 'rb')
|
||||||
|
if not self._fp:
|
||||||
|
PyErr_SetFromErrno(IOError)
|
||||||
|
status = fseek(self._fp, 0, 0) # this can be 0 if there is no header
|
||||||
|
|
||||||
|
def __dealloc__(self):
|
||||||
|
fclose(self._fp)
|
||||||
|
|
||||||
|
cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1:
|
||||||
|
status = self._read(nr_entries, sizeof(int64_t))
|
||||||
|
if status < 1:
|
||||||
|
if feof(self._fp):
|
||||||
|
return 0 # end of file
|
||||||
|
raise IOError("error reading header from input file")
|
||||||
|
|
||||||
|
status = self._read(entity_vector_length, sizeof(int64_t))
|
||||||
|
if status < 1:
|
||||||
|
if feof(self._fp):
|
||||||
|
return 0 # end of file
|
||||||
|
raise IOError("error reading header from input file")
|
||||||
|
|
||||||
|
cdef int read_vector_element(self, float* element) except -1:
|
||||||
|
status = self._read(element, sizeof(float))
|
||||||
|
if status < 1:
|
||||||
|
if feof(self._fp):
|
||||||
|
return 0 # end of file
|
||||||
|
raise IOError("error reading entity vector from input file")
|
||||||
|
|
||||||
|
cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1:
|
||||||
|
status = self._read(entity_hash, sizeof(hash_t))
|
||||||
|
if status < 1:
|
||||||
|
if feof(self._fp):
|
||||||
|
return 0 # end of file
|
||||||
|
raise IOError("error reading entity hash from input file")
|
||||||
|
|
||||||
|
status = self._read(prob, sizeof(float))
|
||||||
|
if status < 1:
|
||||||
|
if feof(self._fp):
|
||||||
|
return 0 # end of file
|
||||||
|
raise IOError("error reading entity prob from input file")
|
||||||
|
|
||||||
|
status = self._read(vector_index, sizeof(int32_t))
|
||||||
|
if status < 1:
|
||||||
|
if feof(self._fp):
|
||||||
|
return 0 # end of file
|
||||||
|
raise IOError("error reading entity vector from input file")
|
||||||
|
|
||||||
|
if feof(self._fp):
|
||||||
|
return 0
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
cdef int read_alias_length(self, int64_t* alias_length) except -1:
|
||||||
|
status = self._read(alias_length, sizeof(int64_t))
|
||||||
|
if status < 1:
|
||||||
|
if feof(self._fp):
|
||||||
|
return 0 # end of file
|
||||||
|
raise IOError("error reading alias length from input file")
|
||||||
|
|
||||||
|
cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1:
|
||||||
|
status = self._read(alias_hash, sizeof(hash_t))
|
||||||
|
if status < 1:
|
||||||
|
if feof(self._fp):
|
||||||
|
return 0 # end of file
|
||||||
|
raise IOError("error reading alias hash from input file")
|
||||||
|
|
||||||
|
status = self._read(candidate_length, sizeof(int64_t))
|
||||||
|
if status < 1:
|
||||||
|
if feof(self._fp):
|
||||||
|
return 0 # end of file
|
||||||
|
raise IOError("error reading candidate length from input file")
|
||||||
|
|
||||||
|
cdef int read_alias(self, int64_t* entry_index, float* prob) except -1:
|
||||||
|
status = self._read(entry_index, sizeof(int64_t))
|
||||||
|
if status < 1:
|
||||||
|
if feof(self._fp):
|
||||||
|
return 0 # end of file
|
||||||
|
raise IOError("error reading entry index for alias from input file")
|
||||||
|
|
||||||
|
status = self._read(prob, sizeof(float))
|
||||||
|
if status < 1:
|
||||||
|
if feof(self._fp):
|
||||||
|
return 0 # end of file
|
||||||
|
raise IOError("error reading prob for entity/alias from input file")
|
||||||
|
|
||||||
|
cdef int _read(self, void* value, size_t size) except -1:
|
||||||
|
status = fread(value, size, 1, self._fp)
|
||||||
|
return status
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -3,16 +3,18 @@
|
||||||
# coding: utf8
|
# coding: utf8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
cimport numpy as np
|
|
||||||
|
|
||||||
import numpy
|
import numpy
|
||||||
import srsly
|
import srsly
|
||||||
|
import random
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from thinc.api import chain
|
from thinc.api import chain
|
||||||
from thinc.v2v import Affine, Maxout, Softmax
|
from thinc.v2v import Affine, Maxout, Softmax
|
||||||
from thinc.misc import LayerNorm
|
from thinc.misc import LayerNorm
|
||||||
from thinc.neural.util import to_categorical, copy_array
|
from thinc.neural.util import to_categorical
|
||||||
|
from thinc.neural.util import get_array_module
|
||||||
|
|
||||||
|
from spacy.kb import KnowledgeBase
|
||||||
|
from ..cli.pretrain import get_cossim_loss
|
||||||
from .functions import merge_subtokens
|
from .functions import merge_subtokens
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..syntax.nn_parser cimport Parser
|
from ..syntax.nn_parser cimport Parser
|
||||||
|
@ -24,9 +26,9 @@ from ..vocab cimport Vocab
|
||||||
from ..syntax import nonproj
|
from ..syntax import nonproj
|
||||||
from ..attrs import POS, ID
|
from ..attrs import POS, ID
|
||||||
from ..parts_of_speech import X
|
from ..parts_of_speech import X
|
||||||
from .._ml import Tok2Vec, build_tagger_model
|
from .._ml import Tok2Vec, build_tagger_model, cosine
|
||||||
from .._ml import build_text_classifier, build_simple_cnn_text_classifier
|
from .._ml import build_text_classifier, build_simple_cnn_text_classifier
|
||||||
from .._ml import build_bow_text_classifier
|
from .._ml import build_bow_text_classifier, build_nel_encoder
|
||||||
from .._ml import link_vectors_to_models, zero_init, flatten
|
from .._ml import link_vectors_to_models, zero_init, flatten
|
||||||
from .._ml import masked_language_model, create_default_optimizer
|
from .._ml import masked_language_model, create_default_optimizer
|
||||||
from ..errors import Errors, TempErrors
|
from ..errors import Errors, TempErrors
|
||||||
|
@ -229,7 +231,7 @@ class Tensorizer(Pipe):
|
||||||
|
|
||||||
vocab (Vocab): A `Vocab` instance. The model must share the same
|
vocab (Vocab): A `Vocab` instance. The model must share the same
|
||||||
`Vocab` instance with the `Doc` objects it will process.
|
`Vocab` instance with the `Doc` objects it will process.
|
||||||
model (Model): A `Model` instance or `True` allocate one later.
|
model (Model): A `Model` instance or `True` to allocate one later.
|
||||||
**cfg: Config parameters.
|
**cfg: Config parameters.
|
||||||
|
|
||||||
EXAMPLE:
|
EXAMPLE:
|
||||||
|
@ -294,7 +296,7 @@ class Tensorizer(Pipe):
|
||||||
|
|
||||||
docs (iterable): A batch of `Doc` objects.
|
docs (iterable): A batch of `Doc` objects.
|
||||||
golds (iterable): A batch of `GoldParse` objects.
|
golds (iterable): A batch of `GoldParse` objects.
|
||||||
drop (float): The droput rate.
|
drop (float): The dropout rate.
|
||||||
sgd (callable): An optimizer.
|
sgd (callable): An optimizer.
|
||||||
RETURNS (dict): Results from the update.
|
RETURNS (dict): Results from the update.
|
||||||
"""
|
"""
|
||||||
|
@ -386,7 +388,7 @@ class Tagger(Pipe):
|
||||||
def predict(self, docs):
|
def predict(self, docs):
|
||||||
self.require_model()
|
self.require_model()
|
||||||
if not any(len(doc) for doc in docs):
|
if not any(len(doc) for doc in docs):
|
||||||
# Handle case where there are no tokens in any docs.
|
# Handle cases where there are no tokens in any docs.
|
||||||
n_labels = len(self.labels)
|
n_labels = len(self.labels)
|
||||||
guesses = [self.model.ops.allocate((0, n_labels)) for doc in docs]
|
guesses = [self.model.ops.allocate((0, n_labels)) for doc in docs]
|
||||||
tokvecs = self.model.ops.allocate((0, self.model.tok2vec.nO))
|
tokvecs = self.model.ops.allocate((0, self.model.tok2vec.nO))
|
||||||
|
@ -1063,52 +1065,252 @@ cdef class EntityRecognizer(Parser):
|
||||||
|
|
||||||
|
|
||||||
class EntityLinker(Pipe):
|
class EntityLinker(Pipe):
|
||||||
|
"""Pipeline component for named entity linking.
|
||||||
|
|
||||||
|
DOCS: TODO
|
||||||
|
"""
|
||||||
name = 'entity_linker'
|
name = 'entity_linker'
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def Model(cls, nr_class=1, **cfg):
|
def Model(cls, **cfg):
|
||||||
# TODO: non-dummy EL implementation
|
embed_width = cfg.get("embed_width", 300)
|
||||||
return None
|
hidden_width = cfg.get("hidden_width", 128)
|
||||||
|
type_to_int = cfg.get("type_to_int", dict())
|
||||||
|
|
||||||
def __init__(self, model=True, **cfg):
|
model = build_nel_encoder(embed_width=embed_width, hidden_width=hidden_width, ner_types=len(type_to_int), **cfg)
|
||||||
self.model = False
|
return model
|
||||||
|
|
||||||
|
def __init__(self, vocab, **cfg):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.model = True
|
||||||
|
self.kb = None
|
||||||
self.cfg = dict(cfg)
|
self.cfg = dict(cfg)
|
||||||
self.kb = self.cfg["kb"]
|
self.sgd_context = None
|
||||||
|
|
||||||
|
def set_kb(self, kb):
|
||||||
|
self.kb = kb
|
||||||
|
|
||||||
|
def require_model(self):
|
||||||
|
# Raise an error if the component's model is not initialized.
|
||||||
|
if getattr(self, "model", None) in (None, True, False):
|
||||||
|
raise ValueError(Errors.E109.format(name=self.name))
|
||||||
|
|
||||||
|
def require_kb(self):
|
||||||
|
# Raise an error if the knowledge base is not initialized.
|
||||||
|
if getattr(self, "kb", None) in (None, True, False):
|
||||||
|
raise ValueError(Errors.E139.format(name=self.name))
|
||||||
|
|
||||||
|
def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs):
|
||||||
|
self.require_kb()
|
||||||
|
self.cfg["entity_width"] = self.kb.entity_vector_length
|
||||||
|
|
||||||
|
if self.model is True:
|
||||||
|
self.model = self.Model(**self.cfg)
|
||||||
|
self.sgd_context = self.create_optimizer()
|
||||||
|
|
||||||
|
if sgd is None:
|
||||||
|
sgd = self.create_optimizer()
|
||||||
|
|
||||||
|
return sgd
|
||||||
|
|
||||||
|
def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None):
|
||||||
|
self.require_model()
|
||||||
|
self.require_kb()
|
||||||
|
|
||||||
|
if losses is not None:
|
||||||
|
losses.setdefault(self.name, 0.0)
|
||||||
|
|
||||||
|
if not docs or not golds:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
if len(docs) != len(golds):
|
||||||
|
raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs),
|
||||||
|
n_golds=len(golds)))
|
||||||
|
|
||||||
|
if isinstance(docs, Doc):
|
||||||
|
docs = [docs]
|
||||||
|
golds = [golds]
|
||||||
|
|
||||||
|
context_docs = []
|
||||||
|
entity_encodings = []
|
||||||
|
cats = []
|
||||||
|
priors = []
|
||||||
|
type_vectors = []
|
||||||
|
|
||||||
|
type_to_int = self.cfg.get("type_to_int", dict())
|
||||||
|
|
||||||
|
for doc, gold in zip(docs, golds):
|
||||||
|
ents_by_offset = dict()
|
||||||
|
for ent in doc.ents:
|
||||||
|
ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent
|
||||||
|
for entity in gold.links:
|
||||||
|
start, end, gold_kb = entity
|
||||||
|
mention = doc.text[start:end]
|
||||||
|
|
||||||
|
gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)]
|
||||||
|
assert gold_ent is not None
|
||||||
|
type_vector = [0 for i in range(len(type_to_int))]
|
||||||
|
if len(type_to_int) > 0:
|
||||||
|
type_vector[type_to_int[gold_ent.label_]] = 1
|
||||||
|
|
||||||
|
candidates = self.kb.get_candidates(mention)
|
||||||
|
random.shuffle(candidates)
|
||||||
|
nr_neg = 0
|
||||||
|
for c in candidates:
|
||||||
|
kb_id = c.entity_
|
||||||
|
entity_encoding = c.entity_vector
|
||||||
|
entity_encodings.append(entity_encoding)
|
||||||
|
context_docs.append(doc)
|
||||||
|
type_vectors.append(type_vector)
|
||||||
|
|
||||||
|
if self.cfg.get("prior_weight", 1) > 0:
|
||||||
|
priors.append([c.prior_prob])
|
||||||
|
else:
|
||||||
|
priors.append([0])
|
||||||
|
|
||||||
|
if kb_id == gold_kb:
|
||||||
|
cats.append([1])
|
||||||
|
else:
|
||||||
|
nr_neg += 1
|
||||||
|
cats.append([0])
|
||||||
|
|
||||||
|
if len(entity_encodings) > 0:
|
||||||
|
assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors)
|
||||||
|
|
||||||
|
context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop)
|
||||||
|
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||||
|
|
||||||
|
mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i]
|
||||||
|
for i in range(len(entity_encodings))]
|
||||||
|
pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop)
|
||||||
|
cats = self.model.ops.asarray(cats, dtype="float32")
|
||||||
|
|
||||||
|
loss, d_scores = self.get_loss(prediction=pred, golds=cats, docs=None)
|
||||||
|
mention_gradient = bp_mention(d_scores, sgd=sgd)
|
||||||
|
|
||||||
|
context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient]
|
||||||
|
bp_context(self.model.ops.asarray(context_gradients, dtype="float32"), sgd=self.sgd_context)
|
||||||
|
|
||||||
|
if losses is not None:
|
||||||
|
losses[self.name] += loss
|
||||||
|
return loss
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def get_loss(self, docs, golds, prediction):
|
||||||
|
d_scores = (prediction - golds)
|
||||||
|
loss = (d_scores ** 2).sum()
|
||||||
|
loss = loss / len(golds)
|
||||||
|
return loss, d_scores
|
||||||
|
|
||||||
|
def get_loss_old(self, docs, golds, scores):
|
||||||
|
# this loss function assumes we're only using positive examples
|
||||||
|
loss, gradients = get_cossim_loss(yh=scores, y=golds)
|
||||||
|
loss = loss / len(golds)
|
||||||
|
return loss, gradients
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
self.set_annotations([doc], scores=None, tensors=None)
|
entities, kb_ids = self.predict([doc])
|
||||||
|
self.set_annotations([doc], entities, kb_ids)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def pipe(self, stream, batch_size=128, n_threads=-1):
|
def pipe(self, stream, batch_size=128, n_threads=-1):
|
||||||
"""Apply the pipe to a stream of documents.
|
|
||||||
Both __call__ and pipe should delegate to the `predict()`
|
|
||||||
and `set_annotations()` methods.
|
|
||||||
"""
|
|
||||||
for docs in util.minibatch(stream, size=batch_size):
|
for docs in util.minibatch(stream, size=batch_size):
|
||||||
docs = list(docs)
|
docs = list(docs)
|
||||||
self.set_annotations(docs, scores=None, tensors=None)
|
entities, kb_ids = self.predict(docs)
|
||||||
|
self.set_annotations(docs, entities, kb_ids)
|
||||||
yield from docs
|
yield from docs
|
||||||
|
|
||||||
def set_annotations(self, docs, scores, tensors=None):
|
def predict(self, docs):
|
||||||
"""
|
self.require_model()
|
||||||
Currently implemented as taking the KB entry with highest prior probability for each named entity
|
self.require_kb()
|
||||||
TODO: actually use context etc
|
|
||||||
"""
|
|
||||||
for i, doc in enumerate(docs):
|
|
||||||
for ent in doc.ents:
|
|
||||||
candidates = self.kb.get_candidates(ent.text)
|
|
||||||
if candidates:
|
|
||||||
best_candidate = max(candidates, key=lambda c: c.prior_prob)
|
|
||||||
for token in ent:
|
|
||||||
token.ent_kb_id_ = best_candidate.entity_
|
|
||||||
|
|
||||||
def get_loss(self, docs, golds, scores):
|
final_entities = []
|
||||||
# TODO
|
final_kb_ids = []
|
||||||
pass
|
|
||||||
|
if not docs:
|
||||||
|
return final_entities, final_kb_ids
|
||||||
|
|
||||||
|
if isinstance(docs, Doc):
|
||||||
|
docs = [docs]
|
||||||
|
|
||||||
|
context_encodings = self.model.tok2vec(docs)
|
||||||
|
xp = get_array_module(context_encodings)
|
||||||
|
|
||||||
|
type_to_int = self.cfg.get("type_to_int", dict())
|
||||||
|
|
||||||
|
for i, doc in enumerate(docs):
|
||||||
|
if len(doc) > 0:
|
||||||
|
context_encoding = context_encodings[i]
|
||||||
|
for ent in doc.ents:
|
||||||
|
type_vector = [0 for i in range(len(type_to_int))]
|
||||||
|
if len(type_to_int) > 0:
|
||||||
|
type_vector[type_to_int[ent.label_]] = 1
|
||||||
|
|
||||||
|
candidates = self.kb.get_candidates(ent.text)
|
||||||
|
if candidates:
|
||||||
|
random.shuffle(candidates)
|
||||||
|
|
||||||
|
# this will set the prior probabilities to 0 (just like in training) if their weight is 0
|
||||||
|
prior_probs = xp.asarray([[c.prior_prob] for c in candidates])
|
||||||
|
prior_probs *= self.cfg.get("prior_weight", 1)
|
||||||
|
scores = prior_probs
|
||||||
|
|
||||||
|
if self.cfg.get("context_weight", 1) > 0:
|
||||||
|
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
||||||
|
assert len(entity_encodings) == len(prior_probs)
|
||||||
|
mention_encodings = [list(context_encoding) + list(entity_encodings[i])
|
||||||
|
+ list(prior_probs[i]) + type_vector
|
||||||
|
for i in range(len(entity_encodings))]
|
||||||
|
scores = self.model(self.model.ops.asarray(mention_encodings, dtype="float32"))
|
||||||
|
|
||||||
|
# TODO: thresholding
|
||||||
|
best_index = scores.argmax()
|
||||||
|
best_candidate = candidates[best_index]
|
||||||
|
final_entities.append(ent)
|
||||||
|
final_kb_ids.append(best_candidate.entity_)
|
||||||
|
|
||||||
|
return final_entities, final_kb_ids
|
||||||
|
|
||||||
|
def set_annotations(self, docs, entities, kb_ids=None):
|
||||||
|
for entity, kb_id in zip(entities, kb_ids):
|
||||||
|
for token in entity:
|
||||||
|
token.ent_kb_id_ = kb_id
|
||||||
|
|
||||||
|
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||||
|
serialize = OrderedDict()
|
||||||
|
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||||
|
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
|
||||||
|
serialize["kb"] = lambda p: self.kb.dump(p)
|
||||||
|
if self.model not in (None, True, False):
|
||||||
|
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes())
|
||||||
|
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
|
||||||
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
|
def from_disk(self, path, exclude=tuple(), **kwargs):
|
||||||
|
def load_model(p):
|
||||||
|
if self.model is True:
|
||||||
|
self.model = self.Model(**self.cfg)
|
||||||
|
self.model.from_bytes(p.open("rb").read())
|
||||||
|
|
||||||
|
def load_kb(p):
|
||||||
|
kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"])
|
||||||
|
kb.load_bulk(p)
|
||||||
|
self.set_kb(kb)
|
||||||
|
|
||||||
|
deserialize = OrderedDict()
|
||||||
|
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
|
||||||
|
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
|
||||||
|
deserialize["kb"] = load_kb
|
||||||
|
deserialize["model"] = load_model
|
||||||
|
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
|
||||||
|
util.from_disk(path, deserialize, exclude)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def rehearse(self, docs, sgd=None, losses=None, **config):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def add_label(self, label):
|
def add_label(self, label):
|
||||||
# TODO
|
raise NotImplementedError
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class Sentencizer(object):
|
class Sentencizer(object):
|
||||||
|
|
|
@ -3,6 +3,10 @@ from libc.stdint cimport uint8_t, uint32_t, int32_t, uint64_t
|
||||||
from .typedefs cimport flags_t, attr_t, hash_t
|
from .typedefs cimport flags_t, attr_t, hash_t
|
||||||
from .parts_of_speech cimport univ_pos_t
|
from .parts_of_speech cimport univ_pos_t
|
||||||
|
|
||||||
|
from libcpp.vector cimport vector
|
||||||
|
from libc.stdint cimport int32_t, int64_t
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef struct LexemeC:
|
cdef struct LexemeC:
|
||||||
flags_t flags
|
flags_t flags
|
||||||
|
@ -72,3 +76,32 @@ cdef struct TokenC:
|
||||||
attr_t ent_type # TODO: Is there a better way to do this? Multiple sources of truth..
|
attr_t ent_type # TODO: Is there a better way to do this? Multiple sources of truth..
|
||||||
attr_t ent_kb_id
|
attr_t ent_kb_id
|
||||||
hash_t ent_id
|
hash_t ent_id
|
||||||
|
|
||||||
|
|
||||||
|
# Internal struct, for storage and disambiguation of entities.
|
||||||
|
cdef struct KBEntryC:
|
||||||
|
|
||||||
|
# The hash of this entry's unique ID/name in the kB
|
||||||
|
hash_t entity_hash
|
||||||
|
|
||||||
|
# Allows retrieval of the entity vector, as an index into a vectors table of the KB.
|
||||||
|
# Can be expanded later to refer to multiple rows (compositional model to reduce storage footprint).
|
||||||
|
int32_t vector_index
|
||||||
|
|
||||||
|
# Allows retrieval of a struct of non-vector features.
|
||||||
|
# This is currently not implemented and set to -1 for the common case where there are no features.
|
||||||
|
int32_t feats_row
|
||||||
|
|
||||||
|
# log probability of entity, based on corpus frequency
|
||||||
|
float prob
|
||||||
|
|
||||||
|
|
||||||
|
# Each alias struct stores a list of Entry pointers with their prior probabilities
|
||||||
|
# for this specific mention/alias.
|
||||||
|
cdef struct AliasC:
|
||||||
|
|
||||||
|
# All entry candidates for this alias
|
||||||
|
vector[int64_t] entry_indices
|
||||||
|
|
||||||
|
# Prior probability P(entity|alias) - should sum up to (at most) 1.
|
||||||
|
vector[float] probs
|
||||||
|
|
|
@ -81,6 +81,7 @@ cdef enum symbol_t:
|
||||||
DEP
|
DEP
|
||||||
ENT_IOB
|
ENT_IOB
|
||||||
ENT_TYPE
|
ENT_TYPE
|
||||||
|
ENT_KB_ID
|
||||||
HEAD
|
HEAD
|
||||||
SENT_START
|
SENT_START
|
||||||
SPACY
|
SPACY
|
||||||
|
|
|
@ -86,6 +86,7 @@ IDS = {
|
||||||
"DEP": DEP,
|
"DEP": DEP,
|
||||||
"ENT_IOB": ENT_IOB,
|
"ENT_IOB": ENT_IOB,
|
||||||
"ENT_TYPE": ENT_TYPE,
|
"ENT_TYPE": ENT_TYPE,
|
||||||
|
"ENT_KB_ID": ENT_KB_ID,
|
||||||
"HEAD": HEAD,
|
"HEAD": HEAD,
|
||||||
"SENT_START": SENT_START,
|
"SENT_START": SENT_START,
|
||||||
"SPACY": SPACY,
|
"SPACY": SPACY,
|
||||||
|
|
|
@ -1,91 +0,0 @@
|
||||||
# coding: utf-8
|
|
||||||
from __future__ import unicode_literals
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from spacy.kb import KnowledgeBase
|
|
||||||
from spacy.lang.en import English
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def nlp():
|
|
||||||
return English()
|
|
||||||
|
|
||||||
|
|
||||||
def test_kb_valid_entities(nlp):
|
|
||||||
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
|
||||||
mykb = KnowledgeBase(nlp.vocab)
|
|
||||||
|
|
||||||
# adding entities
|
|
||||||
mykb.add_entity(entity=u'Q1', prob=0.9)
|
|
||||||
mykb.add_entity(entity=u'Q2')
|
|
||||||
mykb.add_entity(entity=u'Q3', prob=0.5)
|
|
||||||
|
|
||||||
# adding aliases
|
|
||||||
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.2])
|
|
||||||
mykb.add_alias(alias=u'adam', entities=[u'Q2'], probabilities=[0.9])
|
|
||||||
|
|
||||||
# test the size of the corresponding KB
|
|
||||||
assert(mykb.get_size_entities() == 3)
|
|
||||||
assert(mykb.get_size_aliases() == 2)
|
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_entities(nlp):
|
|
||||||
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
|
||||||
mykb = KnowledgeBase(nlp.vocab)
|
|
||||||
|
|
||||||
# adding entities
|
|
||||||
mykb.add_entity(entity=u'Q1', prob=0.9)
|
|
||||||
mykb.add_entity(entity=u'Q2', prob=0.2)
|
|
||||||
mykb.add_entity(entity=u'Q3', prob=0.5)
|
|
||||||
|
|
||||||
# adding aliases - should fail because one of the given IDs is not valid
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q342'], probabilities=[0.8, 0.2])
|
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_probabilities(nlp):
|
|
||||||
"""Test the invalid construction of a KB with wrong prior probabilities"""
|
|
||||||
mykb = KnowledgeBase(nlp.vocab)
|
|
||||||
|
|
||||||
# adding entities
|
|
||||||
mykb.add_entity(entity=u'Q1', prob=0.9)
|
|
||||||
mykb.add_entity(entity=u'Q2', prob=0.2)
|
|
||||||
mykb.add_entity(entity=u'Q3', prob=0.5)
|
|
||||||
|
|
||||||
# adding aliases - should fail because the sum of the probabilities exceeds 1
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.4])
|
|
||||||
|
|
||||||
|
|
||||||
def test_kb_invalid_combination(nlp):
|
|
||||||
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
|
|
||||||
mykb = KnowledgeBase(nlp.vocab)
|
|
||||||
|
|
||||||
# adding entities
|
|
||||||
mykb.add_entity(entity=u'Q1', prob=0.9)
|
|
||||||
mykb.add_entity(entity=u'Q2', prob=0.2)
|
|
||||||
mykb.add_entity(entity=u'Q3', prob=0.5)
|
|
||||||
|
|
||||||
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
|
|
||||||
with pytest.raises(ValueError):
|
|
||||||
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.3, 0.4, 0.1])
|
|
||||||
|
|
||||||
|
|
||||||
def test_candidate_generation(nlp):
|
|
||||||
"""Test correct candidate generation"""
|
|
||||||
mykb = KnowledgeBase(nlp.vocab)
|
|
||||||
|
|
||||||
# adding entities
|
|
||||||
mykb.add_entity(entity=u'Q1', prob=0.9)
|
|
||||||
mykb.add_entity(entity=u'Q2', prob=0.2)
|
|
||||||
mykb.add_entity(entity=u'Q3', prob=0.5)
|
|
||||||
|
|
||||||
# adding aliases
|
|
||||||
mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.2])
|
|
||||||
mykb.add_alias(alias=u'adam', entities=[u'Q2'], probabilities=[0.9])
|
|
||||||
|
|
||||||
# test the size of the relevant candidates
|
|
||||||
assert(len(mykb.get_candidates(u'douglas')) == 2)
|
|
||||||
assert(len(mykb.get_candidates(u'adam')) == 1)
|
|
||||||
assert(len(mykb.get_candidates(u'shrubbery')) == 0)
|
|
145
spacy/tests/pipeline/test_entity_linker.py
Normal file
145
spacy/tests/pipeline/test_entity_linker.py
Normal file
|
@ -0,0 +1,145 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from spacy.kb import KnowledgeBase
|
||||||
|
from spacy.lang.en import English
|
||||||
|
from spacy.pipeline import EntityRuler
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def nlp():
|
||||||
|
return English()
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_valid_entities(nlp):
|
||||||
|
"""Test the valid construction of a KB with 3 entities and two aliases"""
|
||||||
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
|
# adding entities
|
||||||
|
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||||
|
mykb.add_entity(entity='Q2', prob=0.5, entity_vector=[2])
|
||||||
|
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||||
|
|
||||||
|
# adding aliases
|
||||||
|
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
|
||||||
|
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
|
||||||
|
|
||||||
|
# test the size of the corresponding KB
|
||||||
|
assert(mykb.get_size_entities() == 3)
|
||||||
|
assert(mykb.get_size_aliases() == 2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_invalid_entities(nlp):
|
||||||
|
"""Test the invalid construction of a KB with an alias linked to a non-existing entity"""
|
||||||
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
|
# adding entities
|
||||||
|
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||||
|
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||||
|
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||||
|
|
||||||
|
# adding aliases - should fail because one of the given IDs is not valid
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
mykb.add_alias(alias='douglas', entities=['Q2', 'Q342'], probabilities=[0.8, 0.2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_invalid_probabilities(nlp):
|
||||||
|
"""Test the invalid construction of a KB with wrong prior probabilities"""
|
||||||
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
|
# adding entities
|
||||||
|
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||||
|
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||||
|
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||||
|
|
||||||
|
# adding aliases - should fail because the sum of the probabilities exceeds 1
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.4])
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_invalid_combination(nlp):
|
||||||
|
"""Test the invalid construction of a KB with non-matching entity and probability lists"""
|
||||||
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
|
# adding entities
|
||||||
|
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||||
|
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||||
|
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||||
|
|
||||||
|
# adding aliases - should fail because the entities and probabilities vectors are not of equal length
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.3, 0.4, 0.1])
|
||||||
|
|
||||||
|
|
||||||
|
def test_kb_invalid_entity_vector(nlp):
|
||||||
|
"""Test the invalid construction of a KB with non-matching entity vector lengths"""
|
||||||
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3)
|
||||||
|
|
||||||
|
# adding entities
|
||||||
|
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1, 2, 3])
|
||||||
|
|
||||||
|
# this should fail because the kb's expected entity vector length is 3
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||||
|
|
||||||
|
|
||||||
|
def test_candidate_generation(nlp):
|
||||||
|
"""Test correct candidate generation"""
|
||||||
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
|
# adding entities
|
||||||
|
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||||
|
mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2])
|
||||||
|
mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3])
|
||||||
|
|
||||||
|
# adding aliases
|
||||||
|
mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2])
|
||||||
|
mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9])
|
||||||
|
|
||||||
|
# test the size of the relevant candidates
|
||||||
|
assert(len(mykb.get_candidates('douglas')) == 2)
|
||||||
|
assert(len(mykb.get_candidates('adam')) == 1)
|
||||||
|
assert(len(mykb.get_candidates('shrubbery')) == 0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_preserving_links_asdoc(nlp):
|
||||||
|
"""Test that Span.as_doc preserves the existing entity links"""
|
||||||
|
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||||
|
|
||||||
|
# adding entities
|
||||||
|
mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1])
|
||||||
|
mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1])
|
||||||
|
|
||||||
|
# adding aliases
|
||||||
|
mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7])
|
||||||
|
mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6])
|
||||||
|
|
||||||
|
# set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained)
|
||||||
|
sentencizer = nlp.create_pipe("sentencizer")
|
||||||
|
nlp.add_pipe(sentencizer)
|
||||||
|
|
||||||
|
ruler = EntityRuler(nlp)
|
||||||
|
patterns = [{"label": "GPE", "pattern": "Boston"},
|
||||||
|
{"label": "GPE", "pattern": "Denver"}]
|
||||||
|
ruler.add_patterns(patterns)
|
||||||
|
nlp.add_pipe(ruler)
|
||||||
|
|
||||||
|
el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64})
|
||||||
|
el_pipe.set_kb(mykb)
|
||||||
|
el_pipe.begin_training()
|
||||||
|
el_pipe.context_weight = 0
|
||||||
|
el_pipe.prior_weight = 1
|
||||||
|
nlp.add_pipe(el_pipe, last=True)
|
||||||
|
|
||||||
|
# test whether the entity links are preserved by the `as_doc()` function
|
||||||
|
text = "She lives in Boston. He lives in Denver."
|
||||||
|
doc = nlp(text)
|
||||||
|
for ent in doc.ents:
|
||||||
|
orig_text = ent.text
|
||||||
|
orig_kb_id = ent.kb_id_
|
||||||
|
sent_doc = ent.sent.as_doc()
|
||||||
|
for s_ent in sent_doc.ents:
|
||||||
|
if s_ent.text == orig_text:
|
||||||
|
assert s_ent.kb_id_ == orig_kb_id
|
74
spacy/tests/serialize/test_serialize_kb.py
Normal file
74
spacy/tests/serialize/test_serialize_kb.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
# coding: utf-8
|
||||||
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
|
from ..util import make_tempdir
|
||||||
|
from ...util import ensure_path
|
||||||
|
|
||||||
|
from spacy.kb import KnowledgeBase
|
||||||
|
|
||||||
|
|
||||||
|
def test_serialize_kb_disk(en_vocab):
|
||||||
|
# baseline assertions
|
||||||
|
kb1 = _get_dummy_kb(en_vocab)
|
||||||
|
_check_kb(kb1)
|
||||||
|
|
||||||
|
# dumping to file & loading back in
|
||||||
|
with make_tempdir() as d:
|
||||||
|
dir_path = ensure_path(d)
|
||||||
|
if not dir_path.exists():
|
||||||
|
dir_path.mkdir()
|
||||||
|
file_path = dir_path / "kb"
|
||||||
|
kb1.dump(str(file_path))
|
||||||
|
|
||||||
|
kb2 = KnowledgeBase(vocab=en_vocab, entity_vector_length=3)
|
||||||
|
kb2.load_bulk(str(file_path))
|
||||||
|
|
||||||
|
# final assertions
|
||||||
|
_check_kb(kb2)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_dummy_kb(vocab):
|
||||||
|
kb = KnowledgeBase(vocab=vocab, entity_vector_length=3)
|
||||||
|
|
||||||
|
kb.add_entity(entity='Q53', prob=0.33, entity_vector=[0, 5, 3])
|
||||||
|
kb.add_entity(entity='Q17', prob=0.2, entity_vector=[7, 1, 0])
|
||||||
|
kb.add_entity(entity='Q007', prob=0.7, entity_vector=[0, 0, 7])
|
||||||
|
kb.add_entity(entity='Q44', prob=0.4, entity_vector=[4, 4, 4])
|
||||||
|
|
||||||
|
kb.add_alias(alias='double07', entities=['Q17', 'Q007'], probabilities=[0.1, 0.9])
|
||||||
|
kb.add_alias(alias='guy', entities=['Q53', 'Q007', 'Q17', 'Q44'], probabilities=[0.3, 0.3, 0.2, 0.1])
|
||||||
|
kb.add_alias(alias='random', entities=['Q007'], probabilities=[1.0])
|
||||||
|
|
||||||
|
return kb
|
||||||
|
|
||||||
|
|
||||||
|
def _check_kb(kb):
|
||||||
|
# check entities
|
||||||
|
assert kb.get_size_entities() == 4
|
||||||
|
for entity_string in ['Q53', 'Q17', 'Q007', 'Q44']:
|
||||||
|
assert entity_string in kb.get_entity_strings()
|
||||||
|
for entity_string in ['', 'Q0']:
|
||||||
|
assert entity_string not in kb.get_entity_strings()
|
||||||
|
|
||||||
|
# check aliases
|
||||||
|
assert kb.get_size_aliases() == 3
|
||||||
|
for alias_string in ['double07', 'guy', 'random']:
|
||||||
|
assert alias_string in kb.get_alias_strings()
|
||||||
|
for alias_string in ['nothingness', '', 'randomnoise']:
|
||||||
|
assert alias_string not in kb.get_alias_strings()
|
||||||
|
|
||||||
|
# check candidates & probabilities
|
||||||
|
candidates = sorted(kb.get_candidates('double07'), key=lambda x: x.entity_)
|
||||||
|
assert len(candidates) == 2
|
||||||
|
|
||||||
|
assert candidates[0].entity_ == 'Q007'
|
||||||
|
assert 0.6999 < candidates[0].entity_freq < 0.701
|
||||||
|
assert candidates[0].entity_vector == [0, 0, 7]
|
||||||
|
assert candidates[0].alias_ == 'double07'
|
||||||
|
assert 0.899 < candidates[0].prior_prob < 0.901
|
||||||
|
|
||||||
|
assert candidates[1].entity_ == 'Q17'
|
||||||
|
assert 0.199 < candidates[1].entity_freq < 0.201
|
||||||
|
assert candidates[1].entity_vector == [7, 1, 0]
|
||||||
|
assert candidates[1].alias_ == 'double07'
|
||||||
|
assert 0.099 < candidates[1].prior_prob < 0.101
|
|
@ -22,7 +22,7 @@ from ..lexeme cimport Lexeme, EMPTY_LEXEME
|
||||||
from ..typedefs cimport attr_t, flags_t
|
from ..typedefs cimport attr_t, flags_t
|
||||||
from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER
|
from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER
|
||||||
from ..attrs cimport LENGTH, POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB
|
from ..attrs cimport LENGTH, POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB
|
||||||
from ..attrs cimport ENT_TYPE, SENT_START, attr_id_t
|
from ..attrs cimport ENT_TYPE, ENT_KB_ID, SENT_START, attr_id_t
|
||||||
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
|
from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t
|
||||||
|
|
||||||
from ..attrs import intify_attrs, IDS
|
from ..attrs import intify_attrs, IDS
|
||||||
|
@ -64,6 +64,8 @@ cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil:
|
||||||
return token.ent_iob
|
return token.ent_iob
|
||||||
elif feat_name == ENT_TYPE:
|
elif feat_name == ENT_TYPE:
|
||||||
return token.ent_type
|
return token.ent_type
|
||||||
|
elif feat_name == ENT_KB_ID:
|
||||||
|
return token.ent_kb_id
|
||||||
else:
|
else:
|
||||||
return Lexeme.get_struct_attr(token.lex, feat_name)
|
return Lexeme.get_struct_attr(token.lex, feat_name)
|
||||||
|
|
||||||
|
@ -851,7 +853,7 @@ cdef class Doc:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/doc#to_bytes
|
DOCS: https://spacy.io/api/doc#to_bytes
|
||||||
"""
|
"""
|
||||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE]
|
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] # TODO: ENT_KB_ID ?
|
||||||
if self.is_tagged:
|
if self.is_tagged:
|
||||||
array_head.append(TAG)
|
array_head.append(TAG)
|
||||||
# If doc parsed add head and dep attribute
|
# If doc parsed add head and dep attribute
|
||||||
|
@ -1005,6 +1007,7 @@ cdef class Doc:
|
||||||
"""
|
"""
|
||||||
cdef unicode tag, lemma, ent_type
|
cdef unicode tag, lemma, ent_type
|
||||||
deprecation_warning(Warnings.W013.format(obj="Doc"))
|
deprecation_warning(Warnings.W013.format(obj="Doc"))
|
||||||
|
# TODO: ENT_KB_ID ?
|
||||||
if len(args) == 3:
|
if len(args) == 3:
|
||||||
deprecation_warning(Warnings.W003)
|
deprecation_warning(Warnings.W003)
|
||||||
tag, lemma, ent_type = args
|
tag, lemma, ent_type = args
|
||||||
|
|
|
@ -210,7 +210,7 @@ cdef class Span:
|
||||||
words = [t.text for t in self]
|
words = [t.text for t in self]
|
||||||
spaces = [bool(t.whitespace_) for t in self]
|
spaces = [bool(t.whitespace_) for t in self]
|
||||||
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
|
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
|
||||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE]
|
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_KB_ID]
|
||||||
if self.doc.is_tagged:
|
if self.doc.is_tagged:
|
||||||
array_head.append(TAG)
|
array_head.append(TAG)
|
||||||
# If doc parsed add head and dep attribute
|
# If doc parsed add head and dep attribute
|
||||||
|
|
|
@ -53,6 +53,8 @@ cdef class Token:
|
||||||
return token.ent_iob
|
return token.ent_iob
|
||||||
elif feat_name == ENT_TYPE:
|
elif feat_name == ENT_TYPE:
|
||||||
return token.ent_type
|
return token.ent_type
|
||||||
|
elif feat_name == ENT_KB_ID:
|
||||||
|
return token.ent_kb_id
|
||||||
elif feat_name == SENT_START:
|
elif feat_name == SENT_START:
|
||||||
return token.sent_start
|
return token.sent_start
|
||||||
else:
|
else:
|
||||||
|
@ -79,5 +81,7 @@ cdef class Token:
|
||||||
token.ent_iob = value
|
token.ent_iob = value
|
||||||
elif feat_name == ENT_TYPE:
|
elif feat_name == ENT_TYPE:
|
||||||
token.ent_type = value
|
token.ent_type = value
|
||||||
|
elif feat_name == ENT_KB_ID:
|
||||||
|
token.ent_kb_id = value
|
||||||
elif feat_name == SENT_START:
|
elif feat_name == SENT_START:
|
||||||
token.sent_start = value
|
token.sent_start = value
|
||||||
|
|
Loading…
Reference in New Issue
Block a user