KB extensions and better parsing of WikiData (#4375)

* fix overflow error on windows

* more documentation & logging fixes

* md fix

* 3 different limit parameters to play with execution time

* bug fixes directory locations

* small fixes

* exclude dev test articles from prior probabilities stats

* small fixes

* filtering wikidata entities, removing numeric and meta items

* adding aliases from wikidata also to the KB

* fix adding WD aliases

* adding also new aliases to previously added entities

* fixing comma's

* small doc fixes

* adding subclassof filtering

* append alias functionality in KB

* prevent appending the same entity-alias pair

* fix for appending WD aliases

* remove date filter

* remove unnecessary import

* small corrections and reformatting

* remove WD aliases for now (too slow)

* removing numeric entities from training and evaluation

* small fixes

* shortcut during prediction if there is only one candidate

* add counts and fscore logging, remove FP NER from evaluation

* fix entity_linker.predict to take docs instead of single sentences

* remove enumeration sentences from the WP dataset

* entity_linker.update to process full doc instead of single sentence

* spelling corrections and dump locations in readme

* NLP IO fix

* reading KB is unnecessary at the end of the pipeline

* small logging fix

* remove empty files
This commit is contained in:
Sofie Van Landeghem 2019-10-14 11:28:53 +01:00 committed by Ines Montani
parent 428887b8f2
commit 2d249a9502
21 changed files with 1163 additions and 811 deletions

View File

@ -0,0 +1,34 @@
## Entity Linking with Wikipedia and Wikidata
### Step 1: Create a Knowledge Base (KB) and training data
Run `wikipedia_pretrain_kb.py`
* This takes as input the locations of a **Wikipedia and a Wikidata dump**, and produces a **KB directory** + **training file**
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
* You can set the filtering parameters for KB construction:
* `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym
* `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB
* `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
* Further parameters to set:
* `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
* `entity_vector_length`: length of the pre-trained entity description vectors
* `lang`: language for which to fetch Wikidata information (as the dump contains all languages)
Quick testing and rerunning:
* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything.
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
### Step 2: Train an Entity Linking model
Run `wikidata_train_entity_linker.py`
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
* You can set the learning parameters for the EL training:
* `epochs`: number of training iterations
* `dropout`: dropout rate
* `lr`: learning rate
* `l2`: L2 regularization
* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively
* Further parameters to set:
* `labels_discard`: NER label types to discard during training

View File

@ -6,6 +6,7 @@ OUTPUT_MODEL_DIR = "nlp"
PRIOR_PROB_PATH = "prior_prob.csv" PRIOR_PROB_PATH = "prior_prob.csv"
ENTITY_DEFS_PATH = "entity_defs.csv" ENTITY_DEFS_PATH = "entity_defs.csv"
ENTITY_FREQ_PATH = "entity_freq.csv" ENTITY_FREQ_PATH = "entity_freq.csv"
ENTITY_ALIAS_PATH = "entity_alias.csv"
ENTITY_DESCR_PATH = "entity_descriptions.csv" ENTITY_DESCR_PATH = "entity_descriptions.csv"
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'

View File

@ -15,10 +15,11 @@ class Metrics(object):
candidate_is_correct = true_entity == candidate candidate_is_correct = true_entity == candidate
# Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL") # Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
# Therefore, if candidate_is_correct then we have a true positive and never a true negative # Therefore, if candidate_is_correct then we have a true positive and never a true negative.
self.true_pos += candidate_is_correct self.true_pos += candidate_is_correct
self.false_neg += not candidate_is_correct self.false_neg += not candidate_is_correct
if candidate not in {"", "NIL"}: if candidate and candidate not in {"", "NIL"}:
# A wrong prediction (e.g. Q42 != Q3) counts both as a FP as well as a FN.
self.false_pos += not candidate_is_correct self.false_pos += not candidate_is_correct
def calculate_precision(self): def calculate_precision(self):
@ -33,6 +34,14 @@ class Metrics(object):
else: else:
return self.true_pos / (self.true_pos + self.false_neg) return self.true_pos / (self.true_pos + self.false_neg)
def calculate_fscore(self):
p = self.calculate_precision()
r = self.calculate_recall()
if p + r == 0:
return 0.0
else:
return 2 * p * r / (p + r)
class EvaluationResults(object): class EvaluationResults(object):
def __init__(self): def __init__(self):
@ -43,18 +52,20 @@ class EvaluationResults(object):
self.metrics.update_results(true_entity, candidate) self.metrics.update_results(true_entity, candidate)
self.metrics_by_label[ent_label].update_results(true_entity, candidate) self.metrics_by_label[ent_label].update_results(true_entity, candidate)
def increment_false_negatives(self):
self.metrics.false_neg += 1
def report_metrics(self, model_name): def report_metrics(self, model_name):
model_str = model_name.title() model_str = model_name.title()
recall = self.metrics.calculate_recall() recall = self.metrics.calculate_recall()
precision = self.metrics.calculate_precision() precision = self.metrics.calculate_precision()
return ("{}: ".format(model_str) + fscore = self.metrics.calculate_fscore()
"Recall = {} | ".format(round(recall, 3)) + return (
"Precision = {} | ".format(round(precision, 3)) + "{}: ".format(model_str)
"Precision by label = {}".format({k: v.calculate_precision() + "F-score = {} | ".format(round(fscore, 3))
for k, v in self.metrics_by_label.items()})) + "Recall = {} | ".format(round(recall, 3))
+ "Precision = {} | ".format(round(precision, 3))
+ "F-score by label = {}".format(
{k: v.calculate_fscore() for k, v in sorted(self.metrics_by_label.items())}
)
)
class BaselineResults(object): class BaselineResults(object):
@ -63,40 +74,51 @@ class BaselineResults(object):
self.prior = EvaluationResults() self.prior = EvaluationResults()
self.oracle = EvaluationResults() self.oracle = EvaluationResults()
def report_accuracy(self, model): def report_performance(self, model):
results = getattr(self, model) results = getattr(self, model)
return results.report_metrics(model) return results.report_metrics(model)
def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate): def update_baselines(
self,
true_entity,
ent_label,
random_candidate,
prior_candidate,
oracle_candidate,
):
self.oracle.update_metrics(ent_label, true_entity, oracle_candidate) self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
self.prior.update_metrics(ent_label, true_entity, prior_candidate) self.prior.update_metrics(ent_label, true_entity, prior_candidate)
self.random.update_metrics(ent_label, true_entity, random_candidate) self.random.update_metrics(ent_label, true_entity, random_candidate)
def measure_performance(dev_data, kb, el_pipe): def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True):
baseline_accuracies = measure_baselines( if baseline:
dev_data, kb baseline_accuracies, counts = measure_baselines(dev_data, kb)
) logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
logger.info(baseline_accuracies.report_performance("random"))
logger.info(baseline_accuracies.report_performance("prior"))
logger.info(baseline_accuracies.report_performance("oracle"))
logger.info(baseline_accuracies.report_accuracy("random")) if context:
logger.info(baseline_accuracies.report_accuracy("prior")) # using only context
logger.info(baseline_accuracies.report_accuracy("oracle")) el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = False
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context only"))
# using only context # measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = False el_pipe.cfg["incl_prior"] = True
results = get_eval_results(dev_data, el_pipe) results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context only")) logger.info(results.report_metrics("context and prior"))
# measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context and prior"))
def get_eval_results(data, el_pipe=None): def get_eval_results(data, el_pipe=None):
# If the docs in the data require further processing with an entity linker, set el_pipe """
Evaluate the ent.kb_id_ annotations against the gold standard.
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
If the docs in the data require further processing with an entity linker, set el_pipe.
"""
from tqdm import tqdm from tqdm import tqdm
docs = [] docs = []
@ -111,18 +133,15 @@ def get_eval_results(data, el_pipe=None):
results = EvaluationResults() results = EvaluationResults()
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
try: try:
correct_entries_per_article = dict() correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items(): for entity, kb_dict in gold.links.items():
start, end = entity start, end = entity
# only evaluating on positive examples
for gold_kb, value in kb_dict.items(): for gold_kb, value in kb_dict.items():
if value: if value:
# only evaluating on positive examples
offset = _offset(start, end) offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb correct_entries_per_article[offset] = gold_kb
if offset not in tagged_entries_per_article:
results.increment_false_negatives()
for ent in doc.ents: for ent in doc.ents:
ent_label = ent.label_ ent_label = ent.label_
@ -142,7 +161,11 @@ def get_eval_results(data, el_pipe=None):
def measure_baselines(data, kb): def measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound """
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
Also return a dictionary of counts by entity label.
"""
counts_d = dict() counts_d = dict()
baseline_results = BaselineResults() baseline_results = BaselineResults()
@ -152,7 +175,6 @@ def measure_baselines(data, kb):
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
correct_entries_per_article = dict() correct_entries_per_article = dict()
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
for entity, kb_dict in gold.links.items(): for entity, kb_dict in gold.links.items():
start, end = entity start, end = entity
for gold_kb, value in kb_dict.items(): for gold_kb, value in kb_dict.items():
@ -160,10 +182,6 @@ def measure_baselines(data, kb):
if value: if value:
offset = _offset(start, end) offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb correct_entries_per_article[offset] = gold_kb
if offset not in tagged_entries_per_article:
baseline_results.random.increment_false_negatives()
baseline_results.oracle.increment_false_negatives()
baseline_results.prior.increment_false_negatives()
for ent in doc.ents: for ent in doc.ents:
ent_label = ent.label_ ent_label = ent.label_
@ -176,7 +194,7 @@ def measure_baselines(data, kb):
if gold_entity is not None: if gold_entity is not None:
candidates = kb.get_candidates(ent.text) candidates = kb.get_candidates(ent.text)
oracle_candidate = "" oracle_candidate = ""
best_candidate = "" prior_candidate = ""
random_candidate = "" random_candidate = ""
if candidates: if candidates:
scores = [] scores = []
@ -187,13 +205,21 @@ def measure_baselines(data, kb):
oracle_candidate = c.entity_ oracle_candidate = c.entity_
best_index = scores.index(max(scores)) best_index = scores.index(max(scores))
best_candidate = candidates[best_index].entity_ prior_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_ random_candidate = random.choice(candidates).entity_
baseline_results.update_baselines(gold_entity, ent_label, current_count = counts_d.get(ent_label, 0)
random_candidate, best_candidate, oracle_candidate) counts_d[ent_label] = current_count+1
return baseline_results baseline_results.update_baselines(
gold_entity,
ent_label,
random_candidate,
prior_candidate,
oracle_candidate,
)
return baseline_results, counts_d
def _offset(start, end): def _offset(start, end):

View File

@ -1,17 +1,12 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import csv
import logging import logging
import spacy
import sys
from spacy.kb import KnowledgeBase from spacy.kb import KnowledgeBase
from bin.wiki_entity_linking import wikipedia_processor as wp
from bin.wiki_entity_linking.train_descriptions import EntityEncoder from bin.wiki_entity_linking.train_descriptions import EntityEncoder
from bin.wiki_entity_linking import wiki_io as io
csv.field_size_limit(sys.maxsize)
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -22,18 +17,24 @@ def create_kb(
max_entities_per_alias, max_entities_per_alias,
min_entity_freq, min_entity_freq,
min_occ, min_occ,
entity_def_input, entity_def_path,
entity_descr_path, entity_descr_path,
count_input, entity_alias_path,
prior_prob_input, entity_freq_path,
prior_prob_path,
entity_vector_length, entity_vector_length,
): ):
# Create the knowledge base from Wikidata entries # Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length) kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
entity_list, filtered_title_to_id = _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length)
_define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path)
return kb
def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length):
# read the mappings from file # read the mappings from file
title_to_id = get_entity_to_id(entity_def_input) title_to_id = io.read_title_to_id(entity_def_path)
id_to_descr = get_id_to_description(entity_descr_path) id_to_descr = io.read_id_to_descr(entity_descr_path)
# check the length of the nlp vectors # check the length of the nlp vectors
if "vectors" in nlp.meta and nlp.vocab.vectors.size: if "vectors" in nlp.meta and nlp.vocab.vectors.size:
@ -45,10 +46,8 @@ def create_kb(
" cf. https://spacy.io/usage/models#languages." " cf. https://spacy.io/usage/models#languages."
) )
logger.info("Get entity frequencies")
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq)) logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
entity_frequencies = io.read_entity_to_count(entity_freq_path)
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise # filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities( filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
title_to_id, title_to_id,
@ -56,36 +55,33 @@ def create_kb(
entity_frequencies, entity_frequencies,
min_entity_freq min_entity_freq
) )
logger.info("Left with {} entities".format(len(description_list))) logger.info("Kept {} entities from the set of {}".format(len(description_list), len(title_to_id.keys())))
logger.info("Train entity encoder") logger.info("Training entity encoder")
encoder = EntityEncoder(nlp, input_dim, entity_vector_length) encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
encoder.train(description_list=description_list, to_print=True) encoder.train(description_list=description_list, to_print=True)
logger.info("Get entity embeddings:") logger.info("Getting entity embeddings")
embeddings = encoder.apply_encoder(description_list) embeddings = encoder.apply_encoder(description_list)
logger.info("Adding {} entities".format(len(entity_list))) logger.info("Adding {} entities".format(len(entity_list)))
kb.set_entities( kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
) )
return entity_list, filtered_title_to_id
logger.info("Adding aliases")
def _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
logger.info("Adding aliases from Wikipedia and Wikidata")
_add_aliases( _add_aliases(
kb, kb,
entity_list=entity_list,
title_to_id=filtered_title_to_id, title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias, max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ, min_occ=min_occ,
prior_prob_input=prior_prob_input, prior_prob_path=prior_prob_path,
) )
logger.info("KB size: {} entities, {} aliases".format(
kb.get_size_entities(),
kb.get_size_aliases()))
logger.info("Done with kb")
return kb
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies, def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
min_entity_freq: int = 10): min_entity_freq: int = 10):
@ -104,34 +100,13 @@ def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
return filtered_title_to_id, entity_list, description_list, frequency_list return filtered_title_to_id, entity_list, description_list, frequency_list
def get_entity_to_id(entity_def_output): def _add_aliases(kb, entity_list, title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
entity_to_id = dict()
with entity_def_output.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
entity_to_id[row[0]] = row[1]
return entity_to_id
def get_id_to_description(entity_descr_path):
id_to_desc = dict()
with entity_descr_path.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
id_to_desc[row[0]] = row[1]
return id_to_desc
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
wp_titles = title_to_id.keys() wp_titles = title_to_id.keys()
# adding aliases with prior probabilities # adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count # we can read this file sequentially, it's sorted by alias, and then by count
with prior_prob_input.open("r", encoding="utf8") as prior_file: logger.info("Adding WP aliases")
with prior_prob_path.open("r", encoding="utf8") as prior_file:
# skip header # skip header
prior_file.readline() prior_file.readline()
line = prior_file.readline() line = prior_file.readline()
@ -180,10 +155,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
line = prior_file.readline() line = prior_file.readline()
def read_nlp_kb(model_dir, kb_file): def read_kb(nlp, kb_file):
nlp = spacy.load(model_dir)
kb = KnowledgeBase(vocab=nlp.vocab) kb = KnowledgeBase(vocab=nlp.vocab)
kb.load_bulk(kb_file) kb.load_bulk(kb_file)
logger.info("kb entities: {}".format(kb.get_size_entities())) return kb
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
return nlp, kb

View File

@ -53,7 +53,7 @@ class EntityEncoder:
start = start + batch_size start = start + batch_size
stop = min(stop + batch_size, len(description_list)) stop = min(stop + batch_size, len(description_list))
logger.info("encoded: {} entities".format(stop)) logger.info("Encoded: {} entities".format(stop))
return encodings return encodings
@ -62,7 +62,7 @@ class EntityEncoder:
if to_print: if to_print:
logger.info( logger.info(
"Trained entity descriptions on {} ".format(processed) + "Trained entity descriptions on {} ".format(processed) +
"(non-unique) entities across {} ".format(self.epochs) + "(non-unique) descriptions across {} ".format(self.epochs) +
"epochs" "epochs"
) )
logger.info("Final loss: {}".format(loss)) logger.info("Final loss: {}".format(loss))

View File

@ -1,395 +0,0 @@
# coding: utf-8
from __future__ import unicode_literals
import logging
import random
import re
import bz2
import json
from functools import partial
from spacy.gold import GoldParse
from bin.wiki_entity_linking import kb_creator
"""
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
Gold-standard entities are stored in one file in standoff format (by character offset).
"""
ENTITY_FILE = "gold_entities.csv"
logger = logging.getLogger(__name__)
def create_training_examples_and_descriptions(wikipedia_input,
entity_def_input,
description_output,
training_output,
parse_descriptions,
limit=None):
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wikipedia_input,
wp_to_id,
description_output,
training_output,
parse_descriptions,
limit)
def _process_wikipedia_texts(wikipedia_input,
wp_to_id,
output,
training_output,
parse_descriptions,
limit=None):
"""
Read the XML wikipedia data to parse out training data:
raw text data + positive instances
"""
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
read_ids = set()
with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file:
if parse_descriptions:
_write_training_description(descr_file, "WD_id", "description")
with bz2.open(wikipedia_input, mode="rb") as file:
article_count = 0
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
logger.info("Processed {} articles".format(article_count))
for line in file:
clean_line = line.strip().decode("utf-8")
if clean_line == "<revision>":
reading_revision = True
elif clean_line == "</revision>":
reading_revision = False
# Start reading new page
if clean_line == "<page>":
article_text = ""
article_title = None
article_id = None
# finished reading this page
elif clean_line == "</page>":
if article_id:
clean_text, entities = _process_wp_text(
article_title,
article_text,
wp_to_id
)
if clean_text is not None and entities is not None:
_write_training_entities(entity_file,
article_id,
clean_text,
entities)
if article_title in wp_to_id and parse_descriptions:
description = " ".join(clean_text[:1000].split(" ")[:-1])
_write_training_description(
descr_file,
wp_to_id[article_title],
description
)
article_count += 1
if article_count % 10000 == 0:
logger.info("Processed {} articles".format(article_count))
if limit and article_count >= limit:
break
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
# start reading text within a page
if "<text" in clean_line:
reading_text = True
if reading_text:
article_text += " " + clean_line
# stop reading text within a page (we assume a new page doesn't start on the same line)
if "</text" in clean_line:
reading_text = False
# read the ID of this article (outside the revision portion of the document)
if not reading_revision:
ids = id_regex.search(clean_line)
if ids:
article_id = ids[0]
if article_id in read_ids:
logger.info(
"Found duplicate article ID", article_id, clean_line
) # This should never happen ...
read_ids.add(article_id)
# read the title of this article (outside the revision portion of the document)
if not reading_revision:
titles = title_regex.search(clean_line)
if titles:
article_title = titles[0].strip()
logger.info("Finished. Processed {} articles".format(article_count))
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
info_regex = re.compile(r"{[^{]*?}")
htlm_regex = re.compile(r"&lt;!--[^-]*--&gt;")
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
file_regex = re.compile(r"\[\[File:[^[\]]+]]")
ref_regex = re.compile(r"&lt;ref.*?&gt;") # non-greedy
ref_2_regex = re.compile(r"&lt;/ref.*?&gt;") # non-greedy
def _process_wp_text(article_title, article_text, wp_to_id):
# ignore meta Wikipedia pages
if (
article_title.startswith("Wikipedia:") or
article_title.startswith("Kategori:")
):
return None, None
# remove the text tags
text_search = text_regex.search(article_text)
if text_search is None:
return None, None
text = text_search.group(0)
# stop processing if this is a redirect page
if text.startswith("#REDIRECT"):
return None, None
# get the raw text without markup etc, keeping only interwiki links
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
return clean_text, entities
def _get_clean_wp_text(article_text):
clean_text = article_text.strip()
# remove bolding & italic markup
clean_text = clean_text.replace("'''", "")
clean_text = clean_text.replace("''", "")
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
try_again = True
previous_length = len(clean_text)
while try_again:
clean_text = info_regex.sub(
"", clean_text
) # non-greedy match excluding a nested {
if len(clean_text) < previous_length:
try_again = True
else:
try_again = False
previous_length = len(clean_text)
# remove HTML comments
clean_text = htlm_regex.sub("", clean_text)
# remove Category and File statements
clean_text = category_regex.sub("", clean_text)
clean_text = file_regex.sub("", clean_text)
# remove multiple =
while "==" in clean_text:
clean_text = clean_text.replace("==", "=")
clean_text = clean_text.replace(". =", ".")
clean_text = clean_text.replace(" = ", ". ")
clean_text = clean_text.replace("= ", ".")
clean_text = clean_text.replace(" =", "")
# remove refs (non-greedy match)
clean_text = ref_regex.sub("", clean_text)
clean_text = ref_2_regex.sub("", clean_text)
# remove additional wikiformatting
clean_text = re.sub(r"&lt;blockquote&gt;", "", clean_text)
clean_text = re.sub(r"&lt;/blockquote&gt;", "", clean_text)
# change special characters back to normal ones
clean_text = clean_text.replace(r"&lt;", "<")
clean_text = clean_text.replace(r"&gt;", ">")
clean_text = clean_text.replace(r"&quot;", '"')
clean_text = clean_text.replace(r"&amp;nbsp;", " ")
clean_text = clean_text.replace(r"&amp;", "&")
# remove multiple spaces
while " " in clean_text:
clean_text = clean_text.replace(" ", " ")
return clean_text.strip()
def _remove_links(clean_text, wp_to_id):
# read the text char by char to get the right offsets for the interwiki links
entities = []
final_text = ""
open_read = 0
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
entity_buffer = ""
mention_buffer = ""
for index, letter in enumerate(clean_text):
if letter == "[":
open_read += 1
elif letter == "]":
open_read -= 1
elif letter == "|":
if reading_text:
final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern
elif reading_entity:
reading_text = False
reading_entity = False
reading_mention = True
else:
reading_special_case = True
else:
if reading_entity:
entity_buffer += letter
elif reading_mention:
mention_buffer += letter
elif reading_text:
final_text += letter
else:
raise ValueError("Not sure at point", clean_text[index - 2: index + 2])
if open_read > 2:
reading_special_case = True
if open_read == 2 and reading_text:
reading_text = False
reading_entity = True
reading_mention = False
# we just finished reading an entity
if open_read == 0 and not reading_text:
if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
if not mention_buffer:
mention_buffer = entity_buffer
start = len(final_text)
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
entities.append((mention_buffer, qid, start, end))
final_text += mention_buffer
entity_buffer = ""
mention_buffer = ""
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
return final_text, entities
def _write_training_description(outputfile, qid, description):
if description is not None:
line = str(qid) + "|" + description + "\n"
outputfile.write(line)
def _write_training_entities(outputfile, article_id, clean_text, entities):
entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities]
line = json.dumps(
{
"article_id": article_id,
"clean_text": clean_text,
"entities": entities_data
},
ensure_ascii=False) + "\n"
outputfile.write(line)
def read_training(nlp, entity_file_path, dev, limit, kb):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
For training,, it will include negative training examples by using the candidate generator,
and it will only keep positive training examples that can be found by using the candidate generator.
For testing, it will include all positive examples only."""
from tqdm import tqdm
data = []
num_entities = 0
get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb)
logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit))
with entity_file_path.open("r", encoding="utf8") as file:
with tqdm(total=limit, leave=False) as pbar:
for i, line in enumerate(file):
example = json.loads(line)
article_id = example["article_id"]
clean_text = example["clean_text"]
entities = example["entities"]
if dev != is_dev(article_id) or len(clean_text) >= 30000:
continue
doc = nlp(clean_text)
gold = get_gold_parse(doc, entities)
if gold and len(gold.links) > 0:
data.append((doc, gold))
num_entities += len(gold.links)
pbar.update(len(gold.links))
if limit and num_entities >= limit:
break
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
return data
def _get_gold_parse(doc, entities, dev, kb):
gold_entities = {}
tagged_ent_positions = set(
[(ent.start_char, ent.end_char) for ent in doc.ents]
)
for entity in entities:
entity_id = entity["entity"]
alias = entity["alias"]
start = entity["start"]
end = entity["end"]
candidates = kb.get_candidates(alias)
candidate_ids = [
c.entity_ for c in candidates
]
should_add_ent = (
dev or
(
(start, end) in tagged_ent_positions and
entity_id in candidate_ids and
len(candidates) > 1
)
)
if should_add_ent:
value_by_id = {entity_id: 1.0}
if not dev:
random.shuffle(candidate_ids)
value_by_id.update({
kb_id: 0.0
for kb_id in candidate_ids
if kb_id != entity_id
})
gold_entities[(start, end)] = value_by_id
return GoldParse(doc, links=gold_entities)
def is_dev(article_id):
return article_id.endswith("3")

