diff --git a/bin/wiki_entity_linking/README.md b/bin/wiki_entity_linking/README.md
new file mode 100644
index 000000000..540878592
--- /dev/null
+++ b/bin/wiki_entity_linking/README.md
@@ -0,0 +1,34 @@
+## Entity Linking with Wikipedia and Wikidata
+
+### Step 1: Create a Knowledge Base (KB) and training data
+
+Run `wikipedia_pretrain_kb.py`
+* This takes as input the locations of a **Wikipedia and a Wikidata dump**, and produces a **KB directory** + **training file**
+ * WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
+ * Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
+* You can set the filtering parameters for KB construction:
+ * `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym
+ * `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB
+ * `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
+* Further parameters to set:
+ * `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
+ * `entity_vector_length`: length of the pre-trained entity description vectors
+ * `lang`: language for which to fetch Wikidata information (as the dump contains all languages)
+
+Quick testing and rerunning:
+* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything.
+* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
+
+
+### Step 2: Train an Entity Linking model
+
+Run `wikidata_train_entity_linker.py`
+* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
+* You can set the learning parameters for the EL training:
+ * `epochs`: number of training iterations
+ * `dropout`: dropout rate
+ * `lr`: learning rate
+ * `l2`: L2 regularization
+* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively
+* Further parameters to set:
+ * `labels_discard`: NER label types to discard during training
diff --git a/bin/wiki_entity_linking/__init__.py b/bin/wiki_entity_linking/__init__.py
index a604bcc2f..de486bbcf 100644
--- a/bin/wiki_entity_linking/__init__.py
+++ b/bin/wiki_entity_linking/__init__.py
@@ -6,6 +6,7 @@ OUTPUT_MODEL_DIR = "nlp"
PRIOR_PROB_PATH = "prior_prob.csv"
ENTITY_DEFS_PATH = "entity_defs.csv"
ENTITY_FREQ_PATH = "entity_freq.csv"
+ENTITY_ALIAS_PATH = "entity_alias.csv"
ENTITY_DESCR_PATH = "entity_descriptions.csv"
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
diff --git a/bin/wiki_entity_linking/entity_linker_evaluation.py b/bin/wiki_entity_linking/entity_linker_evaluation.py
index 1b1200564..94bafbf30 100644
--- a/bin/wiki_entity_linking/entity_linker_evaluation.py
+++ b/bin/wiki_entity_linking/entity_linker_evaluation.py
@@ -15,10 +15,11 @@ class Metrics(object):
candidate_is_correct = true_entity == candidate
# Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
- # Therefore, if candidate_is_correct then we have a true positive and never a true negative
+ # Therefore, if candidate_is_correct then we have a true positive and never a true negative.
self.true_pos += candidate_is_correct
self.false_neg += not candidate_is_correct
- if candidate not in {"", "NIL"}:
+ if candidate and candidate not in {"", "NIL"}:
+ # A wrong prediction (e.g. Q42 != Q3) counts both as a FP as well as a FN.
self.false_pos += not candidate_is_correct
def calculate_precision(self):
@@ -33,6 +34,14 @@ class Metrics(object):
else:
return self.true_pos / (self.true_pos + self.false_neg)
+ def calculate_fscore(self):
+ p = self.calculate_precision()
+ r = self.calculate_recall()
+ if p + r == 0:
+ return 0.0
+ else:
+ return 2 * p * r / (p + r)
+
class EvaluationResults(object):
def __init__(self):
@@ -43,18 +52,20 @@ class EvaluationResults(object):
self.metrics.update_results(true_entity, candidate)
self.metrics_by_label[ent_label].update_results(true_entity, candidate)
- def increment_false_negatives(self):
- self.metrics.false_neg += 1
-
def report_metrics(self, model_name):
model_str = model_name.title()
recall = self.metrics.calculate_recall()
precision = self.metrics.calculate_precision()
- return ("{}: ".format(model_str) +
- "Recall = {} | ".format(round(recall, 3)) +
- "Precision = {} | ".format(round(precision, 3)) +
- "Precision by label = {}".format({k: v.calculate_precision()
- for k, v in self.metrics_by_label.items()}))
+ fscore = self.metrics.calculate_fscore()
+ return (
+ "{}: ".format(model_str)
+ + "F-score = {} | ".format(round(fscore, 3))
+ + "Recall = {} | ".format(round(recall, 3))
+ + "Precision = {} | ".format(round(precision, 3))
+ + "F-score by label = {}".format(
+ {k: v.calculate_fscore() for k, v in sorted(self.metrics_by_label.items())}
+ )
+ )
class BaselineResults(object):
@@ -63,40 +74,51 @@ class BaselineResults(object):
self.prior = EvaluationResults()
self.oracle = EvaluationResults()
- def report_accuracy(self, model):
+ def report_performance(self, model):
results = getattr(self, model)
return results.report_metrics(model)
- def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate):
+ def update_baselines(
+ self,
+ true_entity,
+ ent_label,
+ random_candidate,
+ prior_candidate,
+ oracle_candidate,
+ ):
self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
self.prior.update_metrics(ent_label, true_entity, prior_candidate)
self.random.update_metrics(ent_label, true_entity, random_candidate)
-def measure_performance(dev_data, kb, el_pipe):
- baseline_accuracies = measure_baselines(
- dev_data, kb
- )
+def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True):
+ if baseline:
+ baseline_accuracies, counts = measure_baselines(dev_data, kb)
+ logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
+ logger.info(baseline_accuracies.report_performance("random"))
+ logger.info(baseline_accuracies.report_performance("prior"))
+ logger.info(baseline_accuracies.report_performance("oracle"))
- logger.info(baseline_accuracies.report_accuracy("random"))
- logger.info(baseline_accuracies.report_accuracy("prior"))
- logger.info(baseline_accuracies.report_accuracy("oracle"))
+ if context:
+ # using only context
+ el_pipe.cfg["incl_context"] = True
+ el_pipe.cfg["incl_prior"] = False
+ results = get_eval_results(dev_data, el_pipe)
+ logger.info(results.report_metrics("context only"))
- # using only context
- el_pipe.cfg["incl_context"] = True
- el_pipe.cfg["incl_prior"] = False
- results = get_eval_results(dev_data, el_pipe)
- logger.info(results.report_metrics("context only"))
-
- # measuring combined accuracy (prior + context)
- el_pipe.cfg["incl_context"] = True
- el_pipe.cfg["incl_prior"] = True
- results = get_eval_results(dev_data, el_pipe)
- logger.info(results.report_metrics("context and prior"))
+ # measuring combined accuracy (prior + context)
+ el_pipe.cfg["incl_context"] = True
+ el_pipe.cfg["incl_prior"] = True
+ results = get_eval_results(dev_data, el_pipe)
+ logger.info(results.report_metrics("context and prior"))
def get_eval_results(data, el_pipe=None):
- # If the docs in the data require further processing with an entity linker, set el_pipe
+ """
+ Evaluate the ent.kb_id_ annotations against the gold standard.
+ Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
+ If the docs in the data require further processing with an entity linker, set el_pipe.
+ """
from tqdm import tqdm
docs = []
@@ -111,18 +133,15 @@ def get_eval_results(data, el_pipe=None):
results = EvaluationResults()
for doc, gold in zip(docs, golds):
- tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
try:
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
- # only evaluating on positive examples
for gold_kb, value in kb_dict.items():
if value:
+ # only evaluating on positive examples
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
- if offset not in tagged_entries_per_article:
- results.increment_false_negatives()
for ent in doc.ents:
ent_label = ent.label_
@@ -142,7 +161,11 @@ def get_eval_results(data, el_pipe=None):
def measure_baselines(data, kb):
- # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
+ """
+ Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
+ Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
+ Also return a dictionary of counts by entity label.
+ """
counts_d = dict()
baseline_results = BaselineResults()
@@ -152,7 +175,6 @@ def measure_baselines(data, kb):
for doc, gold in zip(docs, golds):
correct_entries_per_article = dict()
- tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
for entity, kb_dict in gold.links.items():
start, end = entity
for gold_kb, value in kb_dict.items():
@@ -160,10 +182,6 @@ def measure_baselines(data, kb):
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
- if offset not in tagged_entries_per_article:
- baseline_results.random.increment_false_negatives()
- baseline_results.oracle.increment_false_negatives()
- baseline_results.prior.increment_false_negatives()
for ent in doc.ents:
ent_label = ent.label_
@@ -176,7 +194,7 @@ def measure_baselines(data, kb):
if gold_entity is not None:
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
- best_candidate = ""
+ prior_candidate = ""
random_candidate = ""
if candidates:
scores = []
@@ -187,13 +205,21 @@ def measure_baselines(data, kb):
oracle_candidate = c.entity_
best_index = scores.index(max(scores))
- best_candidate = candidates[best_index].entity_
+ prior_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_
- baseline_results.update_baselines(gold_entity, ent_label,
- random_candidate, best_candidate, oracle_candidate)
+ current_count = counts_d.get(ent_label, 0)
+ counts_d[ent_label] = current_count+1
- return baseline_results
+ baseline_results.update_baselines(
+ gold_entity,
+ ent_label,
+ random_candidate,
+ prior_candidate,
+ oracle_candidate,
+ )
+
+ return baseline_results, counts_d
def _offset(start, end):
diff --git a/bin/wiki_entity_linking/kb_creator.py b/bin/wiki_entity_linking/kb_creator.py
index 0eeb63803..7778fc701 100644
--- a/bin/wiki_entity_linking/kb_creator.py
+++ b/bin/wiki_entity_linking/kb_creator.py
@@ -1,17 +1,12 @@
# coding: utf-8
from __future__ import unicode_literals
-import csv
import logging
-import spacy
-import sys
from spacy.kb import KnowledgeBase
-from bin.wiki_entity_linking import wikipedia_processor as wp
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
-
-csv.field_size_limit(sys.maxsize)
+from bin.wiki_entity_linking import wiki_io as io
logger = logging.getLogger(__name__)
@@ -22,18 +17,24 @@ def create_kb(
max_entities_per_alias,
min_entity_freq,
min_occ,
- entity_def_input,
+ entity_def_path,
entity_descr_path,
- count_input,
- prior_prob_input,
+ entity_alias_path,
+ entity_freq_path,
+ prior_prob_path,
entity_vector_length,
):
# Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
+ entity_list, filtered_title_to_id = _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length)
+ _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path)
+ return kb
+
+def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length):
# read the mappings from file
- title_to_id = get_entity_to_id(entity_def_input)
- id_to_descr = get_id_to_description(entity_descr_path)
+ title_to_id = io.read_title_to_id(entity_def_path)
+ id_to_descr = io.read_id_to_descr(entity_descr_path)
# check the length of the nlp vectors
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
@@ -45,10 +46,8 @@ def create_kb(
" cf. https://spacy.io/usage/models#languages."
)
- logger.info("Get entity frequencies")
- entity_frequencies = wp.get_all_frequencies(count_input=count_input)
-
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
+ entity_frequencies = io.read_entity_to_count(entity_freq_path)
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
title_to_id,
@@ -56,36 +55,33 @@ def create_kb(
entity_frequencies,
min_entity_freq
)
- logger.info("Left with {} entities".format(len(description_list)))
+ logger.info("Kept {} entities from the set of {}".format(len(description_list), len(title_to_id.keys())))
- logger.info("Train entity encoder")
+ logger.info("Training entity encoder")
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
encoder.train(description_list=description_list, to_print=True)
- logger.info("Get entity embeddings:")
+ logger.info("Getting entity embeddings")
embeddings = encoder.apply_encoder(description_list)
logger.info("Adding {} entities".format(len(entity_list)))
kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
)
+ return entity_list, filtered_title_to_id
- logger.info("Adding aliases")
+
+def _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
+ logger.info("Adding aliases from Wikipedia and Wikidata")
_add_aliases(
kb,
+ entity_list=entity_list,
title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ,
- prior_prob_input=prior_prob_input,
+ prior_prob_path=prior_prob_path,
)
- logger.info("KB size: {} entities, {} aliases".format(
- kb.get_size_entities(),
- kb.get_size_aliases()))
-
- logger.info("Done with kb")
- return kb
-
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
min_entity_freq: int = 10):
@@ -104,34 +100,13 @@ def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
return filtered_title_to_id, entity_list, description_list, frequency_list
-def get_entity_to_id(entity_def_output):
- entity_to_id = dict()
- with entity_def_output.open("r", encoding="utf8") as csvfile:
- csvreader = csv.reader(csvfile, delimiter="|")
- # skip header
- next(csvreader)
- for row in csvreader:
- entity_to_id[row[0]] = row[1]
- return entity_to_id
-
-
-def get_id_to_description(entity_descr_path):
- id_to_desc = dict()
- with entity_descr_path.open("r", encoding="utf8") as csvfile:
- csvreader = csv.reader(csvfile, delimiter="|")
- # skip header
- next(csvreader)
- for row in csvreader:
- id_to_desc[row[0]] = row[1]
- return id_to_desc
-
-
-def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
+def _add_aliases(kb, entity_list, title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
wp_titles = title_to_id.keys()
# adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count
- with prior_prob_input.open("r", encoding="utf8") as prior_file:
+ logger.info("Adding WP aliases")
+ with prior_prob_path.open("r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
@@ -180,10 +155,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
line = prior_file.readline()
-def read_nlp_kb(model_dir, kb_file):
- nlp = spacy.load(model_dir)
+def read_kb(nlp, kb_file):
kb = KnowledgeBase(vocab=nlp.vocab)
kb.load_bulk(kb_file)
- logger.info("kb entities: {}".format(kb.get_size_entities()))
- logger.info("kb aliases: {}".format(kb.get_size_aliases()))
- return nlp, kb
+ return kb
diff --git a/bin/wiki_entity_linking/train_descriptions.py b/bin/wiki_entity_linking/train_descriptions.py
index 2cb66909f..af08d6b8f 100644
--- a/bin/wiki_entity_linking/train_descriptions.py
+++ b/bin/wiki_entity_linking/train_descriptions.py
@@ -53,7 +53,7 @@ class EntityEncoder:
start = start + batch_size
stop = min(stop + batch_size, len(description_list))
- logger.info("encoded: {} entities".format(stop))
+ logger.info("Encoded: {} entities".format(stop))
return encodings
@@ -62,7 +62,7 @@ class EntityEncoder:
if to_print:
logger.info(
"Trained entity descriptions on {} ".format(processed) +
- "(non-unique) entities across {} ".format(self.epochs) +
+ "(non-unique) descriptions across {} ".format(self.epochs) +
"epochs"
)
logger.info("Final loss: {}".format(loss))
diff --git a/bin/wiki_entity_linking/training_set_creator.py b/bin/wiki_entity_linking/training_set_creator.py
deleted file mode 100644
index 3f42f8bdd..000000000
--- a/bin/wiki_entity_linking/training_set_creator.py
+++ /dev/null
@@ -1,395 +0,0 @@
-# coding: utf-8
-from __future__ import unicode_literals
-
-import logging
-import random
-import re
-import bz2
-import json
-
-from functools import partial
-
-from spacy.gold import GoldParse
-from bin.wiki_entity_linking import kb_creator
-
-"""
-Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
-Gold-standard entities are stored in one file in standoff format (by character offset).
-"""
-
-ENTITY_FILE = "gold_entities.csv"
-logger = logging.getLogger(__name__)
-
-
-def create_training_examples_and_descriptions(wikipedia_input,
- entity_def_input,
- description_output,
- training_output,
- parse_descriptions,
- limit=None):
- wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
- _process_wikipedia_texts(wikipedia_input,
- wp_to_id,
- description_output,
- training_output,
- parse_descriptions,
- limit)
-
-
-def _process_wikipedia_texts(wikipedia_input,
- wp_to_id,
- output,
- training_output,
- parse_descriptions,
- limit=None):
- """
- Read the XML wikipedia data to parse out training data:
- raw text data + positive instances
- """
- title_regex = re.compile(r"(?<=
).*(?=)")
- id_regex = re.compile(r"(?<=)\d*(?=)")
-
- read_ids = set()
-
- with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file:
- if parse_descriptions:
- _write_training_description(descr_file, "WD_id", "description")
- with bz2.open(wikipedia_input, mode="rb") as file:
- article_count = 0
- article_text = ""
- article_title = None
- article_id = None
- reading_text = False
- reading_revision = False
-
- logger.info("Processed {} articles".format(article_count))
-
- for line in file:
- clean_line = line.strip().decode("utf-8")
-
- if clean_line == "":
- reading_revision = True
- elif clean_line == "":
- reading_revision = False
-
- # Start reading new page
- if clean_line == "":
- article_text = ""
- article_title = None
- article_id = None
- # finished reading this page
- elif clean_line == "":
- if article_id:
- clean_text, entities = _process_wp_text(
- article_title,
- article_text,
- wp_to_id
- )
- if clean_text is not None and entities is not None:
- _write_training_entities(entity_file,
- article_id,
- clean_text,
- entities)
-
- if article_title in wp_to_id and parse_descriptions:
- description = " ".join(clean_text[:1000].split(" ")[:-1])
- _write_training_description(
- descr_file,
- wp_to_id[article_title],
- description
- )
- article_count += 1
- if article_count % 10000 == 0:
- logger.info("Processed {} articles".format(article_count))
- if limit and article_count >= limit:
- break
- article_text = ""
- article_title = None
- article_id = None
- reading_text = False
- reading_revision = False
-
- # start reading text within a page
- if ").*(?=")
- clean_text = clean_text.replace(r""", '"')
- clean_text = clean_text.replace(r" ", " ")
- clean_text = clean_text.replace(r"&", "&")
-
- # remove multiple spaces
- while " " in clean_text:
- clean_text = clean_text.replace(" ", " ")
-
- return clean_text.strip()
-
-
-def _remove_links(clean_text, wp_to_id):
- # read the text char by char to get the right offsets for the interwiki links
- entities = []
- final_text = ""
- open_read = 0
- reading_text = True
- reading_entity = False
- reading_mention = False
- reading_special_case = False
- entity_buffer = ""
- mention_buffer = ""
- for index, letter in enumerate(clean_text):
- if letter == "[":
- open_read += 1
- elif letter == "]":
- open_read -= 1
- elif letter == "|":
- if reading_text:
- final_text += letter
- # switch from reading entity to mention in the [[entity|mention]] pattern
- elif reading_entity:
- reading_text = False
- reading_entity = False
- reading_mention = True
- else:
- reading_special_case = True
- else:
- if reading_entity:
- entity_buffer += letter
- elif reading_mention:
- mention_buffer += letter
- elif reading_text:
- final_text += letter
- else:
- raise ValueError("Not sure at point", clean_text[index - 2: index + 2])
-
- if open_read > 2:
- reading_special_case = True
-
- if open_read == 2 and reading_text:
- reading_text = False
- reading_entity = True
- reading_mention = False
-
- # we just finished reading an entity
- if open_read == 0 and not reading_text:
- if "#" in entity_buffer or entity_buffer.startswith(":"):
- reading_special_case = True
- # Ignore cases with nested structures like File: handles etc
- if not reading_special_case:
- if not mention_buffer:
- mention_buffer = entity_buffer
- start = len(final_text)
- end = start + len(mention_buffer)
- qid = wp_to_id.get(entity_buffer, None)
- if qid:
- entities.append((mention_buffer, qid, start, end))
- final_text += mention_buffer
-
- entity_buffer = ""
- mention_buffer = ""
-
- reading_text = True
- reading_entity = False
- reading_mention = False
- reading_special_case = False
- return final_text, entities
-
-
-def _write_training_description(outputfile, qid, description):
- if description is not None:
- line = str(qid) + "|" + description + "\n"
- outputfile.write(line)
-
-
-def _write_training_entities(outputfile, article_id, clean_text, entities):
- entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities]
- line = json.dumps(
- {
- "article_id": article_id,
- "clean_text": clean_text,
- "entities": entities_data
- },
- ensure_ascii=False) + "\n"
- outputfile.write(line)
-
-
-def read_training(nlp, entity_file_path, dev, limit, kb):
- """ This method provides training examples that correspond to the entity annotations found by the nlp object.
- For training,, it will include negative training examples by using the candidate generator,
- and it will only keep positive training examples that can be found by using the candidate generator.
- For testing, it will include all positive examples only."""
-
- from tqdm import tqdm
- data = []
- num_entities = 0
- get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb)
-
- logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit))
- with entity_file_path.open("r", encoding="utf8") as file:
- with tqdm(total=limit, leave=False) as pbar:
- for i, line in enumerate(file):
- example = json.loads(line)
- article_id = example["article_id"]
- clean_text = example["clean_text"]
- entities = example["entities"]
-
- if dev != is_dev(article_id) or len(clean_text) >= 30000:
- continue
-
- doc = nlp(clean_text)
- gold = get_gold_parse(doc, entities)
- if gold and len(gold.links) > 0:
- data.append((doc, gold))
- num_entities += len(gold.links)
- pbar.update(len(gold.links))
- if limit and num_entities >= limit:
- break
- logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
- return data
-
-
-def _get_gold_parse(doc, entities, dev, kb):
- gold_entities = {}
- tagged_ent_positions = set(
- [(ent.start_char, ent.end_char) for ent in doc.ents]
- )
-
- for entity in entities:
- entity_id = entity["entity"]
- alias = entity["alias"]
- start = entity["start"]
- end = entity["end"]
-
- candidates = kb.get_candidates(alias)
- candidate_ids = [
- c.entity_ for c in candidates
- ]
-
- should_add_ent = (
- dev or
- (
- (start, end) in tagged_ent_positions and
- entity_id in candidate_ids and
- len(candidates) > 1
- )
- )
-
- if should_add_ent:
- value_by_id = {entity_id: 1.0}
- if not dev:
- random.shuffle(candidate_ids)
- value_by_id.update({
- kb_id: 0.0
- for kb_id in candidate_ids
- if kb_id != entity_id
- })
- gold_entities[(start, end)] = value_by_id
-
- return GoldParse(doc, links=gold_entities)
-
-
-def is_dev(article_id):
- return article_id.endswith("3")
diff --git a/bin/wiki_entity_linking/wiki_io.py b/bin/wiki_entity_linking/wiki_io.py
new file mode 100644
index 000000000..43ae87f0f
--- /dev/null
+++ b/bin/wiki_entity_linking/wiki_io.py
@@ -0,0 +1,127 @@
+# coding: utf-8
+from __future__ import unicode_literals
+
+import sys
+import csv
+
+# min() needed to prevent error on windows, cf https://stackoverflow.com/questions/52404416/
+csv.field_size_limit(min(sys.maxsize, 2147483646))
+
+""" This class provides reading/writing methods for temp files """
+
+
+# Entity definition: WP title -> WD ID #
+def write_title_to_id(entity_def_output, title_to_id):
+ with entity_def_output.open("w", encoding="utf8") as id_file:
+ id_file.write("WP_title" + "|" + "WD_id" + "\n")
+ for title, qid in title_to_id.items():
+ id_file.write(title + "|" + str(qid) + "\n")
+
+
+def read_title_to_id(entity_def_output):
+ title_to_id = dict()
+ with entity_def_output.open("r", encoding="utf8") as id_file:
+ csvreader = csv.reader(id_file, delimiter="|")
+ # skip header
+ next(csvreader)
+ for row in csvreader:
+ title_to_id[row[0]] = row[1]
+ return title_to_id
+
+
+# Entity aliases from WD: WD ID -> WD alias #
+def write_id_to_alias(entity_alias_path, id_to_alias):
+ with entity_alias_path.open("w", encoding="utf8") as alias_file:
+ alias_file.write("WD_id" + "|" + "alias" + "\n")
+ for qid, alias_list in id_to_alias.items():
+ for alias in alias_list:
+ alias_file.write(str(qid) + "|" + alias + "\n")
+
+
+def read_id_to_alias(entity_alias_path):
+ id_to_alias = dict()
+ with entity_alias_path.open("r", encoding="utf8") as alias_file:
+ csvreader = csv.reader(alias_file, delimiter="|")
+ # skip header
+ next(csvreader)
+ for row in csvreader:
+ qid = row[0]
+ alias = row[1]
+ alias_list = id_to_alias.get(qid, [])
+ alias_list.append(alias)
+ id_to_alias[qid] = alias_list
+ return id_to_alias
+
+
+def read_alias_to_id_generator(entity_alias_path):
+ """ Read (aliases, qid) tuples """
+
+ with entity_alias_path.open("r", encoding="utf8") as alias_file:
+ csvreader = csv.reader(alias_file, delimiter="|")
+ # skip header
+ next(csvreader)
+ for row in csvreader:
+ qid = row[0]
+ alias = row[1]
+ yield alias, qid
+
+
+# Entity descriptions from WD: WD ID -> WD alias #
+def write_id_to_descr(entity_descr_output, id_to_descr):
+ with entity_descr_output.open("w", encoding="utf8") as descr_file:
+ descr_file.write("WD_id" + "|" + "description" + "\n")
+ for qid, descr in id_to_descr.items():
+ descr_file.write(str(qid) + "|" + descr + "\n")
+
+
+def read_id_to_descr(entity_desc_path):
+ id_to_desc = dict()
+ with entity_desc_path.open("r", encoding="utf8") as descr_file:
+ csvreader = csv.reader(descr_file, delimiter="|")
+ # skip header
+ next(csvreader)
+ for row in csvreader:
+ id_to_desc[row[0]] = row[1]
+ return id_to_desc
+
+
+# Entity counts from WP: WP title -> count #
+def write_entity_to_count(prior_prob_input, count_output):
+ # Write entity counts for quick access later
+ entity_to_count = dict()
+ total_count = 0
+
+ with prior_prob_input.open("r", encoding="utf8") as prior_file:
+ # skip header
+ prior_file.readline()
+ line = prior_file.readline()
+
+ while line:
+ splits = line.replace("\n", "").split(sep="|")
+ # alias = splits[0]
+ count = int(splits[1])
+ entity = splits[2]
+
+ current_count = entity_to_count.get(entity, 0)
+ entity_to_count[entity] = current_count + count
+
+ total_count += count
+
+ line = prior_file.readline()
+
+ with count_output.open("w", encoding="utf8") as entity_file:
+ entity_file.write("entity" + "|" + "count" + "\n")
+ for entity, count in entity_to_count.items():
+ entity_file.write(entity + "|" + str(count) + "\n")
+
+
+def read_entity_to_count(count_input):
+ entity_to_count = dict()
+ with count_input.open("r", encoding="utf8") as csvfile:
+ csvreader = csv.reader(csvfile, delimiter="|")
+ # skip header
+ next(csvreader)
+ for row in csvreader:
+ entity_to_count[row[0]] = int(row[1])
+
+ return entity_to_count
diff --git a/bin/wiki_entity_linking/wiki_namespaces.py b/bin/wiki_entity_linking/wiki_namespaces.py
new file mode 100644
index 000000000..e8f099ccd
--- /dev/null
+++ b/bin/wiki_entity_linking/wiki_namespaces.py
@@ -0,0 +1,128 @@
+# coding: utf8
+from __future__ import unicode_literals
+
+# List of meta pages in Wikidata, should be kept out of the Knowledge base
+WD_META_ITEMS = [
+ "Q163875",
+ "Q191780",
+ "Q224414",
+ "Q4167836",
+ "Q4167410",
+ "Q4663903",
+ "Q11266439",
+ "Q13406463",
+ "Q15407973",
+ "Q18616576",
+ "Q19887878",
+ "Q22808320",
+ "Q23894233",
+ "Q33120876",
+ "Q42104522",
+ "Q47460393",
+ "Q64875536",
+ "Q66480449",
+]
+
+
+# TODO: add more cases from non-English WP's
+
+# List of prefixes that refer to Wikipedia "file" pages
+WP_FILE_NAMESPACE = ["Bestand", "File"]
+
+# List of prefixes that refer to Wikipedia "category" pages
+WP_CATEGORY_NAMESPACE = ["Kategori", "Category", "Categorie"]
+
+# List of prefixes that refer to Wikipedia "meta" pages
+# these will/should be matched ignoring case
+WP_META_NAMESPACE = (
+ WP_FILE_NAMESPACE
+ + WP_CATEGORY_NAMESPACE
+ + [
+ "b",
+ "betawikiversity",
+ "Book",
+ "c",
+ "Commons",
+ "d",
+ "dbdump",
+ "download",
+ "Draft",
+ "Education",
+ "Foundation",
+ "Gadget",
+ "Gadget definition",
+ "Gebruiker",
+ "gerrit",
+ "Help",
+ "Image",
+ "Incubator",
+ "m",
+ "mail",
+ "mailarchive",
+ "media",
+ "MediaWiki",
+ "MediaWiki talk",
+ "Mediawikiwiki",
+ "MediaZilla",
+ "Meta",
+ "Metawikipedia",
+ "Module",
+ "mw",
+ "n",
+ "nost",
+ "oldwikisource",
+ "otrs",
+ "OTRSwiki",
+ "Overleg gebruiker",
+ "outreach",
+ "outreachwiki",
+ "Portal",
+ "phab",
+ "Phabricator",
+ "Project",
+ "q",
+ "quality",
+ "rev",
+ "s",
+ "spcom",
+ "Special",
+ "species",
+ "Strategy",
+ "sulutil",
+ "svn",
+ "Talk",
+ "Template",
+ "Template talk",
+ "Testwiki",
+ "ticket",
+ "TimedText",
+ "Toollabs",
+ "tools",
+ "tswiki",
+ "User",
+ "User talk",
+ "v",
+ "voy",
+ "w",
+ "Wikibooks",
+ "Wikidata",
+ "wikiHow",
+ "Wikinvest",
+ "wikilivres",
+ "Wikimedia",
+ "Wikinews",
+ "Wikipedia",
+ "Wikipedia talk",
+ "Wikiquote",
+ "Wikisource",
+ "Wikispecies",
+ "Wikitech",
+ "Wikiversity",
+ "Wikivoyage",
+ "wikt",
+ "wiktionary",
+ "wmf",
+ "wmania",
+ "WP",
+ ]
+)
diff --git a/bin/wiki_entity_linking/wikidata_pretrain_kb.py b/bin/wiki_entity_linking/wikidata_pretrain_kb.py
index 28650f039..940607b72 100644
--- a/bin/wiki_entity_linking/wikidata_pretrain_kb.py
+++ b/bin/wiki_entity_linking/wikidata_pretrain_kb.py
@@ -18,11 +18,12 @@ from pathlib import Path
import plac
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
+from bin.wiki_entity_linking import wiki_io as io
from bin.wiki_entity_linking import kb_creator
-from bin.wiki_entity_linking import training_set_creator
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
-from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH
+from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH, ENTITY_ALIAS_PATH
import spacy
+from bin.wiki_entity_linking.kb_creator import read_kb
logger = logging.getLogger(__name__)
@@ -39,9 +40,11 @@ logger = logging.getLogger(__name__)
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
- descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"),
- limit=("Optional threshold to limit lines read from dumps", "option", "l", int),
- lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str),
+ descr_from_wp=("Flag for using wp descriptions not wd", "flag", "wp"),
+ limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int),
+ limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int),
+ limit_wd=("Threshold to limit lines read from WD", "option", "lw", int),
+ lang=("Optional language for which to get Wikidata titles. Defaults to 'en'", "option", "la", str),
)
def main(
wd_json,
@@ -54,13 +57,16 @@ def main(
entity_vector_length=64,
loc_prior_prob=None,
loc_entity_defs=None,
+ loc_entity_alias=None,
loc_entity_desc=None,
- descriptions_from_wikipedia=False,
- limit=None,
+ descr_from_wp=False,
+ limit_prior=None,
+ limit_train=None,
+ limit_wd=None,
lang="en",
):
-
entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
+ entity_alias_path = loc_entity_alias if loc_entity_alias else output_dir / ENTITY_ALIAS_PATH
entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
entity_freq_path = output_dir / ENTITY_FREQ_PATH
prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
@@ -69,15 +75,12 @@ def main(
logger.info("Creating KB with Wikipedia and WikiData")
- if limit is not None:
- logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit))
-
# STEP 0: set up IO
if not output_dir.exists():
output_dir.mkdir(parents=True)
- # STEP 1: create the NLP object
- logger.info("STEP 1: Loading model {}".format(model))
+ # STEP 1: Load the NLP object
+ logger.info("STEP 1: Loading NLP model {}".format(model))
nlp = spacy.load(model)
# check the length of the nlp vectors
@@ -90,62 +93,83 @@ def main(
# STEP 2: create prior probabilities from WP
if not prior_prob_path.exists():
# It takes about 2h to process 1000M lines of Wikipedia XML dump
- logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path))
- wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit)
- logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path))
+ logger.info("STEP 2: Writing prior probabilities to {}".format(prior_prob_path))
+ if limit_prior is not None:
+ logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_prior))
+ wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit_prior)
+ else:
+ logger.info("STEP 2: Reading prior probabilities from {}".format(prior_prob_path))
- # STEP 3: deduce entity frequencies from WP (takes only a few minutes)
- logger.info("STEP 3: calculating entity frequencies")
- wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False)
+ # STEP 3: calculate entity frequencies
+ if not entity_freq_path.exists():
+ logger.info("STEP 3: Calculating and writing entity frequencies to {}".format(entity_freq_path))
+ io.write_entity_to_count(prior_prob_path, entity_freq_path)
+ else:
+ logger.info("STEP 3: Reading entity frequencies from {}".format(entity_freq_path))
# STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
- message = " and descriptions" if not descriptions_from_wikipedia else ""
- if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()):
+ if (not entity_defs_path.exists()) or (not descr_from_wp and not entity_descr_path.exists()):
# It takes about 10h to process 55M lines of Wikidata JSON dump
- logger.info("STEP 4: parsing wikidata for entity definitions" + message)
- title_to_id, id_to_descr = wd.read_wikidata_entities_json(
+ logger.info("STEP 4: Parsing and writing Wikidata entity definitions to {}".format(entity_defs_path))
+ if limit_wd is not None:
+ logger.warning("Warning: reading only {} lines of Wikidata dump".format(limit_wd))
+ title_to_id, id_to_descr, id_to_alias = wd.read_wikidata_entities_json(
wd_json,
- limit,
+ limit_wd,
to_print=False,
lang=lang,
- parse_descriptions=(not descriptions_from_wikipedia),
+ parse_descr=(not descr_from_wp),
)
- wd.write_entity_files(entity_defs_path, title_to_id)
- if not descriptions_from_wikipedia:
- wd.write_entity_description_files(entity_descr_path, id_to_descr)
- logger.info("STEP 4: read entity definitions" + message)
+ io.write_title_to_id(entity_defs_path, title_to_id)
- # STEP 5: Getting gold entities from wikipedia
- message = " and descriptions" if descriptions_from_wikipedia else ""
- if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()):
- logger.info("STEP 5: parsing wikipedia for gold entities" + message)
- training_set_creator.create_training_examples_and_descriptions(
- wp_xml,
- entity_defs_path,
- entity_descr_path,
- training_entities_path,
- parse_descriptions=descriptions_from_wikipedia,
- limit=limit,
- )
- logger.info("STEP 5: read gold entities" + message)
+ logger.info("STEP 4b: Writing Wikidata entity aliases to {}".format(entity_alias_path))
+ io.write_id_to_alias(entity_alias_path, id_to_alias)
+
+ if not descr_from_wp:
+ logger.info("STEP 4c: Writing Wikidata entity descriptions to {}".format(entity_descr_path))
+ io.write_id_to_descr(entity_descr_path, id_to_descr)
+ else:
+ logger.info("STEP 4: Reading entity definitions from {}".format(entity_defs_path))
+ logger.info("STEP 4b: Reading entity aliases from {}".format(entity_alias_path))
+ if not descr_from_wp:
+ logger.info("STEP 4c: Reading entity descriptions from {}".format(entity_descr_path))
+
+ # STEP 5: Getting gold entities from Wikipedia
+ if (not training_entities_path.exists()) or (descr_from_wp and not entity_descr_path.exists()):
+ logger.info("STEP 5: Parsing and writing Wikipedia gold entities to {}".format(training_entities_path))
+ if limit_train is not None:
+ logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_train))
+ wp.create_training_and_desc(wp_xml, entity_defs_path, entity_descr_path,
+ training_entities_path, descr_from_wp, limit_train)
+ if descr_from_wp:
+ logger.info("STEP 5b: Parsing and writing Wikipedia descriptions to {}".format(entity_descr_path))
+ else:
+ logger.info("STEP 5: Reading gold entities from {}".format(training_entities_path))
+ if descr_from_wp:
+ logger.info("STEP 5b: Reading entity descriptions from {}".format(entity_descr_path))
# STEP 6: creating the actual KB
# It takes ca. 30 minutes to pretrain the entity embeddings
- logger.info("STEP 6: creating the KB at {}".format(kb_path))
- kb = kb_creator.create_kb(
- nlp=nlp,
- max_entities_per_alias=max_per_alias,
- min_entity_freq=min_freq,
- min_occ=min_pair,
- entity_def_input=entity_defs_path,
- entity_descr_path=entity_descr_path,
- count_input=entity_freq_path,
- prior_prob_input=prior_prob_path,
- entity_vector_length=entity_vector_length,
- )
-
- kb.dump(kb_path)
- nlp.to_disk(output_dir / KB_MODEL_DIR)
+ if not kb_path.exists():
+ logger.info("STEP 6: Creating the KB at {}".format(kb_path))
+ kb = kb_creator.create_kb(
+ nlp=nlp,
+ max_entities_per_alias=max_per_alias,
+ min_entity_freq=min_freq,
+ min_occ=min_pair,
+ entity_def_path=entity_defs_path,
+ entity_descr_path=entity_descr_path,
+ entity_alias_path=entity_alias_path,
+ entity_freq_path=entity_freq_path,
+ prior_prob_path=prior_prob_path,
+ entity_vector_length=entity_vector_length,
+ )
+ kb.dump(kb_path)
+ logger.info("kb entities: {}".format(kb.get_size_entities()))
+ logger.info("kb aliases: {}".format(kb.get_size_aliases()))
+ nlp.to_disk(output_dir / KB_MODEL_DIR)
+ else:
+ logger.info("STEP 6: KB already exists at {}".format(kb_path))
logger.info("Done!")
diff --git a/bin/wiki_entity_linking/wikidata_processor.py b/bin/wiki_entity_linking/wikidata_processor.py
index b4034cb1a..8a070f567 100644
--- a/bin/wiki_entity_linking/wikidata_processor.py
+++ b/bin/wiki_entity_linking/wikidata_processor.py
@@ -1,40 +1,52 @@
# coding: utf-8
from __future__ import unicode_literals
-import gzip
+import bz2
import json
import logging
-import datetime
+
+from bin.wiki_entity_linking.wiki_namespaces import WD_META_ITEMS
logger = logging.getLogger(__name__)
-def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True):
- # Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
+def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descr=True):
+ # Read the JSON wiki data and parse out the entities. Takes about 7-10h to parse 55M lines.
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
site_filter = '{}wiki'.format(lang)
- # properties filter (currently disabled to get ALL data)
- prop_filter = dict()
- # prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
+ # filter: currently defined as OR: one hit suffices to be removed from further processing
+ exclude_list = WD_META_ITEMS
+
+ # punctuation
+ exclude_list.extend(["Q1383557", "Q10617810"])
+
+ # letters etc
+ exclude_list.extend(["Q188725", "Q19776628", "Q3841820", "Q17907810", "Q9788", "Q9398093"])
+
+ neg_prop_filter = {
+ 'P31': exclude_list, # instance of
+ 'P279': exclude_list # subclass
+ }
title_to_id = dict()
id_to_descr = dict()
+ id_to_alias = dict()
# parse appropriate fields - depending on what we need in the KB
parse_properties = False
parse_sitelinks = True
parse_labels = False
- parse_aliases = False
- parse_claims = False
+ parse_aliases = True
+ parse_claims = True
- with gzip.open(wikidata_file, mode='rb') as file:
+ with bz2.open(wikidata_file, mode='rb') as file:
for cnt, line in enumerate(file):
if limit and cnt >= limit:
break
- if cnt % 500000 == 0:
- logger.info("processed {} lines of WikiData dump".format(cnt))
+ if cnt % 500000 == 0 and cnt > 0:
+ logger.info("processed {} lines of WikiData JSON dump".format(cnt))
clean_line = line.strip()
if clean_line.endswith(b","):
clean_line = clean_line[:-1]
@@ -43,13 +55,11 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
entry_type = obj["type"]
if entry_type == "item":
- # filtering records on their properties (currently disabled to get ALL data)
- # keep = False
keep = True
claims = obj["claims"]
if parse_claims:
- for prop, value_set in prop_filter.items():
+ for prop, value_set in neg_prop_filter.items():
claim_property = claims.get(prop, None)
if claim_property:
for cp in claim_property:
@@ -61,7 +71,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
)
cp_rank = cp["rank"]
if cp_rank != "deprecated" and cp_id in value_set:
- keep = True
+ keep = False
if keep:
unique_id = obj["id"]
@@ -108,7 +118,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
"label (" + lang + "):", lang_label["value"]
)
- if found_link and parse_descriptions:
+ if found_link and parse_descr:
descriptions = obj["descriptions"]
if descriptions:
lang_descr = descriptions.get(lang, None)
@@ -130,22 +140,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
print(
"alias (" + lang + "):", item["value"]
)
+ alias_list = id_to_alias.get(unique_id, [])
+ alias_list.append(item["value"])
+ id_to_alias[unique_id] = alias_list
if to_print:
print()
- return title_to_id, id_to_descr
+ # log final number of lines processed
+ logger.info("Finished. Processed {} lines of WikiData JSON dump".format(cnt))
+ return title_to_id, id_to_descr, id_to_alias
-def write_entity_files(entity_def_output, title_to_id):
- with entity_def_output.open("w", encoding="utf8") as id_file:
- id_file.write("WP_title" + "|" + "WD_id" + "\n")
- for title, qid in title_to_id.items():
- id_file.write(title + "|" + str(qid) + "\n")
-
-
-def write_entity_description_files(entity_descr_output, id_to_descr):
- with entity_descr_output.open("w", encoding="utf8") as descr_file:
- descr_file.write("WD_id" + "|" + "description" + "\n")
- for qid, descr in id_to_descr.items():
- descr_file.write(str(qid) + "|" + descr + "\n")
diff --git a/bin/wiki_entity_linking/wikidata_train_entity_linker.py b/bin/wiki_entity_linking/wikidata_train_entity_linker.py
index ac131e0ef..20c5fe91b 100644
--- a/bin/wiki_entity_linking/wikidata_train_entity_linker.py
+++ b/bin/wiki_entity_linking/wikidata_train_entity_linker.py
@@ -6,19 +6,19 @@ as created by the script `wikidata_create_kb`.
For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2
from https://dumps.wikimedia.org/enwiki/latest/
-
"""
from __future__ import unicode_literals
import random
import logging
+import spacy
from pathlib import Path
import plac
-from bin.wiki_entity_linking import training_set_creator
+from bin.wiki_entity_linking import wikipedia_processor
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
-from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines
-from bin.wiki_entity_linking.kb_creator import read_nlp_kb
+from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance
+from bin.wiki_entity_linking.kb_creator import read_kb
from spacy.util import minibatch, compounding
@@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
l2=("L2 regularization", "option", "r", float),
train_inst=("# training instances (default 90% of all)", "option", "t", int),
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
+ labels_discard=("NER labels to discard (default None)", "option", "l", str),
)
def main(
dir_kb,
@@ -46,13 +47,14 @@ def main(
l2=1e-6,
train_inst=None,
dev_inst=None,
+ labels_discard=None
):
logger.info("Creating Entity Linker with Wikipedia and WikiData")
output_dir = Path(output_dir) if output_dir else dir_kb
- training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE
+ training_path = loc_training if loc_training else dir_kb / TRAINING_DATA_FILE
nlp_dir = dir_kb / KB_MODEL_DIR
- kb_path = output_dir / KB_FILE
+ kb_path = dir_kb / KB_FILE
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
# STEP 0: set up IO
@@ -60,38 +62,47 @@ def main(
output_dir.mkdir()
# STEP 1 : load the NLP object
- logger.info("STEP 1: loading model from {}".format(nlp_dir))
- nlp, kb = read_nlp_kb(nlp_dir, kb_path)
+ logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
+ nlp = spacy.load(nlp_dir)
+ logger.info("STEP 1b: Loading KB from {}".format(kb_path))
+ kb = read_kb(nlp, kb_path)
# check that there is a NER component in the pipeline
if "ner" not in nlp.pipe_names:
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
- # STEP 2: create a training dataset from WP
- logger.info("STEP 2: reading training dataset from {}".format(training_path))
+ # STEP 2: read the training dataset previously created from WP
+ logger.info("STEP 2: Reading training dataset from {}".format(training_path))
- train_data = training_set_creator.read_training(
+ if labels_discard:
+ labels_discard = [x.strip() for x in labels_discard.split(",")]
+ logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
+
+ train_data = wikipedia_processor.read_training(
nlp=nlp,
entity_file_path=training_path,
dev=False,
limit=train_inst,
kb=kb,
+ labels_discard=labels_discard
)
- # for testing, get all pos instances, whether or not they are in the kb
- dev_data = training_set_creator.read_training(
+ # for testing, get all pos instances (independently of KB)
+ dev_data = wikipedia_processor.read_training(
nlp=nlp,
entity_file_path=training_path,
dev=True,
limit=dev_inst,
- kb=kb,
+ kb=None,
+ labels_discard=labels_discard
)
- # STEP 3: create and train the entity linking pipe
- logger.info("STEP 3: training Entity Linking pipe")
+ # STEP 3: create and train an entity linking pipe
+ logger.info("STEP 3: Creating and training an Entity Linking pipe")
el_pipe = nlp.create_pipe(
- name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
+ name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
+ "labels_discard": labels_discard}
)
el_pipe.set_kb(kb)
nlp.add_pipe(el_pipe, last=True)
@@ -105,14 +116,9 @@ def main(
logger.info("Training on {} articles".format(len(train_data)))
logger.info("Dev testing on {} articles".format(len(dev_data)))
- dev_baseline_accuracies = measure_baselines(
- dev_data, kb
- )
-
+ # baseline performance on dev data
logger.info("Dev Baseline Accuracies:")
- logger.info(dev_baseline_accuracies.report_accuracy("random"))
- logger.info(dev_baseline_accuracies.report_accuracy("prior"))
- logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
+ measure_performance(dev_data, kb, el_pipe, baseline=True, context=False)
for itn in range(epochs):
random.shuffle(train_data)
@@ -136,18 +142,18 @@ def main(
logger.error("Error updating batch:" + str(e))
if batchnr > 0:
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
- measure_performance(dev_data, kb, el_pipe)
+ measure_performance(dev_data, kb, el_pipe, baseline=False, context=True)
# STEP 4: measure the performance of our trained pipe on an independent dev set
- logger.info("STEP 4: performance measurement of Entity Linking pipe")
+ logger.info("STEP 4: Final performance measurement of Entity Linking pipe")
measure_performance(dev_data, kb, el_pipe)
# STEP 5: apply the EL pipe on a toy example
- logger.info("STEP 5: applying Entity Linking to toy example")
+ logger.info("STEP 5: Applying Entity Linking to toy example")
run_el_toy_example(nlp=nlp)
if output_dir:
- # STEP 6: write the NLP pipeline (including entity linker) to file
+ # STEP 6: write the NLP pipeline (now including an EL model) to file
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
nlp.to_disk(nlp_output_dir)
diff --git a/bin/wiki_entity_linking/wikipedia_processor.py b/bin/wiki_entity_linking/wikipedia_processor.py
index 8f928723e..25e914b32 100644
--- a/bin/wiki_entity_linking/wikipedia_processor.py
+++ b/bin/wiki_entity_linking/wikipedia_processor.py
@@ -3,147 +3,104 @@ from __future__ import unicode_literals
import re
import bz2
-import csv
-import datetime
import logging
+import random
+import json
-from bin.wiki_entity_linking import LOG_FORMAT
+from functools import partial
+
+from spacy.gold import GoldParse
+from bin.wiki_entity_linking import wiki_io as io
+from bin.wiki_entity_linking.wiki_namespaces import (
+ WP_META_NAMESPACE,
+ WP_FILE_NAMESPACE,
+ WP_CATEGORY_NAMESPACE,
+)
"""
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
Write these results to file for downstream KB and training data generation.
+
+Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
"""
+ENTITY_FILE = "gold_entities.csv"
+
map_alias_to_link = dict()
logger = logging.getLogger(__name__)
-
-# these will/should be matched ignoring case
-wiki_namespaces = [
- "b",
- "betawikiversity",
- "Book",
- "c",
- "Category",
- "Commons",
- "d",
- "dbdump",
- "download",
- "Draft",
- "Education",
- "Foundation",
- "Gadget",
- "Gadget definition",
- "gerrit",
- "File",
- "Help",
- "Image",
- "Incubator",
- "m",
- "mail",
- "mailarchive",
- "media",
- "MediaWiki",
- "MediaWiki talk",
- "Mediawikiwiki",
- "MediaZilla",
- "Meta",
- "Metawikipedia",
- "Module",
- "mw",
- "n",
- "nost",
- "oldwikisource",
- "outreach",
- "outreachwiki",
- "otrs",
- "OTRSwiki",
- "Portal",
- "phab",
- "Phabricator",
- "Project",
- "q",
- "quality",
- "rev",
- "s",
- "spcom",
- "Special",
- "species",
- "Strategy",
- "sulutil",
- "svn",
- "Talk",
- "Template",
- "Template talk",
- "Testwiki",
- "ticket",
- "TimedText",
- "Toollabs",
- "tools",
- "tswiki",
- "User",
- "User talk",
- "v",
- "voy",
- "w",
- "Wikibooks",
- "Wikidata",
- "wikiHow",
- "Wikinvest",
- "wikilivres",
- "Wikimedia",
- "Wikinews",
- "Wikipedia",
- "Wikipedia talk",
- "Wikiquote",
- "Wikisource",
- "Wikispecies",
- "Wikitech",
- "Wikiversity",
- "Wikivoyage",
- "wikt",
- "wiktionary",
- "wmf",
- "wmania",
- "WP",
-]
+title_regex = re.compile(r"(?<=).*(?=)")
+id_regex = re.compile(r"(?<=)\d*(?=)")
+text_regex = re.compile(r"(?<=).*(?= 0:
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
clean_line = line.strip().decode("utf-8")
- aliases, entities, normalizations = get_wp_links(clean_line)
- for alias, entity, norm in zip(aliases, entities, normalizations):
- _store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
- _store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
+ # we attempt at reading the article's ID (but not the revision or contributor ID)
+ if "" in clean_line or "" in clean_line:
+ read_id = False
+ if "" in clean_line:
+ read_id = True
+
+ if read_id:
+ ids = id_regex.search(clean_line)
+ if ids:
+ current_article_id = ids[0]
+
+ # only processing prior probabilities from true training (non-dev) articles
+ if not is_dev(current_article_id):
+ aliases, entities, normalizations = get_wp_links(clean_line)
+ for alias, entity, norm in zip(aliases, entities, normalizations):
+ _store_alias(
+ alias, entity, normalize_alias=norm, normalize_entity=True
+ )
line = file.readline()
cnt += 1
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
+ logger.info("Finished. processed {} lines of Wikipedia XML dump".format(cnt))
# write all aliases and their entities and count occurrences to file
with prior_prob_output.open("w", encoding="utf8") as outputfile:
@@ -182,7 +139,7 @@ def get_wp_links(text):
match = match[2:][:-2].replace("_", " ").strip()
if ns_regex.match(match):
- pass # ignore namespaces at the beginning of the string
+ pass # ignore the entity if it points to a "meta" page
# this is a simple [[link]], with the alias the same as the mention
elif "|" not in match:
@@ -218,47 +175,382 @@ def _capitalize_first(text):
return result
-def write_entity_counts(prior_prob_input, count_output, to_print=False):
- # Write entity counts for quick access later
- entity_to_count = dict()
- total_count = 0
-
- with prior_prob_input.open("r", encoding="utf8") as prior_file:
- # skip header
- prior_file.readline()
- line = prior_file.readline()
-
- while line:
- splits = line.replace("\n", "").split(sep="|")
- # alias = splits[0]
- count = int(splits[1])
- entity = splits[2]
-
- current_count = entity_to_count.get(entity, 0)
- entity_to_count[entity] = current_count + count
-
- total_count += count
-
- line = prior_file.readline()
-
- with count_output.open("w", encoding="utf8") as entity_file:
- entity_file.write("entity" + "|" + "count" + "\n")
- for entity, count in entity_to_count.items():
- entity_file.write(entity + "|" + str(count) + "\n")
-
- if to_print:
- for entity, count in entity_to_count.items():
- print("Entity count:", entity, count)
- print("Total count:", total_count)
+def create_training_and_desc(
+ wp_input, def_input, desc_output, training_output, parse_desc, limit=None
+):
+ wp_to_id = io.read_title_to_id(def_input)
+ _process_wikipedia_texts(
+ wp_input, wp_to_id, desc_output, training_output, parse_desc, limit
+ )
-def get_all_frequencies(count_input):
- entity_to_count = dict()
- with count_input.open("r", encoding="utf8") as csvfile:
- csvreader = csv.reader(csvfile, delimiter="|")
- # skip header
- next(csvreader)
- for row in csvreader:
- entity_to_count[row[0]] = int(row[1])
+def _process_wikipedia_texts(
+ wikipedia_input, wp_to_id, output, training_output, parse_descriptions, limit=None
+):
+ """
+ Read the XML wikipedia data to parse out training data:
+ raw text data + positive instances
+ """
- return entity_to_count
+ read_ids = set()
+
+ with output.open("a", encoding="utf8") as descr_file, training_output.open(
+ "w", encoding="utf8"
+ ) as entity_file:
+ if parse_descriptions:
+ _write_training_description(descr_file, "WD_id", "description")
+ with bz2.open(wikipedia_input, mode="rb") as file:
+ article_count = 0
+ article_text = ""
+ article_title = None
+ article_id = None
+ reading_text = False
+ reading_revision = False
+
+ for line in file:
+ clean_line = line.strip().decode("utf-8")
+
+ if clean_line == "":
+ reading_revision = True
+ elif clean_line == "":
+ reading_revision = False
+
+ # Start reading new page
+ if clean_line == "":
+ article_text = ""
+ article_title = None
+ article_id = None
+ # finished reading this page
+ elif clean_line == "":
+ if article_id:
+ clean_text, entities = _process_wp_text(
+ article_title, article_text, wp_to_id
+ )
+ if clean_text is not None and entities is not None:
+ _write_training_entities(
+ entity_file, article_id, clean_text, entities
+ )
+
+ if article_title in wp_to_id and parse_descriptions:
+ description = " ".join(
+ clean_text[:1000].split(" ")[:-1]
+ )
+ _write_training_description(
+ descr_file, wp_to_id[article_title], description
+ )
+ article_count += 1
+ if article_count % 10000 == 0 and article_count > 0:
+ logger.info(
+ "Processed {} articles".format(article_count)
+ )
+ if limit and article_count >= limit:
+ break
+ article_text = ""
+ article_title = None
+ article_id = None
+ reading_text = False
+ reading_revision = False
+
+ # start reading text within a page
+ if "")
+ clean_text = clean_text.replace(r""", '"')
+ clean_text = clean_text.replace(r" ", " ")
+ clean_text = clean_text.replace(r"&", "&")
+
+ # remove multiple spaces
+ while " " in clean_text:
+ clean_text = clean_text.replace(" ", " ")
+
+ return clean_text.strip()
+
+
+def _remove_links(clean_text, wp_to_id):
+ # read the text char by char to get the right offsets for the interwiki links
+ entities = []
+ final_text = ""
+ open_read = 0
+ reading_text = True
+ reading_entity = False
+ reading_mention = False
+ reading_special_case = False
+ entity_buffer = ""
+ mention_buffer = ""
+ for index, letter in enumerate(clean_text):
+ if letter == "[":
+ open_read += 1
+ elif letter == "]":
+ open_read -= 1
+ elif letter == "|":
+ if reading_text:
+ final_text += letter
+ # switch from reading entity to mention in the [[entity|mention]] pattern
+ elif reading_entity:
+ reading_text = False
+ reading_entity = False
+ reading_mention = True
+ else:
+ reading_special_case = True
+ else:
+ if reading_entity:
+ entity_buffer += letter
+ elif reading_mention:
+ mention_buffer += letter
+ elif reading_text:
+ final_text += letter
+ else:
+ raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
+
+ if open_read > 2:
+ reading_special_case = True
+
+ if open_read == 2 and reading_text:
+ reading_text = False
+ reading_entity = True
+ reading_mention = False
+
+ # we just finished reading an entity
+ if open_read == 0 and not reading_text:
+ if "#" in entity_buffer or entity_buffer.startswith(":"):
+ reading_special_case = True
+ # Ignore cases with nested structures like File: handles etc
+ if not reading_special_case:
+ if not mention_buffer:
+ mention_buffer = entity_buffer
+ start = len(final_text)
+ end = start + len(mention_buffer)
+ qid = wp_to_id.get(entity_buffer, None)
+ if qid:
+ entities.append((mention_buffer, qid, start, end))
+ final_text += mention_buffer
+
+ entity_buffer = ""
+ mention_buffer = ""
+
+ reading_text = True
+ reading_entity = False
+ reading_mention = False
+ reading_special_case = False
+ return final_text, entities
+
+
+def _write_training_description(outputfile, qid, description):
+ if description is not None:
+ line = str(qid) + "|" + description + "\n"
+ outputfile.write(line)
+
+
+def _write_training_entities(outputfile, article_id, clean_text, entities):
+ entities_data = [
+ {"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]}
+ for ent in entities
+ ]
+ line = (
+ json.dumps(
+ {
+ "article_id": article_id,
+ "clean_text": clean_text,
+ "entities": entities_data,
+ },
+ ensure_ascii=False,
+ )
+ + "\n"
+ )
+ outputfile.write(line)
+
+
+def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
+ """ This method provides training examples that correspond to the entity annotations found by the nlp object.
+ For training, it will include both positive and negative examples by using the candidate generator from the kb.
+ For testing (kb=None), it will include all positive examples only."""
+
+ from tqdm import tqdm
+
+ if not labels_discard:
+ labels_discard = []
+
+ data = []
+ num_entities = 0
+ get_gold_parse = partial(
+ _get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard
+ )
+
+ logger.info(
+ "Reading {} data with limit {}".format("dev" if dev else "train", limit)
+ )
+ with entity_file_path.open("r", encoding="utf8") as file:
+ with tqdm(total=limit, leave=False) as pbar:
+ for i, line in enumerate(file):
+ example = json.loads(line)
+ article_id = example["article_id"]
+ clean_text = example["clean_text"]
+ entities = example["entities"]
+
+ if dev != is_dev(article_id) or not is_valid_article(clean_text):
+ continue
+
+ doc = nlp(clean_text)
+ gold = get_gold_parse(doc, entities)
+ if gold and len(gold.links) > 0:
+ data.append((doc, gold))
+ num_entities += len(gold.links)
+ pbar.update(len(gold.links))
+ if limit and num_entities >= limit:
+ break
+ logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
+ return data
+
+
+def _get_gold_parse(doc, entities, dev, kb, labels_discard):
+ gold_entities = {}
+ tagged_ent_positions = {
+ (ent.start_char, ent.end_char): ent
+ for ent in doc.ents
+ if ent.label_ not in labels_discard
+ }
+
+ for entity in entities:
+ entity_id = entity["entity"]
+ alias = entity["alias"]
+ start = entity["start"]
+ end = entity["end"]
+
+ candidate_ids = []
+ if kb and not dev:
+ candidates = kb.get_candidates(alias)
+ candidate_ids = [cand.entity_ for cand in candidates]
+
+ tagged_ent = tagged_ent_positions.get((start, end), None)
+ if tagged_ent:
+ # TODO: check that alias == doc.text[start:end]
+ should_add_ent = (dev or entity_id in candidate_ids) and is_valid_sentence(
+ tagged_ent.sent.text
+ )
+
+ if should_add_ent:
+ value_by_id = {entity_id: 1.0}
+ if not dev:
+ random.shuffle(candidate_ids)
+ value_by_id.update(
+ {kb_id: 0.0 for kb_id in candidate_ids if kb_id != entity_id}
+ )
+ gold_entities[(start, end)] = value_by_id
+
+ return GoldParse(doc, links=gold_entities)
+
+
+def is_dev(article_id):
+ if not article_id:
+ return False
+ return article_id.endswith("3")
+
+
+def is_valid_article(doc_text):
+ # custom length cut-off
+ return 10 < len(doc_text) < 30000
+
+
+def is_valid_sentence(sent_text):
+ if not 10 < len(sent_text) < 3000:
+ # custom length cut-off
+ return False
+
+ if sent_text.strip().startswith("*") or sent_text.strip().startswith("#"):
+ # remove 'enumeration' sentences (occurs often on Wikipedia)
+ return False
+
+ return True
diff --git a/examples/pipeline/dummy_entity_linking.py b/examples/pipeline/dummy_entity_linking.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/examples/pipeline/wikidata_entity_linking.py b/examples/pipeline/wikidata_entity_linking.py
deleted file mode 100644
index e69de29bb..000000000
diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py
index 891e15fa2..9b63b31f0 100644
--- a/spacy/cli/pretrain.py
+++ b/spacy/cli/pretrain.py
@@ -246,7 +246,7 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
"""Perform an update over a single batch of documents.
docs (iterable): A batch of `Doc` objects.
- drop (float): The droput rate.
+ drop (float): The dropout rate.
optimizer (callable): An optimizer.
RETURNS loss: A float for the loss.
"""
diff --git a/spacy/errors.py b/spacy/errors.py
index d75b1cec8..69dd77be7 100644
--- a/spacy/errors.py
+++ b/spacy/errors.py
@@ -80,8 +80,8 @@ class Warnings(object):
"the v2.x models cannot release the global interpreter lock. "
"Future versions may introduce a `n_process` argument for "
"parallel inference via multiprocessing.")
- W017 = ("Alias '{alias}' already exists in the Knowledge base.")
- W018 = ("Entity '{entity}' already exists in the Knowledge base.")
+ W017 = ("Alias '{alias}' already exists in the Knowledge Base.")
+ W018 = ("Entity '{entity}' already exists in the Knowledge Base.")
W019 = ("Changing vectors name from {old} to {new}, to avoid clash with "
"previously loaded vectors. See Issue #3853.")
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
@@ -96,6 +96,8 @@ class Warnings(object):
"If this is surprising, make sure you have the spacy-lookups-data "
"package installed.")
W023 = ("Multiprocessing of Language.pipe is not supported in Python2. 'n_process' will be set to 1.")
+ W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in "
+ "the Knowledge Base.")
@add_codes
@@ -408,7 +410,7 @@ class Errors(object):
"{probabilities_length} respectively.")
E133 = ("The sum of prior probabilities for alias '{alias}' should not "
"exceed 1, but found {sum}.")
- E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.")
+ E134 = ("Entity '{entity}' is not defined in the Knowledge Base.")
E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
E136 = ("This additional feature requires the jsonschema library to be "
@@ -420,7 +422,7 @@ class Errors(object):
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
"includes either the `text` or `tokens` key. For more info, see "
"the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
- E139 = ("Knowledge base for component '{name}' not initialized. Did you "
+ E139 = ("Knowledge Base for component '{name}' not initialized. Did you "
"forget to call set_kb()?")
E140 = ("The list of entities, prior probabilities and entity vectors "
"should be of equal length.")
@@ -499,6 +501,7 @@ class Errors(object):
E174 = ("Architecture '{name}' not found in registry. Available "
"names: {names}")
E175 = ("Can't remove rule for unknown match pattern ID: {key}")
+ E176 = ("Alias '{alias}' is not defined in the Knowledge Base.")
@add_codes
diff --git a/spacy/kb.pyx b/spacy/kb.pyx
index 6cbc06e2c..31fd1706e 100644
--- a/spacy/kb.pyx
+++ b/spacy/kb.pyx
@@ -142,6 +142,7 @@ cdef class KnowledgeBase:
i = 0
cdef KBEntryC entry
+ cdef hash_t entity_hash
while i < nr_entities:
entity_vector = vector_list[i]
if len(entity_vector) != self.entity_vector_length:
@@ -161,6 +162,14 @@ cdef class KnowledgeBase:
i += 1
+ def contains_entity(self, unicode entity):
+ cdef hash_t entity_hash = self.vocab.strings.add(entity)
+ return entity_hash in self._entry_index
+
+ def contains_alias(self, unicode alias):
+ cdef hash_t alias_hash = self.vocab.strings.add(alias)
+ return alias_hash in self._alias_index
+
def add_alias(self, unicode alias, entities, probabilities):
"""
For a given alias, add its potential entities and prior probabilies to the KB.
@@ -190,7 +199,7 @@ cdef class KnowledgeBase:
for entity, prob in zip(entities, probabilities):
entity_hash = self.vocab.strings[entity]
if not entity_hash in self._entry_index:
- raise ValueError(Errors.E134.format(alias=alias, entity=entity))
+ raise ValueError(Errors.E134.format(entity=entity))
entry_index = self._entry_index.get(entity_hash)
entry_indices.push_back(int(entry_index))
@@ -201,8 +210,63 @@ cdef class KnowledgeBase:
return alias_hash
- def get_candidates(self, unicode alias):
+ def append_alias(self, unicode alias, unicode entity, float prior_prob, ignore_warnings=False):
+ """
+ For an alias already existing in the KB, extend its potential entities with one more.
+ Throw a warning if either the alias or the entity is unknown,
+ or when the combination is already previously recorded.
+ Throw an error if this entity+prior prob would exceed the sum of 1.
+ For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
+ """
+ # Check if the alias exists in the KB
cdef hash_t alias_hash = self.vocab.strings[alias]
+ if not alias_hash in self._alias_index:
+ raise ValueError(Errors.E176.format(alias=alias))
+
+ # Check if the entity exists in the KB
+ cdef hash_t entity_hash = self.vocab.strings[entity]
+ if not entity_hash in self._entry_index:
+ raise ValueError(Errors.E134.format(entity=entity))
+ entry_index = self._entry_index.get(entity_hash)
+
+ # Throw an error if the prior probabilities (including the new one) sum up to more than 1
+ alias_index = self._alias_index.get(alias_hash)
+ alias_entry = self._aliases_table[alias_index]
+ current_sum = sum([p for p in alias_entry.probs])
+ new_sum = current_sum + prior_prob
+
+ if new_sum > 1.00001:
+ raise ValueError(Errors.E133.format(alias=alias, sum=new_sum))
+
+ entry_indices = alias_entry.entry_indices
+
+ is_present = False
+ for i in range(entry_indices.size()):
+ if entry_indices[i] == int(entry_index):
+ is_present = True
+
+ if is_present:
+ if not ignore_warnings:
+ user_warning(Warnings.W024.format(entity=entity, alias=alias))
+ else:
+ entry_indices.push_back(int(entry_index))
+ alias_entry.entry_indices = entry_indices
+
+ probs = alias_entry.probs
+ probs.push_back(float(prior_prob))
+ alias_entry.probs = probs
+ self._aliases_table[alias_index] = alias_entry
+
+
+ def get_candidates(self, unicode alias):
+ """
+ Return candidate entities for an alias. Each candidate defines the entity, the original alias,
+ and the prior probability of that alias resolving to that entity.
+ If the alias is not known in the KB, and empty list is returned.
+ """
+ cdef hash_t alias_hash = self.vocab.strings[alias]
+ if not alias_hash in self._alias_index:
+ return []
alias_index = self._alias_index.get(alias_hash)
alias_entry = self._aliases_table[alias_index]
@@ -341,7 +405,6 @@ cdef class KnowledgeBase:
assert nr_entities == self.get_size_entities()
# STEP 3: load aliases
-
cdef int64_t nr_aliases
reader.read_alias_length(&nr_aliases)
self._alias_index = PreshMap(nr_aliases+1)
diff --git a/spacy/language.py b/spacy/language.py
index b240dd8a5..33ac21a85 100644
--- a/spacy/language.py
+++ b/spacy/language.py
@@ -483,7 +483,7 @@ class Language(object):
docs (iterable): A batch of `Doc` objects.
golds (iterable): A batch of `GoldParse` objects.
- drop (float): The droput rate.
+ drop (float): The dropout rate.
sgd (callable): An optimizer.
losses (dict): Dictionary to update with the loss, keyed by component.
component_cfg (dict): Config parameters for specific pipeline
@@ -531,7 +531,7 @@ class Language(object):
even if you're updating it with a smaller set of examples.
docs (iterable): A batch of `Doc` objects.
- drop (float): The droput rate.
+ drop (float): The dropout rate.
sgd (callable): An optimizer.
RETURNS (dict): Results from the update.
diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx
index 63ab09e56..0607ac43d 100644
--- a/spacy/pipeline/pipes.pyx
+++ b/spacy/pipeline/pipes.pyx
@@ -1195,23 +1195,26 @@ class EntityLinker(Pipe):
docs = [docs]
golds = [golds]
- context_docs = []
+ sentence_docs = []
for doc, gold in zip(docs, golds):
ents_by_offset = dict()
for ent in doc.ents:
- ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent
+ ents_by_offset[(ent.start_char, ent.end_char)] = ent
+
for entity, kb_dict in gold.links.items():
start, end = entity
mention = doc.text[start:end]
+ # the gold annotations should link to proper entities - if this fails, the dataset is likely corrupt
+ ent = ents_by_offset[(start, end)]
for kb_id, value in kb_dict.items():
# Currently only training on the positive instances
if value:
- context_docs.append(doc)
+ sentence_docs.append(ent.sent.as_doc())
- context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop)
- loss, d_scores = self.get_similarity_loss(scores=context_encodings, golds=golds, docs=None)
+ sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop)
+ loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None)
bp_context(d_scores, sgd=sgd)
if losses is not None:
@@ -1280,50 +1283,68 @@ class EntityLinker(Pipe):
if isinstance(docs, Doc):
docs = [docs]
- context_encodings = self.model(docs)
- xp = get_array_module(context_encodings)
-
for i, doc in enumerate(docs):
if len(doc) > 0:
- # currently, the context is the same for each entity in a sentence (should be refined)
- context_encoding = context_encodings[i]
- context_enc_t = context_encoding.T
- norm_1 = xp.linalg.norm(context_enc_t)
- for ent in doc.ents:
- entity_count += 1
+ # Looping through each sentence and each entity
+ # This may go wrong if there are entities across sentences - because they might not get a KB ID
+ for sent in doc.ents:
+ sent_doc = sent.as_doc()
+ # currently, the context is the same for each entity in a sentence (should be refined)
+ sentence_encoding = self.model([sent_doc])[0]
+ xp = get_array_module(sentence_encoding)
+ sentence_encoding_t = sentence_encoding.T
+ sentence_norm = xp.linalg.norm(sentence_encoding_t)
- candidates = self.kb.get_candidates(ent.text)
- if not candidates:
- final_kb_ids.append(self.NIL) # no prediction possible for this entity
- final_tensors.append(context_encoding)
- else:
- random.shuffle(candidates)
+ for ent in sent_doc.ents:
+ entity_count += 1
- # this will set all prior probabilities to 0 if they should be excluded from the model
- prior_probs = xp.asarray([c.prior_prob for c in candidates])
- if not self.cfg.get("incl_prior", True):
- prior_probs = xp.asarray([0.0 for c in candidates])
- scores = prior_probs
+ if ent.label_ in self.cfg.get("labels_discard", []):
+ # ignoring this entity - setting to NIL
+ final_kb_ids.append(self.NIL)
+ final_tensors.append(sentence_encoding)
- # add in similarity from the context
- if self.cfg.get("incl_context", True):
- entity_encodings = xp.asarray([c.entity_vector for c in candidates])
- norm_2 = xp.linalg.norm(entity_encodings, axis=1)
+ else:
+ candidates = self.kb.get_candidates(ent.text)
+ if not candidates:
+ # no prediction possible for this entity - setting to NIL
+ final_kb_ids.append(self.NIL)
+ final_tensors.append(sentence_encoding)
- if len(entity_encodings) != len(prior_probs):
- raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
+ elif len(candidates) == 1:
+ # shortcut for efficiency reasons: take the 1 candidate
- # cosine similarity
- sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
- if sims.shape != prior_probs.shape:
- raise ValueError(Errors.E161)
- scores = prior_probs + sims - (prior_probs*sims)
+ # TODO: thresholding
+ final_kb_ids.append(candidates[0].entity_)
+ final_tensors.append(sentence_encoding)
- # TODO: thresholding
- best_index = scores.argmax()
- best_candidate = candidates[best_index]
- final_kb_ids.append(best_candidate.entity_)
- final_tensors.append(context_encoding)
+ else:
+ random.shuffle(candidates)
+
+ # this will set all prior probabilities to 0 if they should be excluded from the model
+ prior_probs = xp.asarray([c.prior_prob for c in candidates])
+ if not self.cfg.get("incl_prior", True):
+ prior_probs = xp.asarray([0.0 for c in candidates])
+ scores = prior_probs
+
+ # add in similarity from the context
+ if self.cfg.get("incl_context", True):
+ entity_encodings = xp.asarray([c.entity_vector for c in candidates])
+ entity_norm = xp.linalg.norm(entity_encodings, axis=1)
+
+ if len(entity_encodings) != len(prior_probs):
+ raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
+
+ # cosine similarity
+ sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm)
+ if sims.shape != prior_probs.shape:
+ raise ValueError(Errors.E161)
+ scores = prior_probs + sims - (prior_probs*sims)
+
+ # TODO: thresholding
+ best_index = scores.argmax()
+ best_candidate = candidates[best_index]
+ final_kb_ids.append(best_candidate.entity_)
+ final_tensors.append(sentence_encoding)
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py
index 0c89a2e14..bc76d1e47 100644
--- a/spacy/tests/pipeline/test_entity_linker.py
+++ b/spacy/tests/pipeline/test_entity_linker.py
@@ -131,6 +131,53 @@ def test_candidate_generation(nlp):
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
+def test_append_alias(nlp):
+ """Test that we can append additional alias-entity pairs"""
+ mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
+
+ # adding entities
+ mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
+ mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
+ mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
+
+ # adding aliases
+ mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.4, 0.1])
+ mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
+
+ # test the size of the relevant candidates
+ assert len(mykb.get_candidates("douglas")) == 2
+
+ # append an alias
+ mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
+
+ # test the size of the relevant candidates has been incremented
+ assert len(mykb.get_candidates("douglas")) == 3
+
+ # append the same alias-entity pair again should not work (will throw a warning)
+ mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3)
+
+ # test the size of the relevant candidates remained unchanged
+ assert len(mykb.get_candidates("douglas")) == 3
+
+
+def test_append_invalid_alias(nlp):
+ """Test that append an alias will throw an error if prior probs are exceeding 1"""
+ mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
+
+ # adding entities
+ mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
+ mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
+ mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
+
+ # adding aliases
+ mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
+ mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
+
+ # append an alias - should fail because the entities and probabilities vectors are not of equal length
+ with pytest.raises(ValueError):
+ mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
+
+
def test_preserving_links_asdoc(nlp):
"""Test that Span.as_doc preserves the existing entity links"""
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
diff --git a/spacy/tests/regression/test_issue1-1000.py b/spacy/tests/regression/test_issue1-1000.py
index dca3d624f..989eba805 100644
--- a/spacy/tests/regression/test_issue1-1000.py
+++ b/spacy/tests/regression/test_issue1-1000.py
@@ -430,7 +430,7 @@ def test_issue957(en_tokenizer):
def test_issue999(train_data):
"""Test that adding entities and resuming training works passably OK.
There are two issues here:
- 1) We have to read labels. This isn't very nice.
+ 1) We have to re-add labels. This isn't very nice.
2) There's no way to set the learning rate for the weight update, so we
end up out-of-scale, causing it to learn too fast.
"""