diff --git a/bin/wiki_entity_linking/README.md b/bin/wiki_entity_linking/README.md new file mode 100644 index 000000000..540878592 --- /dev/null +++ b/bin/wiki_entity_linking/README.md @@ -0,0 +1,34 @@ +## Entity Linking with Wikipedia and Wikidata + +### Step 1: Create a Knowledge Base (KB) and training data + +Run `wikipedia_pretrain_kb.py` +* This takes as input the locations of a **Wikipedia and a Wikidata dump**, and produces a **KB directory** + **training file** + * WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/ + * Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language) +* You can set the filtering parameters for KB construction: + * `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym + * `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB + * `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB +* Further parameters to set: + * `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`) + * `entity_vector_length`: length of the pre-trained entity description vectors + * `lang`: language for which to fetch Wikidata information (as the dump contains all languages) + +Quick testing and rerunning: +* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything. +* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed. + + +### Step 2: Train an Entity Linking model + +Run `wikidata_train_entity_linker.py` +* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model** +* You can set the learning parameters for the EL training: + * `epochs`: number of training iterations + * `dropout`: dropout rate + * `lr`: learning rate + * `l2`: L2 regularization +* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively +* Further parameters to set: + * `labels_discard`: NER label types to discard during training diff --git a/bin/wiki_entity_linking/__init__.py b/bin/wiki_entity_linking/__init__.py index a604bcc2f..de486bbcf 100644 --- a/bin/wiki_entity_linking/__init__.py +++ b/bin/wiki_entity_linking/__init__.py @@ -6,6 +6,7 @@ OUTPUT_MODEL_DIR = "nlp" PRIOR_PROB_PATH = "prior_prob.csv" ENTITY_DEFS_PATH = "entity_defs.csv" ENTITY_FREQ_PATH = "entity_freq.csv" +ENTITY_ALIAS_PATH = "entity_alias.csv" ENTITY_DESCR_PATH = "entity_descriptions.csv" LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' diff --git a/bin/wiki_entity_linking/entity_linker_evaluation.py b/bin/wiki_entity_linking/entity_linker_evaluation.py index 1b1200564..94bafbf30 100644 --- a/bin/wiki_entity_linking/entity_linker_evaluation.py +++ b/bin/wiki_entity_linking/entity_linker_evaluation.py @@ -15,10 +15,11 @@ class Metrics(object): candidate_is_correct = true_entity == candidate # Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL") - # Therefore, if candidate_is_correct then we have a true positive and never a true negative + # Therefore, if candidate_is_correct then we have a true positive and never a true negative. self.true_pos += candidate_is_correct self.false_neg += not candidate_is_correct - if candidate not in {"", "NIL"}: + if candidate and candidate not in {"", "NIL"}: + # A wrong prediction (e.g. Q42 != Q3) counts both as a FP as well as a FN. self.false_pos += not candidate_is_correct def calculate_precision(self): @@ -33,6 +34,14 @@ class Metrics(object): else: return self.true_pos / (self.true_pos + self.false_neg) + def calculate_fscore(self): + p = self.calculate_precision() + r = self.calculate_recall() + if p + r == 0: + return 0.0 + else: + return 2 * p * r / (p + r) + class EvaluationResults(object): def __init__(self): @@ -43,18 +52,20 @@ class EvaluationResults(object): self.metrics.update_results(true_entity, candidate) self.metrics_by_label[ent_label].update_results(true_entity, candidate) - def increment_false_negatives(self): - self.metrics.false_neg += 1 - def report_metrics(self, model_name): model_str = model_name.title() recall = self.metrics.calculate_recall() precision = self.metrics.calculate_precision() - return ("{}: ".format(model_str) + - "Recall = {} | ".format(round(recall, 3)) + - "Precision = {} | ".format(round(precision, 3)) + - "Precision by label = {}".format({k: v.calculate_precision() - for k, v in self.metrics_by_label.items()})) + fscore = self.metrics.calculate_fscore() + return ( + "{}: ".format(model_str) + + "F-score = {} | ".format(round(fscore, 3)) + + "Recall = {} | ".format(round(recall, 3)) + + "Precision = {} | ".format(round(precision, 3)) + + "F-score by label = {}".format( + {k: v.calculate_fscore() for k, v in sorted(self.metrics_by_label.items())} + ) + ) class BaselineResults(object): @@ -63,40 +74,51 @@ class BaselineResults(object): self.prior = EvaluationResults() self.oracle = EvaluationResults() - def report_accuracy(self, model): + def report_performance(self, model): results = getattr(self, model) return results.report_metrics(model) - def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate): + def update_baselines( + self, + true_entity, + ent_label, + random_candidate, + prior_candidate, + oracle_candidate, + ): self.oracle.update_metrics(ent_label, true_entity, oracle_candidate) self.prior.update_metrics(ent_label, true_entity, prior_candidate) self.random.update_metrics(ent_label, true_entity, random_candidate) -def measure_performance(dev_data, kb, el_pipe): - baseline_accuracies = measure_baselines( - dev_data, kb - ) +def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True): + if baseline: + baseline_accuracies, counts = measure_baselines(dev_data, kb) + logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())})) + logger.info(baseline_accuracies.report_performance("random")) + logger.info(baseline_accuracies.report_performance("prior")) + logger.info(baseline_accuracies.report_performance("oracle")) - logger.info(baseline_accuracies.report_accuracy("random")) - logger.info(baseline_accuracies.report_accuracy("prior")) - logger.info(baseline_accuracies.report_accuracy("oracle")) + if context: + # using only context + el_pipe.cfg["incl_context"] = True + el_pipe.cfg["incl_prior"] = False + results = get_eval_results(dev_data, el_pipe) + logger.info(results.report_metrics("context only")) - # using only context - el_pipe.cfg["incl_context"] = True - el_pipe.cfg["incl_prior"] = False - results = get_eval_results(dev_data, el_pipe) - logger.info(results.report_metrics("context only")) - - # measuring combined accuracy (prior + context) - el_pipe.cfg["incl_context"] = True - el_pipe.cfg["incl_prior"] = True - results = get_eval_results(dev_data, el_pipe) - logger.info(results.report_metrics("context and prior")) + # measuring combined accuracy (prior + context) + el_pipe.cfg["incl_context"] = True + el_pipe.cfg["incl_prior"] = True + results = get_eval_results(dev_data, el_pipe) + logger.info(results.report_metrics("context and prior")) def get_eval_results(data, el_pipe=None): - # If the docs in the data require further processing with an entity linker, set el_pipe + """ + Evaluate the ent.kb_id_ annotations against the gold standard. + Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL. + If the docs in the data require further processing with an entity linker, set el_pipe. + """ from tqdm import tqdm docs = [] @@ -111,18 +133,15 @@ def get_eval_results(data, el_pipe=None): results = EvaluationResults() for doc, gold in zip(docs, golds): - tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents} try: correct_entries_per_article = dict() for entity, kb_dict in gold.links.items(): start, end = entity - # only evaluating on positive examples for gold_kb, value in kb_dict.items(): if value: + # only evaluating on positive examples offset = _offset(start, end) correct_entries_per_article[offset] = gold_kb - if offset not in tagged_entries_per_article: - results.increment_false_negatives() for ent in doc.ents: ent_label = ent.label_ @@ -142,7 +161,11 @@ def get_eval_results(data, el_pipe=None): def measure_baselines(data, kb): - # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound + """ + Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound. + Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL. + Also return a dictionary of counts by entity label. + """ counts_d = dict() baseline_results = BaselineResults() @@ -152,7 +175,6 @@ def measure_baselines(data, kb): for doc, gold in zip(docs, golds): correct_entries_per_article = dict() - tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents} for entity, kb_dict in gold.links.items(): start, end = entity for gold_kb, value in kb_dict.items(): @@ -160,10 +182,6 @@ def measure_baselines(data, kb): if value: offset = _offset(start, end) correct_entries_per_article[offset] = gold_kb - if offset not in tagged_entries_per_article: - baseline_results.random.increment_false_negatives() - baseline_results.oracle.increment_false_negatives() - baseline_results.prior.increment_false_negatives() for ent in doc.ents: ent_label = ent.label_ @@ -176,7 +194,7 @@ def measure_baselines(data, kb): if gold_entity is not None: candidates = kb.get_candidates(ent.text) oracle_candidate = "" - best_candidate = "" + prior_candidate = "" random_candidate = "" if candidates: scores = [] @@ -187,13 +205,21 @@ def measure_baselines(data, kb): oracle_candidate = c.entity_ best_index = scores.index(max(scores)) - best_candidate = candidates[best_index].entity_ + prior_candidate = candidates[best_index].entity_ random_candidate = random.choice(candidates).entity_ - baseline_results.update_baselines(gold_entity, ent_label, - random_candidate, best_candidate, oracle_candidate) + current_count = counts_d.get(ent_label, 0) + counts_d[ent_label] = current_count+1 - return baseline_results + baseline_results.update_baselines( + gold_entity, + ent_label, + random_candidate, + prior_candidate, + oracle_candidate, + ) + + return baseline_results, counts_d def _offset(start, end): diff --git a/bin/wiki_entity_linking/kb_creator.py b/bin/wiki_entity_linking/kb_creator.py index 0eeb63803..7778fc701 100644 --- a/bin/wiki_entity_linking/kb_creator.py +++ b/bin/wiki_entity_linking/kb_creator.py @@ -1,17 +1,12 @@ # coding: utf-8 from __future__ import unicode_literals -import csv import logging -import spacy -import sys from spacy.kb import KnowledgeBase -from bin.wiki_entity_linking import wikipedia_processor as wp from bin.wiki_entity_linking.train_descriptions import EntityEncoder - -csv.field_size_limit(sys.maxsize) +from bin.wiki_entity_linking import wiki_io as io logger = logging.getLogger(__name__) @@ -22,18 +17,24 @@ def create_kb( max_entities_per_alias, min_entity_freq, min_occ, - entity_def_input, + entity_def_path, entity_descr_path, - count_input, - prior_prob_input, + entity_alias_path, + entity_freq_path, + prior_prob_path, entity_vector_length, ): # Create the knowledge base from Wikidata entries kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length) + entity_list, filtered_title_to_id = _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length) + _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path) + return kb + +def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length): # read the mappings from file - title_to_id = get_entity_to_id(entity_def_input) - id_to_descr = get_id_to_description(entity_descr_path) + title_to_id = io.read_title_to_id(entity_def_path) + id_to_descr = io.read_id_to_descr(entity_descr_path) # check the length of the nlp vectors if "vectors" in nlp.meta and nlp.vocab.vectors.size: @@ -45,10 +46,8 @@ def create_kb( " cf. https://spacy.io/usage/models#languages." ) - logger.info("Get entity frequencies") - entity_frequencies = wp.get_all_frequencies(count_input=count_input) - logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq)) + entity_frequencies = io.read_entity_to_count(entity_freq_path) # filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities( title_to_id, @@ -56,36 +55,33 @@ def create_kb( entity_frequencies, min_entity_freq ) - logger.info("Left with {} entities".format(len(description_list))) + logger.info("Kept {} entities from the set of {}".format(len(description_list), len(title_to_id.keys()))) - logger.info("Train entity encoder") + logger.info("Training entity encoder") encoder = EntityEncoder(nlp, input_dim, entity_vector_length) encoder.train(description_list=description_list, to_print=True) - logger.info("Get entity embeddings:") + logger.info("Getting entity embeddings") embeddings = encoder.apply_encoder(description_list) logger.info("Adding {} entities".format(len(entity_list))) kb.set_entities( entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings ) + return entity_list, filtered_title_to_id - logger.info("Adding aliases") + +def _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path): + logger.info("Adding aliases from Wikipedia and Wikidata") _add_aliases( kb, + entity_list=entity_list, title_to_id=filtered_title_to_id, max_entities_per_alias=max_entities_per_alias, min_occ=min_occ, - prior_prob_input=prior_prob_input, + prior_prob_path=prior_prob_path, ) - logger.info("KB size: {} entities, {} aliases".format( - kb.get_size_entities(), - kb.get_size_aliases())) - - logger.info("Done with kb") - return kb - def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies, min_entity_freq: int = 10): @@ -104,34 +100,13 @@ def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies, return filtered_title_to_id, entity_list, description_list, frequency_list -def get_entity_to_id(entity_def_output): - entity_to_id = dict() - with entity_def_output.open("r", encoding="utf8") as csvfile: - csvreader = csv.reader(csvfile, delimiter="|") - # skip header - next(csvreader) - for row in csvreader: - entity_to_id[row[0]] = row[1] - return entity_to_id - - -def get_id_to_description(entity_descr_path): - id_to_desc = dict() - with entity_descr_path.open("r", encoding="utf8") as csvfile: - csvreader = csv.reader(csvfile, delimiter="|") - # skip header - next(csvreader) - for row in csvreader: - id_to_desc[row[0]] = row[1] - return id_to_desc - - -def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input): +def _add_aliases(kb, entity_list, title_to_id, max_entities_per_alias, min_occ, prior_prob_path): wp_titles = title_to_id.keys() # adding aliases with prior probabilities # we can read this file sequentially, it's sorted by alias, and then by count - with prior_prob_input.open("r", encoding="utf8") as prior_file: + logger.info("Adding WP aliases") + with prior_prob_path.open("r", encoding="utf8") as prior_file: # skip header prior_file.readline() line = prior_file.readline() @@ -180,10 +155,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in line = prior_file.readline() -def read_nlp_kb(model_dir, kb_file): - nlp = spacy.load(model_dir) +def read_kb(nlp, kb_file): kb = KnowledgeBase(vocab=nlp.vocab) kb.load_bulk(kb_file) - logger.info("kb entities: {}".format(kb.get_size_entities())) - logger.info("kb aliases: {}".format(kb.get_size_aliases())) - return nlp, kb + return kb diff --git a/bin/wiki_entity_linking/train_descriptions.py b/bin/wiki_entity_linking/train_descriptions.py index 2cb66909f..af08d6b8f 100644 --- a/bin/wiki_entity_linking/train_descriptions.py +++ b/bin/wiki_entity_linking/train_descriptions.py @@ -53,7 +53,7 @@ class EntityEncoder: start = start + batch_size stop = min(stop + batch_size, len(description_list)) - logger.info("encoded: {} entities".format(stop)) + logger.info("Encoded: {} entities".format(stop)) return encodings @@ -62,7 +62,7 @@ class EntityEncoder: if to_print: logger.info( "Trained entity descriptions on {} ".format(processed) + - "(non-unique) entities across {} ".format(self.epochs) + + "(non-unique) descriptions across {} ".format(self.epochs) + "epochs" ) logger.info("Final loss: {}".format(loss)) diff --git a/bin/wiki_entity_linking/training_set_creator.py b/bin/wiki_entity_linking/training_set_creator.py deleted file mode 100644 index 3f42f8bdd..000000000 --- a/bin/wiki_entity_linking/training_set_creator.py +++ /dev/null @@ -1,395 +0,0 @@ -# coding: utf-8 -from __future__ import unicode_literals - -import logging -import random -import re -import bz2 -import json - -from functools import partial - -from spacy.gold import GoldParse -from bin.wiki_entity_linking import kb_creator - -""" -Process Wikipedia interlinks to generate a training dataset for the EL algorithm. -Gold-standard entities are stored in one file in standoff format (by character offset). -""" - -ENTITY_FILE = "gold_entities.csv" -logger = logging.getLogger(__name__) - - -def create_training_examples_and_descriptions(wikipedia_input, - entity_def_input, - description_output, - training_output, - parse_descriptions, - limit=None): - wp_to_id = kb_creator.get_entity_to_id(entity_def_input) - _process_wikipedia_texts(wikipedia_input, - wp_to_id, - description_output, - training_output, - parse_descriptions, - limit) - - -def _process_wikipedia_texts(wikipedia_input, - wp_to_id, - output, - training_output, - parse_descriptions, - limit=None): - """ - Read the XML wikipedia data to parse out training data: - raw text data + positive instances - """ - title_regex = re.compile(r"(?<=).*(?=)") - id_regex = re.compile(r"(?<=)\d*(?=)") - - read_ids = set() - - with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file: - if parse_descriptions: - _write_training_description(descr_file, "WD_id", "description") - with bz2.open(wikipedia_input, mode="rb") as file: - article_count = 0 - article_text = "" - article_title = None - article_id = None - reading_text = False - reading_revision = False - - logger.info("Processed {} articles".format(article_count)) - - for line in file: - clean_line = line.strip().decode("utf-8") - - if clean_line == "": - reading_revision = True - elif clean_line == "": - reading_revision = False - - # Start reading new page - if clean_line == "": - article_text = "" - article_title = None - article_id = None - # finished reading this page - elif clean_line == "": - if article_id: - clean_text, entities = _process_wp_text( - article_title, - article_text, - wp_to_id - ) - if clean_text is not None and entities is not None: - _write_training_entities(entity_file, - article_id, - clean_text, - entities) - - if article_title in wp_to_id and parse_descriptions: - description = " ".join(clean_text[:1000].split(" ")[:-1]) - _write_training_description( - descr_file, - wp_to_id[article_title], - description - ) - article_count += 1 - if article_count % 10000 == 0: - logger.info("Processed {} articles".format(article_count)) - if limit and article_count >= limit: - break - article_text = "" - article_title = None - article_id = None - reading_text = False - reading_revision = False - - # start reading text within a page - if ").*(?=") - clean_text = clean_text.replace(r""", '"') - clean_text = clean_text.replace(r"&nbsp;", " ") - clean_text = clean_text.replace(r"&", "&") - - # remove multiple spaces - while " " in clean_text: - clean_text = clean_text.replace(" ", " ") - - return clean_text.strip() - - -def _remove_links(clean_text, wp_to_id): - # read the text char by char to get the right offsets for the interwiki links - entities = [] - final_text = "" - open_read = 0 - reading_text = True - reading_entity = False - reading_mention = False - reading_special_case = False - entity_buffer = "" - mention_buffer = "" - for index, letter in enumerate(clean_text): - if letter == "[": - open_read += 1 - elif letter == "]": - open_read -= 1 - elif letter == "|": - if reading_text: - final_text += letter - # switch from reading entity to mention in the [[entity|mention]] pattern - elif reading_entity: - reading_text = False - reading_entity = False - reading_mention = True - else: - reading_special_case = True - else: - if reading_entity: - entity_buffer += letter - elif reading_mention: - mention_buffer += letter - elif reading_text: - final_text += letter - else: - raise ValueError("Not sure at point", clean_text[index - 2: index + 2]) - - if open_read > 2: - reading_special_case = True - - if open_read == 2 and reading_text: - reading_text = False - reading_entity = True - reading_mention = False - - # we just finished reading an entity - if open_read == 0 and not reading_text: - if "#" in entity_buffer or entity_buffer.startswith(":"): - reading_special_case = True - # Ignore cases with nested structures like File: handles etc - if not reading_special_case: - if not mention_buffer: - mention_buffer = entity_buffer - start = len(final_text) - end = start + len(mention_buffer) - qid = wp_to_id.get(entity_buffer, None) - if qid: - entities.append((mention_buffer, qid, start, end)) - final_text += mention_buffer - - entity_buffer = "" - mention_buffer = "" - - reading_text = True - reading_entity = False - reading_mention = False - reading_special_case = False - return final_text, entities - - -def _write_training_description(outputfile, qid, description): - if description is not None: - line = str(qid) + "|" + description + "\n" - outputfile.write(line) - - -def _write_training_entities(outputfile, article_id, clean_text, entities): - entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities] - line = json.dumps( - { - "article_id": article_id, - "clean_text": clean_text, - "entities": entities_data - }, - ensure_ascii=False) + "\n" - outputfile.write(line) - - -def read_training(nlp, entity_file_path, dev, limit, kb): - """ This method provides training examples that correspond to the entity annotations found by the nlp object. - For training,, it will include negative training examples by using the candidate generator, - and it will only keep positive training examples that can be found by using the candidate generator. - For testing, it will include all positive examples only.""" - - from tqdm import tqdm - data = [] - num_entities = 0 - get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb) - - logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit)) - with entity_file_path.open("r", encoding="utf8") as file: - with tqdm(total=limit, leave=False) as pbar: - for i, line in enumerate(file): - example = json.loads(line) - article_id = example["article_id"] - clean_text = example["clean_text"] - entities = example["entities"] - - if dev != is_dev(article_id) or len(clean_text) >= 30000: - continue - - doc = nlp(clean_text) - gold = get_gold_parse(doc, entities) - if gold and len(gold.links) > 0: - data.append((doc, gold)) - num_entities += len(gold.links) - pbar.update(len(gold.links)) - if limit and num_entities >= limit: - break - logger.info("Read {} entities in {} articles".format(num_entities, len(data))) - return data - - -def _get_gold_parse(doc, entities, dev, kb): - gold_entities = {} - tagged_ent_positions = set( - [(ent.start_char, ent.end_char) for ent in doc.ents] - ) - - for entity in entities: - entity_id = entity["entity"] - alias = entity["alias"] - start = entity["start"] - end = entity["end"] - - candidates = kb.get_candidates(alias) - candidate_ids = [ - c.entity_ for c in candidates - ] - - should_add_ent = ( - dev or - ( - (start, end) in tagged_ent_positions and - entity_id in candidate_ids and - len(candidates) > 1 - ) - ) - - if should_add_ent: - value_by_id = {entity_id: 1.0} - if not dev: - random.shuffle(candidate_ids) - value_by_id.update({ - kb_id: 0.0 - for kb_id in candidate_ids - if kb_id != entity_id - }) - gold_entities[(start, end)] = value_by_id - - return GoldParse(doc, links=gold_entities) - - -def is_dev(article_id): - return article_id.endswith("3") diff --git a/bin/wiki_entity_linking/wiki_io.py b/bin/wiki_entity_linking/wiki_io.py new file mode 100644 index 000000000..43ae87f0f --- /dev/null +++ b/bin/wiki_entity_linking/wiki_io.py @@ -0,0 +1,127 @@ +# coding: utf-8 +from __future__ import unicode_literals + +import sys +import csv + +# min() needed to prevent error on windows, cf https://stackoverflow.com/questions/52404416/ +csv.field_size_limit(min(sys.maxsize, 2147483646)) + +""" This class provides reading/writing methods for temp files """ + + +# Entity definition: WP title -> WD ID # +def write_title_to_id(entity_def_output, title_to_id): + with entity_def_output.open("w", encoding="utf8") as id_file: + id_file.write("WP_title" + "|" + "WD_id" + "\n") + for title, qid in title_to_id.items(): + id_file.write(title + "|" + str(qid) + "\n") + + +def read_title_to_id(entity_def_output): + title_to_id = dict() + with entity_def_output.open("r", encoding="utf8") as id_file: + csvreader = csv.reader(id_file, delimiter="|") + # skip header + next(csvreader) + for row in csvreader: + title_to_id[row[0]] = row[1] + return title_to_id + + +# Entity aliases from WD: WD ID -> WD alias # +def write_id_to_alias(entity_alias_path, id_to_alias): + with entity_alias_path.open("w", encoding="utf8") as alias_file: + alias_file.write("WD_id" + "|" + "alias" + "\n") + for qid, alias_list in id_to_alias.items(): + for alias in alias_list: + alias_file.write(str(qid) + "|" + alias + "\n") + + +def read_id_to_alias(entity_alias_path): + id_to_alias = dict() + with entity_alias_path.open("r", encoding="utf8") as alias_file: + csvreader = csv.reader(alias_file, delimiter="|") + # skip header + next(csvreader) + for row in csvreader: + qid = row[0] + alias = row[1] + alias_list = id_to_alias.get(qid, []) + alias_list.append(alias) + id_to_alias[qid] = alias_list + return id_to_alias + + +def read_alias_to_id_generator(entity_alias_path): + """ Read (aliases, qid) tuples """ + + with entity_alias_path.open("r", encoding="utf8") as alias_file: + csvreader = csv.reader(alias_file, delimiter="|") + # skip header + next(csvreader) + for row in csvreader: + qid = row[0] + alias = row[1] + yield alias, qid + + +# Entity descriptions from WD: WD ID -> WD alias # +def write_id_to_descr(entity_descr_output, id_to_descr): + with entity_descr_output.open("w", encoding="utf8") as descr_file: + descr_file.write("WD_id" + "|" + "description" + "\n") + for qid, descr in id_to_descr.items(): + descr_file.write(str(qid) + "|" + descr + "\n") + + +def read_id_to_descr(entity_desc_path): + id_to_desc = dict() + with entity_desc_path.open("r", encoding="utf8") as descr_file: + csvreader = csv.reader(descr_file, delimiter="|") + # skip header + next(csvreader) + for row in csvreader: + id_to_desc[row[0]] = row[1] + return id_to_desc + + +# Entity counts from WP: WP title -> count # +def write_entity_to_count(prior_prob_input, count_output): + # Write entity counts for quick access later + entity_to_count = dict() + total_count = 0 + + with prior_prob_input.open("r", encoding="utf8") as prior_file: + # skip header + prior_file.readline() + line = prior_file.readline() + + while line: + splits = line.replace("\n", "").split(sep="|") + # alias = splits[0] + count = int(splits[1]) + entity = splits[2] + + current_count = entity_to_count.get(entity, 0) + entity_to_count[entity] = current_count + count + + total_count += count + + line = prior_file.readline() + + with count_output.open("w", encoding="utf8") as entity_file: + entity_file.write("entity" + "|" + "count" + "\n") + for entity, count in entity_to_count.items(): + entity_file.write(entity + "|" + str(count) + "\n") + + +def read_entity_to_count(count_input): + entity_to_count = dict() + with count_input.open("r", encoding="utf8") as csvfile: + csvreader = csv.reader(csvfile, delimiter="|") + # skip header + next(csvreader) + for row in csvreader: + entity_to_count[row[0]] = int(row[1]) + + return entity_to_count diff --git a/bin/wiki_entity_linking/wiki_namespaces.py b/bin/wiki_entity_linking/wiki_namespaces.py new file mode 100644 index 000000000..e8f099ccd --- /dev/null +++ b/bin/wiki_entity_linking/wiki_namespaces.py @@ -0,0 +1,128 @@ +# coding: utf8 +from __future__ import unicode_literals + +# List of meta pages in Wikidata, should be kept out of the Knowledge base +WD_META_ITEMS = [ + "Q163875", + "Q191780", + "Q224414", + "Q4167836", + "Q4167410", + "Q4663903", + "Q11266439", + "Q13406463", + "Q15407973", + "Q18616576", + "Q19887878", + "Q22808320", + "Q23894233", + "Q33120876", + "Q42104522", + "Q47460393", + "Q64875536", + "Q66480449", +] + + +# TODO: add more cases from non-English WP's + +# List of prefixes that refer to Wikipedia "file" pages +WP_FILE_NAMESPACE = ["Bestand", "File"] + +# List of prefixes that refer to Wikipedia "category" pages +WP_CATEGORY_NAMESPACE = ["Kategori", "Category", "Categorie"] + +# List of prefixes that refer to Wikipedia "meta" pages +# these will/should be matched ignoring case +WP_META_NAMESPACE = ( + WP_FILE_NAMESPACE + + WP_CATEGORY_NAMESPACE + + [ + "b", + "betawikiversity", + "Book", + "c", + "Commons", + "d", + "dbdump", + "download", + "Draft", + "Education", + "Foundation", + "Gadget", + "Gadget definition", + "Gebruiker", + "gerrit", + "Help", + "Image", + "Incubator", + "m", + "mail", + "mailarchive", + "media", + "MediaWiki", + "MediaWiki talk", + "Mediawikiwiki", + "MediaZilla", + "Meta", + "Metawikipedia", + "Module", + "mw", + "n", + "nost", + "oldwikisource", + "otrs", + "OTRSwiki", + "Overleg gebruiker", + "outreach", + "outreachwiki", + "Portal", + "phab", + "Phabricator", + "Project", + "q", + "quality", + "rev", + "s", + "spcom", + "Special", + "species", + "Strategy", + "sulutil", + "svn", + "Talk", + "Template", + "Template talk", + "Testwiki", + "ticket", + "TimedText", + "Toollabs", + "tools", + "tswiki", + "User", + "User talk", + "v", + "voy", + "w", + "Wikibooks", + "Wikidata", + "wikiHow", + "Wikinvest", + "wikilivres", + "Wikimedia", + "Wikinews", + "Wikipedia", + "Wikipedia talk", + "Wikiquote", + "Wikisource", + "Wikispecies", + "Wikitech", + "Wikiversity", + "Wikivoyage", + "wikt", + "wiktionary", + "wmf", + "wmania", + "WP", + ] +) diff --git a/bin/wiki_entity_linking/wikidata_pretrain_kb.py b/bin/wiki_entity_linking/wikidata_pretrain_kb.py index 28650f039..940607b72 100644 --- a/bin/wiki_entity_linking/wikidata_pretrain_kb.py +++ b/bin/wiki_entity_linking/wikidata_pretrain_kb.py @@ -18,11 +18,12 @@ from pathlib import Path import plac from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd +from bin.wiki_entity_linking import wiki_io as io from bin.wiki_entity_linking import kb_creator -from bin.wiki_entity_linking import training_set_creator from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT -from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH +from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH, ENTITY_ALIAS_PATH import spacy +from bin.wiki_entity_linking.kb_creator import read_kb logger = logging.getLogger(__name__) @@ -39,9 +40,11 @@ logger = logging.getLogger(__name__) loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path), loc_entity_defs=("Location to file with entity definitions", "option", "d", Path), loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path), - descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"), - limit=("Optional threshold to limit lines read from dumps", "option", "l", int), - lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str), + descr_from_wp=("Flag for using wp descriptions not wd", "flag", "wp"), + limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int), + limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int), + limit_wd=("Threshold to limit lines read from WD", "option", "lw", int), + lang=("Optional language for which to get Wikidata titles. Defaults to 'en'", "option", "la", str), ) def main( wd_json, @@ -54,13 +57,16 @@ def main( entity_vector_length=64, loc_prior_prob=None, loc_entity_defs=None, + loc_entity_alias=None, loc_entity_desc=None, - descriptions_from_wikipedia=False, - limit=None, + descr_from_wp=False, + limit_prior=None, + limit_train=None, + limit_wd=None, lang="en", ): - entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH + entity_alias_path = loc_entity_alias if loc_entity_alias else output_dir / ENTITY_ALIAS_PATH entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH entity_freq_path = output_dir / ENTITY_FREQ_PATH prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH @@ -69,15 +75,12 @@ def main( logger.info("Creating KB with Wikipedia and WikiData") - if limit is not None: - logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit)) - # STEP 0: set up IO if not output_dir.exists(): output_dir.mkdir(parents=True) - # STEP 1: create the NLP object - logger.info("STEP 1: Loading model {}".format(model)) + # STEP 1: Load the NLP object + logger.info("STEP 1: Loading NLP model {}".format(model)) nlp = spacy.load(model) # check the length of the nlp vectors @@ -90,62 +93,83 @@ def main( # STEP 2: create prior probabilities from WP if not prior_prob_path.exists(): # It takes about 2h to process 1000M lines of Wikipedia XML dump - logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path)) - wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit) - logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path)) + logger.info("STEP 2: Writing prior probabilities to {}".format(prior_prob_path)) + if limit_prior is not None: + logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_prior)) + wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit_prior) + else: + logger.info("STEP 2: Reading prior probabilities from {}".format(prior_prob_path)) - # STEP 3: deduce entity frequencies from WP (takes only a few minutes) - logger.info("STEP 3: calculating entity frequencies") - wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False) + # STEP 3: calculate entity frequencies + if not entity_freq_path.exists(): + logger.info("STEP 3: Calculating and writing entity frequencies to {}".format(entity_freq_path)) + io.write_entity_to_count(prior_prob_path, entity_freq_path) + else: + logger.info("STEP 3: Reading entity frequencies from {}".format(entity_freq_path)) # STEP 4: reading definitions and (possibly) descriptions from WikiData or from file - message = " and descriptions" if not descriptions_from_wikipedia else "" - if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()): + if (not entity_defs_path.exists()) or (not descr_from_wp and not entity_descr_path.exists()): # It takes about 10h to process 55M lines of Wikidata JSON dump - logger.info("STEP 4: parsing wikidata for entity definitions" + message) - title_to_id, id_to_descr = wd.read_wikidata_entities_json( + logger.info("STEP 4: Parsing and writing Wikidata entity definitions to {}".format(entity_defs_path)) + if limit_wd is not None: + logger.warning("Warning: reading only {} lines of Wikidata dump".format(limit_wd)) + title_to_id, id_to_descr, id_to_alias = wd.read_wikidata_entities_json( wd_json, - limit, + limit_wd, to_print=False, lang=lang, - parse_descriptions=(not descriptions_from_wikipedia), + parse_descr=(not descr_from_wp), ) - wd.write_entity_files(entity_defs_path, title_to_id) - if not descriptions_from_wikipedia: - wd.write_entity_description_files(entity_descr_path, id_to_descr) - logger.info("STEP 4: read entity definitions" + message) + io.write_title_to_id(entity_defs_path, title_to_id) - # STEP 5: Getting gold entities from wikipedia - message = " and descriptions" if descriptions_from_wikipedia else "" - if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()): - logger.info("STEP 5: parsing wikipedia for gold entities" + message) - training_set_creator.create_training_examples_and_descriptions( - wp_xml, - entity_defs_path, - entity_descr_path, - training_entities_path, - parse_descriptions=descriptions_from_wikipedia, - limit=limit, - ) - logger.info("STEP 5: read gold entities" + message) + logger.info("STEP 4b: Writing Wikidata entity aliases to {}".format(entity_alias_path)) + io.write_id_to_alias(entity_alias_path, id_to_alias) + + if not descr_from_wp: + logger.info("STEP 4c: Writing Wikidata entity descriptions to {}".format(entity_descr_path)) + io.write_id_to_descr(entity_descr_path, id_to_descr) + else: + logger.info("STEP 4: Reading entity definitions from {}".format(entity_defs_path)) + logger.info("STEP 4b: Reading entity aliases from {}".format(entity_alias_path)) + if not descr_from_wp: + logger.info("STEP 4c: Reading entity descriptions from {}".format(entity_descr_path)) + + # STEP 5: Getting gold entities from Wikipedia + if (not training_entities_path.exists()) or (descr_from_wp and not entity_descr_path.exists()): + logger.info("STEP 5: Parsing and writing Wikipedia gold entities to {}".format(training_entities_path)) + if limit_train is not None: + logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_train)) + wp.create_training_and_desc(wp_xml, entity_defs_path, entity_descr_path, + training_entities_path, descr_from_wp, limit_train) + if descr_from_wp: + logger.info("STEP 5b: Parsing and writing Wikipedia descriptions to {}".format(entity_descr_path)) + else: + logger.info("STEP 5: Reading gold entities from {}".format(training_entities_path)) + if descr_from_wp: + logger.info("STEP 5b: Reading entity descriptions from {}".format(entity_descr_path)) # STEP 6: creating the actual KB # It takes ca. 30 minutes to pretrain the entity embeddings - logger.info("STEP 6: creating the KB at {}".format(kb_path)) - kb = kb_creator.create_kb( - nlp=nlp, - max_entities_per_alias=max_per_alias, - min_entity_freq=min_freq, - min_occ=min_pair, - entity_def_input=entity_defs_path, - entity_descr_path=entity_descr_path, - count_input=entity_freq_path, - prior_prob_input=prior_prob_path, - entity_vector_length=entity_vector_length, - ) - - kb.dump(kb_path) - nlp.to_disk(output_dir / KB_MODEL_DIR) + if not kb_path.exists(): + logger.info("STEP 6: Creating the KB at {}".format(kb_path)) + kb = kb_creator.create_kb( + nlp=nlp, + max_entities_per_alias=max_per_alias, + min_entity_freq=min_freq, + min_occ=min_pair, + entity_def_path=entity_defs_path, + entity_descr_path=entity_descr_path, + entity_alias_path=entity_alias_path, + entity_freq_path=entity_freq_path, + prior_prob_path=prior_prob_path, + entity_vector_length=entity_vector_length, + ) + kb.dump(kb_path) + logger.info("kb entities: {}".format(kb.get_size_entities())) + logger.info("kb aliases: {}".format(kb.get_size_aliases())) + nlp.to_disk(output_dir / KB_MODEL_DIR) + else: + logger.info("STEP 6: KB already exists at {}".format(kb_path)) logger.info("Done!") diff --git a/bin/wiki_entity_linking/wikidata_processor.py b/bin/wiki_entity_linking/wikidata_processor.py index b4034cb1a..8a070f567 100644 --- a/bin/wiki_entity_linking/wikidata_processor.py +++ b/bin/wiki_entity_linking/wikidata_processor.py @@ -1,40 +1,52 @@ # coding: utf-8 from __future__ import unicode_literals -import gzip +import bz2 import json import logging -import datetime + +from bin.wiki_entity_linking.wiki_namespaces import WD_META_ITEMS logger = logging.getLogger(__name__) -def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True): - # Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. +def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descr=True): + # Read the JSON wiki data and parse out the entities. Takes about 7-10h to parse 55M lines. # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ site_filter = '{}wiki'.format(lang) - # properties filter (currently disabled to get ALL data) - prop_filter = dict() - # prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected + # filter: currently defined as OR: one hit suffices to be removed from further processing + exclude_list = WD_META_ITEMS + + # punctuation + exclude_list.extend(["Q1383557", "Q10617810"]) + + # letters etc + exclude_list.extend(["Q188725", "Q19776628", "Q3841820", "Q17907810", "Q9788", "Q9398093"]) + + neg_prop_filter = { + 'P31': exclude_list, # instance of + 'P279': exclude_list # subclass + } title_to_id = dict() id_to_descr = dict() + id_to_alias = dict() # parse appropriate fields - depending on what we need in the KB parse_properties = False parse_sitelinks = True parse_labels = False - parse_aliases = False - parse_claims = False + parse_aliases = True + parse_claims = True - with gzip.open(wikidata_file, mode='rb') as file: + with bz2.open(wikidata_file, mode='rb') as file: for cnt, line in enumerate(file): if limit and cnt >= limit: break - if cnt % 500000 == 0: - logger.info("processed {} lines of WikiData dump".format(cnt)) + if cnt % 500000 == 0 and cnt > 0: + logger.info("processed {} lines of WikiData JSON dump".format(cnt)) clean_line = line.strip() if clean_line.endswith(b","): clean_line = clean_line[:-1] @@ -43,13 +55,11 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang= entry_type = obj["type"] if entry_type == "item": - # filtering records on their properties (currently disabled to get ALL data) - # keep = False keep = True claims = obj["claims"] if parse_claims: - for prop, value_set in prop_filter.items(): + for prop, value_set in neg_prop_filter.items(): claim_property = claims.get(prop, None) if claim_property: for cp in claim_property: @@ -61,7 +71,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang= ) cp_rank = cp["rank"] if cp_rank != "deprecated" and cp_id in value_set: - keep = True + keep = False if keep: unique_id = obj["id"] @@ -108,7 +118,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang= "label (" + lang + "):", lang_label["value"] ) - if found_link and parse_descriptions: + if found_link and parse_descr: descriptions = obj["descriptions"] if descriptions: lang_descr = descriptions.get(lang, None) @@ -130,22 +140,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang= print( "alias (" + lang + "):", item["value"] ) + alias_list = id_to_alias.get(unique_id, []) + alias_list.append(item["value"]) + id_to_alias[unique_id] = alias_list if to_print: print() - return title_to_id, id_to_descr + # log final number of lines processed + logger.info("Finished. Processed {} lines of WikiData JSON dump".format(cnt)) + return title_to_id, id_to_descr, id_to_alias -def write_entity_files(entity_def_output, title_to_id): - with entity_def_output.open("w", encoding="utf8") as id_file: - id_file.write("WP_title" + "|" + "WD_id" + "\n") - for title, qid in title_to_id.items(): - id_file.write(title + "|" + str(qid) + "\n") - - -def write_entity_description_files(entity_descr_output, id_to_descr): - with entity_descr_output.open("w", encoding="utf8") as descr_file: - descr_file.write("WD_id" + "|" + "description" + "\n") - for qid, descr in id_to_descr.items(): - descr_file.write(str(qid) + "|" + descr + "\n") diff --git a/bin/wiki_entity_linking/wikidata_train_entity_linker.py b/bin/wiki_entity_linking/wikidata_train_entity_linker.py index ac131e0ef..20c5fe91b 100644 --- a/bin/wiki_entity_linking/wikidata_train_entity_linker.py +++ b/bin/wiki_entity_linking/wikidata_train_entity_linker.py @@ -6,19 +6,19 @@ as created by the script `wikidata_create_kb`. For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/ - """ from __future__ import unicode_literals import random import logging +import spacy from pathlib import Path import plac -from bin.wiki_entity_linking import training_set_creator +from bin.wiki_entity_linking import wikipedia_processor from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR -from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines -from bin.wiki_entity_linking.kb_creator import read_nlp_kb +from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance +from bin.wiki_entity_linking.kb_creator import read_kb from spacy.util import minibatch, compounding @@ -35,6 +35,7 @@ logger = logging.getLogger(__name__) l2=("L2 regularization", "option", "r", float), train_inst=("# training instances (default 90% of all)", "option", "t", int), dev_inst=("# test instances (default 10% of all)", "option", "d", int), + labels_discard=("NER labels to discard (default None)", "option", "l", str), ) def main( dir_kb, @@ -46,13 +47,14 @@ def main( l2=1e-6, train_inst=None, dev_inst=None, + labels_discard=None ): logger.info("Creating Entity Linker with Wikipedia and WikiData") output_dir = Path(output_dir) if output_dir else dir_kb - training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE + training_path = loc_training if loc_training else dir_kb / TRAINING_DATA_FILE nlp_dir = dir_kb / KB_MODEL_DIR - kb_path = output_dir / KB_FILE + kb_path = dir_kb / KB_FILE nlp_output_dir = output_dir / OUTPUT_MODEL_DIR # STEP 0: set up IO @@ -60,38 +62,47 @@ def main( output_dir.mkdir() # STEP 1 : load the NLP object - logger.info("STEP 1: loading model from {}".format(nlp_dir)) - nlp, kb = read_nlp_kb(nlp_dir, kb_path) + logger.info("STEP 1a: Loading model from {}".format(nlp_dir)) + nlp = spacy.load(nlp_dir) + logger.info("STEP 1b: Loading KB from {}".format(kb_path)) + kb = read_kb(nlp, kb_path) # check that there is a NER component in the pipeline if "ner" not in nlp.pipe_names: raise ValueError("The `nlp` object should have a pretrained `ner` component.") - # STEP 2: create a training dataset from WP - logger.info("STEP 2: reading training dataset from {}".format(training_path)) + # STEP 2: read the training dataset previously created from WP + logger.info("STEP 2: Reading training dataset from {}".format(training_path)) - train_data = training_set_creator.read_training( + if labels_discard: + labels_discard = [x.strip() for x in labels_discard.split(",")] + logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard)) + + train_data = wikipedia_processor.read_training( nlp=nlp, entity_file_path=training_path, dev=False, limit=train_inst, kb=kb, + labels_discard=labels_discard ) - # for testing, get all pos instances, whether or not they are in the kb - dev_data = training_set_creator.read_training( + # for testing, get all pos instances (independently of KB) + dev_data = wikipedia_processor.read_training( nlp=nlp, entity_file_path=training_path, dev=True, limit=dev_inst, - kb=kb, + kb=None, + labels_discard=labels_discard ) - # STEP 3: create and train the entity linking pipe - logger.info("STEP 3: training Entity Linking pipe") + # STEP 3: create and train an entity linking pipe + logger.info("STEP 3: Creating and training an Entity Linking pipe") el_pipe = nlp.create_pipe( - name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name} + name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name, + "labels_discard": labels_discard} ) el_pipe.set_kb(kb) nlp.add_pipe(el_pipe, last=True) @@ -105,14 +116,9 @@ def main( logger.info("Training on {} articles".format(len(train_data))) logger.info("Dev testing on {} articles".format(len(dev_data))) - dev_baseline_accuracies = measure_baselines( - dev_data, kb - ) - + # baseline performance on dev data logger.info("Dev Baseline Accuracies:") - logger.info(dev_baseline_accuracies.report_accuracy("random")) - logger.info(dev_baseline_accuracies.report_accuracy("prior")) - logger.info(dev_baseline_accuracies.report_accuracy("oracle")) + measure_performance(dev_data, kb, el_pipe, baseline=True, context=False) for itn in range(epochs): random.shuffle(train_data) @@ -136,18 +142,18 @@ def main( logger.error("Error updating batch:" + str(e)) if batchnr > 0: logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2))) - measure_performance(dev_data, kb, el_pipe) + measure_performance(dev_data, kb, el_pipe, baseline=False, context=True) # STEP 4: measure the performance of our trained pipe on an independent dev set - logger.info("STEP 4: performance measurement of Entity Linking pipe") + logger.info("STEP 4: Final performance measurement of Entity Linking pipe") measure_performance(dev_data, kb, el_pipe) # STEP 5: apply the EL pipe on a toy example - logger.info("STEP 5: applying Entity Linking to toy example") + logger.info("STEP 5: Applying Entity Linking to toy example") run_el_toy_example(nlp=nlp) if output_dir: - # STEP 6: write the NLP pipeline (including entity linker) to file + # STEP 6: write the NLP pipeline (now including an EL model) to file logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir)) nlp.to_disk(nlp_output_dir) diff --git a/bin/wiki_entity_linking/wikipedia_processor.py b/bin/wiki_entity_linking/wikipedia_processor.py index 8f928723e..25e914b32 100644 --- a/bin/wiki_entity_linking/wikipedia_processor.py +++ b/bin/wiki_entity_linking/wikipedia_processor.py @@ -3,147 +3,104 @@ from __future__ import unicode_literals import re import bz2 -import csv -import datetime import logging +import random +import json -from bin.wiki_entity_linking import LOG_FORMAT +from functools import partial + +from spacy.gold import GoldParse +from bin.wiki_entity_linking import wiki_io as io +from bin.wiki_entity_linking.wiki_namespaces import ( + WP_META_NAMESPACE, + WP_FILE_NAMESPACE, + WP_CATEGORY_NAMESPACE, +) """ Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions. Write these results to file for downstream KB and training data generation. + +Process Wikipedia interlinks to generate a training dataset for the EL algorithm. """ +ENTITY_FILE = "gold_entities.csv" + map_alias_to_link = dict() logger = logging.getLogger(__name__) - -# these will/should be matched ignoring case -wiki_namespaces = [ - "b", - "betawikiversity", - "Book", - "c", - "Category", - "Commons", - "d", - "dbdump", - "download", - "Draft", - "Education", - "Foundation", - "Gadget", - "Gadget definition", - "gerrit", - "File", - "Help", - "Image", - "Incubator", - "m", - "mail", - "mailarchive", - "media", - "MediaWiki", - "MediaWiki talk", - "Mediawikiwiki", - "MediaZilla", - "Meta", - "Metawikipedia", - "Module", - "mw", - "n", - "nost", - "oldwikisource", - "outreach", - "outreachwiki", - "otrs", - "OTRSwiki", - "Portal", - "phab", - "Phabricator", - "Project", - "q", - "quality", - "rev", - "s", - "spcom", - "Special", - "species", - "Strategy", - "sulutil", - "svn", - "Talk", - "Template", - "Template talk", - "Testwiki", - "ticket", - "TimedText", - "Toollabs", - "tools", - "tswiki", - "User", - "User talk", - "v", - "voy", - "w", - "Wikibooks", - "Wikidata", - "wikiHow", - "Wikinvest", - "wikilivres", - "Wikimedia", - "Wikinews", - "Wikipedia", - "Wikipedia talk", - "Wikiquote", - "Wikisource", - "Wikispecies", - "Wikitech", - "Wikiversity", - "Wikivoyage", - "wikt", - "wiktionary", - "wmf", - "wmania", - "WP", -] +title_regex = re.compile(r"(?<=).*(?=)") +id_regex = re.compile(r"(?<=)\d*(?=)") +text_regex = re.compile(r"(?<=).*(?= 0: logger.info("processed {} lines of Wikipedia XML dump".format(cnt)) clean_line = line.strip().decode("utf-8") - aliases, entities, normalizations = get_wp_links(clean_line) - for alias, entity, norm in zip(aliases, entities, normalizations): - _store_alias(alias, entity, normalize_alias=norm, normalize_entity=True) - _store_alias(alias, entity, normalize_alias=norm, normalize_entity=True) + # we attempt at reading the article's ID (but not the revision or contributor ID) + if "" in clean_line or "" in clean_line: + read_id = False + if "" in clean_line: + read_id = True + + if read_id: + ids = id_regex.search(clean_line) + if ids: + current_article_id = ids[0] + + # only processing prior probabilities from true training (non-dev) articles + if not is_dev(current_article_id): + aliases, entities, normalizations = get_wp_links(clean_line) + for alias, entity, norm in zip(aliases, entities, normalizations): + _store_alias( + alias, entity, normalize_alias=norm, normalize_entity=True + ) line = file.readline() cnt += 1 logger.info("processed {} lines of Wikipedia XML dump".format(cnt)) + logger.info("Finished. processed {} lines of Wikipedia XML dump".format(cnt)) # write all aliases and their entities and count occurrences to file with prior_prob_output.open("w", encoding="utf8") as outputfile: @@ -182,7 +139,7 @@ def get_wp_links(text): match = match[2:][:-2].replace("_", " ").strip() if ns_regex.match(match): - pass # ignore namespaces at the beginning of the string + pass # ignore the entity if it points to a "meta" page # this is a simple [[link]], with the alias the same as the mention elif "|" not in match: @@ -218,47 +175,382 @@ def _capitalize_first(text): return result -def write_entity_counts(prior_prob_input, count_output, to_print=False): - # Write entity counts for quick access later - entity_to_count = dict() - total_count = 0 - - with prior_prob_input.open("r", encoding="utf8") as prior_file: - # skip header - prior_file.readline() - line = prior_file.readline() - - while line: - splits = line.replace("\n", "").split(sep="|") - # alias = splits[0] - count = int(splits[1]) - entity = splits[2] - - current_count = entity_to_count.get(entity, 0) - entity_to_count[entity] = current_count + count - - total_count += count - - line = prior_file.readline() - - with count_output.open("w", encoding="utf8") as entity_file: - entity_file.write("entity" + "|" + "count" + "\n") - for entity, count in entity_to_count.items(): - entity_file.write(entity + "|" + str(count) + "\n") - - if to_print: - for entity, count in entity_to_count.items(): - print("Entity count:", entity, count) - print("Total count:", total_count) +def create_training_and_desc( + wp_input, def_input, desc_output, training_output, parse_desc, limit=None +): + wp_to_id = io.read_title_to_id(def_input) + _process_wikipedia_texts( + wp_input, wp_to_id, desc_output, training_output, parse_desc, limit + ) -def get_all_frequencies(count_input): - entity_to_count = dict() - with count_input.open("r", encoding="utf8") as csvfile: - csvreader = csv.reader(csvfile, delimiter="|") - # skip header - next(csvreader) - for row in csvreader: - entity_to_count[row[0]] = int(row[1]) +def _process_wikipedia_texts( + wikipedia_input, wp_to_id, output, training_output, parse_descriptions, limit=None +): + """ + Read the XML wikipedia data to parse out training data: + raw text data + positive instances + """ - return entity_to_count + read_ids = set() + + with output.open("a", encoding="utf8") as descr_file, training_output.open( + "w", encoding="utf8" + ) as entity_file: + if parse_descriptions: + _write_training_description(descr_file, "WD_id", "description") + with bz2.open(wikipedia_input, mode="rb") as file: + article_count = 0 + article_text = "" + article_title = None + article_id = None + reading_text = False + reading_revision = False + + for line in file: + clean_line = line.strip().decode("utf-8") + + if clean_line == "": + reading_revision = True + elif clean_line == "": + reading_revision = False + + # Start reading new page + if clean_line == "": + article_text = "" + article_title = None + article_id = None + # finished reading this page + elif clean_line == "": + if article_id: + clean_text, entities = _process_wp_text( + article_title, article_text, wp_to_id + ) + if clean_text is not None and entities is not None: + _write_training_entities( + entity_file, article_id, clean_text, entities + ) + + if article_title in wp_to_id and parse_descriptions: + description = " ".join( + clean_text[:1000].split(" ")[:-1] + ) + _write_training_description( + descr_file, wp_to_id[article_title], description + ) + article_count += 1 + if article_count % 10000 == 0 and article_count > 0: + logger.info( + "Processed {} articles".format(article_count) + ) + if limit and article_count >= limit: + break + article_text = "" + article_title = None + article_id = None + reading_text = False + reading_revision = False + + # start reading text within a page + if "") + clean_text = clean_text.replace(r""", '"') + clean_text = clean_text.replace(r"&nbsp;", " ") + clean_text = clean_text.replace(r"&", "&") + + # remove multiple spaces + while " " in clean_text: + clean_text = clean_text.replace(" ", " ") + + return clean_text.strip() + + +def _remove_links(clean_text, wp_to_id): + # read the text char by char to get the right offsets for the interwiki links + entities = [] + final_text = "" + open_read = 0 + reading_text = True + reading_entity = False + reading_mention = False + reading_special_case = False + entity_buffer = "" + mention_buffer = "" + for index, letter in enumerate(clean_text): + if letter == "[": + open_read += 1 + elif letter == "]": + open_read -= 1 + elif letter == "|": + if reading_text: + final_text += letter + # switch from reading entity to mention in the [[entity|mention]] pattern + elif reading_entity: + reading_text = False + reading_entity = False + reading_mention = True + else: + reading_special_case = True + else: + if reading_entity: + entity_buffer += letter + elif reading_mention: + mention_buffer += letter + elif reading_text: + final_text += letter + else: + raise ValueError("Not sure at point", clean_text[index - 2 : index + 2]) + + if open_read > 2: + reading_special_case = True + + if open_read == 2 and reading_text: + reading_text = False + reading_entity = True + reading_mention = False + + # we just finished reading an entity + if open_read == 0 and not reading_text: + if "#" in entity_buffer or entity_buffer.startswith(":"): + reading_special_case = True + # Ignore cases with nested structures like File: handles etc + if not reading_special_case: + if not mention_buffer: + mention_buffer = entity_buffer + start = len(final_text) + end = start + len(mention_buffer) + qid = wp_to_id.get(entity_buffer, None) + if qid: + entities.append((mention_buffer, qid, start, end)) + final_text += mention_buffer + + entity_buffer = "" + mention_buffer = "" + + reading_text = True + reading_entity = False + reading_mention = False + reading_special_case = False + return final_text, entities + + +def _write_training_description(outputfile, qid, description): + if description is not None: + line = str(qid) + "|" + description + "\n" + outputfile.write(line) + + +def _write_training_entities(outputfile, article_id, clean_text, entities): + entities_data = [ + {"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} + for ent in entities + ] + line = ( + json.dumps( + { + "article_id": article_id, + "clean_text": clean_text, + "entities": entities_data, + }, + ensure_ascii=False, + ) + + "\n" + ) + outputfile.write(line) + + +def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None): + """ This method provides training examples that correspond to the entity annotations found by the nlp object. + For training, it will include both positive and negative examples by using the candidate generator from the kb. + For testing (kb=None), it will include all positive examples only.""" + + from tqdm import tqdm + + if not labels_discard: + labels_discard = [] + + data = [] + num_entities = 0 + get_gold_parse = partial( + _get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard + ) + + logger.info( + "Reading {} data with limit {}".format("dev" if dev else "train", limit) + ) + with entity_file_path.open("r", encoding="utf8") as file: + with tqdm(total=limit, leave=False) as pbar: + for i, line in enumerate(file): + example = json.loads(line) + article_id = example["article_id"] + clean_text = example["clean_text"] + entities = example["entities"] + + if dev != is_dev(article_id) or not is_valid_article(clean_text): + continue + + doc = nlp(clean_text) + gold = get_gold_parse(doc, entities) + if gold and len(gold.links) > 0: + data.append((doc, gold)) + num_entities += len(gold.links) + pbar.update(len(gold.links)) + if limit and num_entities >= limit: + break + logger.info("Read {} entities in {} articles".format(num_entities, len(data))) + return data + + +def _get_gold_parse(doc, entities, dev, kb, labels_discard): + gold_entities = {} + tagged_ent_positions = { + (ent.start_char, ent.end_char): ent + for ent in doc.ents + if ent.label_ not in labels_discard + } + + for entity in entities: + entity_id = entity["entity"] + alias = entity["alias"] + start = entity["start"] + end = entity["end"] + + candidate_ids = [] + if kb and not dev: + candidates = kb.get_candidates(alias) + candidate_ids = [cand.entity_ for cand in candidates] + + tagged_ent = tagged_ent_positions.get((start, end), None) + if tagged_ent: + # TODO: check that alias == doc.text[start:end] + should_add_ent = (dev or entity_id in candidate_ids) and is_valid_sentence( + tagged_ent.sent.text + ) + + if should_add_ent: + value_by_id = {entity_id: 1.0} + if not dev: + random.shuffle(candidate_ids) + value_by_id.update( + {kb_id: 0.0 for kb_id in candidate_ids if kb_id != entity_id} + ) + gold_entities[(start, end)] = value_by_id + + return GoldParse(doc, links=gold_entities) + + +def is_dev(article_id): + if not article_id: + return False + return article_id.endswith("3") + + +def is_valid_article(doc_text): + # custom length cut-off + return 10 < len(doc_text) < 30000 + + +def is_valid_sentence(sent_text): + if not 10 < len(sent_text) < 3000: + # custom length cut-off + return False + + if sent_text.strip().startswith("*") or sent_text.strip().startswith("#"): + # remove 'enumeration' sentences (occurs often on Wikipedia) + return False + + return True diff --git a/examples/pipeline/dummy_entity_linking.py b/examples/pipeline/dummy_entity_linking.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/examples/pipeline/wikidata_entity_linking.py b/examples/pipeline/wikidata_entity_linking.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index 891e15fa2..9b63b31f0 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -246,7 +246,7 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"): """Perform an update over a single batch of documents. docs (iterable): A batch of `Doc` objects. - drop (float): The droput rate. + drop (float): The dropout rate. optimizer (callable): An optimizer. RETURNS loss: A float for the loss. """ diff --git a/spacy/errors.py b/spacy/errors.py index d75b1cec8..69dd77be7 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -80,8 +80,8 @@ class Warnings(object): "the v2.x models cannot release the global interpreter lock. " "Future versions may introduce a `n_process` argument for " "parallel inference via multiprocessing.") - W017 = ("Alias '{alias}' already exists in the Knowledge base.") - W018 = ("Entity '{entity}' already exists in the Knowledge base.") + W017 = ("Alias '{alias}' already exists in the Knowledge Base.") + W018 = ("Entity '{entity}' already exists in the Knowledge Base.") W019 = ("Changing vectors name from {old} to {new}, to avoid clash with " "previously loaded vectors. See Issue #3853.") W020 = ("Unnamed vectors. This won't allow multiple vectors models to be " @@ -96,6 +96,8 @@ class Warnings(object): "If this is surprising, make sure you have the spacy-lookups-data " "package installed.") W023 = ("Multiprocessing of Language.pipe is not supported in Python2. 'n_process' will be set to 1.") + W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in " + "the Knowledge Base.") @add_codes @@ -408,7 +410,7 @@ class Errors(object): "{probabilities_length} respectively.") E133 = ("The sum of prior probabilities for alias '{alias}' should not " "exceed 1, but found {sum}.") - E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.") + E134 = ("Entity '{entity}' is not defined in the Knowledge Base.") E135 = ("If you meant to replace a built-in component, use `create_pipe`: " "`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`") E136 = ("This additional feature requires the jsonschema library to be " @@ -420,7 +422,7 @@ class Errors(object): E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input " "includes either the `text` or `tokens` key. For more info, see " "the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl") - E139 = ("Knowledge base for component '{name}' not initialized. Did you " + E139 = ("Knowledge Base for component '{name}' not initialized. Did you " "forget to call set_kb()?") E140 = ("The list of entities, prior probabilities and entity vectors " "should be of equal length.") @@ -499,6 +501,7 @@ class Errors(object): E174 = ("Architecture '{name}' not found in registry. Available " "names: {names}") E175 = ("Can't remove rule for unknown match pattern ID: {key}") + E176 = ("Alias '{alias}' is not defined in the Knowledge Base.") @add_codes diff --git a/spacy/kb.pyx b/spacy/kb.pyx index 6cbc06e2c..31fd1706e 100644 --- a/spacy/kb.pyx +++ b/spacy/kb.pyx @@ -142,6 +142,7 @@ cdef class KnowledgeBase: i = 0 cdef KBEntryC entry + cdef hash_t entity_hash while i < nr_entities: entity_vector = vector_list[i] if len(entity_vector) != self.entity_vector_length: @@ -161,6 +162,14 @@ cdef class KnowledgeBase: i += 1 + def contains_entity(self, unicode entity): + cdef hash_t entity_hash = self.vocab.strings.add(entity) + return entity_hash in self._entry_index + + def contains_alias(self, unicode alias): + cdef hash_t alias_hash = self.vocab.strings.add(alias) + return alias_hash in self._alias_index + def add_alias(self, unicode alias, entities, probabilities): """ For a given alias, add its potential entities and prior probabilies to the KB. @@ -190,7 +199,7 @@ cdef class KnowledgeBase: for entity, prob in zip(entities, probabilities): entity_hash = self.vocab.strings[entity] if not entity_hash in self._entry_index: - raise ValueError(Errors.E134.format(alias=alias, entity=entity)) + raise ValueError(Errors.E134.format(entity=entity)) entry_index = self._entry_index.get(entity_hash) entry_indices.push_back(int(entry_index)) @@ -201,8 +210,63 @@ cdef class KnowledgeBase: return alias_hash - def get_candidates(self, unicode alias): + def append_alias(self, unicode alias, unicode entity, float prior_prob, ignore_warnings=False): + """ + For an alias already existing in the KB, extend its potential entities with one more. + Throw a warning if either the alias or the entity is unknown, + or when the combination is already previously recorded. + Throw an error if this entity+prior prob would exceed the sum of 1. + For efficiency, it's best to use the method `add_alias` as much as possible instead of this one. + """ + # Check if the alias exists in the KB cdef hash_t alias_hash = self.vocab.strings[alias] + if not alias_hash in self._alias_index: + raise ValueError(Errors.E176.format(alias=alias)) + + # Check if the entity exists in the KB + cdef hash_t entity_hash = self.vocab.strings[entity] + if not entity_hash in self._entry_index: + raise ValueError(Errors.E134.format(entity=entity)) + entry_index = self._entry_index.get(entity_hash) + + # Throw an error if the prior probabilities (including the new one) sum up to more than 1 + alias_index = self._alias_index.get(alias_hash) + alias_entry = self._aliases_table[alias_index] + current_sum = sum([p for p in alias_entry.probs]) + new_sum = current_sum + prior_prob + + if new_sum > 1.00001: + raise ValueError(Errors.E133.format(alias=alias, sum=new_sum)) + + entry_indices = alias_entry.entry_indices + + is_present = False + for i in range(entry_indices.size()): + if entry_indices[i] == int(entry_index): + is_present = True + + if is_present: + if not ignore_warnings: + user_warning(Warnings.W024.format(entity=entity, alias=alias)) + else: + entry_indices.push_back(int(entry_index)) + alias_entry.entry_indices = entry_indices + + probs = alias_entry.probs + probs.push_back(float(prior_prob)) + alias_entry.probs = probs + self._aliases_table[alias_index] = alias_entry + + + def get_candidates(self, unicode alias): + """ + Return candidate entities for an alias. Each candidate defines the entity, the original alias, + and the prior probability of that alias resolving to that entity. + If the alias is not known in the KB, and empty list is returned. + """ + cdef hash_t alias_hash = self.vocab.strings[alias] + if not alias_hash in self._alias_index: + return [] alias_index = self._alias_index.get(alias_hash) alias_entry = self._aliases_table[alias_index] @@ -341,7 +405,6 @@ cdef class KnowledgeBase: assert nr_entities == self.get_size_entities() # STEP 3: load aliases - cdef int64_t nr_aliases reader.read_alias_length(&nr_aliases) self._alias_index = PreshMap(nr_aliases+1) diff --git a/spacy/language.py b/spacy/language.py index b240dd8a5..33ac21a85 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -483,7 +483,7 @@ class Language(object): docs (iterable): A batch of `Doc` objects. golds (iterable): A batch of `GoldParse` objects. - drop (float): The droput rate. + drop (float): The dropout rate. sgd (callable): An optimizer. losses (dict): Dictionary to update with the loss, keyed by component. component_cfg (dict): Config parameters for specific pipeline @@ -531,7 +531,7 @@ class Language(object): even if you're updating it with a smaller set of examples. docs (iterable): A batch of `Doc` objects. - drop (float): The droput rate. + drop (float): The dropout rate. sgd (callable): An optimizer. RETURNS (dict): Results from the update. diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 63ab09e56..0607ac43d 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1195,23 +1195,26 @@ class EntityLinker(Pipe): docs = [docs] golds = [golds] - context_docs = [] + sentence_docs = [] for doc, gold in zip(docs, golds): ents_by_offset = dict() for ent in doc.ents: - ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent + ents_by_offset[(ent.start_char, ent.end_char)] = ent + for entity, kb_dict in gold.links.items(): start, end = entity mention = doc.text[start:end] + # the gold annotations should link to proper entities - if this fails, the dataset is likely corrupt + ent = ents_by_offset[(start, end)] for kb_id, value in kb_dict.items(): # Currently only training on the positive instances if value: - context_docs.append(doc) + sentence_docs.append(ent.sent.as_doc()) - context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop) - loss, d_scores = self.get_similarity_loss(scores=context_encodings, golds=golds, docs=None) + sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop) + loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None) bp_context(d_scores, sgd=sgd) if losses is not None: @@ -1280,50 +1283,68 @@ class EntityLinker(Pipe): if isinstance(docs, Doc): docs = [docs] - context_encodings = self.model(docs) - xp = get_array_module(context_encodings) - for i, doc in enumerate(docs): if len(doc) > 0: - # currently, the context is the same for each entity in a sentence (should be refined) - context_encoding = context_encodings[i] - context_enc_t = context_encoding.T - norm_1 = xp.linalg.norm(context_enc_t) - for ent in doc.ents: - entity_count += 1 + # Looping through each sentence and each entity + # This may go wrong if there are entities across sentences - because they might not get a KB ID + for sent in doc.ents: + sent_doc = sent.as_doc() + # currently, the context is the same for each entity in a sentence (should be refined) + sentence_encoding = self.model([sent_doc])[0] + xp = get_array_module(sentence_encoding) + sentence_encoding_t = sentence_encoding.T + sentence_norm = xp.linalg.norm(sentence_encoding_t) - candidates = self.kb.get_candidates(ent.text) - if not candidates: - final_kb_ids.append(self.NIL) # no prediction possible for this entity - final_tensors.append(context_encoding) - else: - random.shuffle(candidates) + for ent in sent_doc.ents: + entity_count += 1 - # this will set all prior probabilities to 0 if they should be excluded from the model - prior_probs = xp.asarray([c.prior_prob for c in candidates]) - if not self.cfg.get("incl_prior", True): - prior_probs = xp.asarray([0.0 for c in candidates]) - scores = prior_probs + if ent.label_ in self.cfg.get("labels_discard", []): + # ignoring this entity - setting to NIL + final_kb_ids.append(self.NIL) + final_tensors.append(sentence_encoding) - # add in similarity from the context - if self.cfg.get("incl_context", True): - entity_encodings = xp.asarray([c.entity_vector for c in candidates]) - norm_2 = xp.linalg.norm(entity_encodings, axis=1) + else: + candidates = self.kb.get_candidates(ent.text) + if not candidates: + # no prediction possible for this entity - setting to NIL + final_kb_ids.append(self.NIL) + final_tensors.append(sentence_encoding) - if len(entity_encodings) != len(prior_probs): - raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length")) + elif len(candidates) == 1: + # shortcut for efficiency reasons: take the 1 candidate - # cosine similarity - sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2) - if sims.shape != prior_probs.shape: - raise ValueError(Errors.E161) - scores = prior_probs + sims - (prior_probs*sims) + # TODO: thresholding + final_kb_ids.append(candidates[0].entity_) + final_tensors.append(sentence_encoding) - # TODO: thresholding - best_index = scores.argmax() - best_candidate = candidates[best_index] - final_kb_ids.append(best_candidate.entity_) - final_tensors.append(context_encoding) + else: + random.shuffle(candidates) + + # this will set all prior probabilities to 0 if they should be excluded from the model + prior_probs = xp.asarray([c.prior_prob for c in candidates]) + if not self.cfg.get("incl_prior", True): + prior_probs = xp.asarray([0.0 for c in candidates]) + scores = prior_probs + + # add in similarity from the context + if self.cfg.get("incl_context", True): + entity_encodings = xp.asarray([c.entity_vector for c in candidates]) + entity_norm = xp.linalg.norm(entity_encodings, axis=1) + + if len(entity_encodings) != len(prior_probs): + raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length")) + + # cosine similarity + sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm) + if sims.shape != prior_probs.shape: + raise ValueError(Errors.E161) + scores = prior_probs + sims - (prior_probs*sims) + + # TODO: thresholding + best_index = scores.argmax() + best_candidate = candidates[best_index] + final_kb_ids.append(best_candidate.entity_) + final_tensors.append(sentence_encoding) if not (len(final_tensors) == len(final_kb_ids) == entity_count): raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length")) diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py index 0c89a2e14..bc76d1e47 100644 --- a/spacy/tests/pipeline/test_entity_linker.py +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -131,6 +131,53 @@ def test_candidate_generation(nlp): assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9) +def test_append_alias(nlp): + """Test that we can append additional alias-entity pairs""" + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + + # adding entities + mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) + mykb.add_entity(entity="Q2", freq=12, entity_vector=[2]) + mykb.add_entity(entity="Q3", freq=5, entity_vector=[3]) + + # adding aliases + mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.4, 0.1]) + mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) + + # test the size of the relevant candidates + assert len(mykb.get_candidates("douglas")) == 2 + + # append an alias + mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2) + + # test the size of the relevant candidates has been incremented + assert len(mykb.get_candidates("douglas")) == 3 + + # append the same alias-entity pair again should not work (will throw a warning) + mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3) + + # test the size of the relevant candidates remained unchanged + assert len(mykb.get_candidates("douglas")) == 3 + + +def test_append_invalid_alias(nlp): + """Test that append an alias will throw an error if prior probs are exceeding 1""" + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + + # adding entities + mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) + mykb.add_entity(entity="Q2", freq=12, entity_vector=[2]) + mykb.add_entity(entity="Q3", freq=5, entity_vector=[3]) + + # adding aliases + mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]) + mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) + + # append an alias - should fail because the entities and probabilities vectors are not of equal length + with pytest.raises(ValueError): + mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2) + + def test_preserving_links_asdoc(nlp): """Test that Span.as_doc preserves the existing entity links""" mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) diff --git a/spacy/tests/regression/test_issue1-1000.py b/spacy/tests/regression/test_issue1-1000.py index dca3d624f..989eba805 100644 --- a/spacy/tests/regression/test_issue1-1000.py +++ b/spacy/tests/regression/test_issue1-1000.py @@ -430,7 +430,7 @@ def test_issue957(en_tokenizer): def test_issue999(train_data): """Test that adding entities and resuming training works passably OK. There are two issues here: - 1) We have to read labels. This isn't very nice. + 1) We have to re-add labels. This isn't very nice. 2) There's no way to set the learning rate for the weight update, so we end up out-of-scale, causing it to learn too fast. """