View File

@ -0,0 +1,127 @@
# coding: utf-8
from __future__ import unicode_literals
import sys
import csv
# min() needed to prevent error on windows, cf https://stackoverflow.com/questions/52404416/
csv.field_size_limit(min(sys.maxsize, 2147483646))
""" This class provides reading/writing methods for temp files """
# Entity definition: WP title -> WD ID #
def write_title_to_id(entity_def_output, title_to_id):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
def read_title_to_id(entity_def_output):
title_to_id = dict()
with entity_def_output.open("r", encoding="utf8") as id_file:
csvreader = csv.reader(id_file, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
title_to_id[row[0]] = row[1]
return title_to_id
# Entity aliases from WD: WD ID -> WD alias #
def write_id_to_alias(entity_alias_path, id_to_alias):
with entity_alias_path.open("w", encoding="utf8") as alias_file:
alias_file.write("WD_id" + "|" + "alias" + "\n")
for qid, alias_list in id_to_alias.items():
for alias in alias_list:
alias_file.write(str(qid) + "|" + alias + "\n")
def read_id_to_alias(entity_alias_path):
id_to_alias = dict()
with entity_alias_path.open("r", encoding="utf8") as alias_file:
csvreader = csv.reader(alias_file, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
qid = row[0]
alias = row[1]
alias_list = id_to_alias.get(qid, [])
alias_list.append(alias)
id_to_alias[qid] = alias_list
return id_to_alias
def read_alias_to_id_generator(entity_alias_path):
""" Read (aliases, qid) tuples """
with entity_alias_path.open("r", encoding="utf8") as alias_file:
csvreader = csv.reader(alias_file, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
qid = row[0]
alias = row[1]
yield alias, qid
# Entity descriptions from WD: WD ID -> WD alias #
def write_id_to_descr(entity_descr_output, id_to_descr):
with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")
def read_id_to_descr(entity_desc_path):
id_to_desc = dict()
with entity_desc_path.open("r", encoding="utf8") as descr_file:
csvreader = csv.reader(descr_file, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
id_to_desc[row[0]] = row[1]
return id_to_desc
# Entity counts from WP: WP title -> count #
def write_entity_to_count(prior_prob_input, count_output):
# Write entity counts for quick access later
entity_to_count = dict()
total_count = 0
with prior_prob_input.open("r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
while line:
splits = line.replace("\n", "").split(sep="|")
# alias = splits[0]
count = int(splits[1])
entity = splits[2]
current_count = entity_to_count.get(entity, 0)
entity_to_count[entity] = current_count + count
total_count += count
line = prior_file.readline()
with count_output.open("w", encoding="utf8") as entity_file:
entity_file.write("entity" + "|" + "count" + "\n")
for entity, count in entity_to_count.items():
entity_file.write(entity + "|" + str(count) + "\n")
def read_entity_to_count(count_input):
entity_to_count = dict()
with count_input.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
entity_to_count[row[0]] = int(row[1])
return entity_to_count

View File

@ -0,0 +1,128 @@
# coding: utf8
from __future__ import unicode_literals
# List of meta pages in Wikidata, should be kept out of the Knowledge base
WD_META_ITEMS = [
"Q163875",
"Q191780",
"Q224414",
"Q4167836",
"Q4167410",
"Q4663903",
"Q11266439",
"Q13406463",
"Q15407973",
"Q18616576",
"Q19887878",
"Q22808320",
"Q23894233",
"Q33120876",
"Q42104522",
"Q47460393",
"Q64875536",
"Q66480449",
]
# TODO: add more cases from non-English WP's
# List of prefixes that refer to Wikipedia "file" pages
WP_FILE_NAMESPACE = ["Bestand", "File"]
# List of prefixes that refer to Wikipedia "category" pages
WP_CATEGORY_NAMESPACE = ["Kategori", "Category", "Categorie"]
# List of prefixes that refer to Wikipedia "meta" pages
# these will/should be matched ignoring case
WP_META_NAMESPACE = (
WP_FILE_NAMESPACE
+ WP_CATEGORY_NAMESPACE
+ [
"b",
"betawikiversity",
"Book",
"c",
"Commons",
"d",
"dbdump",
"download",
"Draft",
"Education",
"Foundation",
"Gadget",
"Gadget definition",
"Gebruiker",
"gerrit",
"Help",
"Image",
"Incubator",
"m",
"mail",
"mailarchive",
"media",
"MediaWiki",
"MediaWiki talk",
"Mediawikiwiki",
"MediaZilla",
"Meta",
"Metawikipedia",
"Module",
"mw",
"n",
"nost",
"oldwikisource",
"otrs",
"OTRSwiki",
"Overleg gebruiker",
"outreach",
"outreachwiki",
"Portal",
"phab",
"Phabricator",
"Project",
"q",
"quality",
"rev",
"s",
"spcom",
"Special",
"species",
"Strategy",
"sulutil",
"svn",
"Talk",
"Template",
"Template talk",
"Testwiki",
"ticket",
"TimedText",
"Toollabs",
"tools",
"tswiki",
"User",
"User talk",
"v",
"voy",
"w",
"Wikibooks",
"Wikidata",
"wikiHow",
"Wikinvest",
"wikilivres",
"Wikimedia",
"Wikinews",
"Wikipedia",
"Wikipedia talk",
"Wikiquote",
"Wikisource",
"Wikispecies",
"Wikitech",
"Wikiversity",
"Wikivoyage",
"wikt",
"wiktionary",
"wmf",
"wmania",
"WP",
]
)

View File

@ -18,11 +18,12 @@ from pathlib import Path
import plac import plac
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
from bin.wiki_entity_linking import wiki_io as io
from bin.wiki_entity_linking import kb_creator from bin.wiki_entity_linking import kb_creator
from bin.wiki_entity_linking import training_set_creator
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH, ENTITY_ALIAS_PATH
import spacy import spacy
from bin.wiki_entity_linking.kb_creator import read_kb
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -39,9 +40,11 @@ logger = logging.getLogger(__name__)
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path), loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path), loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path), loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"), descr_from_wp=("Flag for using wp descriptions not wd", "flag", "wp"),
limit=("Optional threshold to limit lines read from dumps", "option", "l", int), limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int),
lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str), limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int),
limit_wd=("Threshold to limit lines read from WD", "option", "lw", int),
lang=("Optional language for which to get Wikidata titles. Defaults to 'en'", "option", "la", str),
) )
def main( def main(
wd_json, wd_json,
@ -54,13 +57,16 @@ def main(
entity_vector_length=64, entity_vector_length=64,
loc_prior_prob=None, loc_prior_prob=None,
loc_entity_defs=None, loc_entity_defs=None,
loc_entity_alias=None,
loc_entity_desc=None, loc_entity_desc=None,
descriptions_from_wikipedia=False, descr_from_wp=False,
limit=None, limit_prior=None,
limit_train=None,
limit_wd=None,
lang="en", lang="en",
): ):
entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
entity_alias_path = loc_entity_alias if loc_entity_alias else output_dir / ENTITY_ALIAS_PATH
entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
entity_freq_path = output_dir / ENTITY_FREQ_PATH entity_freq_path = output_dir / ENTITY_FREQ_PATH
prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
@ -69,15 +75,12 @@ def main(
logger.info("Creating KB with Wikipedia and WikiData") logger.info("Creating KB with Wikipedia and WikiData")
if limit is not None:
logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit))
# STEP 0: set up IO # STEP 0: set up IO
if not output_dir.exists(): if not output_dir.exists():
output_dir.mkdir(parents=True) output_dir.mkdir(parents=True)
# STEP 1: create the NLP object # STEP 1: Load the NLP object
logger.info("STEP 1: Loading model {}".format(model)) logger.info("STEP 1: Loading NLP model {}".format(model))
nlp = spacy.load(model) nlp = spacy.load(model)
# check the length of the nlp vectors # check the length of the nlp vectors
@ -90,62 +93,83 @@ def main(
# STEP 2: create prior probabilities from WP # STEP 2: create prior probabilities from WP
if not prior_prob_path.exists(): if not prior_prob_path.exists():
# It takes about 2h to process 1000M lines of Wikipedia XML dump # It takes about 2h to process 1000M lines of Wikipedia XML dump
logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path)) logger.info("STEP 2: Writing prior probabilities to {}".format(prior_prob_path))
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit) if limit_prior is not None:
logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path)) logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_prior))
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit_prior)
else:
logger.info("STEP 2: Reading prior probabilities from {}".format(prior_prob_path))
# STEP 3: deduce entity frequencies from WP (takes only a few minutes) # STEP 3: calculate entity frequencies
logger.info("STEP 3: calculating entity frequencies") if not entity_freq_path.exists():
wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False) logger.info("STEP 3: Calculating and writing entity frequencies to {}".format(entity_freq_path))
io.write_entity_to_count(prior_prob_path, entity_freq_path)
else:
logger.info("STEP 3: Reading entity frequencies from {}".format(entity_freq_path))
# STEP 4: reading definitions and (possibly) descriptions from WikiData or from file # STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
message = " and descriptions" if not descriptions_from_wikipedia else "" if (not entity_defs_path.exists()) or (not descr_from_wp and not entity_descr_path.exists()):
if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()):
# It takes about 10h to process 55M lines of Wikidata JSON dump # It takes about 10h to process 55M lines of Wikidata JSON dump
logger.info("STEP 4: parsing wikidata for entity definitions" + message) logger.info("STEP 4: Parsing and writing Wikidata entity definitions to {}".format(entity_defs_path))
title_to_id, id_to_descr = wd.read_wikidata_entities_json( if limit_wd is not None:
logger.warning("Warning: reading only {} lines of Wikidata dump".format(limit_wd))
title_to_id, id_to_descr, id_to_alias = wd.read_wikidata_entities_json(
wd_json, wd_json,
limit, limit_wd,
to_print=False, to_print=False,
lang=lang, lang=lang,
parse_descriptions=(not descriptions_from_wikipedia), parse_descr=(not descr_from_wp),
) )
wd.write_entity_files(entity_defs_path, title_to_id) io.write_title_to_id(entity_defs_path, title_to_id)
if not descriptions_from_wikipedia:
wd.write_entity_description_files(entity_descr_path, id_to_descr)
logger.info("STEP 4: read entity definitions" + message)
# STEP 5: Getting gold entities from wikipedia logger.info("STEP 4b: Writing Wikidata entity aliases to {}".format(entity_alias_path))
message = " and descriptions" if descriptions_from_wikipedia else "" io.write_id_to_alias(entity_alias_path, id_to_alias)
if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()):
logger.info("STEP 5: parsing wikipedia for gold entities" + message) if not descr_from_wp:
training_set_creator.create_training_examples_and_descriptions( logger.info("STEP 4c: Writing Wikidata entity descriptions to {}".format(entity_descr_path))
wp_xml, io.write_id_to_descr(entity_descr_path, id_to_descr)
entity_defs_path, else:
entity_descr_path, logger.info("STEP 4: Reading entity definitions from {}".format(entity_defs_path))
training_entities_path, logger.info("STEP 4b: Reading entity aliases from {}".format(entity_alias_path))
parse_descriptions=descriptions_from_wikipedia, if not descr_from_wp:
limit=limit, logger.info("STEP 4c: Reading entity descriptions from {}".format(entity_descr_path))
)
logger.info("STEP 5: read gold entities" + message) # STEP 5: Getting gold entities from Wikipedia
if (not training_entities_path.exists()) or (descr_from_wp and not entity_descr_path.exists()):
logger.info("STEP 5: Parsing and writing Wikipedia gold entities to {}".format(training_entities_path))
if limit_train is not None:
logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_train))
wp.create_training_and_desc(wp_xml, entity_defs_path, entity_descr_path,
training_entities_path, descr_from_wp, limit_train)
if descr_from_wp:
logger.info("STEP 5b: Parsing and writing Wikipedia descriptions to {}".format(entity_descr_path))
else:
logger.info("STEP 5: Reading gold entities from {}".format(training_entities_path))
if descr_from_wp:
logger.info("STEP 5b: Reading entity descriptions from {}".format(entity_descr_path))
# STEP 6: creating the actual KB # STEP 6: creating the actual KB
# It takes ca. 30 minutes to pretrain the entity embeddings # It takes ca. 30 minutes to pretrain the entity embeddings
logger.info("STEP 6: creating the KB at {}".format(kb_path)) if not kb_path.exists():
kb = kb_creator.create_kb( logger.info("STEP 6: Creating the KB at {}".format(kb_path))
nlp=nlp, kb = kb_creator.create_kb(
max_entities_per_alias=max_per_alias, nlp=nlp,
min_entity_freq=min_freq, max_entities_per_alias=max_per_alias,
min_occ=min_pair, min_entity_freq=min_freq,
entity_def_input=entity_defs_path, min_occ=min_pair,
entity_descr_path=entity_descr_path, entity_def_path=entity_defs_path,
count_input=entity_freq_path, entity_descr_path=entity_descr_path,
prior_prob_input=prior_prob_path, entity_alias_path=entity_alias_path,
entity_vector_length=entity_vector_length, entity_freq_path=entity_freq_path,
) prior_prob_path=prior_prob_path,
entity_vector_length=entity_vector_length,
kb.dump(kb_path) )
nlp.to_disk(output_dir / KB_MODEL_DIR) kb.dump(kb_path)
logger.info("kb entities: {}".format(kb.get_size_entities()))
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
nlp.to_disk(output_dir / KB_MODEL_DIR)
else:
logger.info("STEP 6: KB already exists at {}".format(kb_path))
logger.info("Done!") logger.info("Done!")

