mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
KB extensions and better parsing of WikiData (#4375)
* fix overflow error on windows * more documentation & logging fixes * md fix * 3 different limit parameters to play with execution time * bug fixes directory locations * small fixes * exclude dev test articles from prior probabilities stats * small fixes * filtering wikidata entities, removing numeric and meta items * adding aliases from wikidata also to the KB * fix adding WD aliases * adding also new aliases to previously added entities * fixing comma's * small doc fixes * adding subclassof filtering * append alias functionality in KB * prevent appending the same entity-alias pair * fix for appending WD aliases * remove date filter * remove unnecessary import * small corrections and reformatting * remove WD aliases for now (too slow) * removing numeric entities from training and evaluation * small fixes * shortcut during prediction if there is only one candidate * add counts and fscore logging, remove FP NER from evaluation * fix entity_linker.predict to take docs instead of single sentences * remove enumeration sentences from the WP dataset * entity_linker.update to process full doc instead of single sentence * spelling corrections and dump locations in readme * NLP IO fix * reading KB is unnecessary at the end of the pipeline * small logging fix * remove empty files
This commit is contained in:
parent
428887b8f2
commit
2d249a9502
34
bin/wiki_entity_linking/README.md
Normal file
34
bin/wiki_entity_linking/README.md
Normal file
|
@ -0,0 +1,34 @@
|
|||
## Entity Linking with Wikipedia and Wikidata
|
||||
|
||||
### Step 1: Create a Knowledge Base (KB) and training data
|
||||
|
||||
Run `wikipedia_pretrain_kb.py`
|
||||
* This takes as input the locations of a **Wikipedia and a Wikidata dump**, and produces a **KB directory** + **training file**
|
||||
* WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||
* Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
|
||||
* You can set the filtering parameters for KB construction:
|
||||
* `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym
|
||||
* `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB
|
||||
* `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
|
||||
* Further parameters to set:
|
||||
* `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
|
||||
* `entity_vector_length`: length of the pre-trained entity description vectors
|
||||
* `lang`: language for which to fetch Wikidata information (as the dump contains all languages)
|
||||
|
||||
Quick testing and rerunning:
|
||||
* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything.
|
||||
* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
|
||||
|
||||
|
||||
### Step 2: Train an Entity Linking model
|
||||
|
||||
Run `wikidata_train_entity_linker.py`
|
||||
* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
|
||||
* You can set the learning parameters for the EL training:
|
||||
* `epochs`: number of training iterations
|
||||
* `dropout`: dropout rate
|
||||
* `lr`: learning rate
|
||||
* `l2`: L2 regularization
|
||||
* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively
|
||||
* Further parameters to set:
|
||||
* `labels_discard`: NER label types to discard during training
|
|
@ -6,6 +6,7 @@ OUTPUT_MODEL_DIR = "nlp"
|
|||
PRIOR_PROB_PATH = "prior_prob.csv"
|
||||
ENTITY_DEFS_PATH = "entity_defs.csv"
|
||||
ENTITY_FREQ_PATH = "entity_freq.csv"
|
||||
ENTITY_ALIAS_PATH = "entity_alias.csv"
|
||||
ENTITY_DESCR_PATH = "entity_descriptions.csv"
|
||||
|
||||
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
|
||||
|
|
|
@ -15,10 +15,11 @@ class Metrics(object):
|
|||
candidate_is_correct = true_entity == candidate
|
||||
|
||||
# Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
|
||||
# Therefore, if candidate_is_correct then we have a true positive and never a true negative
|
||||
# Therefore, if candidate_is_correct then we have a true positive and never a true negative.
|
||||
self.true_pos += candidate_is_correct
|
||||
self.false_neg += not candidate_is_correct
|
||||
if candidate not in {"", "NIL"}:
|
||||
if candidate and candidate not in {"", "NIL"}:
|
||||
# A wrong prediction (e.g. Q42 != Q3) counts both as a FP as well as a FN.
|
||||
self.false_pos += not candidate_is_correct
|
||||
|
||||
def calculate_precision(self):
|
||||
|
@ -33,6 +34,14 @@ class Metrics(object):
|
|||
else:
|
||||
return self.true_pos / (self.true_pos + self.false_neg)
|
||||
|
||||
def calculate_fscore(self):
|
||||
p = self.calculate_precision()
|
||||
r = self.calculate_recall()
|
||||
if p + r == 0:
|
||||
return 0.0
|
||||
else:
|
||||
return 2 * p * r / (p + r)
|
||||
|
||||
|
||||
class EvaluationResults(object):
|
||||
def __init__(self):
|
||||
|
@ -43,18 +52,20 @@ class EvaluationResults(object):
|
|||
self.metrics.update_results(true_entity, candidate)
|
||||
self.metrics_by_label[ent_label].update_results(true_entity, candidate)
|
||||
|
||||
def increment_false_negatives(self):
|
||||
self.metrics.false_neg += 1
|
||||
|
||||
def report_metrics(self, model_name):
|
||||
model_str = model_name.title()
|
||||
recall = self.metrics.calculate_recall()
|
||||
precision = self.metrics.calculate_precision()
|
||||
return ("{}: ".format(model_str) +
|
||||
"Recall = {} | ".format(round(recall, 3)) +
|
||||
"Precision = {} | ".format(round(precision, 3)) +
|
||||
"Precision by label = {}".format({k: v.calculate_precision()
|
||||
for k, v in self.metrics_by_label.items()}))
|
||||
fscore = self.metrics.calculate_fscore()
|
||||
return (
|
||||
"{}: ".format(model_str)
|
||||
+ "F-score = {} | ".format(round(fscore, 3))
|
||||
+ "Recall = {} | ".format(round(recall, 3))
|
||||
+ "Precision = {} | ".format(round(precision, 3))
|
||||
+ "F-score by label = {}".format(
|
||||
{k: v.calculate_fscore() for k, v in sorted(self.metrics_by_label.items())}
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BaselineResults(object):
|
||||
|
@ -63,40 +74,51 @@ class BaselineResults(object):
|
|||
self.prior = EvaluationResults()
|
||||
self.oracle = EvaluationResults()
|
||||
|
||||
def report_accuracy(self, model):
|
||||
def report_performance(self, model):
|
||||
results = getattr(self, model)
|
||||
return results.report_metrics(model)
|
||||
|
||||
def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate):
|
||||
def update_baselines(
|
||||
self,
|
||||
true_entity,
|
||||
ent_label,
|
||||
random_candidate,
|
||||
prior_candidate,
|
||||
oracle_candidate,
|
||||
):
|
||||
self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
|
||||
self.prior.update_metrics(ent_label, true_entity, prior_candidate)
|
||||
self.random.update_metrics(ent_label, true_entity, random_candidate)
|
||||
|
||||
|
||||
def measure_performance(dev_data, kb, el_pipe):
|
||||
baseline_accuracies = measure_baselines(
|
||||
dev_data, kb
|
||||
)
|
||||
def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True):
|
||||
if baseline:
|
||||
baseline_accuracies, counts = measure_baselines(dev_data, kb)
|
||||
logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
|
||||
logger.info(baseline_accuracies.report_performance("random"))
|
||||
logger.info(baseline_accuracies.report_performance("prior"))
|
||||
logger.info(baseline_accuracies.report_performance("oracle"))
|
||||
|
||||
logger.info(baseline_accuracies.report_accuracy("random"))
|
||||
logger.info(baseline_accuracies.report_accuracy("prior"))
|
||||
logger.info(baseline_accuracies.report_accuracy("oracle"))
|
||||
if context:
|
||||
# using only context
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = False
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context only"))
|
||||
|
||||
# using only context
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = False
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context only"))
|
||||
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = True
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context and prior"))
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = True
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context and prior"))
|
||||
|
||||
|
||||
def get_eval_results(data, el_pipe=None):
|
||||
# If the docs in the data require further processing with an entity linker, set el_pipe
|
||||
"""
|
||||
Evaluate the ent.kb_id_ annotations against the gold standard.
|
||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||
If the docs in the data require further processing with an entity linker, set el_pipe.
|
||||
"""
|
||||
from tqdm import tqdm
|
||||
|
||||
docs = []
|
||||
|
@ -111,18 +133,15 @@ def get_eval_results(data, el_pipe=None):
|
|||
|
||||
results = EvaluationResults()
|
||||
for doc, gold in zip(docs, golds):
|
||||
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
# only evaluating on positive examples
|
||||
for gold_kb, value in kb_dict.items():
|
||||
if value:
|
||||
# only evaluating on positive examples
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
if offset not in tagged_entries_per_article:
|
||||
results.increment_false_negatives()
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
|
@ -142,7 +161,11 @@ def get_eval_results(data, el_pipe=None):
|
|||
|
||||
|
||||
def measure_baselines(data, kb):
|
||||
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
||||
"""
|
||||
Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
|
||||
Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
|
||||
Also return a dictionary of counts by entity label.
|
||||
"""
|
||||
counts_d = dict()
|
||||
|
||||
baseline_results = BaselineResults()
|
||||
|
@ -152,7 +175,6 @@ def measure_baselines(data, kb):
|
|||
|
||||
for doc, gold in zip(docs, golds):
|
||||
correct_entries_per_article = dict()
|
||||
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
for gold_kb, value in kb_dict.items():
|
||||
|
@ -160,10 +182,6 @@ def measure_baselines(data, kb):
|
|||
if value:
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
if offset not in tagged_entries_per_article:
|
||||
baseline_results.random.increment_false_negatives()
|
||||
baseline_results.oracle.increment_false_negatives()
|
||||
baseline_results.prior.increment_false_negatives()
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
|
@ -176,7 +194,7 @@ def measure_baselines(data, kb):
|
|||
if gold_entity is not None:
|
||||
candidates = kb.get_candidates(ent.text)
|
||||
oracle_candidate = ""
|
||||
best_candidate = ""
|
||||
prior_candidate = ""
|
||||
random_candidate = ""
|
||||
if candidates:
|
||||
scores = []
|
||||
|
@ -187,13 +205,21 @@ def measure_baselines(data, kb):
|
|||
oracle_candidate = c.entity_
|
||||
|
||||
best_index = scores.index(max(scores))
|
||||
best_candidate = candidates[best_index].entity_
|
||||
prior_candidate = candidates[best_index].entity_
|
||||
random_candidate = random.choice(candidates).entity_
|
||||
|
||||
baseline_results.update_baselines(gold_entity, ent_label,
|
||||
random_candidate, best_candidate, oracle_candidate)
|
||||
current_count = counts_d.get(ent_label, 0)
|
||||
counts_d[ent_label] = current_count+1
|
||||
|
||||
return baseline_results
|
||||
baseline_results.update_baselines(
|
||||
gold_entity,
|
||||
ent_label,
|
||||
random_candidate,
|
||||
prior_candidate,
|
||||
oracle_candidate,
|
||||
)
|
||||
|
||||
return baseline_results, counts_d
|
||||
|
||||
|
||||
def _offset(start, end):
|
||||
|
|
|
@ -1,17 +1,12 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import csv
|
||||
import logging
|
||||
import spacy
|
||||
import sys
|
||||
|
||||
from spacy.kb import KnowledgeBase
|
||||
|
||||
from bin.wiki_entity_linking import wikipedia_processor as wp
|
||||
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
|
||||
|
||||
csv.field_size_limit(sys.maxsize)
|
||||
from bin.wiki_entity_linking import wiki_io as io
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
@ -22,18 +17,24 @@ def create_kb(
|
|||
max_entities_per_alias,
|
||||
min_entity_freq,
|
||||
min_occ,
|
||||
entity_def_input,
|
||||
entity_def_path,
|
||||
entity_descr_path,
|
||||
count_input,
|
||||
prior_prob_input,
|
||||
entity_alias_path,
|
||||
entity_freq_path,
|
||||
prior_prob_path,
|
||||
entity_vector_length,
|
||||
):
|
||||
# Create the knowledge base from Wikidata entries
|
||||
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
|
||||
entity_list, filtered_title_to_id = _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length)
|
||||
_define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path)
|
||||
return kb
|
||||
|
||||
|
||||
def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length):
|
||||
# read the mappings from file
|
||||
title_to_id = get_entity_to_id(entity_def_input)
|
||||
id_to_descr = get_id_to_description(entity_descr_path)
|
||||
title_to_id = io.read_title_to_id(entity_def_path)
|
||||
id_to_descr = io.read_id_to_descr(entity_descr_path)
|
||||
|
||||
# check the length of the nlp vectors
|
||||
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
|
||||
|
@ -45,10 +46,8 @@ def create_kb(
|
|||
" cf. https://spacy.io/usage/models#languages."
|
||||
)
|
||||
|
||||
logger.info("Get entity frequencies")
|
||||
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
|
||||
|
||||
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
|
||||
entity_frequencies = io.read_entity_to_count(entity_freq_path)
|
||||
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
|
||||
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
|
||||
title_to_id,
|
||||
|
@ -56,36 +55,33 @@ def create_kb(
|
|||
entity_frequencies,
|
||||
min_entity_freq
|
||||
)
|
||||
logger.info("Left with {} entities".format(len(description_list)))
|
||||
logger.info("Kept {} entities from the set of {}".format(len(description_list), len(title_to_id.keys())))
|
||||
|
||||
logger.info("Train entity encoder")
|
||||
logger.info("Training entity encoder")
|
||||
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
|
||||
encoder.train(description_list=description_list, to_print=True)
|
||||
|
||||
logger.info("Get entity embeddings:")
|
||||
logger.info("Getting entity embeddings")
|
||||
embeddings = encoder.apply_encoder(description_list)
|
||||
|
||||
logger.info("Adding {} entities".format(len(entity_list)))
|
||||
kb.set_entities(
|
||||
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
||||
)
|
||||
return entity_list, filtered_title_to_id
|
||||
|
||||
logger.info("Adding aliases")
|
||||
|
||||
def _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
|
||||
logger.info("Adding aliases from Wikipedia and Wikidata")
|
||||
_add_aliases(
|
||||
kb,
|
||||
entity_list=entity_list,
|
||||
title_to_id=filtered_title_to_id,
|
||||
max_entities_per_alias=max_entities_per_alias,
|
||||
min_occ=min_occ,
|
||||
prior_prob_input=prior_prob_input,
|
||||
prior_prob_path=prior_prob_path,
|
||||
)
|
||||
|
||||
logger.info("KB size: {} entities, {} aliases".format(
|
||||
kb.get_size_entities(),
|
||||
kb.get_size_aliases()))
|
||||
|
||||
logger.info("Done with kb")
|
||||
return kb
|
||||
|
||||
|
||||
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
|
||||
min_entity_freq: int = 10):
|
||||
|
@ -104,34 +100,13 @@ def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
|
|||
return filtered_title_to_id, entity_list, description_list, frequency_list
|
||||
|
||||
|
||||
def get_entity_to_id(entity_def_output):
|
||||
entity_to_id = dict()
|
||||
with entity_def_output.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
entity_to_id[row[0]] = row[1]
|
||||
return entity_to_id
|
||||
|
||||
|
||||
def get_id_to_description(entity_descr_path):
|
||||
id_to_desc = dict()
|
||||
with entity_descr_path.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
id_to_desc[row[0]] = row[1]
|
||||
return id_to_desc
|
||||
|
||||
|
||||
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
|
||||
def _add_aliases(kb, entity_list, title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
|
||||
wp_titles = title_to_id.keys()
|
||||
|
||||
# adding aliases with prior probabilities
|
||||
# we can read this file sequentially, it's sorted by alias, and then by count
|
||||
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||
logger.info("Adding WP aliases")
|
||||
with prior_prob_path.open("r", encoding="utf8") as prior_file:
|
||||
# skip header
|
||||
prior_file.readline()
|
||||
line = prior_file.readline()
|
||||
|
@ -180,10 +155,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
|||
line = prior_file.readline()
|
||||
|
||||
|
||||
def read_nlp_kb(model_dir, kb_file):
|
||||
nlp = spacy.load(model_dir)
|
||||
def read_kb(nlp, kb_file):
|
||||
kb = KnowledgeBase(vocab=nlp.vocab)
|
||||
kb.load_bulk(kb_file)
|
||||
logger.info("kb entities: {}".format(kb.get_size_entities()))
|
||||
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
|
||||
return nlp, kb
|
||||
return kb
|
||||
|
|
|
@ -53,7 +53,7 @@ class EntityEncoder:
|
|||
|
||||
start = start + batch_size
|
||||
stop = min(stop + batch_size, len(description_list))
|
||||
logger.info("encoded: {} entities".format(stop))
|
||||
logger.info("Encoded: {} entities".format(stop))
|
||||
|
||||
return encodings
|
||||
|
||||
|
@ -62,7 +62,7 @@ class EntityEncoder:
|
|||
if to_print:
|
||||
logger.info(
|
||||
"Trained entity descriptions on {} ".format(processed) +
|
||||
"(non-unique) entities across {} ".format(self.epochs) +
|
||||
"(non-unique) descriptions across {} ".format(self.epochs) +
|
||||
"epochs"
|
||||
)
|
||||
logger.info("Final loss: {}".format(loss))
|
||||
|
|
|
@ -1,395 +0,0 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import bz2
|
||||
import json
|
||||
|
||||
from functools import partial
|
||||
|
||||
from spacy.gold import GoldParse
|
||||
from bin.wiki_entity_linking import kb_creator
|
||||
|
||||
"""
|
||||
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
|
||||
Gold-standard entities are stored in one file in standoff format (by character offset).
|
||||
"""
|
||||
|
||||
ENTITY_FILE = "gold_entities.csv"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_training_examples_and_descriptions(wikipedia_input,
|
||||
entity_def_input,
|
||||
description_output,
|
||||
training_output,
|
||||
parse_descriptions,
|
||||
limit=None):
|
||||
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
||||
_process_wikipedia_texts(wikipedia_input,
|
||||
wp_to_id,
|
||||
description_output,
|
||||
training_output,
|
||||
parse_descriptions,
|
||||
limit)
|
||||
|
||||
|
||||
def _process_wikipedia_texts(wikipedia_input,
|
||||
wp_to_id,
|
||||
output,
|
||||
training_output,
|
||||
parse_descriptions,
|
||||
limit=None):
|
||||
"""
|
||||
Read the XML wikipedia data to parse out training data:
|
||||
raw text data + positive instances
|
||||
"""
|
||||
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
|
||||
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
||||
|
||||
read_ids = set()
|
||||
|
||||
with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file:
|
||||
if parse_descriptions:
|
||||
_write_training_description(descr_file, "WD_id", "description")
|
||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||
article_count = 0
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
reading_text = False
|
||||
reading_revision = False
|
||||
|
||||
logger.info("Processed {} articles".format(article_count))
|
||||
|
||||
for line in file:
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
|
||||
if clean_line == "<revision>":
|
||||
reading_revision = True
|
||||
elif clean_line == "</revision>":
|
||||
reading_revision = False
|
||||
|
||||
# Start reading new page
|
||||
if clean_line == "<page>":
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
# finished reading this page
|
||||
elif clean_line == "</page>":
|
||||
if article_id:
|
||||
clean_text, entities = _process_wp_text(
|
||||
article_title,
|
||||
article_text,
|
||||
wp_to_id
|
||||
)
|
||||
if clean_text is not None and entities is not None:
|
||||
_write_training_entities(entity_file,
|
||||
article_id,
|
||||
clean_text,
|
||||
entities)
|
||||
|
||||
if article_title in wp_to_id and parse_descriptions:
|
||||
description = " ".join(clean_text[:1000].split(" ")[:-1])
|
||||
_write_training_description(
|
||||
descr_file,
|
||||
wp_to_id[article_title],
|
||||
description
|
||||
)
|
||||
article_count += 1
|
||||
if article_count % 10000 == 0:
|
||||
logger.info("Processed {} articles".format(article_count))
|
||||
if limit and article_count >= limit:
|
||||
break
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
reading_text = False
|
||||
reading_revision = False
|
||||
|
||||
# start reading text within a page
|
||||
if "<text" in clean_line:
|
||||
reading_text = True
|
||||
|
||||
if reading_text:
|
||||
article_text += " " + clean_line
|
||||
|
||||
# stop reading text within a page (we assume a new page doesn't start on the same line)
|
||||
if "</text" in clean_line:
|
||||
reading_text = False
|
||||
|
||||
# read the ID of this article (outside the revision portion of the document)
|
||||
if not reading_revision:
|
||||
ids = id_regex.search(clean_line)
|
||||
if ids:
|
||||
article_id = ids[0]
|
||||
if article_id in read_ids:
|
||||
logger.info(
|
||||
"Found duplicate article ID", article_id, clean_line
|
||||
) # This should never happen ...
|
||||
read_ids.add(article_id)
|
||||
|
||||
# read the title of this article (outside the revision portion of the document)
|
||||
if not reading_revision:
|
||||
titles = title_regex.search(clean_line)
|
||||
if titles:
|
||||
article_title = titles[0].strip()
|
||||
logger.info("Finished. Processed {} articles".format(article_count))
|
||||
|
||||
|
||||
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
||||
info_regex = re.compile(r"{[^{]*?}")
|
||||
htlm_regex = re.compile(r"<!--[^-]*-->")
|
||||
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
|
||||
file_regex = re.compile(r"\[\[File:[^[\]]+]]")
|
||||
ref_regex = re.compile(r"<ref.*?>") # non-greedy
|
||||
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
||||
|
||||
|
||||
def _process_wp_text(article_title, article_text, wp_to_id):
|
||||
# ignore meta Wikipedia pages
|
||||
if (
|
||||
article_title.startswith("Wikipedia:") or
|
||||
article_title.startswith("Kategori:")
|
||||
):
|
||||
return None, None
|
||||
|
||||
# remove the text tags
|
||||
text_search = text_regex.search(article_text)
|
||||
if text_search is None:
|
||||
return None, None
|
||||
text = text_search.group(0)
|
||||
|
||||
# stop processing if this is a redirect page
|
||||
if text.startswith("#REDIRECT"):
|
||||
return None, None
|
||||
|
||||
# get the raw text without markup etc, keeping only interwiki links
|
||||
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
|
||||
return clean_text, entities
|
||||
|
||||
|
||||
def _get_clean_wp_text(article_text):
|
||||
clean_text = article_text.strip()
|
||||
|
||||
# remove bolding & italic markup
|
||||
clean_text = clean_text.replace("'''", "")
|
||||
clean_text = clean_text.replace("''", "")
|
||||
|
||||
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
||||
try_again = True
|
||||
previous_length = len(clean_text)
|
||||
while try_again:
|
||||
clean_text = info_regex.sub(
|
||||
"", clean_text
|
||||
) # non-greedy match excluding a nested {
|
||||
if len(clean_text) < previous_length:
|
||||
try_again = True
|
||||
else:
|
||||
try_again = False
|
||||
previous_length = len(clean_text)
|
||||
|
||||
# remove HTML comments
|
||||
clean_text = htlm_regex.sub("", clean_text)
|
||||
|
||||
# remove Category and File statements
|
||||
clean_text = category_regex.sub("", clean_text)
|
||||
clean_text = file_regex.sub("", clean_text)
|
||||
|
||||
# remove multiple =
|
||||
while "==" in clean_text:
|
||||
clean_text = clean_text.replace("==", "=")
|
||||
|
||||
clean_text = clean_text.replace(". =", ".")
|
||||
clean_text = clean_text.replace(" = ", ". ")
|
||||
clean_text = clean_text.replace("= ", ".")
|
||||
clean_text = clean_text.replace(" =", "")
|
||||
|
||||
# remove refs (non-greedy match)
|
||||
clean_text = ref_regex.sub("", clean_text)
|
||||
clean_text = ref_2_regex.sub("", clean_text)
|
||||
|
||||
# remove additional wikiformatting
|
||||
clean_text = re.sub(r"<blockquote>", "", clean_text)
|
||||
clean_text = re.sub(r"</blockquote>", "", clean_text)
|
||||
|
||||
# change special characters back to normal ones
|
||||
clean_text = clean_text.replace(r"<", "<")
|
||||
clean_text = clean_text.replace(r">", ">")
|
||||
clean_text = clean_text.replace(r""", '"')
|
||||
clean_text = clean_text.replace(r"&nbsp;", " ")
|
||||
clean_text = clean_text.replace(r"&", "&")
|
||||
|
||||
# remove multiple spaces
|
||||
while " " in clean_text:
|
||||
clean_text = clean_text.replace(" ", " ")
|
||||
|
||||
return clean_text.strip()
|
||||
|
||||
|
||||
def _remove_links(clean_text, wp_to_id):
|
||||
# read the text char by char to get the right offsets for the interwiki links
|
||||
entities = []
|
||||
final_text = ""
|
||||
open_read = 0
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
for index, letter in enumerate(clean_text):
|
||||
if letter == "[":
|
||||
open_read += 1
|
||||
elif letter == "]":
|
||||
open_read -= 1
|
||||
elif letter == "|":
|
||||
if reading_text:
|
||||
final_text += letter
|
||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||
elif reading_entity:
|
||||
reading_text = False
|
||||
reading_entity = False
|
||||
reading_mention = True
|
||||
else:
|
||||
reading_special_case = True
|
||||
else:
|
||||
if reading_entity:
|
||||
entity_buffer += letter
|
||||
elif reading_mention:
|
||||
mention_buffer += letter
|
||||
elif reading_text:
|
||||
final_text += letter
|
||||
else:
|
||||
raise ValueError("Not sure at point", clean_text[index - 2: index + 2])
|
||||
|
||||
if open_read > 2:
|
||||
reading_special_case = True
|
||||
|
||||
if open_read == 2 and reading_text:
|
||||
reading_text = False
|
||||
reading_entity = True
|
||||
reading_mention = False
|
||||
|
||||
# we just finished reading an entity
|
||||
if open_read == 0 and not reading_text:
|
||||
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
||||
reading_special_case = True
|
||||
# Ignore cases with nested structures like File: handles etc
|
||||
if not reading_special_case:
|
||||
if not mention_buffer:
|
||||
mention_buffer = entity_buffer
|
||||
start = len(final_text)
|
||||
end = start + len(mention_buffer)
|
||||
qid = wp_to_id.get(entity_buffer, None)
|
||||
if qid:
|
||||
entities.append((mention_buffer, qid, start, end))
|
||||
final_text += mention_buffer
|
||||
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
return final_text, entities
|
||||
|
||||
|
||||
def _write_training_description(outputfile, qid, description):
|
||||
if description is not None:
|
||||
line = str(qid) + "|" + description + "\n"
|
||||
outputfile.write(line)
|
||||
|
||||
|
||||
def _write_training_entities(outputfile, article_id, clean_text, entities):
|
||||
entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities]
|
||||
line = json.dumps(
|
||||
{
|
||||
"article_id": article_id,
|
||||
"clean_text": clean_text,
|
||||
"entities": entities_data
|
||||
},
|
||||
ensure_ascii=False) + "\n"
|
||||
outputfile.write(line)
|
||||
|
||||
|
||||
def read_training(nlp, entity_file_path, dev, limit, kb):
|
||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||
For training,, it will include negative training examples by using the candidate generator,
|
||||
and it will only keep positive training examples that can be found by using the candidate generator.
|
||||
For testing, it will include all positive examples only."""
|
||||
|
||||
from tqdm import tqdm
|
||||
data = []
|
||||
num_entities = 0
|
||||
get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb)
|
||||
|
||||
logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit))
|
||||
with entity_file_path.open("r", encoding="utf8") as file:
|
||||
with tqdm(total=limit, leave=False) as pbar:
|
||||
for i, line in enumerate(file):
|
||||
example = json.loads(line)
|
||||
article_id = example["article_id"]
|
||||
clean_text = example["clean_text"]
|
||||
entities = example["entities"]
|
||||
|
||||
if dev != is_dev(article_id) or len(clean_text) >= 30000:
|
||||
continue
|
||||
|
||||
doc = nlp(clean_text)
|
||||
gold = get_gold_parse(doc, entities)
|
||||
if gold and len(gold.links) > 0:
|
||||
data.append((doc, gold))
|
||||
num_entities += len(gold.links)
|
||||
pbar.update(len(gold.links))
|
||||
if limit and num_entities >= limit:
|
||||
break
|
||||
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
|
||||
return data
|
||||
|
||||
|
||||
def _get_gold_parse(doc, entities, dev, kb):
|
||||
gold_entities = {}
|
||||
tagged_ent_positions = set(
|
||||
[(ent.start_char, ent.end_char) for ent in doc.ents]
|
||||
)
|
||||
|
||||
for entity in entities:
|
||||
entity_id = entity["entity"]
|
||||
alias = entity["alias"]
|
||||
start = entity["start"]
|
||||
end = entity["end"]
|
||||
|
||||
candidates = kb.get_candidates(alias)
|
||||
candidate_ids = [
|
||||
c.entity_ for c in candidates
|
||||
]
|
||||
|
||||
should_add_ent = (
|
||||
dev or
|
||||
(
|
||||
(start, end) in tagged_ent_positions and
|
||||
entity_id in candidate_ids and
|
||||
len(candidates) > 1
|
||||
)
|
||||
)
|
||||
|
||||
if should_add_ent:
|
||||
value_by_id = {entity_id: 1.0}
|
||||
if not dev:
|
||||
random.shuffle(candidate_ids)
|
||||
value_by_id.update({
|
||||
kb_id: 0.0
|
||||
for kb_id in candidate_ids
|
||||
if kb_id != entity_id
|
||||
})
|
||||
gold_entities[(start, end)] = value_by_id
|
||||
|
||||
return GoldParse(doc, links=gold_entities)
|
||||
|
||||
|
||||
def is_dev(article_id):
|
||||
return article_id.endswith("3")
|
127
bin/wiki_entity_linking/wiki_io.py
Normal file
127
bin/wiki_entity_linking/wiki_io.py
Normal file
|
@ -0,0 +1,127 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import sys
|
||||
import csv
|
||||
|
||||
# min() needed to prevent error on windows, cf https://stackoverflow.com/questions/52404416/
|
||||
csv.field_size_limit(min(sys.maxsize, 2147483646))
|
||||
|
||||
""" This class provides reading/writing methods for temp files """
|
||||
|
||||
|
||||
# Entity definition: WP title -> WD ID #
|
||||
def write_title_to_id(entity_def_output, title_to_id):
|
||||
with entity_def_output.open("w", encoding="utf8") as id_file:
|
||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||
for title, qid in title_to_id.items():
|
||||
id_file.write(title + "|" + str(qid) + "\n")
|
||||
|
||||
|
||||
def read_title_to_id(entity_def_output):
|
||||
title_to_id = dict()
|
||||
with entity_def_output.open("r", encoding="utf8") as id_file:
|
||||
csvreader = csv.reader(id_file, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
title_to_id[row[0]] = row[1]
|
||||
return title_to_id
|
||||
|
||||
|
||||
# Entity aliases from WD: WD ID -> WD alias #
|
||||
def write_id_to_alias(entity_alias_path, id_to_alias):
|
||||
with entity_alias_path.open("w", encoding="utf8") as alias_file:
|
||||
alias_file.write("WD_id" + "|" + "alias" + "\n")
|
||||
for qid, alias_list in id_to_alias.items():
|
||||
for alias in alias_list:
|
||||
alias_file.write(str(qid) + "|" + alias + "\n")
|
||||
|
||||
|
||||
def read_id_to_alias(entity_alias_path):
|
||||
id_to_alias = dict()
|
||||
with entity_alias_path.open("r", encoding="utf8") as alias_file:
|
||||
csvreader = csv.reader(alias_file, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
qid = row[0]
|
||||
alias = row[1]
|
||||
alias_list = id_to_alias.get(qid, [])
|
||||
alias_list.append(alias)
|
||||
id_to_alias[qid] = alias_list
|
||||
return id_to_alias
|
||||
|
||||
|
||||
def read_alias_to_id_generator(entity_alias_path):
|
||||
""" Read (aliases, qid) tuples """
|
||||
|
||||
with entity_alias_path.open("r", encoding="utf8") as alias_file:
|
||||
csvreader = csv.reader(alias_file, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
qid = row[0]
|
||||
alias = row[1]
|
||||
yield alias, qid
|
||||
|
||||
|
||||
# Entity descriptions from WD: WD ID -> WD alias #
|
||||
def write_id_to_descr(entity_descr_output, id_to_descr):
|
||||
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||
for qid, descr in id_to_descr.items():
|
||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||
|
||||
|
||||
def read_id_to_descr(entity_desc_path):
|
||||
id_to_desc = dict()
|
||||
with entity_desc_path.open("r", encoding="utf8") as descr_file:
|
||||
csvreader = csv.reader(descr_file, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
id_to_desc[row[0]] = row[1]
|
||||
return id_to_desc
|
||||
|
||||
|
||||
# Entity counts from WP: WP title -> count #
|
||||
def write_entity_to_count(prior_prob_input, count_output):
|
||||
# Write entity counts for quick access later
|
||||
entity_to_count = dict()
|
||||
total_count = 0
|
||||
|
||||
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||
# skip header
|
||||
prior_file.readline()
|
||||
line = prior_file.readline()
|
||||
|
||||
while line:
|
||||
splits = line.replace("\n", "").split(sep="|")
|
||||
# alias = splits[0]
|
||||
count = int(splits[1])
|
||||
entity = splits[2]
|
||||
|
||||
current_count = entity_to_count.get(entity, 0)
|
||||
entity_to_count[entity] = current_count + count
|
||||
|
||||
total_count += count
|
||||
|
||||
line = prior_file.readline()
|
||||
|
||||
with count_output.open("w", encoding="utf8") as entity_file:
|
||||
entity_file.write("entity" + "|" + "count" + "\n")
|
||||
for entity, count in entity_to_count.items():
|
||||
entity_file.write(entity + "|" + str(count) + "\n")
|
||||
|
||||
|
||||
def read_entity_to_count(count_input):
|
||||
entity_to_count = dict()
|
||||
with count_input.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
entity_to_count[row[0]] = int(row[1])
|
||||
|
||||
return entity_to_count
|
128
bin/wiki_entity_linking/wiki_namespaces.py
Normal file
128
bin/wiki_entity_linking/wiki_namespaces.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
# List of meta pages in Wikidata, should be kept out of the Knowledge base
|
||||
WD_META_ITEMS = [
|
||||
"Q163875",
|
||||
"Q191780",
|
||||
"Q224414",
|
||||
"Q4167836",
|
||||
"Q4167410",
|
||||
"Q4663903",
|
||||
"Q11266439",
|
||||
"Q13406463",
|
||||
"Q15407973",
|
||||
"Q18616576",
|
||||
"Q19887878",
|
||||
"Q22808320",
|
||||
"Q23894233",
|
||||
"Q33120876",
|
||||
"Q42104522",
|
||||
"Q47460393",
|
||||
"Q64875536",
|
||||
"Q66480449",
|
||||
]
|
||||
|
||||
|
||||
# TODO: add more cases from non-English WP's
|
||||
|
||||
# List of prefixes that refer to Wikipedia "file" pages
|
||||
WP_FILE_NAMESPACE = ["Bestand", "File"]
|
||||
|
||||
# List of prefixes that refer to Wikipedia "category" pages
|
||||
WP_CATEGORY_NAMESPACE = ["Kategori", "Category", "Categorie"]
|
||||
|
||||
# List of prefixes that refer to Wikipedia "meta" pages
|
||||
# these will/should be matched ignoring case
|
||||
WP_META_NAMESPACE = (
|
||||
WP_FILE_NAMESPACE
|
||||
+ WP_CATEGORY_NAMESPACE
|
||||
+ [
|
||||
"b",
|
||||
"betawikiversity",
|
||||
"Book",
|
||||
"c",
|
||||
"Commons",
|
||||
"d",
|
||||
"dbdump",
|
||||
"download",
|
||||
"Draft",
|
||||
"Education",
|
||||
"Foundation",
|
||||
"Gadget",
|
||||
"Gadget definition",
|
||||
"Gebruiker",
|
||||
"gerrit",
|
||||
"Help",
|
||||
"Image",
|
||||
"Incubator",
|
||||
"m",
|
||||
"mail",
|
||||
"mailarchive",
|
||||
"media",
|
||||
"MediaWiki",
|
||||
"MediaWiki talk",
|
||||
"Mediawikiwiki",
|
||||
"MediaZilla",
|
||||
"Meta",
|
||||
"Metawikipedia",
|
||||
"Module",
|
||||
"mw",
|
||||
"n",
|
||||
"nost",
|
||||
"oldwikisource",
|
||||
"otrs",
|
||||
"OTRSwiki",
|
||||
"Overleg gebruiker",
|
||||
"outreach",
|
||||
"outreachwiki",
|
||||
"Portal",
|
||||
"phab",
|
||||
"Phabricator",
|
||||
"Project",
|
||||
"q",
|
||||
"quality",
|
||||
"rev",
|
||||
"s",
|
||||
"spcom",
|
||||
"Special",
|
||||
"species",
|
||||
"Strategy",
|
||||
"sulutil",
|
||||
"svn",
|
||||
"Talk",
|
||||
"Template",
|
||||
"Template talk",
|
||||
"Testwiki",
|
||||
"ticket",
|
||||
"TimedText",
|
||||
"Toollabs",
|
||||
"tools",
|
||||
"tswiki",
|
||||
"User",
|
||||
"User talk",
|
||||
"v",
|
||||
"voy",
|
||||
"w",
|
||||
"Wikibooks",
|
||||
"Wikidata",
|
||||
"wikiHow",
|
||||
"Wikinvest",
|
||||
"wikilivres",
|
||||
"Wikimedia",
|
||||
"Wikinews",
|
||||
"Wikipedia",
|
||||
"Wikipedia talk",
|
||||
"Wikiquote",
|
||||
"Wikisource",
|
||||
"Wikispecies",
|
||||
"Wikitech",
|
||||
"Wikiversity",
|
||||
"Wikivoyage",
|
||||
"wikt",
|
||||
"wiktionary",
|
||||
"wmf",
|
||||
"wmania",
|
||||
"WP",
|
||||
]
|
||||
)
|
|
@ -18,11 +18,12 @@ from pathlib import Path
|
|||
import plac
|
||||
|
||||
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
|
||||
from bin.wiki_entity_linking import wiki_io as io
|
||||
from bin.wiki_entity_linking import kb_creator
|
||||
from bin.wiki_entity_linking import training_set_creator
|
||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
|
||||
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH
|
||||
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH, ENTITY_ALIAS_PATH
|
||||
import spacy
|
||||
from bin.wiki_entity_linking.kb_creator import read_kb
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -39,9 +40,11 @@ logger = logging.getLogger(__name__)
|
|||
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
|
||||
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
|
||||
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
|
||||
descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"),
|
||||
limit=("Optional threshold to limit lines read from dumps", "option", "l", int),
|
||||
lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str),
|
||||
descr_from_wp=("Flag for using wp descriptions not wd", "flag", "wp"),
|
||||
limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int),
|
||||
limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int),
|
||||
limit_wd=("Threshold to limit lines read from WD", "option", "lw", int),
|
||||
lang=("Optional language for which to get Wikidata titles. Defaults to 'en'", "option", "la", str),
|
||||
)
|
||||
def main(
|
||||
wd_json,
|
||||
|
@ -54,13 +57,16 @@ def main(
|
|||
entity_vector_length=64,
|
||||
loc_prior_prob=None,
|
||||
loc_entity_defs=None,
|
||||
loc_entity_alias=None,
|
||||
loc_entity_desc=None,
|
||||
descriptions_from_wikipedia=False,
|
||||
limit=None,
|
||||
descr_from_wp=False,
|
||||
limit_prior=None,
|
||||
limit_train=None,
|
||||
limit_wd=None,
|
||||
lang="en",
|
||||
):
|
||||
|
||||
entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
|
||||
entity_alias_path = loc_entity_alias if loc_entity_alias else output_dir / ENTITY_ALIAS_PATH
|
||||
entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
|
||||
entity_freq_path = output_dir / ENTITY_FREQ_PATH
|
||||
prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
|
||||
|
@ -69,15 +75,12 @@ def main(
|
|||
|
||||
logger.info("Creating KB with Wikipedia and WikiData")
|
||||
|
||||
if limit is not None:
|
||||
logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit))
|
||||
|
||||
# STEP 0: set up IO
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
# STEP 1: create the NLP object
|
||||
logger.info("STEP 1: Loading model {}".format(model))
|
||||
# STEP 1: Load the NLP object
|
||||
logger.info("STEP 1: Loading NLP model {}".format(model))
|
||||
nlp = spacy.load(model)
|
||||
|
||||
# check the length of the nlp vectors
|
||||
|
@ -90,62 +93,83 @@ def main(
|
|||
# STEP 2: create prior probabilities from WP
|
||||
if not prior_prob_path.exists():
|
||||
# It takes about 2h to process 1000M lines of Wikipedia XML dump
|
||||
logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path))
|
||||
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit)
|
||||
logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path))
|
||||
logger.info("STEP 2: Writing prior probabilities to {}".format(prior_prob_path))
|
||||
if limit_prior is not None:
|
||||
logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_prior))
|
||||
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit_prior)
|
||||
else:
|
||||
logger.info("STEP 2: Reading prior probabilities from {}".format(prior_prob_path))
|
||||
|
||||
# STEP 3: deduce entity frequencies from WP (takes only a few minutes)
|
||||
logger.info("STEP 3: calculating entity frequencies")
|
||||
wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False)
|
||||
# STEP 3: calculate entity frequencies
|
||||
if not entity_freq_path.exists():
|
||||
logger.info("STEP 3: Calculating and writing entity frequencies to {}".format(entity_freq_path))
|
||||
io.write_entity_to_count(prior_prob_path, entity_freq_path)
|
||||
else:
|
||||
logger.info("STEP 3: Reading entity frequencies from {}".format(entity_freq_path))
|
||||
|
||||
# STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
|
||||
message = " and descriptions" if not descriptions_from_wikipedia else ""
|
||||
if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()):
|
||||
if (not entity_defs_path.exists()) or (not descr_from_wp and not entity_descr_path.exists()):
|
||||
# It takes about 10h to process 55M lines of Wikidata JSON dump
|
||||
logger.info("STEP 4: parsing wikidata for entity definitions" + message)
|
||||
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
|
||||
logger.info("STEP 4: Parsing and writing Wikidata entity definitions to {}".format(entity_defs_path))
|
||||
if limit_wd is not None:
|
||||
logger.warning("Warning: reading only {} lines of Wikidata dump".format(limit_wd))
|
||||
title_to_id, id_to_descr, id_to_alias = wd.read_wikidata_entities_json(
|
||||
wd_json,
|
||||
limit,
|
||||
limit_wd,
|
||||
to_print=False,
|
||||
lang=lang,
|
||||
parse_descriptions=(not descriptions_from_wikipedia),
|
||||
parse_descr=(not descr_from_wp),
|
||||
)
|
||||
wd.write_entity_files(entity_defs_path, title_to_id)
|
||||
if not descriptions_from_wikipedia:
|
||||
wd.write_entity_description_files(entity_descr_path, id_to_descr)
|
||||
logger.info("STEP 4: read entity definitions" + message)
|
||||
io.write_title_to_id(entity_defs_path, title_to_id)
|
||||
|
||||
# STEP 5: Getting gold entities from wikipedia
|
||||
message = " and descriptions" if descriptions_from_wikipedia else ""
|
||||
if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()):
|
||||
logger.info("STEP 5: parsing wikipedia for gold entities" + message)
|
||||
training_set_creator.create_training_examples_and_descriptions(
|
||||
wp_xml,
|
||||
entity_defs_path,
|
||||
entity_descr_path,
|
||||
training_entities_path,
|
||||
parse_descriptions=descriptions_from_wikipedia,
|
||||
limit=limit,
|
||||
)
|
||||
logger.info("STEP 5: read gold entities" + message)
|
||||
logger.info("STEP 4b: Writing Wikidata entity aliases to {}".format(entity_alias_path))
|
||||
io.write_id_to_alias(entity_alias_path, id_to_alias)
|
||||
|
||||
if not descr_from_wp:
|
||||
logger.info("STEP 4c: Writing Wikidata entity descriptions to {}".format(entity_descr_path))
|
||||
io.write_id_to_descr(entity_descr_path, id_to_descr)
|
||||
else:
|
||||
logger.info("STEP 4: Reading entity definitions from {}".format(entity_defs_path))
|
||||
logger.info("STEP 4b: Reading entity aliases from {}".format(entity_alias_path))
|
||||
if not descr_from_wp:
|
||||
logger.info("STEP 4c: Reading entity descriptions from {}".format(entity_descr_path))
|
||||
|
||||
# STEP 5: Getting gold entities from Wikipedia
|
||||
if (not training_entities_path.exists()) or (descr_from_wp and not entity_descr_path.exists()):
|
||||
logger.info("STEP 5: Parsing and writing Wikipedia gold entities to {}".format(training_entities_path))
|
||||
if limit_train is not None:
|
||||
logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_train))
|
||||
wp.create_training_and_desc(wp_xml, entity_defs_path, entity_descr_path,
|
||||
training_entities_path, descr_from_wp, limit_train)
|
||||
if descr_from_wp:
|
||||
logger.info("STEP 5b: Parsing and writing Wikipedia descriptions to {}".format(entity_descr_path))
|
||||
else:
|
||||
logger.info("STEP 5: Reading gold entities from {}".format(training_entities_path))
|
||||
if descr_from_wp:
|
||||
logger.info("STEP 5b: Reading entity descriptions from {}".format(entity_descr_path))
|
||||
|
||||
# STEP 6: creating the actual KB
|
||||
# It takes ca. 30 minutes to pretrain the entity embeddings
|
||||
logger.info("STEP 6: creating the KB at {}".format(kb_path))
|
||||
kb = kb_creator.create_kb(
|
||||
nlp=nlp,
|
||||
max_entities_per_alias=max_per_alias,
|
||||
min_entity_freq=min_freq,
|
||||
min_occ=min_pair,
|
||||
entity_def_input=entity_defs_path,
|
||||
entity_descr_path=entity_descr_path,
|
||||
count_input=entity_freq_path,
|
||||
prior_prob_input=prior_prob_path,
|
||||
entity_vector_length=entity_vector_length,
|
||||
)
|
||||
|
||||
kb.dump(kb_path)
|
||||
nlp.to_disk(output_dir / KB_MODEL_DIR)
|
||||
if not kb_path.exists():
|
||||
logger.info("STEP 6: Creating the KB at {}".format(kb_path))
|
||||
kb = kb_creator.create_kb(
|
||||
nlp=nlp,
|
||||
max_entities_per_alias=max_per_alias,
|
||||
min_entity_freq=min_freq,
|
||||
min_occ=min_pair,
|
||||
entity_def_path=entity_defs_path,
|
||||
entity_descr_path=entity_descr_path,
|
||||
entity_alias_path=entity_alias_path,
|
||||
entity_freq_path=entity_freq_path,
|
||||
prior_prob_path=prior_prob_path,
|
||||
entity_vector_length=entity_vector_length,
|
||||
)
|
||||
kb.dump(kb_path)
|
||||
logger.info("kb entities: {}".format(kb.get_size_entities()))
|
||||
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
|
||||
nlp.to_disk(output_dir / KB_MODEL_DIR)
|
||||
else:
|
||||
logger.info("STEP 6: KB already exists at {}".format(kb_path))
|
||||
|
||||
logger.info("Done!")
|
||||
|
||||
|
|
|
@ -1,40 +1,52 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import gzip
|
||||
import bz2
|
||||
import json
|
||||
import logging
|
||||
import datetime
|
||||
|
||||
from bin.wiki_entity_linking.wiki_namespaces import WD_META_ITEMS
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True):
|
||||
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
|
||||
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descr=True):
|
||||
# Read the JSON wiki data and parse out the entities. Takes about 7-10h to parse 55M lines.
|
||||
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||
|
||||
site_filter = '{}wiki'.format(lang)
|
||||
|
||||
# properties filter (currently disabled to get ALL data)
|
||||
prop_filter = dict()
|
||||
# prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected
|
||||
# filter: currently defined as OR: one hit suffices to be removed from further processing
|
||||
exclude_list = WD_META_ITEMS
|
||||
|
||||
# punctuation
|
||||
exclude_list.extend(["Q1383557", "Q10617810"])
|
||||
|
||||
# letters etc
|
||||
exclude_list.extend(["Q188725", "Q19776628", "Q3841820", "Q17907810", "Q9788", "Q9398093"])
|
||||
|
||||
neg_prop_filter = {
|
||||
'P31': exclude_list, # instance of
|
||||
'P279': exclude_list # subclass
|
||||
}
|
||||
|
||||
title_to_id = dict()
|
||||
id_to_descr = dict()
|
||||
id_to_alias = dict()
|
||||
|
||||
# parse appropriate fields - depending on what we need in the KB
|
||||
parse_properties = False
|
||||
parse_sitelinks = True
|
||||
parse_labels = False
|
||||
parse_aliases = False
|
||||
parse_claims = False
|
||||
parse_aliases = True
|
||||
parse_claims = True
|
||||
|
||||
with gzip.open(wikidata_file, mode='rb') as file:
|
||||
with bz2.open(wikidata_file, mode='rb') as file:
|
||||
for cnt, line in enumerate(file):
|
||||
if limit and cnt >= limit:
|
||||
break
|
||||
if cnt % 500000 == 0:
|
||||
logger.info("processed {} lines of WikiData dump".format(cnt))
|
||||
if cnt % 500000 == 0 and cnt > 0:
|
||||
logger.info("processed {} lines of WikiData JSON dump".format(cnt))
|
||||
clean_line = line.strip()
|
||||
if clean_line.endswith(b","):
|
||||
clean_line = clean_line[:-1]
|
||||
|
@ -43,13 +55,11 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
|
|||
entry_type = obj["type"]
|
||||
|
||||
if entry_type == "item":
|
||||
# filtering records on their properties (currently disabled to get ALL data)
|
||||
# keep = False
|
||||
keep = True
|
||||
|
||||
claims = obj["claims"]
|
||||
if parse_claims:
|
||||
for prop, value_set in prop_filter.items():
|
||||
for prop, value_set in neg_prop_filter.items():
|
||||
claim_property = claims.get(prop, None)
|
||||
if claim_property:
|
||||
for cp in claim_property:
|
||||
|
@ -61,7 +71,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
|
|||
)
|
||||
cp_rank = cp["rank"]
|
||||
if cp_rank != "deprecated" and cp_id in value_set:
|
||||
keep = True
|
||||
keep = False
|
||||
|
||||
if keep:
|
||||
unique_id = obj["id"]
|
||||
|
@ -108,7 +118,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
|
|||
"label (" + lang + "):", lang_label["value"]
|
||||
)
|
||||
|
||||
if found_link and parse_descriptions:
|
||||
if found_link and parse_descr:
|
||||
descriptions = obj["descriptions"]
|
||||
if descriptions:
|
||||
lang_descr = descriptions.get(lang, None)
|
||||
|
@ -130,22 +140,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
|
|||
print(
|
||||
"alias (" + lang + "):", item["value"]
|
||||
)
|
||||
alias_list = id_to_alias.get(unique_id, [])
|
||||
alias_list.append(item["value"])
|
||||
id_to_alias[unique_id] = alias_list
|
||||
|
||||
if to_print:
|
||||
print()
|
||||
|
||||
return title_to_id, id_to_descr
|
||||
# log final number of lines processed
|
||||
logger.info("Finished. Processed {} lines of WikiData JSON dump".format(cnt))
|
||||
return title_to_id, id_to_descr, id_to_alias
|
||||
|
||||
|
||||
def write_entity_files(entity_def_output, title_to_id):
|
||||
with entity_def_output.open("w", encoding="utf8") as id_file:
|
||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||
for title, qid in title_to_id.items():
|
||||
id_file.write(title + "|" + str(qid) + "\n")
|
||||
|
||||
|
||||
def write_entity_description_files(entity_descr_output, id_to_descr):
|
||||
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||
for qid, descr in id_to_descr.items():
|
||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||
|
|
|
@ -6,19 +6,19 @@ as created by the script `wikidata_create_kb`.
|
|||
|
||||
For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2
|
||||
from https://dumps.wikimedia.org/enwiki/latest/
|
||||
|
||||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import random
|
||||
import logging
|
||||
import spacy
|
||||
from pathlib import Path
|
||||
import plac
|
||||
|
||||
from bin.wiki_entity_linking import training_set_creator
|
||||
from bin.wiki_entity_linking import wikipedia_processor
|
||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
||||
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines
|
||||
from bin.wiki_entity_linking.kb_creator import read_nlp_kb
|
||||
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance
|
||||
from bin.wiki_entity_linking.kb_creator import read_kb
|
||||
|
||||
from spacy.util import minibatch, compounding
|
||||
|
||||
|
@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
|
|||
l2=("L2 regularization", "option", "r", float),
|
||||
train_inst=("# training instances (default 90% of all)", "option", "t", int),
|
||||
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
|
||||
labels_discard=("NER labels to discard (default None)", "option", "l", str),
|
||||
)
|
||||
def main(
|
||||
dir_kb,
|
||||
|
@ -46,13 +47,14 @@ def main(
|
|||
l2=1e-6,
|
||||
train_inst=None,
|
||||
dev_inst=None,
|
||||
labels_discard=None
|
||||
):
|
||||
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
||||
|
||||
output_dir = Path(output_dir) if output_dir else dir_kb
|
||||
training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE
|
||||
training_path = loc_training if loc_training else dir_kb / TRAINING_DATA_FILE
|
||||
nlp_dir = dir_kb / KB_MODEL_DIR
|
||||
kb_path = output_dir / KB_FILE
|
||||
kb_path = dir_kb / KB_FILE
|
||||
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
|
||||
|
||||
# STEP 0: set up IO
|
||||
|
@ -60,38 +62,47 @@ def main(
|
|||
output_dir.mkdir()
|
||||
|
||||
# STEP 1 : load the NLP object
|
||||
logger.info("STEP 1: loading model from {}".format(nlp_dir))
|
||||
nlp, kb = read_nlp_kb(nlp_dir, kb_path)
|
||||
logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
|
||||
nlp = spacy.load(nlp_dir)
|
||||
logger.info("STEP 1b: Loading KB from {}".format(kb_path))
|
||||
kb = read_kb(nlp, kb_path)
|
||||
|
||||
# check that there is a NER component in the pipeline
|
||||
if "ner" not in nlp.pipe_names:
|
||||
raise ValueError("The `nlp` object should have a pretrained `ner` component.")
|
||||
|
||||
# STEP 2: create a training dataset from WP
|
||||
logger.info("STEP 2: reading training dataset from {}".format(training_path))
|
||||
# STEP 2: read the training dataset previously created from WP
|
||||
logger.info("STEP 2: Reading training dataset from {}".format(training_path))
|
||||
|
||||
train_data = training_set_creator.read_training(
|
||||
if labels_discard:
|
||||
labels_discard = [x.strip() for x in labels_discard.split(",")]
|
||||
logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
|
||||
|
||||
train_data = wikipedia_processor.read_training(
|
||||
nlp=nlp,
|
||||
entity_file_path=training_path,
|
||||
dev=False,
|
||||
limit=train_inst,
|
||||
kb=kb,
|
||||
labels_discard=labels_discard
|
||||
)
|
||||
|
||||
# for testing, get all pos instances, whether or not they are in the kb
|
||||
dev_data = training_set_creator.read_training(
|
||||
# for testing, get all pos instances (independently of KB)
|
||||
dev_data = wikipedia_processor.read_training(
|
||||
nlp=nlp,
|
||||
entity_file_path=training_path,
|
||||
dev=True,
|
||||
limit=dev_inst,
|
||||
kb=kb,
|
||||
kb=None,
|
||||
labels_discard=labels_discard
|
||||
)
|
||||
|
||||
# STEP 3: create and train the entity linking pipe
|
||||
logger.info("STEP 3: training Entity Linking pipe")
|
||||
# STEP 3: create and train an entity linking pipe
|
||||
logger.info("STEP 3: Creating and training an Entity Linking pipe")
|
||||
|
||||
el_pipe = nlp.create_pipe(
|
||||
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
|
||||
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
|
||||
"labels_discard": labels_discard}
|
||||
)
|
||||
el_pipe.set_kb(kb)
|
||||
nlp.add_pipe(el_pipe, last=True)
|
||||
|
@ -105,14 +116,9 @@ def main(
|
|||
logger.info("Training on {} articles".format(len(train_data)))
|
||||
logger.info("Dev testing on {} articles".format(len(dev_data)))
|
||||
|
||||
dev_baseline_accuracies = measure_baselines(
|
||||
dev_data, kb
|
||||
)
|
||||
|
||||
# baseline performance on dev data
|
||||
logger.info("Dev Baseline Accuracies:")
|
||||
logger.info(dev_baseline_accuracies.report_accuracy("random"))
|
||||
logger.info(dev_baseline_accuracies.report_accuracy("prior"))
|
||||
logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
|
||||
measure_performance(dev_data, kb, el_pipe, baseline=True, context=False)
|
||||
|
||||
for itn in range(epochs):
|
||||
random.shuffle(train_data)
|
||||
|
@ -136,18 +142,18 @@ def main(
|
|||
logger.error("Error updating batch:" + str(e))
|
||||
if batchnr > 0:
|
||||
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
|
||||
measure_performance(dev_data, kb, el_pipe)
|
||||
measure_performance(dev_data, kb, el_pipe, baseline=False, context=True)
|
||||
|
||||
# STEP 4: measure the performance of our trained pipe on an independent dev set
|
||||
logger.info("STEP 4: performance measurement of Entity Linking pipe")
|
||||
logger.info("STEP 4: Final performance measurement of Entity Linking pipe")
|
||||
measure_performance(dev_data, kb, el_pipe)
|
||||
|
||||
# STEP 5: apply the EL pipe on a toy example
|
||||
logger.info("STEP 5: applying Entity Linking to toy example")
|
||||
logger.info("STEP 5: Applying Entity Linking to toy example")
|
||||
run_el_toy_example(nlp=nlp)
|
||||
|
||||
if output_dir:
|
||||
# STEP 6: write the NLP pipeline (including entity linker) to file
|
||||
# STEP 6: write the NLP pipeline (now including an EL model) to file
|
||||
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
|
||||
nlp.to_disk(nlp_output_dir)
|
||||
|
||||
|
|
|
@ -3,147 +3,104 @@ from __future__ import unicode_literals
|
|||
|
||||
import re
|
||||
import bz2
|
||||
import csv
|
||||
import datetime
|
||||
import logging
|
||||
import random
|
||||
import json
|
||||
|
||||
from bin.wiki_entity_linking import LOG_FORMAT
|
||||
from functools import partial
|
||||
|
||||
from spacy.gold import GoldParse
|
||||
from bin.wiki_entity_linking import wiki_io as io
|
||||
from bin.wiki_entity_linking.wiki_namespaces import (
|
||||
WP_META_NAMESPACE,
|
||||
WP_FILE_NAMESPACE,
|
||||
WP_CATEGORY_NAMESPACE,
|
||||
)
|
||||
|
||||
"""
|
||||
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
|
||||
Write these results to file for downstream KB and training data generation.
|
||||
|
||||
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
|
||||
"""
|
||||
|
||||
ENTITY_FILE = "gold_entities.csv"
|
||||
|
||||
map_alias_to_link = dict()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# these will/should be matched ignoring case
|
||||
wiki_namespaces = [
|
||||
"b",
|
||||
"betawikiversity",
|
||||
"Book",
|
||||
"c",
|
||||
"Category",
|
||||
"Commons",
|
||||
"d",
|
||||
"dbdump",
|
||||
"download",
|
||||
"Draft",
|
||||
"Education",
|
||||
"Foundation",
|
||||
"Gadget",
|
||||
"Gadget definition",
|
||||
"gerrit",
|
||||
"File",
|
||||
"Help",
|
||||
"Image",
|
||||
"Incubator",
|
||||
"m",
|
||||
"mail",
|
||||
"mailarchive",
|
||||
"media",
|
||||
"MediaWiki",
|
||||
"MediaWiki talk",
|
||||
"Mediawikiwiki",
|
||||
"MediaZilla",
|
||||
"Meta",
|
||||
"Metawikipedia",
|
||||
"Module",
|
||||
"mw",
|
||||
"n",
|
||||
"nost",
|
||||
"oldwikisource",
|
||||
"outreach",
|
||||
"outreachwiki",
|
||||
"otrs",
|
||||
"OTRSwiki",
|
||||
"Portal",
|
||||
"phab",
|
||||
"Phabricator",
|
||||
"Project",
|
||||
"q",
|
||||
"quality",
|
||||
"rev",
|
||||
"s",
|
||||
"spcom",
|
||||
"Special",
|
||||
"species",
|
||||
"Strategy",
|
||||
"sulutil",
|
||||
"svn",
|
||||
"Talk",
|
||||
"Template",
|
||||
"Template talk",
|
||||
"Testwiki",
|
||||
"ticket",
|
||||
"TimedText",
|
||||
"Toollabs",
|
||||
"tools",
|
||||
"tswiki",
|
||||
"User",
|
||||
"User talk",
|
||||
"v",
|
||||
"voy",
|
||||
"w",
|
||||
"Wikibooks",
|
||||
"Wikidata",
|
||||
"wikiHow",
|
||||
"Wikinvest",
|
||||
"wikilivres",
|
||||
"Wikimedia",
|
||||
"Wikinews",
|
||||
"Wikipedia",
|
||||
"Wikipedia talk",
|
||||
"Wikiquote",
|
||||
"Wikisource",
|
||||
"Wikispecies",
|
||||
"Wikitech",
|
||||
"Wikiversity",
|
||||
"Wikivoyage",
|
||||
"wikt",
|
||||
"wiktionary",
|
||||
"wmf",
|
||||
"wmania",
|
||||
"WP",
|
||||
]
|
||||
title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
|
||||
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
||||
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
||||
info_regex = re.compile(r"{[^{]*?}")
|
||||
html_regex = re.compile(r"<!--[^-]*-->")
|
||||
ref_regex = re.compile(r"<ref.*?>") # non-greedy
|
||||
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
||||
|
||||
# find the links
|
||||
link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
|
||||
|
||||
# match on interwiki links, e.g. `en:` or `:fr:`
|
||||
ns_regex = r":?" + "[a-z][a-z]" + ":"
|
||||
|
||||
# match on Namespace: optionally preceded by a :
|
||||
for ns in wiki_namespaces:
|
||||
for ns in WP_META_NAMESPACE:
|
||||
ns_regex += "|" + ":?" + ns + ":"
|
||||
|
||||
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
||||
|
||||
files = r""
|
||||
for f in WP_FILE_NAMESPACE:
|
||||
files += "\[\[" + f + ":[^[\]]+]]" + "|"
|
||||
files = files[0 : len(files) - 1]
|
||||
file_regex = re.compile(files)
|
||||
|
||||
cats = r""
|
||||
for c in WP_CATEGORY_NAMESPACE:
|
||||
cats += "\[\[" + c + ":[^\[]*]]" + "|"
|
||||
cats = cats[0 : len(cats) - 1]
|
||||
category_regex = re.compile(cats)
|
||||
|
||||
|
||||
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
|
||||
"""
|
||||
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
||||
The full file takes about 2h to parse 1100M lines.
|
||||
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
|
||||
The full file takes about 2-3h to parse 1100M lines.
|
||||
It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from,
|
||||
though dev test articles are excluded in order not to get an artificially strong baseline.
|
||||
"""
|
||||
cnt = 0
|
||||
read_id = False
|
||||
current_article_id = None
|
||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||
line = file.readline()
|
||||
cnt = 0
|
||||
while line and (not limit or cnt < limit):
|
||||
if cnt % 25000000 == 0:
|
||||
if cnt % 25000000 == 0 and cnt > 0:
|
||||
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
|
||||
aliases, entities, normalizations = get_wp_links(clean_line)
|
||||
for alias, entity, norm in zip(aliases, entities, normalizations):
|
||||
_store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
|
||||
_store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
|
||||
# we attempt at reading the article's ID (but not the revision or contributor ID)
|
||||
if "<revision>" in clean_line or "<contributor>" in clean_line:
|
||||
read_id = False
|
||||
if "<page>" in clean_line:
|
||||
read_id = True
|
||||
|
||||
if read_id:
|
||||
ids = id_regex.search(clean_line)
|
||||
if ids:
|
||||
current_article_id = ids[0]
|
||||
|
||||
# only processing prior probabilities from true training (non-dev) articles
|
||||
if not is_dev(current_article_id):
|
||||
aliases, entities, normalizations = get_wp_links(clean_line)
|
||||
for alias, entity, norm in zip(aliases, entities, normalizations):
|
||||
_store_alias(
|
||||
alias, entity, normalize_alias=norm, normalize_entity=True
|
||||
)
|
||||
|
||||
line = file.readline()
|
||||
cnt += 1
|
||||
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
||||
logger.info("Finished. processed {} lines of Wikipedia XML dump".format(cnt))
|
||||
|
||||
# write all aliases and their entities and count occurrences to file
|
||||
with prior_prob_output.open("w", encoding="utf8") as outputfile:
|
||||
|
@ -182,7 +139,7 @@ def get_wp_links(text):
|
|||
match = match[2:][:-2].replace("_", " ").strip()
|
||||
|
||||
if ns_regex.match(match):
|
||||
pass # ignore namespaces at the beginning of the string
|
||||
pass # ignore the entity if it points to a "meta" page
|
||||
|
||||
# this is a simple [[link]], with the alias the same as the mention
|
||||
elif "|" not in match:
|
||||
|
@ -218,47 +175,382 @@ def _capitalize_first(text):
|
|||
return result
|
||||
|
||||
|
||||
def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||
# Write entity counts for quick access later
|
||||
entity_to_count = dict()
|
||||
total_count = 0
|
||||
|
||||
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||
# skip header
|
||||
prior_file.readline()
|
||||
line = prior_file.readline()
|
||||
|
||||
while line:
|
||||
splits = line.replace("\n", "").split(sep="|")
|
||||
# alias = splits[0]
|
||||
count = int(splits[1])
|
||||
entity = splits[2]
|
||||
|
||||
current_count = entity_to_count.get(entity, 0)
|
||||
entity_to_count[entity] = current_count + count
|
||||
|
||||
total_count += count
|
||||
|
||||
line = prior_file.readline()
|
||||
|
||||
with count_output.open("w", encoding="utf8") as entity_file:
|
||||
entity_file.write("entity" + "|" + "count" + "\n")
|
||||
for entity, count in entity_to_count.items():
|
||||
entity_file.write(entity + "|" + str(count) + "\n")
|
||||
|
||||
if to_print:
|
||||
for entity, count in entity_to_count.items():
|
||||
print("Entity count:", entity, count)
|
||||
print("Total count:", total_count)
|
||||
def create_training_and_desc(
|
||||
wp_input, def_input, desc_output, training_output, parse_desc, limit=None
|
||||
):
|
||||
wp_to_id = io.read_title_to_id(def_input)
|
||||
_process_wikipedia_texts(
|
||||
wp_input, wp_to_id, desc_output, training_output, parse_desc, limit
|
||||
)
|
||||
|
||||
|
||||
def get_all_frequencies(count_input):
|
||||
entity_to_count = dict()
|
||||
with count_input.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
entity_to_count[row[0]] = int(row[1])
|
||||
def _process_wikipedia_texts(
|
||||
wikipedia_input, wp_to_id, output, training_output, parse_descriptions, limit=None
|
||||
):
|
||||
"""
|
||||
Read the XML wikipedia data to parse out training data:
|
||||
raw text data + positive instances
|
||||
"""
|
||||
|
||||
return entity_to_count
|
||||
read_ids = set()
|
||||
|
||||
with output.open("a", encoding="utf8") as descr_file, training_output.open(
|
||||
"w", encoding="utf8"
|
||||
) as entity_file:
|
||||
if parse_descriptions:
|
||||
_write_training_description(descr_file, "WD_id", "description")
|
||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||
article_count = 0
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
reading_text = False
|
||||
reading_revision = False
|
||||
|
||||
for line in file:
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
|
||||
if clean_line == "<revision>":
|
||||
reading_revision = True
|
||||
elif clean_line == "</revision>":
|
||||
reading_revision = False
|
||||
|
||||
# Start reading new page
|
||||
if clean_line == "<page>":
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
# finished reading this page
|
||||
elif clean_line == "</page>":
|
||||
if article_id:
|
||||
clean_text, entities = _process_wp_text(
|
||||
article_title, article_text, wp_to_id
|
||||
)
|
||||
if clean_text is not None and entities is not None:
|
||||
_write_training_entities(
|
||||
entity_file, article_id, clean_text, entities
|
||||
)
|
||||
|
||||
if article_title in wp_to_id and parse_descriptions:
|
||||
description = " ".join(
|
||||
clean_text[:1000].split(" ")[:-1]
|
||||
)
|
||||
_write_training_description(
|
||||
descr_file, wp_to_id[article_title], description
|
||||
)
|
||||
article_count += 1
|
||||
if article_count % 10000 == 0 and article_count > 0:
|
||||
logger.info(
|
||||
"Processed {} articles".format(article_count)
|
||||
)
|
||||
if limit and article_count >= limit:
|
||||
break
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
reading_text = False
|
||||
reading_revision = False
|
||||
|
||||
# start reading text within a page
|
||||
if "<text" in clean_line:
|
||||
reading_text = True
|
||||
|
||||
if reading_text:
|
||||
article_text += " " + clean_line
|
||||
|
||||
# stop reading text within a page (we assume a new page doesn't start on the same line)
|
||||
if "</text" in clean_line:
|
||||
reading_text = False
|
||||
|
||||
# read the ID of this article (outside the revision portion of the document)
|
||||
if not reading_revision:
|
||||
ids = id_regex.search(clean_line)
|
||||
if ids:
|
||||
article_id = ids[0]
|
||||
if article_id in read_ids:
|
||||
logger.info(
|
||||
"Found duplicate article ID", article_id, clean_line
|
||||
) # This should never happen ...
|
||||
read_ids.add(article_id)
|
||||
|
||||
# read the title of this article (outside the revision portion of the document)
|
||||
if not reading_revision:
|
||||
titles = title_regex.search(clean_line)
|
||||
if titles:
|
||||
article_title = titles[0].strip()
|
||||
logger.info("Finished. Processed {} articles".format(article_count))
|
||||
|
||||
|
||||
def _process_wp_text(article_title, article_text, wp_to_id):
|
||||
# ignore meta Wikipedia pages
|
||||
if ns_regex.match(article_title):
|
||||
return None, None
|
||||
|
||||
# remove the text tags
|
||||
text_search = text_regex.search(article_text)
|
||||
if text_search is None:
|
||||
return None, None
|
||||
text = text_search.group(0)
|
||||
|
||||
# stop processing if this is a redirect page
|
||||
if text.startswith("#REDIRECT"):
|
||||
return None, None
|
||||
|
||||
# get the raw text without markup etc, keeping only interwiki links
|
||||
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
|
||||
return clean_text, entities
|
||||
|
||||
|
||||
def _get_clean_wp_text(article_text):
|
||||
clean_text = article_text.strip()
|
||||
|
||||
# remove bolding & italic markup
|
||||
clean_text = clean_text.replace("'''", "")
|
||||
clean_text = clean_text.replace("''", "")
|
||||
|
||||
# remove nested {{info}} statements by removing the inner/smallest ones first and iterating
|
||||
try_again = True
|
||||
previous_length = len(clean_text)
|
||||
while try_again:
|
||||
clean_text = info_regex.sub(
|
||||
"", clean_text
|
||||
) # non-greedy match excluding a nested {
|
||||
if len(clean_text) < previous_length:
|
||||
try_again = True
|
||||
else:
|
||||
try_again = False
|
||||
previous_length = len(clean_text)
|
||||
|
||||
# remove HTML comments
|
||||
clean_text = html_regex.sub("", clean_text)
|
||||
|
||||
# remove Category and File statements
|
||||
clean_text = category_regex.sub("", clean_text)
|
||||
clean_text = file_regex.sub("", clean_text)
|
||||
|
||||
# remove multiple =
|
||||
while "==" in clean_text:
|
||||
clean_text = clean_text.replace("==", "=")
|
||||
|
||||
clean_text = clean_text.replace(". =", ".")
|
||||
clean_text = clean_text.replace(" = ", ". ")
|
||||
clean_text = clean_text.replace("= ", ".")
|
||||
clean_text = clean_text.replace(" =", "")
|
||||
|
||||
# remove refs (non-greedy match)
|
||||
clean_text = ref_regex.sub("", clean_text)
|
||||
clean_text = ref_2_regex.sub("", clean_text)
|
||||
|
||||
# remove additional wikiformatting
|
||||
clean_text = re.sub(r"<blockquote>", "", clean_text)
|
||||
clean_text = re.sub(r"</blockquote>", "", clean_text)
|
||||
|
||||
# change special characters back to normal ones
|
||||
clean_text = clean_text.replace(r"<", "<")
|
||||
clean_text = clean_text.replace(r">", ">")
|
||||
clean_text = clean_text.replace(r""", '"')
|
||||
clean_text = clean_text.replace(r"&nbsp;", " ")
|
||||
clean_text = clean_text.replace(r"&", "&")
|
||||
|
||||
# remove multiple spaces
|
||||
while " " in clean_text:
|
||||
clean_text = clean_text.replace(" ", " ")
|
||||
|
||||
return clean_text.strip()
|
||||
|
||||
|
||||
def _remove_links(clean_text, wp_to_id):
|
||||
# read the text char by char to get the right offsets for the interwiki links
|
||||
entities = []
|
||||
final_text = ""
|
||||
open_read = 0
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
for index, letter in enumerate(clean_text):
|
||||
if letter == "[":
|
||||
open_read += 1
|
||||
elif letter == "]":
|
||||
open_read -= 1
|
||||
elif letter == "|":
|
||||
if reading_text:
|
||||
final_text += letter
|
||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||
elif reading_entity:
|
||||
reading_text = False
|
||||
reading_entity = False
|
||||
reading_mention = True
|
||||
else:
|
||||
reading_special_case = True
|
||||
else:
|
||||
if reading_entity:
|
||||
entity_buffer += letter
|
||||
elif reading_mention:
|
||||
mention_buffer += letter
|
||||
elif reading_text:
|
||||
final_text += letter
|
||||
else:
|
||||
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
|
||||
|
||||
if open_read > 2:
|
||||
reading_special_case = True
|
||||
|
||||
if open_read == 2 and reading_text:
|
||||
reading_text = False
|
||||
reading_entity = True
|
||||
reading_mention = False
|
||||
|
||||
# we just finished reading an entity
|
||||
if open_read == 0 and not reading_text:
|
||||
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
||||
reading_special_case = True
|
||||
# Ignore cases with nested structures like File: handles etc
|
||||
if not reading_special_case:
|
||||
if not mention_buffer:
|
||||
mention_buffer = entity_buffer
|
||||
start = len(final_text)
|
||||
end = start + len(mention_buffer)
|
||||
qid = wp_to_id.get(entity_buffer, None)
|
||||
if qid:
|
||||
entities.append((mention_buffer, qid, start, end))
|
||||
final_text += mention_buffer
|
||||
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
return final_text, entities
|
||||
|
||||
|
||||
def _write_training_description(outputfile, qid, description):
|
||||
if description is not None:
|
||||
line = str(qid) + "|" + description + "\n"
|
||||
outputfile.write(line)
|
||||
|
||||
|
||||
def _write_training_entities(outputfile, article_id, clean_text, entities):
|
||||
entities_data = [
|
||||
{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]}
|
||||
for ent in entities
|
||||
]
|
||||
line = (
|
||||
json.dumps(
|
||||
{
|
||||
"article_id": article_id,
|
||||
"clean_text": clean_text,
|
||||
"entities": entities_data,
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
+ "\n"
|
||||
)
|
||||
outputfile.write(line)
|
||||
|
||||
|
||||
def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
|
||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||
For training, it will include both positive and negative examples by using the candidate generator from the kb.
|
||||
For testing (kb=None), it will include all positive examples only."""
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
if not labels_discard:
|
||||
labels_discard = []
|
||||
|
||||
data = []
|
||||
num_entities = 0
|
||||
get_gold_parse = partial(
|
||||
_get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"Reading {} data with limit {}".format("dev" if dev else "train", limit)
|
||||
)
|
||||
with entity_file_path.open("r", encoding="utf8") as file:
|
||||
with tqdm(total=limit, leave=False) as pbar:
|
||||
for i, line in enumerate(file):
|
||||
example = json.loads(line)
|
||||
article_id = example["article_id"]
|
||||
clean_text = example["clean_text"]
|
||||
entities = example["entities"]
|
||||
|
||||
if dev != is_dev(article_id) or not is_valid_article(clean_text):
|
||||
continue
|
||||
|
||||
doc = nlp(clean_text)
|
||||
gold = get_gold_parse(doc, entities)
|
||||
if gold and len(gold.links) > 0:
|
||||
data.append((doc, gold))
|
||||
num_entities += len(gold.links)
|
||||
pbar.update(len(gold.links))
|
||||
if limit and num_entities >= limit:
|
||||
break
|
||||
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
|
||||
return data
|
||||
|
||||
|
||||
def _get_gold_parse(doc, entities, dev, kb, labels_discard):
|
||||
gold_entities = {}
|
||||
tagged_ent_positions = {
|
||||
(ent.start_char, ent.end_char): ent
|
||||
for ent in doc.ents
|
||||
if ent.label_ not in labels_discard
|
||||
}
|
||||
|
||||
for entity in entities:
|
||||
entity_id = entity["entity"]
|
||||
alias = entity["alias"]
|
||||
start = entity["start"]
|
||||
end = entity["end"]
|
||||
|
||||
candidate_ids = []
|
||||
if kb and not dev:
|
||||
candidates = kb.get_candidates(alias)
|
||||
candidate_ids = [cand.entity_ for cand in candidates]
|
||||
|
||||
tagged_ent = tagged_ent_positions.get((start, end), None)
|
||||
if tagged_ent:
|
||||
# TODO: check that alias == doc.text[start:end]
|
||||
should_add_ent = (dev or entity_id in candidate_ids) and is_valid_sentence(
|
||||
tagged_ent.sent.text
|
||||
)
|
||||
|
||||
if should_add_ent:
|
||||
value_by_id = {entity_id: 1.0}
|
||||
if not dev:
|
||||
random.shuffle(candidate_ids)
|
||||
value_by_id.update(
|
||||
{kb_id: 0.0 for kb_id in candidate_ids if kb_id != entity_id}
|
||||
)
|
||||
gold_entities[(start, end)] = value_by_id
|
||||
|
||||
return GoldParse(doc, links=gold_entities)
|
||||
|
||||
|
||||
def is_dev(article_id):
|
||||
if not article_id:
|
||||
return False
|
||||
return article_id.endswith("3")
|
||||
|
||||
|
||||
def is_valid_article(doc_text):
|
||||
# custom length cut-off
|
||||
return 10 < len(doc_text) < 30000
|
||||
|
||||
|
||||
def is_valid_sentence(sent_text):
|
||||
if not 10 < len(sent_text) < 3000:
|
||||
# custom length cut-off
|
||||
return False
|
||||
|
||||
if sent_text.strip().startswith("*") or sent_text.strip().startswith("#"):
|
||||
# remove 'enumeration' sentences (occurs often on Wikipedia)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
|
|
@ -246,7 +246,7 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
|
|||
"""Perform an update over a single batch of documents.
|
||||
|
||||
docs (iterable): A batch of `Doc` objects.
|
||||
drop (float): The droput rate.
|
||||
drop (float): The dropout rate.
|
||||
optimizer (callable): An optimizer.
|
||||
RETURNS loss: A float for the loss.
|
||||
"""
|
||||
|
|
|
@ -80,8 +80,8 @@ class Warnings(object):
|
|||
"the v2.x models cannot release the global interpreter lock. "
|
||||
"Future versions may introduce a `n_process` argument for "
|
||||
"parallel inference via multiprocessing.")
|
||||
W017 = ("Alias '{alias}' already exists in the Knowledge base.")
|
||||
W018 = ("Entity '{entity}' already exists in the Knowledge base.")
|
||||
W017 = ("Alias '{alias}' already exists in the Knowledge Base.")
|
||||
W018 = ("Entity '{entity}' already exists in the Knowledge Base.")
|
||||
W019 = ("Changing vectors name from {old} to {new}, to avoid clash with "
|
||||
"previously loaded vectors. See Issue #3853.")
|
||||
W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
|
||||
|
@ -96,6 +96,8 @@ class Warnings(object):
|
|||
"If this is surprising, make sure you have the spacy-lookups-data "
|
||||
"package installed.")
|
||||
W023 = ("Multiprocessing of Language.pipe is not supported in Python2. 'n_process' will be set to 1.")
|
||||
W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in "
|
||||
"the Knowledge Base.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
@ -408,7 +410,7 @@ class Errors(object):
|
|||
"{probabilities_length} respectively.")
|
||||
E133 = ("The sum of prior probabilities for alias '{alias}' should not "
|
||||
"exceed 1, but found {sum}.")
|
||||
E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.")
|
||||
E134 = ("Entity '{entity}' is not defined in the Knowledge Base.")
|
||||
E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
|
||||
"`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
|
||||
E136 = ("This additional feature requires the jsonschema library to be "
|
||||
|
@ -420,7 +422,7 @@ class Errors(object):
|
|||
E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
|
||||
"includes either the `text` or `tokens` key. For more info, see "
|
||||
"the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
|
||||
E139 = ("Knowledge base for component '{name}' not initialized. Did you "
|
||||
E139 = ("Knowledge Base for component '{name}' not initialized. Did you "
|
||||
"forget to call set_kb()?")
|
||||
E140 = ("The list of entities, prior probabilities and entity vectors "
|
||||
"should be of equal length.")
|
||||
|
@ -499,6 +501,7 @@ class Errors(object):
|
|||
E174 = ("Architecture '{name}' not found in registry. Available "
|
||||
"names: {names}")
|
||||
E175 = ("Can't remove rule for unknown match pattern ID: {key}")
|
||||
E176 = ("Alias '{alias}' is not defined in the Knowledge Base.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
69
spacy/kb.pyx
69
spacy/kb.pyx
|
@ -142,6 +142,7 @@ cdef class KnowledgeBase:
|
|||
|
||||
i = 0
|
||||
cdef KBEntryC entry
|
||||
cdef hash_t entity_hash
|
||||
while i < nr_entities:
|
||||
entity_vector = vector_list[i]
|
||||
if len(entity_vector) != self.entity_vector_length:
|
||||
|
@ -161,6 +162,14 @@ cdef class KnowledgeBase:
|
|||
|
||||
i += 1
|
||||
|
||||
def contains_entity(self, unicode entity):
|
||||
cdef hash_t entity_hash = self.vocab.strings.add(entity)
|
||||
return entity_hash in self._entry_index
|
||||
|
||||
def contains_alias(self, unicode alias):
|
||||
cdef hash_t alias_hash = self.vocab.strings.add(alias)
|
||||
return alias_hash in self._alias_index
|
||||
|
||||
def add_alias(self, unicode alias, entities, probabilities):
|
||||
"""
|
||||
For a given alias, add its potential entities and prior probabilies to the KB.
|
||||
|
@ -190,7 +199,7 @@ cdef class KnowledgeBase:
|
|||
for entity, prob in zip(entities, probabilities):
|
||||
entity_hash = self.vocab.strings[entity]
|
||||
if not entity_hash in self._entry_index:
|
||||
raise ValueError(Errors.E134.format(alias=alias, entity=entity))
|
||||
raise ValueError(Errors.E134.format(entity=entity))
|
||||
|
||||
entry_index = <int64_t>self._entry_index.get(entity_hash)
|
||||
entry_indices.push_back(int(entry_index))
|
||||
|
@ -201,8 +210,63 @@ cdef class KnowledgeBase:
|
|||
|
||||
return alias_hash
|
||||
|
||||
def get_candidates(self, unicode alias):
|
||||
def append_alias(self, unicode alias, unicode entity, float prior_prob, ignore_warnings=False):
|
||||
"""
|
||||
For an alias already existing in the KB, extend its potential entities with one more.
|
||||
Throw a warning if either the alias or the entity is unknown,
|
||||
or when the combination is already previously recorded.
|
||||
Throw an error if this entity+prior prob would exceed the sum of 1.
|
||||
For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
|
||||
"""
|
||||
# Check if the alias exists in the KB
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
if not alias_hash in self._alias_index:
|
||||
raise ValueError(Errors.E176.format(alias=alias))
|
||||
|
||||
# Check if the entity exists in the KB
|
||||
cdef hash_t entity_hash = self.vocab.strings[entity]
|
||||
if not entity_hash in self._entry_index:
|
||||
raise ValueError(Errors.E134.format(entity=entity))
|
||||
entry_index = <int64_t>self._entry_index.get(entity_hash)
|
||||
|
||||
# Throw an error if the prior probabilities (including the new one) sum up to more than 1
|
||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||
alias_entry = self._aliases_table[alias_index]
|
||||
current_sum = sum([p for p in alias_entry.probs])
|
||||
new_sum = current_sum + prior_prob
|
||||
|
||||
if new_sum > 1.00001:
|
||||
raise ValueError(Errors.E133.format(alias=alias, sum=new_sum))
|
||||
|
||||
entry_indices = alias_entry.entry_indices
|
||||
|
||||
is_present = False
|
||||
for i in range(entry_indices.size()):
|
||||
if entry_indices[i] == int(entry_index):
|
||||
is_present = True
|
||||
|
||||
if is_present:
|
||||
if not ignore_warnings:
|
||||
user_warning(Warnings.W024.format(entity=entity, alias=alias))
|
||||
else:
|
||||
entry_indices.push_back(int(entry_index))
|
||||
alias_entry.entry_indices = entry_indices
|
||||
|
||||
probs = alias_entry.probs
|
||||
probs.push_back(float(prior_prob))
|
||||
alias_entry.probs = probs
|
||||
self._aliases_table[alias_index] = alias_entry
|
||||
|
||||
|
||||
def get_candidates(self, unicode alias):
|
||||
"""
|
||||
Return candidate entities for an alias. Each candidate defines the entity, the original alias,
|
||||
and the prior probability of that alias resolving to that entity.
|
||||
If the alias is not known in the KB, and empty list is returned.
|
||||
"""
|
||||
cdef hash_t alias_hash = self.vocab.strings[alias]
|
||||
if not alias_hash in self._alias_index:
|
||||
return []
|
||||
alias_index = <int64_t>self._alias_index.get(alias_hash)
|
||||
alias_entry = self._aliases_table[alias_index]
|
||||
|
||||
|
@ -341,7 +405,6 @@ cdef class KnowledgeBase:
|
|||
assert nr_entities == self.get_size_entities()
|
||||
|
||||
# STEP 3: load aliases
|
||||
|
||||
cdef int64_t nr_aliases
|
||||
reader.read_alias_length(&nr_aliases)
|
||||
self._alias_index = PreshMap(nr_aliases+1)
|
||||
|
|
|
@ -483,7 +483,7 @@ class Language(object):
|
|||
|
||||
docs (iterable): A batch of `Doc` objects.
|
||||
golds (iterable): A batch of `GoldParse` objects.
|
||||
drop (float): The droput rate.
|
||||
drop (float): The dropout rate.
|
||||
sgd (callable): An optimizer.
|
||||
losses (dict): Dictionary to update with the loss, keyed by component.
|
||||
component_cfg (dict): Config parameters for specific pipeline
|
||||
|
@ -531,7 +531,7 @@ class Language(object):
|
|||
even if you're updating it with a smaller set of examples.
|
||||
|
||||
docs (iterable): A batch of `Doc` objects.
|
||||
drop (float): The droput rate.
|
||||
drop (float): The dropout rate.
|
||||
sgd (callable): An optimizer.
|
||||
RETURNS (dict): Results from the update.
|
||||
|
||||
|
|
|
@ -1195,23 +1195,26 @@ class EntityLinker(Pipe):
|
|||
docs = [docs]
|
||||
golds = [golds]
|
||||
|
||||
context_docs = []
|
||||
sentence_docs = []
|
||||
|
||||
for doc, gold in zip(docs, golds):
|
||||
ents_by_offset = dict()
|
||||
for ent in doc.ents:
|
||||
ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent
|
||||
ents_by_offset[(ent.start_char, ent.end_char)] = ent
|
||||
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
mention = doc.text[start:end]
|
||||
# the gold annotations should link to proper entities - if this fails, the dataset is likely corrupt
|
||||
ent = ents_by_offset[(start, end)]
|
||||
|
||||
for kb_id, value in kb_dict.items():
|
||||
# Currently only training on the positive instances
|
||||
if value:
|
||||
context_docs.append(doc)
|
||||
sentence_docs.append(ent.sent.as_doc())
|
||||
|
||||
context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop)
|
||||
loss, d_scores = self.get_similarity_loss(scores=context_encodings, golds=golds, docs=None)
|
||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop)
|
||||
loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None)
|
||||
bp_context(d_scores, sgd=sgd)
|
||||
|
||||
if losses is not None:
|
||||
|
@ -1280,50 +1283,68 @@ class EntityLinker(Pipe):
|
|||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
|
||||
context_encodings = self.model(docs)
|
||||
xp = get_array_module(context_encodings)
|
||||
|
||||
for i, doc in enumerate(docs):
|
||||
if len(doc) > 0:
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
context_encoding = context_encodings[i]
|
||||
context_enc_t = context_encoding.T
|
||||
norm_1 = xp.linalg.norm(context_enc_t)
|
||||
for ent in doc.ents:
|
||||
entity_count += 1
|
||||
# Looping through each sentence and each entity
|
||||
# This may go wrong if there are entities across sentences - because they might not get a KB ID
|
||||
for sent in doc.ents:
|
||||
sent_doc = sent.as_doc()
|
||||
# currently, the context is the same for each entity in a sentence (should be refined)
|
||||
sentence_encoding = self.model([sent_doc])[0]
|
||||
xp = get_array_module(sentence_encoding)
|
||||
sentence_encoding_t = sentence_encoding.T
|
||||
sentence_norm = xp.linalg.norm(sentence_encoding_t)
|
||||
|
||||
candidates = self.kb.get_candidates(ent.text)
|
||||
if not candidates:
|
||||
final_kb_ids.append(self.NIL) # no prediction possible for this entity
|
||||
final_tensors.append(context_encoding)
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
for ent in sent_doc.ents:
|
||||
entity_count += 1
|
||||
|
||||
# this will set all prior probabilities to 0 if they should be excluded from the model
|
||||
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
||||
if not self.cfg.get("incl_prior", True):
|
||||
prior_probs = xp.asarray([0.0 for c in candidates])
|
||||
scores = prior_probs
|
||||
if ent.label_ in self.cfg.get("labels_discard", []):
|
||||
# ignoring this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
# add in similarity from the context
|
||||
if self.cfg.get("incl_context", True):
|
||||
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
||||
norm_2 = xp.linalg.norm(entity_encodings, axis=1)
|
||||
else:
|
||||
candidates = self.kb.get_candidates(ent.text)
|
||||
if not candidates:
|
||||
# no prediction possible for this entity - setting to NIL
|
||||
final_kb_ids.append(self.NIL)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
if len(entity_encodings) != len(prior_probs):
|
||||
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
|
||||
elif len(candidates) == 1:
|
||||
# shortcut for efficiency reasons: take the 1 candidate
|
||||
|
||||
# cosine similarity
|
||||
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
|
||||
if sims.shape != prior_probs.shape:
|
||||
raise ValueError(Errors.E161)
|
||||
scores = prior_probs + sims - (prior_probs*sims)
|
||||
# TODO: thresholding
|
||||
final_kb_ids.append(candidates[0].entity_)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
# TODO: thresholding
|
||||
best_index = scores.argmax()
|
||||
best_candidate = candidates[best_index]
|
||||
final_kb_ids.append(best_candidate.entity_)
|
||||
final_tensors.append(context_encoding)
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
|
||||
# this will set all prior probabilities to 0 if they should be excluded from the model
|
||||
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
||||
if not self.cfg.get("incl_prior", True):
|
||||
prior_probs = xp.asarray([0.0 for c in candidates])
|
||||
scores = prior_probs
|
||||
|
||||
# add in similarity from the context
|
||||
if self.cfg.get("incl_context", True):
|
||||
entity_encodings = xp.asarray([c.entity_vector for c in candidates])
|
||||
entity_norm = xp.linalg.norm(entity_encodings, axis=1)
|
||||
|
||||
if len(entity_encodings) != len(prior_probs):
|
||||
raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
|
||||
|
||||
# cosine similarity
|
||||
sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm)
|
||||
if sims.shape != prior_probs.shape:
|
||||
raise ValueError(Errors.E161)
|
||||
scores = prior_probs + sims - (prior_probs*sims)
|
||||
|
||||
# TODO: thresholding
|
||||
best_index = scores.argmax()
|
||||
best_candidate = candidates[best_index]
|
||||
final_kb_ids.append(best_candidate.entity_)
|
||||
final_tensors.append(sentence_encoding)
|
||||
|
||||
if not (len(final_tensors) == len(final_kb_ids) == entity_count):
|
||||
raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
|
||||
|
|
|
@ -131,6 +131,53 @@ def test_candidate_generation(nlp):
|
|||
assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
|
||||
|
||||
|
||||
def test_append_alias(nlp):
|
||||
"""Test that we can append additional alias-entity pairs"""
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.4, 0.1])
|
||||
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||
|
||||
# test the size of the relevant candidates
|
||||
assert len(mykb.get_candidates("douglas")) == 2
|
||||
|
||||
# append an alias
|
||||
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
|
||||
|
||||
# test the size of the relevant candidates has been incremented
|
||||
assert len(mykb.get_candidates("douglas")) == 3
|
||||
|
||||
# append the same alias-entity pair again should not work (will throw a warning)
|
||||
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3)
|
||||
|
||||
# test the size of the relevant candidates remained unchanged
|
||||
assert len(mykb.get_candidates("douglas")) == 3
|
||||
|
||||
|
||||
def test_append_invalid_alias(nlp):
|
||||
"""Test that append an alias will throw an error if prior probs are exceeding 1"""
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
||||
# adding entities
|
||||
mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
|
||||
mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
|
||||
mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
|
||||
|
||||
# adding aliases
|
||||
mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
|
||||
mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
|
||||
|
||||
# append an alias - should fail because the entities and probabilities vectors are not of equal length
|
||||
with pytest.raises(ValueError):
|
||||
mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
|
||||
|
||||
|
||||
def test_preserving_links_asdoc(nlp):
|
||||
"""Test that Span.as_doc preserves the existing entity links"""
|
||||
mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
|
||||
|
|
|
@ -430,7 +430,7 @@ def test_issue957(en_tokenizer):
|
|||
def test_issue999(train_data):
|
||||
"""Test that adding entities and resuming training works passably OK.
|
||||
There are two issues here:
|
||||
1) We have to read labels. This isn't very nice.
|
||||
1) We have to re-add labels. This isn't very nice.
|
||||
2) There's no way to set the learning rate for the weight update, so we
|
||||
end up out-of-scale, causing it to learn too fast.
|
||||
"""
|
||||
|
|
Loading…
Reference in New Issue
Block a user