View File

@ -1,40 +1,52 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import gzip import bz2
import json import json
import logging import logging
import datetime
from bin.wiki_entity_linking.wiki_namespaces import WD_META_ITEMS
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True): def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descr=True):
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. # Read the JSON wiki data and parse out the entities. Takes about 7-10h to parse 55M lines.
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
site_filter = '{}wiki'.format(lang) site_filter = '{}wiki'.format(lang)
# properties filter (currently disabled to get ALL data) # filter: currently defined as OR: one hit suffices to be removed from further processing
prop_filter = dict() exclude_list = WD_META_ITEMS
# prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
# punctuation
exclude_list.extend(["Q1383557", "Q10617810"])
# letters etc
exclude_list.extend(["Q188725", "Q19776628", "Q3841820", "Q17907810", "Q9788", "Q9398093"])
neg_prop_filter = {
'P31': exclude_list, # instance of
'P279': exclude_list # subclass
}
title_to_id = dict() title_to_id = dict()
id_to_descr = dict() id_to_descr = dict()
id_to_alias = dict()
# parse appropriate fields - depending on what we need in the KB # parse appropriate fields - depending on what we need in the KB
parse_properties = False parse_properties = False
parse_sitelinks = True parse_sitelinks = True
parse_labels = False parse_labels = False
parse_aliases = False parse_aliases = True
parse_claims = False parse_claims = True
with gzip.open(wikidata_file, mode='rb') as file: with bz2.open(wikidata_file, mode='rb') as file:
for cnt, line in enumerate(file): for cnt, line in enumerate(file):
if limit and cnt >= limit: if limit and cnt >= limit:
break break
if cnt % 500000 == 0: if cnt % 500000 == 0 and cnt > 0:
logger.info("processed {} lines of WikiData dump".format(cnt)) logger.info("processed {} lines of WikiData JSON dump".format(cnt))
clean_line = line.strip() clean_line = line.strip()
if clean_line.endswith(b","): if clean_line.endswith(b","):
clean_line = clean_line[:-1] clean_line = clean_line[:-1]
@ -43,13 +55,11 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
entry_type = obj["type"] entry_type = obj["type"]
if entry_type == "item": if entry_type == "item":
# filtering records on their properties (currently disabled to get ALL data)
# keep = False
keep = True keep = True
claims = obj["claims"] claims = obj["claims"]
if parse_claims: if parse_claims:
for prop, value_set in prop_filter.items(): for prop, value_set in neg_prop_filter.items():
claim_property = claims.get(prop, None) claim_property = claims.get(prop, None)
if claim_property: if claim_property:
for cp in claim_property: for cp in claim_property:
@ -61,7 +71,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
) )
cp_rank = cp["rank"] cp_rank = cp["rank"]
if cp_rank != "deprecated" and cp_id in value_set: if cp_rank != "deprecated" and cp_id in value_set:
keep = True keep = False
if keep: if keep:
unique_id = obj["id"] unique_id = obj["id"]
@ -108,7 +118,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
"label (" + lang + "):", lang_label["value"] "label (" + lang + "):", lang_label["value"]
) )
if found_link and parse_descriptions: if found_link and parse_descr:
descriptions = obj["descriptions"] descriptions = obj["descriptions"]
if descriptions: if descriptions:
lang_descr = descriptions.get(lang, None) lang_descr = descriptions.get(lang, None)
@ -130,22 +140,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
print( print(
"alias (" + lang + "):", item["value"] "alias (" + lang + "):", item["value"]
) )
alias_list = id_to_alias.get(unique_id, [])
alias_list.append(item["value"])
id_to_alias[unique_id] = alias_list
if to_print: if to_print:
print() print()
return title_to_id, id_to_descr # log final number of lines processed
logger.info("Finished. Processed {} lines of WikiData JSON dump".format(cnt))
return title_to_id, id_to_descr, id_to_alias
def write_entity_files(entity_def_output, title_to_id):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
def write_entity_description_files(entity_descr_output, id_to_descr):
with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")

View File

@ -6,19 +6,19 @@ as created by the script `wikidata_create_kb`.
For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2 For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2
from https://dumps.wikimedia.org/enwiki/latest/ from https://dumps.wikimedia.org/enwiki/latest/
""" """
from __future__ import unicode_literals from __future__ import unicode_literals
import random import random
import logging import logging
import spacy
from pathlib import Path from pathlib import Path
import plac import plac
from bin.wiki_entity_linking import training_set_creator from bin.wiki_entity_linking import wikipedia_processor
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance
from bin.wiki_entity_linking.kb_creator import read_nlp_kb from bin.wiki_entity_linking.kb_creator import read_kb
from spacy.util import minibatch, compounding from spacy.util import minibatch, compounding
@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
l2=("L2 regularization", "option", "r", float), l2=("L2 regularization", "option", "r", float),
train_inst=("# training instances (default 90% of all)", "option", "t", int), train_inst=("# training instances (default 90% of all)", "option", "t", int),
dev_inst=("# test instances (default 10% of all)", "option", "d", int), dev_inst=("# test instances (default 10% of all)", "option", "d", int),
labels_discard=("NER labels to discard (default None)", "option", "l", str),
) )
def main( def main(
dir_kb, dir_kb,
@ -46,13 +47,14 @@ def main(
l2=1e-6, l2=1e-6,
train_inst=None, train_inst=None,
dev_inst=None, dev_inst=None,
labels_discard=None
): ):
logger.info("Creating Entity Linker with Wikipedia and WikiData") logger.info("Creating Entity Linker with Wikipedia and WikiData")
output_dir = Path(output_dir) if output_dir else dir_kb output_dir = Path(output_dir) if output_dir else dir_kb
training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE training_path = loc_training if loc_training else dir_kb / TRAINING_DATA_FILE
nlp_dir = dir_kb / KB_MODEL_DIR nlp_dir = dir_kb / KB_MODEL_DIR
kb_path = output_dir / KB_FILE kb_path = dir_kb / KB_FILE
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
# STEP 0: set up IO # STEP 0: set up IO
@ -60,38 +62,47 @@ def main(
output_dir.mkdir() output_dir.mkdir()
# STEP 1 : load the NLP object # STEP 1 : load the NLP object
logger.info("STEP 1: loading model from {}".format(nlp_dir)) logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
nlp, kb = read_nlp_kb(nlp_dir, kb_path) nlp = spacy.load(nlp_dir)
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
kb = read_kb(nlp, kb_path)
# check that there is a NER component in the pipeline # check that there is a NER component in the pipeline
if "ner" not in nlp.pipe_names: if "ner" not in nlp.pipe_names:
raise ValueError("The `nlp` object should have a pretrained `ner` component.") raise ValueError("The `nlp` object should have a pretrained `ner` component.")
# STEP 2: create a training dataset from WP # STEP 2: read the training dataset previously created from WP
logger.info("STEP 2: reading training dataset from {}".format(training_path)) logger.info("STEP 2: Reading training dataset from {}".format(training_path))
train_data = training_set_creator.read_training( if labels_discard:
labels_discard = [x.strip() for x in labels_discard.split(",")]
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
train_data = wikipedia_processor.read_training(
nlp=nlp, nlp=nlp,
entity_file_path=training_path, entity_file_path=training_path,
dev=False, dev=False,
limit=train_inst, limit=train_inst,
kb=kb, kb=kb,
labels_discard=labels_discard
) )
# for testing, get all pos instances, whether or not they are in the kb # for testing, get all pos instances (independently of KB)
dev_data = training_set_creator.read_training( dev_data = wikipedia_processor.read_training(
nlp=nlp, nlp=nlp,
entity_file_path=training_path, entity_file_path=training_path,
dev=True, dev=True,
limit=dev_inst, limit=dev_inst,
kb=kb, kb=None,
labels_discard=labels_discard
) )
# STEP 3: create and train the entity linking pipe # STEP 3: create and train an entity linking pipe
logger.info("STEP 3: training Entity Linking pipe") logger.info("STEP 3: Creating and training an Entity Linking pipe")
el_pipe = nlp.create_pipe( el_pipe = nlp.create_pipe(
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name} name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
"labels_discard": labels_discard}
) )
el_pipe.set_kb(kb) el_pipe.set_kb(kb)
nlp.add_pipe(el_pipe, last=True) nlp.add_pipe(el_pipe, last=True)
@ -105,14 +116,9 @@ def main(
logger.info("Training on {} articles".format(len(train_data))) logger.info("Training on {} articles".format(len(train_data)))
logger.info("Dev testing on {} articles".format(len(dev_data))) logger.info("Dev testing on {} articles".format(len(dev_data)))
dev_baseline_accuracies = measure_baselines( # baseline performance on dev data
dev_data, kb
)
logger.info("Dev Baseline Accuracies:") logger.info("Dev Baseline Accuracies:")
logger.info(dev_baseline_accuracies.report_accuracy("random")) measure_performance(dev_data, kb, el_pipe, baseline=True, context=False)
logger.info(dev_baseline_accuracies.report_accuracy("prior"))
logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
for itn in range(epochs): for itn in range(epochs):
random.shuffle(train_data) random.shuffle(train_data)
@ -136,18 +142,18 @@ def main(
logger.error("Error updating batch:" + str(e)) logger.error("Error updating batch:" + str(e))
if batchnr > 0: if batchnr > 0:
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2))) logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
measure_performance(dev_data, kb, el_pipe) measure_performance(dev_data, kb, el_pipe, baseline=False, context=True)
# STEP 4: measure the performance of our trained pipe on an independent dev set # STEP 4: measure the performance of our trained pipe on an independent dev set
logger.info("STEP 4: performance measurement of Entity Linking pipe") logger.info("STEP 4: Final performance measurement of Entity Linking pipe")
measure_performance(dev_data, kb, el_pipe) measure_performance(dev_data, kb, el_pipe)
# STEP 5: apply the EL pipe on a toy example # STEP 5: apply the EL pipe on a toy example
logger.info("STEP 5: applying Entity Linking to toy example") logger.info("STEP 5: Applying Entity Linking to toy example")
run_el_toy_example(nlp=nlp) run_el_toy_example(nlp=nlp)
if output_dir: if output_dir:
# STEP 6: write the NLP pipeline (including entity linker) to file # STEP 6: write the NLP pipeline (now including an EL model) to file
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir)) logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
nlp.to_disk(nlp_output_dir) nlp.to_disk(nlp_output_dir)

View File

@ -3,147 +3,104 @@ from __future__ import unicode_literals
import re import re
import bz2 import bz2
import csv
import datetime
import logging import logging
import random
import json
from bin.wiki_entity_linking import LOG_FORMAT from functools import partial
from spacy.gold import GoldParse
from bin.wiki_entity_linking import wiki_io as io
from bin.wiki_entity_linking.wiki_namespaces import (
WP_META_NAMESPACE,
WP_FILE_NAMESPACE,
WP_CATEGORY_NAMESPACE,
)
""" """
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions. Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
Write these results to file for downstream KB and training data generation. Write these results to file for downstream KB and training data generation.
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
""" """
ENTITY_FILE = "gold_entities.csv"
map_alias_to_link = dict() map_alias_to_link = dict()
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
# these will/should be matched ignoring case id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
wiki_namespaces = [ text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
"b", info_regex = re.compile(r"{[^{]*?}")
"betawikiversity", html_regex = re.compile(r"&lt;!--[^-]*--&gt;")
"Book", ref_regex = re.compile(r"&lt;ref.*?&gt;") # non-greedy
"c", ref_2_regex = re.compile(r"&lt;/ref.*?&gt;") # non-greedy
"Category",
"Commons",
"d",
"dbdump",
"download",
"Draft",
"Education",
"Foundation",
"Gadget",
"Gadget definition",
"gerrit",
"File",
"Help",
"Image",
"Incubator",
"m",
"mail",
"mailarchive",
"media",
"MediaWiki",
"MediaWiki talk",
"Mediawikiwiki",
"MediaZilla",
"Meta",
"Metawikipedia",
"Module",
"mw",
"n",
"nost",
"oldwikisource",
"outreach",
"outreachwiki",
"otrs",
"OTRSwiki",
"Portal",
"phab",
"Phabricator",
"Project",
"q",
"quality",
"rev",
"s",
"spcom",
"Special",
"species",
"Strategy",
"sulutil",
"svn",
"Talk",
"Template",
"Template talk",
"Testwiki",
"ticket",
"TimedText",
"Toollabs",
"tools",
"tswiki",
"User",
"User talk",
"v",
"voy",
"w",
"Wikibooks",
"Wikidata",
"wikiHow",
"Wikinvest",
"wikilivres",
"Wikimedia",
"Wikinews",
"Wikipedia",
"Wikipedia talk",
"Wikiquote",
"Wikisource",
"Wikispecies",
"Wikitech",
"Wikiversity",
"Wikivoyage",
"wikt",
"wiktionary",
"wmf",
"wmania",
"WP",
]
# find the links # find the links
link_regex = re.compile(r"\[\[[^\[\]]*\]\]") link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
# match on interwiki links, e.g. `en:` or `:fr:` # match on interwiki links, e.g. `en:` or `:fr:`
ns_regex = r":?" + "[a-z][a-z]" + ":" ns_regex = r":?" + "[a-z][a-z]" + ":"
# match on Namespace: optionally preceded by a : # match on Namespace: optionally preceded by a :
for ns in wiki_namespaces: for ns in WP_META_NAMESPACE:
ns_regex += "|" + ":?" + ns + ":" ns_regex += "|" + ":?" + ns + ":"
ns_regex = re.compile(ns_regex, re.IGNORECASE) ns_regex = re.compile(ns_regex, re.IGNORECASE)
files = r""
for f in WP_FILE_NAMESPACE:
files += "\[\[" + f + ":[^[\]]+]]" + "|"
files = files[0 : len(files) - 1]
file_regex = re.compile(files)
cats = r""
for c in WP_CATEGORY_NAMESPACE:
cats += "\[\[" + c + ":[^\[]*]]" + "|"
cats = cats[0 : len(cats) - 1]
category_regex = re.compile(cats)
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None): def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
""" """
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities. Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
The full file takes about 2h to parse 1100M lines. The full file takes about 2-3h to parse 1100M lines.
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from. It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from,
though dev test articles are excluded in order not to get an artificially strong baseline.
""" """
cnt = 0
read_id = False
current_article_id = None
with bz2.open(wikipedia_input, mode="rb") as file: with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline() line = file.readline()
cnt = 0
while line and (not limit or cnt < limit): while line and (not limit or cnt < limit):
if cnt % 25000000 == 0: if cnt % 25000000 == 0 and cnt > 0:
logger.info("processed {} lines of Wikipedia XML dump".format(cnt)) logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
clean_line = line.strip().decode("utf-8") clean_line = line.strip().decode("utf-8")
aliases, entities, normalizations = get_wp_links(clean_line) # we attempt at reading the article's ID (but not the revision or contributor ID)
for alias, entity, norm in zip(aliases, entities, normalizations): if "<revision>" in clean_line or "<contributor>" in clean_line:
_store_alias(alias, entity, normalize_alias=norm, normalize_entity=True) read_id = False
_store_alias(alias, entity, normalize_alias=norm, normalize_entity=True) if "<page>" in clean_line:
read_id = True
if read_id:
ids = id_regex.search(clean_line)
if ids:
current_article_id = ids[0]
# only processing prior probabilities from true training (non-dev) articles
if not is_dev(current_article_id):
aliases, entities, normalizations = get_wp_links(clean_line)
for alias, entity, norm in zip(aliases, entities, normalizations):
_store_alias(
alias, entity, normalize_alias=norm, normalize_entity=True
)
line = file.readline() line = file.readline()
cnt += 1 cnt += 1
logger.info("processed {} lines of Wikipedia XML dump".format(cnt)) logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
logger.info("Finished. processed {} lines of Wikipedia XML dump".format(cnt))
# write all aliases and their entities and count occurrences to file # write all aliases and their entities and count occurrences to file
with prior_prob_output.open("w", encoding="utf8") as outputfile: with prior_prob_output.open("w", encoding="utf8") as outputfile:
@ -182,7 +139,7 @@ def get_wp_links(text):
match = match[2:][:-2].replace("_", " ").strip() match = match[2:][:-2].replace("_", " ").strip()
if ns_regex.match(match): if ns_regex.match(match):
pass # ignore namespaces at the beginning of the string pass # ignore the entity if it points to a "meta" page
# this is a simple [[link]], with the alias the same as the mention # this is a simple [[link]], with the alias the same as the mention
elif "|" not in match: elif "|" not in match:
@ -218,47 +175,382 @@ def _capitalize_first(text):
return result return result
def write_entity_counts(prior_prob_input, count_output, to_print=False): def create_training_and_desc(
# Write entity counts for quick access later wp_input, def_input, desc_output, training_output, parse_desc, limit=None
entity_to_count = dict() ):
total_count = 0 wp_to_id = io.read_title_to_id(def_input)
_process_wikipedia_texts(
with prior_prob_input.open("r", encoding="utf8") as prior_file: wp_input, wp_to_id, desc_output, training_output, parse_desc, limit
# skip header )
prior_file.readline()
line = prior_file.readline()
while line:
splits = line.replace("\n", "").split(sep="|")
# alias = splits[0]
count = int(splits[1])
entity = splits[2]
current_count = entity_to_count.get(entity, 0)
entity_to_count[entity] = current_count + count
total_count += count
line = prior_file.readline()
with count_output.open("w", encoding="utf8") as entity_file:
entity_file.write("entity" + "|" + "count" + "\n")
for entity, count in entity_to_count.items():
entity_file.write(entity + "|" + str(count) + "\n")
if to_print:
for entity, count in entity_to_count.items():
print("Entity count:", entity, count)
print("Total count:", total_count)
def get_all_frequencies(count_input): def _process_wikipedia_texts(
entity_to_count = dict() wikipedia_input, wp_to_id, output, training_output, parse_descriptions, limit=None
with count_input.open("r", encoding="utf8") as csvfile: ):
csvreader = csv.reader(csvfile, delimiter="|") """
# skip header Read the XML wikipedia data to parse out training data:
next(csvreader) raw text data + positive instances
for row in csvreader: """
entity_to_count[row[0]] = int(row[1])
return entity_to_count read_ids = set()
with output.open("a", encoding="utf8") as descr_file, training_output.open(
"w", encoding="utf8"
) as entity_file:
if parse_descriptions:
_write_training_description(descr_file, "WD_id", "description")
with bz2.open(wikipedia_input, mode="rb") as file:
article_count = 0
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
for line in file:
clean_line = line.strip().decode("utf-8")
if clean_line == "<revision>":
reading_revision = True
elif clean_line == "</revision>":
reading_revision = False
# Start reading new page
if clean_line == "<page>":
article_text = ""
article_title = None
article_id = None
# finished reading this page
elif clean_line == "</page>":
if article_id:
clean_text, entities = _process_wp_text(
article_title, article_text, wp_to_id
)
if clean_text is not None and entities is not None:
_write_training_entities(
entity_file, article_id, clean_text, entities
)
if article_title in wp_to_id and parse_descriptions:
description = " ".join(
clean_text[:1000].split(" ")[:-1]
)
_write_training_description(
descr_file, wp_to_id[article_title], description
)
article_count += 1
if article_count % 10000 == 0 and article_count > 0:
logger.info(
"Processed {} articles".format(article_count)
)
if limit and article_count >= limit:
break
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
# start reading text within a page
if "<text" in clean_line:
reading_text = True
if reading_text:
article_text += " " + clean_line
# stop reading text within a page (we assume a new page doesn't start on the same line)
if "</text" in clean_line:
reading_text = False
# read the ID of this article (outside the revision portion of the document)
if not reading_revision:
ids = id_regex.search(clean_line)
if ids:
article_id = ids[0]
if article_id in read_ids:
logger.info(
"Found duplicate article ID", article_id, clean_line
) # This should never happen ...
read_ids.add(article_id)
# read the title of this article (outside the revision portion of the document)
if not reading_revision:
titles = title_regex.search(clean_line)
if titles:
article_title = titles[0].strip()
logger.info("Finished. Processed {} articles".format(article_count))
def _process_wp_text(article_title, article_text, wp_to_id):
# ignore meta Wikipedia pages
if ns_regex.match(article_title):
return None, None
# remove the text tags
text_search = text_regex.search(article_text)
if text_search is None:
return None, None
text = text_search.group(0)
# stop processing if this is a redirect page
if text.startswith("#REDIRECT"):
return None, None
# get the raw text without markup etc, keeping only interwiki links
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
return clean_text, entities
def _get_clean_wp_text(article_text):
clean_text = article_text.strip()
# remove bolding & italic markup
clean_text = clean_text.replace("'''", "")
clean_text = clean_text.replace("''", "")
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
try_again = True
previous_length = len(clean_text)
while try_again:
clean_text = info_regex.sub(
"", clean_text
) # non-greedy match excluding a nested {
if len(clean_text) < previous_length:
try_again = True
else:
try_again = False
previous_length = len(clean_text)
# remove HTML comments
clean_text = html_regex.sub("", clean_text)
# remove Category and File statements
clean_text = category_regex.sub("", clean_text)
clean_text = file_regex.sub("", clean_text)
# remove multiple =
while "==" in clean_text:
clean_text = clean_text.replace("==", "=")
clean_text = clean_text.replace(". =", ".")
clean_text = clean_text.replace(" = ", ". ")
clean_text = clean_text.replace("= ", ".")
clean_text = clean_text.replace(" =", "")
# remove refs (non-greedy match)
clean_text = ref_regex.sub("", clean_text)
clean_text = ref_2_regex.sub("", clean_text)
# remove additional wikiformatting
clean_text = re.sub(r"&lt;blockquote&gt;", "", clean_text)
clean_text = re.sub(r"&lt;/blockquote&gt;", "", clean_text)
# change special characters back to normal ones
clean_text = clean_text.replace(r"&lt;", "<")
clean_text = clean_text.replace(r"&gt;", ">")
clean_text = clean_text.replace(r"&quot;", '"')
clean_text = clean_text.replace(r"&amp;nbsp;", " ")
clean_text = clean_text.replace(r"&amp;", "&")
# remove multiple spaces
while " " in clean_text:
clean_text = clean_text.replace(" ", " ")
return clean_text.strip()
def _remove_links(clean_text, wp_to_id):
# read the text char by char to get the right offsets for the interwiki links
entities = []
final_text = ""
open_read = 0
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
entity_buffer = ""
mention_buffer = ""
for index, letter in enumerate(clean_text):
if letter == "[":
open_read += 1
elif letter == "]":
open_read -= 1
elif letter == "|":
if reading_text:
final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern
elif reading_entity:
reading_text = False
reading_entity = False
reading_mention = True
else:
reading_special_case = True
else:
if reading_entity:
entity_buffer += letter
elif reading_mention:
mention_buffer += letter
elif reading_text:
final_text += letter
else:
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
if open_read > 2:
reading_special_case = True
if open_read == 2 and reading_text:
reading_text = False
reading_entity = True
reading_mention = False
# we just finished reading an entity
if open_read == 0 and not reading_text:
if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
if not mention_buffer:
mention_buffer = entity_buffer
start = len(final_text)
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
entities.append((mention_buffer, qid, start, end))
final_text += mention_buffer
entity_buffer = ""
mention_buffer = ""
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
return final_text, entities
def _write_training_description(outputfile, qid, description):
if description is not None:
line = str(qid) + "|" + description + "\n"
outputfile.write(line)
def _write_training_entities(outputfile, article_id, clean_text, entities):
entities_data = [
{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]}
for ent in entities
]
line = (
json.dumps(
{
"article_id": article_id,
"clean_text": clean_text,
"entities": entities_data,
},
ensure_ascii=False,
)
+ "\n"
)
outputfile.write(line)
def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
For training, it will include both positive and negative examples by using the candidate generator from the kb.
For testing (kb=None), it will include all positive examples only."""
from tqdm import tqdm
if not labels_discard:
labels_discard = []
data = []
num_entities = 0
get_gold_parse = partial(
_get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard
)
logger.info(
"Reading {} data with limit {}".format("dev" if dev else "train", limit)
)
with entity_file_path.open("r", encoding="utf8") as file:
with tqdm(total=limit, leave=False) as pbar:
for i, line in enumerate(file):
example = json.loads(line)
article_id = example["article_id"]
clean_text = example["clean_text"]
entities = example["entities"]
if dev != is_dev(article_id) or not is_valid_article(clean_text):
continue
doc = nlp(clean_text)
gold = get_gold_parse(doc, entities)
if gold and len(gold.links) > 0:
data.append((doc, gold))
num_entities += len(gold.links)
pbar.update(len(gold.links))
if limit and num_entities >= limit:
break
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
return data
def _get_gold_parse(doc, entities, dev, kb, labels_discard):
gold_entities = {}
tagged_ent_positions = {
(ent.start_char, ent.end_char): ent
for ent in doc.ents
if ent.label_ not in labels_discard
}
for entity in entities:
entity_id = entity["entity"]
alias = entity["alias"]
start = entity["start"]
end = entity["end"]
candidate_ids = []
if kb and not dev:
candidates = kb.get_candidates(alias)
candidate_ids = [cand.entity_ for cand in candidates]
tagged_ent = tagged_ent_positions.get((start, end), None)
if tagged_ent:
# TODO: check that alias == doc.text[start:end]
should_add_ent = (dev or entity_id in candidate_ids) and is_valid_sentence(
tagged_ent.sent.text
)
if should_add_ent:
value_by_id = {entity_id: 1.0}
if not dev:
random.shuffle(candidate_ids)
value_by_id.update(
{kb_id: 0.0 for kb_id in candidate_ids if kb_id != entity_id}
)
gold_entities[(start, end)] = value_by_id
return GoldParse(doc, links=gold_entities)
def is_dev(article_id):
if not article_id:
return False
return article_id.endswith("3")
def is_valid_article(doc_text):
# custom length cut-off
return 10 < len(doc_text) < 30000
def is_valid_sentence(sent_text):
if not 10 < len(sent_text) < 3000:
# custom length cut-off
return False
if sent_text.strip().startswith("*") or sent_text.strip().startswith("#"):
# remove 'enumeration' sentences (occurs often on Wikipedia)
return False
return True

View File

@ -246,7 +246,7 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
"""Perform an update over a single batch of documents. """Perform an update over a single batch of documents.
docs (iterable): A batch of `Doc` objects. docs (iterable): A batch of `Doc` objects.
drop (float): The droput rate. drop (float): The dropout rate.
optimizer (callable): An optimizer. optimizer (callable): An optimizer.
RETURNS loss: A float for the loss. RETURNS loss: A float for the loss.
""" """

View File

@ -80,8 +80,8 @@ class Warnings(object):
"the v2.x models cannot release the global interpreter lock. " "the v2.x models cannot release the global interpreter lock. "
"Future versions may introduce a `n_process` argument for " "Future versions may introduce a `n_process` argument for "
"parallel inference via multiprocessing.") "parallel inference via multiprocessing.")
W017 = ("Alias '{alias}' already exists in the Knowledge base.") W017 = ("Alias '{alias}' already exists in the Knowledge Base.")
W018 = ("Entity '{entity}' already exists in the Knowledge base.") W018 = ("Entity '{entity}' already exists in the Knowledge Base.")
W019 = ("Changing vectors name from {old} to {new}, to avoid clash with " W019 = ("Changing vectors name from {old} to {new}, to avoid clash with "
"previously loaded vectors. See Issue #3853.") "previously loaded vectors. See Issue #3853.")
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be " W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
@ -96,6 +96,8 @@ class Warnings(object):
"If this is surprising, make sure you have the spacy-lookups-data " "If this is surprising, make sure you have the spacy-lookups-data "
"package installed.") "package installed.")
W023 = ("Multiprocessing of Language.pipe is not supported in Python2. 'n_process' will be set to 1.") W023 = ("Multiprocessing of Language.pipe is not supported in Python2. 'n_process' will be set to 1.")
W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in "
"the Knowledge Base.")
@add_codes @add_codes
@ -408,7 +410,7 @@ class Errors(object):
"{probabilities_length} respectively.") "{probabilities_length} respectively.")
E133 = ("The sum of prior probabilities for alias '{alias}' should not " E133 = ("The sum of prior probabilities for alias '{alias}' should not "
"exceed 1, but found {sum}.") "exceed 1, but found {sum}.")
E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.") E134 = ("Entity '{entity}' is not defined in the Knowledge Base.")
E135 = ("If you meant to replace a built-in component, use `create_pipe`: " E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`") "`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
E136 = ("This additional feature requires the jsonschema library to be " E136 = ("This additional feature requires the jsonschema library to be "
@ -420,7 +422,7 @@ class Errors(object):
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input " E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
"includes either the `text` or `tokens` key. For more info, see " "includes either the `text` or `tokens` key. For more info, see "
"the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl") "the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
E139 = ("Knowledge base for component '{name}' not initialized. Did you " E139 = ("Knowledge Base for component '{name}' not initialized. Did you "
"forget to call set_kb()?") "forget to call set_kb()?")
E140 = ("The list of entities, prior probabilities and entity vectors " E140 = ("The list of entities, prior probabilities and entity vectors "
"should be of equal length.") "should be of equal length.")
@ -499,6 +501,7 @@ class Errors(object):
E174 = ("Architecture '{name}' not found in registry. Available " E174 = ("Architecture '{name}' not found in registry. Available "
"names: {names}") "names: {names}")
E175 = ("Can't remove rule for unknown match pattern ID: {key}") E175 = ("Can't remove rule for unknown match pattern ID: {key}")
E176 = ("Alias '{alias}' is not defined in the Knowledge Base.")
@add_codes @add_codes

View File

@ -142,6 +142,7 @@ cdef class KnowledgeBase:
i = 0 i = 0
cdef KBEntryC entry cdef KBEntryC entry
cdef hash_t entity_hash
while i < nr_entities: while i < nr_entities:
entity_vector = vector_list[i] entity_vector = vector_list[i]
if len(entity_vector) != self.entity_vector_length: if len(entity_vector) != self.entity_vector_length:
@ -161,6 +162,14 @@ cdef class KnowledgeBase:
i += 1 i += 1
def contains_entity(self, unicode entity):
cdef hash_t entity_hash = self.vocab.strings.add(entity)
return entity_hash in self._entry_index
def contains_alias(self, unicode alias):
cdef hash_t alias_hash = self.vocab.strings.add(alias)
return alias_hash in self._alias_index
def add_alias(self, unicode alias, entities, probabilities): def add_alias(self, unicode alias, entities, probabilities):
""" """
For a given alias, add its potential entities and prior probabilies to the KB. For a given alias, add its potential entities and prior probabilies to the KB.
@ -190,7 +199,7 @@ cdef class KnowledgeBase:
for entity, prob in zip(entities, probabilities): for entity, prob in zip(entities, probabilities):
entity_hash = self.vocab.strings[entity] entity_hash = self.vocab.strings[entity]
if not entity_hash in self._entry_index: if not entity_hash in self._entry_index:
raise ValueError(Errors.E134.format(alias=alias, entity=entity)) raise ValueError(Errors.E134.format(entity=entity))
entry_index = <int64_t>self._entry_index.get(entity_hash) entry_index = <int64_t>self._entry_index.get(entity_hash)
entry_indices.push_back(int(entry_index)) entry_indices.push_back(int(entry_index))
@ -201,8 +210,63 @@ cdef class KnowledgeBase:
return alias_hash return alias_hash
def get_candidates(self, unicode alias): def append_alias(self, unicode alias, unicode entity, float prior_prob, ignore_warnings=False):
"""
For an alias already existing in the KB, extend its potential entities with one more.
Throw a warning if either the alias or the entity is unknown,
or when the combination is already previously recorded.
Throw an error if this entity+prior prob would exceed the sum of 1.
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
"""
# Check if the alias exists in the KB
cdef hash_t alias_hash = self.vocab.strings[alias] cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index:
raise ValueError(Errors.E176.format(alias=alias))
# Check if the entity exists in the KB
cdef hash_t entity_hash = self.vocab.strings[entity]
if not entity_hash in self._entry_index:
raise ValueError(Errors.E134.format(entity=entity))
entry_index = <int64_t>self._entry_index.get(entity_hash)
# Throw an error if the prior probabilities (including the new one) sum up to more than 1
alias_index = <int64_t>self._alias_index.get(alias_hash)
alias_entry = self._aliases_table[alias_index]
current_sum = sum([p for p in alias_entry.probs])
new_sum = current_sum + prior_prob
if new_sum > 1.00001:
raise ValueError(Errors.E133.format(alias=alias, sum=new_sum))
entry_indices = alias_entry.entry_indices
is_present = False
for i in range(entry_indices.size()):
if entry_indices[i] == int(entry_index):
is_present = True
if is_present:
if not ignore_warnings:
user_warning(Warnings.W024.format(entity=entity, alias=alias))
else:
entry_indices.push_back(int(entry_index))
alias_entry.entry_indices = entry_indices
probs = alias_entry.probs
probs.push_back(float(prior_prob))
alias_entry.probs = probs
self._aliases_table[alias_index] = alias_entry
def get_candidates(self, unicode alias):
"""
Return candidate entities for an alias. Each candidate defines the entity, the original alias,
and the prior probability of that alias resolving to that entity.
If the alias is not known in the KB, and empty list is returned.
"""
cdef hash_t alias_hash = self.vocab.strings[alias]
if not alias_hash in self._alias_index:
return []
alias_index = <int64_t>self._alias_index.get(alias_hash) alias_index = <int64_t>self._alias_index.get(alias_hash)
alias_entry = self._aliases_table[alias_index] alias_entry = self._aliases_table[alias_index]
@ -341,7 +405,6 @@ cdef class KnowledgeBase:
assert nr_entities == self.get_size_entities() assert nr_entities == self.get_size_entities()
# STEP 3: load aliases # STEP 3: load aliases
cdef int64_t nr_aliases cdef int64_t nr_aliases
reader.read_alias_length(&nr_aliases) reader.read_alias_length(&nr_aliases)
self._alias_index = PreshMap(nr_aliases+1) self._alias_index = PreshMap(nr_aliases+1)

View File

@ -483,7 +483,7 @@ class Language(object):
docs (iterable): A batch of `Doc` objects. docs (iterable): A batch of `Doc` objects.
golds (iterable): A batch of `GoldParse` objects. golds (iterable): A batch of `GoldParse` objects.
drop (float): The droput rate. drop (float): The dropout rate.
sgd (callable): An optimizer. sgd (callable): An optimizer.
losses (dict): Dictionary to update with the loss, keyed by component. losses (dict): Dictionary to update with the loss, keyed by component.
component_cfg (dict): Config parameters for specific pipeline component_cfg (dict): Config parameters for specific pipeline
@ -531,7 +531,7 @@ class Language(object):
even if you're updating it with a smaller set of examples. even if you're updating it with a smaller set of examples.
docs (iterable): A batch of `Doc` objects. docs (iterable): A batch of `Doc` objects.
drop (float): The droput rate. drop (float): The dropout rate.
sgd (callable): An optimizer. sgd (callable): An optimizer.
RETURNS (dict): Results from the update. RETURNS (dict): Results from the update.

View File

@ -1195,23 +1195,26 @@ class EntityLinker(Pipe):
docs = [docs] docs = [docs]
golds = [golds] golds = [golds]
context_docs = [] sentence_docs = []
for doc, gold in zip(docs, golds): for doc, gold in zip(docs, golds):
ents_by_offset = dict() ents_by_offset = dict()
for ent in doc.ents: for ent in doc.ents:
ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent ents_by_offset[(ent.start_char, ent.end_char)] = ent
for entity, kb_dict in gold.links.items(): for entity, kb_dict in gold.links.items():
start, end = entity start, end = entity
mention = doc.text[start:end] mention = doc.text[start:end]
# the gold annotations should link to proper entities - if this fails, the dataset is likely corrupt
ent = ents_by_offset[(start, end)]
for kb_id, value in kb_dict.items(): for kb_id, value in kb_dict.items():
# Currently only training on the positive instances # Currently only training on the positive instances
if value: if value:
context_docs.append(doc) sentence_docs.append(ent.sent.as_doc())
context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop) sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop)
loss, d_scores = self.get_similarity_loss(scores=context_encodings, golds=golds, docs=None) loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None)
bp_context(d_scores, sgd=sgd) bp_context(d_scores, sgd=sgd)
if losses is not None: if losses is not None:
@ -1280,50 +1283,68 @@ class EntityLinker(Pipe):
if isinstance(docs, Doc): if isinstance(docs, Doc):
docs = [docs] docs = [docs]
context_encodings = self.model(docs)
xp = get_array_module(context_encodings)
for i, doc in enumerate(docs): for i, doc in enumerate(docs):
if len(doc) > 0: if len(doc) > 0:
# currently, the context is the same for each entity in a sentence (should be refined) # Looping through each sentence and each entity
context_encoding = context_encodings[i] # This may go wrong if there are entities across sentences - because they might not get a KB ID
context_enc_t = context_encoding.T for sent in doc.ents:
norm_1 = xp.linalg.norm(context_enc_t) sent_doc = sent.as_doc()
for ent in doc.ents: # currently, the context is the same for each entity in a sentence (should be refined)
entity_count += 1 sentence_encoding = self.model([sent_doc])[0]
xp = get_array_module(sentence_encoding)
sentence_encoding_t = sentence_encoding.T
sentence_norm = xp.linalg.norm(sentence_encoding_t)
candidates = self.kb.get_candidates(ent.text) for ent in sent_doc.ents:
if not candidates: entity_count += 1
final_kb_ids.append(self.NIL) # no prediction possible for this entity
final_tensors.append(context_encoding)
else:
random.shuffle(candidates)
# this will set all prior probabilities to 0 if they should be excluded from the model if ent.label_ in self.cfg.get("labels_discard", []):
prior_probs = xp.asarray([c.prior_prob for c in candidates]) # ignoring this entity - setting to NIL
if not self.cfg.get("incl_prior", True): final_kb_ids.append(self.NIL)
prior_probs = xp.asarray([0.0 for c in candidates]) final_tensors.append(sentence_encoding)
scores = prior_probs
# add in similarity from the context else:
if self.cfg.get("incl_context", True): candidates = self.kb.get_candidates(ent.text)
entity_encodings = xp.asarray([c.entity_vector for c in candidates]) if not candidates:
norm_2 = xp.linalg.norm(entity_encodings, axis=1) # no prediction possible for this entity - setting to NIL
final_kb_ids.append(self.NIL)
final_tensors.append(sentence_encoding)
if len(entity_encodings) != len(prior_probs): elif len(candidates) == 1:
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length")) # shortcut for efficiency reasons: take the 1 candidate
# cosine similarity # TODO: thresholding
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2) final_kb_ids.append(candidates[0].entity_)
if sims.shape != prior_probs.shape: final_tensors.append(sentence_encoding)
raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs*sims)
# TODO: thresholding else:
best_index = scores.argmax() random.shuffle(candidates)
best_candidate = candidates[best_index]
final_kb_ids.append(best_candidate.entity_) # this will set all prior probabilities to 0 if they should be excluded from the model
final_tensors.append(context_encoding) prior_probs = xp.asarray([c.prior_prob for c in candidates])
if not self.cfg.get("incl_prior", True):
prior_probs = xp.asarray([0.0 for c in candidates])
scores = prior_probs
# add in similarity from the context
if self.cfg.get("incl_context", True):
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
if len(entity_encodings) != len(prior_probs):
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
# cosine similarity
sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm)
if sims.shape != prior_probs.shape:
raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs*sims)
# TODO: thresholding
best_index = scores.argmax()
best_candidate = candidates[best_index]
final_kb_ids.append(best_candidate.entity_)
final_tensors.append(sentence_encoding)
if not (len(final_tensors) == len(final_kb_ids) == entity_count): if not (len(final_tensors) == len(final_kb_ids) == entity_count):
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length")) raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))

View File

@ -131,6 +131,53 @@ def test_candidate_generation(nlp):
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9) assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
def test_append_alias(nlp):
"""Test that we can append additional alias-entity pairs"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
# adding aliases
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.4, 0.1])
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# test the size of the relevant candidates
assert len(mykb.get_candidates("douglas")) == 2
# append an alias
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
# test the size of the relevant candidates has been incremented
assert len(mykb.get_candidates("douglas")) == 3
# append the same alias-entity pair again should not work (will throw a warning)
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3)
# test the size of the relevant candidates remained unchanged
assert len(mykb.get_candidates("douglas")) == 3
def test_append_invalid_alias(nlp):
"""Test that append an alias will throw an error if prior probs are exceeding 1"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
# adding entities
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
# adding aliases
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
# append an alias - should fail because the entities and probabilities vectors are not of equal length
with pytest.raises(ValueError):
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
def test_preserving_links_asdoc(nlp): def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links""" """Test that Span.as_doc preserves the existing entity links"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)

View File

@ -430,7 +430,7 @@ def test_issue957(en_tokenizer):
def test_issue999(train_data): def test_issue999(train_data):
"""Test that adding entities and resuming training works passably OK. """Test that adding entities and resuming training works passably OK.
There are two issues here: There are two issues here:
1) We have to read labels. This isn't very nice. 1) We have to re-add labels. This isn't very nice.
2) There's no way to set the learning rate for the weight update, so we 2) There's no way to set the learning rate for the weight update, so we
end up out-of-scale, causing it to learn too fast. end up out-of-scale, causing it to learn too fast.
""" """