From 2ae5db580edec1c83afdb52db76cc0763bed2d5c Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Fri, 13 Sep 2019 16:30:05 +0200 Subject: [PATCH 1/2] dim bugfix when incl_prior is False (#4285) --- spacy/errors.py | 2 ++ spacy/pipeline/pipes.pyx | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/spacy/errors.py b/spacy/errors.py index b8a8dccba..587a6e700 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -455,6 +455,8 @@ class Errors(object): E158 = ("Can't add table '{name}' to lookups because it already exists.") E159 = ("Can't find table '{name}' in lookups. Available tables: {tables}") E160 = ("Can't find language data file: {path}") + E161 = ("Found an internal inconsistency when predicting entity links. " + "This is likely a bug in spaCy, so feel free to open an issue.") @add_codes diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 3d799b3da..412433565 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1283,7 +1283,7 @@ class EntityLinker(Pipe): # this will set all prior probabilities to 0 if they should be excluded from the model prior_probs = xp.asarray([c.prior_prob for c in candidates]) if not self.cfg.get("incl_prior", True): - prior_probs = xp.asarray([[0.0] for c in candidates]) + prior_probs = xp.asarray([0.0 for c in candidates]) scores = prior_probs # add in similarity from the context @@ -1296,6 +1296,8 @@ class EntityLinker(Pipe): # cosine similarity sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2) + if sims.shape != prior_probs.shape: + raise ValueError(Errors.E161) scores = prior_probs + sims - (prior_probs*sims) # TODO: thresholding From a6830d60e8c4acc6f378a3b4e6c48e851e729408 Mon Sep 17 00:00:00 2001 From: Euan Dowers Date: Fri, 13 Sep 2019 17:03:57 +0200 Subject: [PATCH 2/2] Changes to wiki_entity_linker (#4235) * Changes to wiki_entity_linker * No more f-strings * Make some requested changes * Add back option to get descriptions from wd not wp * Fix logs * Address comments and clean evaluation * Remove type hints * Refactor evaluation, add back metrics by label * Address comments * Log training performance as well as dev --- bin/wiki_entity_linking/__init__.py | 11 + .../entity_linker_evaluation.py | 200 +++++++ bin/wiki_entity_linking/kb_creator.py | 162 +++--- bin/wiki_entity_linking/train_descriptions.py | 19 +- .../training_set_creator.py | 502 ++++++++---------- .../wikidata_pretrain_kb.py | 123 +++-- bin/wiki_entity_linking/wikidata_processor.py | 42 +- .../wikidata_train_entity_linker.py | 407 +++----------- .../wikipedia_processor.py | 14 +- 9 files changed, 709 insertions(+), 771 deletions(-) create mode 100644 bin/wiki_entity_linking/entity_linker_evaluation.py diff --git a/bin/wiki_entity_linking/__init__.py b/bin/wiki_entity_linking/__init__.py index e69de29bb..a604bcc2f 100644 --- a/bin/wiki_entity_linking/__init__.py +++ b/bin/wiki_entity_linking/__init__.py @@ -0,0 +1,11 @@ +TRAINING_DATA_FILE = "gold_entities.jsonl" +KB_FILE = "kb" +KB_MODEL_DIR = "nlp_kb" +OUTPUT_MODEL_DIR = "nlp" + +PRIOR_PROB_PATH = "prior_prob.csv" +ENTITY_DEFS_PATH = "entity_defs.csv" +ENTITY_FREQ_PATH = "entity_freq.csv" +ENTITY_DESCR_PATH = "entity_descriptions.csv" + +LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' diff --git a/bin/wiki_entity_linking/entity_linker_evaluation.py b/bin/wiki_entity_linking/entity_linker_evaluation.py new file mode 100644 index 000000000..1b1200564 --- /dev/null +++ b/bin/wiki_entity_linking/entity_linker_evaluation.py @@ -0,0 +1,200 @@ +import logging +import random + +from collections import defaultdict + +logger = logging.getLogger(__name__) + + +class Metrics(object): + true_pos = 0 + false_pos = 0 + false_neg = 0 + + def update_results(self, true_entity, candidate): + candidate_is_correct = true_entity == candidate + + # Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL") + # Therefore, if candidate_is_correct then we have a true positive and never a true negative + self.true_pos += candidate_is_correct + self.false_neg += not candidate_is_correct + if candidate not in {"", "NIL"}: + self.false_pos += not candidate_is_correct + + def calculate_precision(self): + if self.true_pos == 0: + return 0.0 + else: + return self.true_pos / (self.true_pos + self.false_pos) + + def calculate_recall(self): + if self.true_pos == 0: + return 0.0 + else: + return self.true_pos / (self.true_pos + self.false_neg) + + +class EvaluationResults(object): + def __init__(self): + self.metrics = Metrics() + self.metrics_by_label = defaultdict(Metrics) + + def update_metrics(self, ent_label, true_entity, candidate): + self.metrics.update_results(true_entity, candidate) + self.metrics_by_label[ent_label].update_results(true_entity, candidate) + + def increment_false_negatives(self): + self.metrics.false_neg += 1 + + def report_metrics(self, model_name): + model_str = model_name.title() + recall = self.metrics.calculate_recall() + precision = self.metrics.calculate_precision() + return ("{}: ".format(model_str) + + "Recall = {} | ".format(round(recall, 3)) + + "Precision = {} | ".format(round(precision, 3)) + + "Precision by label = {}".format({k: v.calculate_precision() + for k, v in self.metrics_by_label.items()})) + + +class BaselineResults(object): + def __init__(self): + self.random = EvaluationResults() + self.prior = EvaluationResults() + self.oracle = EvaluationResults() + + def report_accuracy(self, model): + results = getattr(self, model) + return results.report_metrics(model) + + def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate): + self.oracle.update_metrics(ent_label, true_entity, oracle_candidate) + self.prior.update_metrics(ent_label, true_entity, prior_candidate) + self.random.update_metrics(ent_label, true_entity, random_candidate) + + +def measure_performance(dev_data, kb, el_pipe): + baseline_accuracies = measure_baselines( + dev_data, kb + ) + + logger.info(baseline_accuracies.report_accuracy("random")) + logger.info(baseline_accuracies.report_accuracy("prior")) + logger.info(baseline_accuracies.report_accuracy("oracle")) + + # using only context + el_pipe.cfg["incl_context"] = True + el_pipe.cfg["incl_prior"] = False + results = get_eval_results(dev_data, el_pipe) + logger.info(results.report_metrics("context only")) + + # measuring combined accuracy (prior + context) + el_pipe.cfg["incl_context"] = True + el_pipe.cfg["incl_prior"] = True + results = get_eval_results(dev_data, el_pipe) + logger.info(results.report_metrics("context and prior")) + + +def get_eval_results(data, el_pipe=None): + # If the docs in the data require further processing with an entity linker, set el_pipe + from tqdm import tqdm + + docs = [] + golds = [] + for d, g in tqdm(data, leave=False): + if len(d) > 0: + golds.append(g) + if el_pipe is not None: + docs.append(el_pipe(d)) + else: + docs.append(d) + + results = EvaluationResults() + for doc, gold in zip(docs, golds): + tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents} + try: + correct_entries_per_article = dict() + for entity, kb_dict in gold.links.items(): + start, end = entity + # only evaluating on positive examples + for gold_kb, value in kb_dict.items(): + if value: + offset = _offset(start, end) + correct_entries_per_article[offset] = gold_kb + if offset not in tagged_entries_per_article: + results.increment_false_negatives() + + for ent in doc.ents: + ent_label = ent.label_ + pred_entity = ent.kb_id_ + start = ent.start_char + end = ent.end_char + offset = _offset(start, end) + gold_entity = correct_entries_per_article.get(offset, None) + # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' + if gold_entity is not None: + results.update_metrics(ent_label, gold_entity, pred_entity) + + except Exception as e: + logging.error("Error assessing accuracy " + str(e)) + + return results + + +def measure_baselines(data, kb): + # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound + counts_d = dict() + + baseline_results = BaselineResults() + + docs = [d for d, g in data if len(d) > 0] + golds = [g for d, g in data if len(d) > 0] + + for doc, gold in zip(docs, golds): + correct_entries_per_article = dict() + tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents} + for entity, kb_dict in gold.links.items(): + start, end = entity + for gold_kb, value in kb_dict.items(): + # only evaluating on positive examples + if value: + offset = _offset(start, end) + correct_entries_per_article[offset] = gold_kb + if offset not in tagged_entries_per_article: + baseline_results.random.increment_false_negatives() + baseline_results.oracle.increment_false_negatives() + baseline_results.prior.increment_false_negatives() + + for ent in doc.ents: + ent_label = ent.label_ + start = ent.start_char + end = ent.end_char + offset = _offset(start, end) + gold_entity = correct_entries_per_article.get(offset, None) + + # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' + if gold_entity is not None: + candidates = kb.get_candidates(ent.text) + oracle_candidate = "" + best_candidate = "" + random_candidate = "" + if candidates: + scores = [] + + for c in candidates: + scores.append(c.prior_prob) + if c.entity_ == gold_entity: + oracle_candidate = c.entity_ + + best_index = scores.index(max(scores)) + best_candidate = candidates[best_index].entity_ + random_candidate = random.choice(candidates).entity_ + + baseline_results.update_baselines(gold_entity, ent_label, + random_candidate, best_candidate, oracle_candidate) + + return baseline_results + + +def _offset(start, end): + return "{}_{}".format(start, end) diff --git a/bin/wiki_entity_linking/kb_creator.py b/bin/wiki_entity_linking/kb_creator.py index bd862f536..54ed7815e 100644 --- a/bin/wiki_entity_linking/kb_creator.py +++ b/bin/wiki_entity_linking/kb_creator.py @@ -1,12 +1,20 @@ # coding: utf-8 from __future__ import unicode_literals -from bin.wiki_entity_linking.train_descriptions import EntityEncoder -from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp +import csv +import logging +import spacy +import sys + from spacy.kb import KnowledgeBase -import csv -import datetime +from bin.wiki_entity_linking import wikipedia_processor as wp +from bin.wiki_entity_linking.train_descriptions import EntityEncoder + +csv.field_size_limit(sys.maxsize) + + +logger = logging.getLogger(__name__) def create_kb( @@ -14,52 +22,73 @@ def create_kb( max_entities_per_alias, min_entity_freq, min_occ, - entity_def_output, - entity_descr_output, + entity_def_input, + entity_descr_path, count_input, prior_prob_input, - wikidata_input, entity_vector_length, - limit=None, - read_raw_data=True, ): # Create the knowledge base from Wikidata entries kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length) + # read the mappings from file + title_to_id = get_entity_to_id(entity_def_input) + id_to_descr = get_id_to_description(entity_descr_path) + # check the length of the nlp vectors if "vectors" in nlp.meta and nlp.vocab.vectors.size: input_dim = nlp.vocab.vectors_length - print("Loaded pre-trained vectors of size %s" % input_dim) + logger.info("Loaded pre-trained vectors of size %s" % input_dim) else: raise ValueError( "The `nlp` object should have access to pre-trained word vectors, " " cf. https://spacy.io/usage/models#languages." ) - # disable this part of the pipeline when rerunning the KB generation from preprocessed files - if read_raw_data: - print() - print(now(), " * read wikidata entities:") - title_to_id, id_to_descr = wd.read_wikidata_entities_json( - wikidata_input, limit=limit - ) - - # write the title-ID and ID-description mappings to file - _write_entity_files( - entity_def_output, entity_descr_output, title_to_id, id_to_descr - ) - - else: - # read the mappings from file - title_to_id = get_entity_to_id(entity_def_output) - id_to_descr = get_id_to_description(entity_descr_output) - - print() - print(now(), " * get entity frequencies:") - print() + logger.info("Get entity frequencies") entity_frequencies = wp.get_all_frequencies(count_input=count_input) + logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq)) # filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise + filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities( + title_to_id, + id_to_descr, + entity_frequencies, + min_entity_freq + ) + logger.info("Left with {} entities".format(len(description_list))) + + logger.info("Train entity encoder") + encoder = EntityEncoder(nlp, input_dim, entity_vector_length) + encoder.train(description_list=description_list, to_print=True) + + logger.info("Get entity embeddings:") + embeddings = encoder.apply_encoder(description_list) + + logger.info("Adding {} entities".format(len(entity_list))) + kb.set_entities( + entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings + ) + + logger.info("Adding aliases") + _add_aliases( + kb, + title_to_id=filtered_title_to_id, + max_entities_per_alias=max_entities_per_alias, + min_occ=min_occ, + prior_prob_input=prior_prob_input, + ) + + logger.info("KB size: {} entities, {} aliases".format( + kb.get_size_entities(), + kb.get_size_aliases())) + + logger.info("Done with kb") + return kb + + +def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies, + min_entity_freq: int = 10): filtered_title_to_id = dict() entity_list = [] description_list = [] @@ -72,58 +101,7 @@ def create_kb( description_list.append(desc) frequency_list.append(freq) filtered_title_to_id[title] = entity - - print(len(title_to_id.keys()), "original titles") - kept_nr = len(filtered_title_to_id.keys()) - print("kept", kept_nr, "entities with min. frequency", min_entity_freq) - - print() - print(now(), " * train entity encoder:") - print() - encoder = EntityEncoder(nlp, input_dim, entity_vector_length) - encoder.train(description_list=description_list, to_print=True) - - print() - print(now(), " * get entity embeddings:") - print() - embeddings = encoder.apply_encoder(description_list) - - print(now(), " * adding", len(entity_list), "entities") - kb.set_entities( - entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings - ) - - alias_cnt = _add_aliases( - kb, - title_to_id=filtered_title_to_id, - max_entities_per_alias=max_entities_per_alias, - min_occ=min_occ, - prior_prob_input=prior_prob_input, - ) - print() - print(now(), " * adding", alias_cnt, "aliases") - print() - - print() - print("# of entities in kb:", kb.get_size_entities()) - print("# of aliases in kb:", kb.get_size_aliases()) - - print(now(), "Done with kb") - return kb - - -def _write_entity_files( - entity_def_output, entity_descr_output, title_to_id, id_to_descr -): - with entity_def_output.open("w", encoding="utf8") as id_file: - id_file.write("WP_title" + "|" + "WD_id" + "\n") - for title, qid in title_to_id.items(): - id_file.write(title + "|" + str(qid) + "\n") - - with entity_descr_output.open("w", encoding="utf8") as descr_file: - descr_file.write("WD_id" + "|" + "description" + "\n") - for qid, descr in id_to_descr.items(): - descr_file.write(str(qid) + "|" + descr + "\n") + return filtered_title_to_id, entity_list, description_list, frequency_list def get_entity_to_id(entity_def_output): @@ -137,9 +115,9 @@ def get_entity_to_id(entity_def_output): return entity_to_id -def get_id_to_description(entity_descr_output): +def get_id_to_description(entity_descr_path): id_to_desc = dict() - with entity_descr_output.open("r", encoding="utf8") as csvfile: + with entity_descr_path.open("r", encoding="utf8") as csvfile: csvreader = csv.reader(csvfile, delimiter="|") # skip header next(csvreader) @@ -150,7 +128,6 @@ def get_id_to_description(entity_descr_output): def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input): wp_titles = title_to_id.keys() - cnt = 0 # adding aliases with prior probabilities # we can read this file sequentially, it's sorted by alias, and then by count @@ -187,9 +164,8 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in entities=selected_entities, probabilities=prior_probs, ) - cnt += 1 except ValueError as e: - print(e) + logger.error(e) total_count = 0 counts = [] entities = [] @@ -202,8 +178,12 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in previous_alias = new_alias line = prior_file.readline() - return cnt -def now(): - return datetime.datetime.now() +def read_nlp_kb(model_dir, kb_file): + nlp = spacy.load(model_dir) + kb = KnowledgeBase(vocab=nlp.vocab) + kb.load_bulk(kb_file) + logger.info("kb entities: {}".format(kb.get_size_entities())) + logger.info("kb aliases: {}".format(kb.get_size_aliases())) + return nlp, kb diff --git a/bin/wiki_entity_linking/train_descriptions.py b/bin/wiki_entity_linking/train_descriptions.py index 0663296e4..2cb66909f 100644 --- a/bin/wiki_entity_linking/train_descriptions.py +++ b/bin/wiki_entity_linking/train_descriptions.py @@ -1,6 +1,7 @@ # coding: utf-8 from random import shuffle +import logging import numpy as np from spacy._ml import zero_init, create_default_optimizer @@ -10,6 +11,8 @@ from thinc.v2v import Model from thinc.api import chain from thinc.neural._classes.affine import Affine +logger = logging.getLogger(__name__) + class EntityEncoder: """ @@ -50,21 +53,19 @@ class EntityEncoder: start = start + batch_size stop = min(stop + batch_size, len(description_list)) - print("encoded:", stop, "entities") + logger.info("encoded: {} entities".format(stop)) return encodings def train(self, description_list, to_print=False): processed, loss = self._train_model(description_list) if to_print: - print( - "Trained entity descriptions on", - processed, - "(non-unique) entities across", - self.epochs, - "epochs", + logger.info( + "Trained entity descriptions on {} ".format(processed) + + "(non-unique) entities across {} ".format(self.epochs) + + "epochs" ) - print("Final loss:", loss) + logger.info("Final loss: {}".format(loss)) def _train_model(self, description_list): best_loss = 1.0 @@ -93,7 +94,7 @@ class EntityEncoder: loss = self._update(batch) if batch_nr % 25 == 0: - print("loss:", loss) + logger.info("loss: {} ".format(loss)) processed += len(batch) # in general, continue training if we haven't reached our ideal min yet diff --git a/bin/wiki_entity_linking/training_set_creator.py b/bin/wiki_entity_linking/training_set_creator.py index 7f45d9435..3f42f8bdd 100644 --- a/bin/wiki_entity_linking/training_set_creator.py +++ b/bin/wiki_entity_linking/training_set_creator.py @@ -1,10 +1,13 @@ # coding: utf-8 from __future__ import unicode_literals +import logging import random import re import bz2 -import datetime +import json + +from functools import partial from spacy.gold import GoldParse from bin.wiki_entity_linking import kb_creator @@ -15,18 +18,30 @@ Gold-standard entities are stored in one file in standoff format (by character o """ ENTITY_FILE = "gold_entities.csv" +logger = logging.getLogger(__name__) -def now(): - return datetime.datetime.now() - - -def create_training(wikipedia_input, entity_def_input, training_output, limit=None): +def create_training_examples_and_descriptions(wikipedia_input, + entity_def_input, + description_output, + training_output, + parse_descriptions, + limit=None): wp_to_id = kb_creator.get_entity_to_id(entity_def_input) - _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=limit) + _process_wikipedia_texts(wikipedia_input, + wp_to_id, + description_output, + training_output, + parse_descriptions, + limit) -def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None): +def _process_wikipedia_texts(wikipedia_input, + wp_to_id, + output, + training_output, + parse_descriptions, + limit=None): """ Read the XML wikipedia data to parse out training data: raw text data + positive instances @@ -35,29 +50,21 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N id_regex = re.compile(r"(?<=)\d*(?=)") read_ids = set() - entityfile_loc = training_output / ENTITY_FILE - with entityfile_loc.open("w", encoding="utf8") as entityfile: - # write entity training header file - _write_training_entity( - outputfile=entityfile, - article_id="article_id", - alias="alias", - entity="WD_id", - start="start", - end="end", - ) + with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file: + if parse_descriptions: + _write_training_description(descr_file, "WD_id", "description") with bz2.open(wikipedia_input, mode="rb") as file: - line = file.readline() - cnt = 0 + article_count = 0 article_text = "" article_title = None article_id = None reading_text = False reading_revision = False - while line and (not limit or cnt < limit): - if cnt % 1000000 == 0: - print(now(), "processed", cnt, "lines of Wikipedia dump") + + logger.info("Processed {} articles".format(article_count)) + + for line in file: clean_line = line.strip().decode("utf-8") if clean_line == "": @@ -70,28 +77,32 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N article_text = "" article_title = None article_id = None - # finished reading this page elif clean_line == "": if article_id: - try: - _process_wp_text( - wp_to_id, - entityfile, - article_id, - article_title, - article_text.strip(), - training_output, - ) - except Exception as e: - print( - "Error processing article", article_id, article_title, e - ) - else: - print( - "Done processing a page, but couldn't find an article_id ?", + clean_text, entities = _process_wp_text( article_title, + article_text, + wp_to_id ) + if clean_text is not None and entities is not None: + _write_training_entities(entity_file, + article_id, + clean_text, + entities) + + if article_title in wp_to_id and parse_descriptions: + description = " ".join(clean_text[:1000].split(" ")[:-1]) + _write_training_description( + descr_file, + wp_to_id[article_title], + description + ) + article_count += 1 + if article_count % 10000 == 0: + logger.info("Processed {} articles".format(article_count)) + if limit and article_count >= limit: + break article_text = "" article_title = None article_id = None @@ -115,7 +126,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N if ids: article_id = ids[0] if article_id in read_ids: - print( + logger.info( "Found duplicate article ID", article_id, clean_line ) # This should never happen ... read_ids.add(article_id) @@ -125,115 +136,10 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N titles = title_regex.search(clean_line) if titles: article_title = titles[0].strip() - - line = file.readline() - cnt += 1 - print(now(), "processed", cnt, "lines of Wikipedia dump") + logger.info("Finished. Processed {} articles".format(article_count)) text_regex = re.compile(r"(?<=).*(?= 2: - reading_special_case = True - - if open_read == 2 and reading_text: - reading_text = False - reading_entity = True - reading_mention = False - - # we just finished reading an entity - if open_read == 0 and not reading_text: - if "#" in entity_buffer or entity_buffer.startswith(":"): - reading_special_case = True - # Ignore cases with nested structures like File: handles etc - if not reading_special_case: - if not mention_buffer: - mention_buffer = entity_buffer - start = len(final_text) - end = start + len(mention_buffer) - qid = wp_to_id.get(entity_buffer, None) - if qid: - _write_training_entity( - outputfile=entityfile, - article_id=article_id, - alias=mention_buffer, - entity=qid, - start=start, - end=end, - ) - found_entities = True - final_text += mention_buffer - - entity_buffer = "" - mention_buffer = "" - - reading_text = True - reading_entity = False - reading_mention = False - reading_special_case = False - - if found_entities: - _write_training_article( - article_id=article_id, - clean_text=final_text, - training_output=training_output, - ) - - info_regex = re.compile(r"{[^{]*?}") htlm_regex = re.compile(r"<!--[^-]*-->") category_regex = re.compile(r"\[\[Category:[^\[]*]]") @@ -242,6 +148,29 @@ ref_regex = re.compile(r"<ref.*?>") # non-greedy ref_2_regex = re.compile(r"</ref.*?>") # non-greedy +def _process_wp_text(article_title, article_text, wp_to_id): + # ignore meta Wikipedia pages + if ( + article_title.startswith("Wikipedia:") or + article_title.startswith("Kategori:") + ): + return None, None + + # remove the text tags + text_search = text_regex.search(article_text) + if text_search is None: + return None, None + text = text_search.group(0) + + # stop processing if this is a redirect page + if text.startswith("#REDIRECT"): + return None, None + + # get the raw text without markup etc, keeping only interwiki links + clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id) + return clean_text, entities + + def _get_clean_wp_text(article_text): clean_text = article_text.strip() @@ -300,130 +229,167 @@ def _get_clean_wp_text(article_text): return clean_text.strip() -def _write_training_article(article_id, clean_text, training_output): - file_loc = training_output / "{}.txt".format(article_id) - with file_loc.open("w", encoding="utf8") as outputfile: - outputfile.write(clean_text) +def _remove_links(clean_text, wp_to_id): + # read the text char by char to get the right offsets for the interwiki links + entities = [] + final_text = "" + open_read = 0 + reading_text = True + reading_entity = False + reading_mention = False + reading_special_case = False + entity_buffer = "" + mention_buffer = "" + for index, letter in enumerate(clean_text): + if letter == "[": + open_read += 1 + elif letter == "]": + open_read -= 1 + elif letter == "|": + if reading_text: + final_text += letter + # switch from reading entity to mention in the [[entity|mention]] pattern + elif reading_entity: + reading_text = False + reading_entity = False + reading_mention = True + else: + reading_special_case = True + else: + if reading_entity: + entity_buffer += letter + elif reading_mention: + mention_buffer += letter + elif reading_text: + final_text += letter + else: + raise ValueError("Not sure at point", clean_text[index - 2: index + 2]) + + if open_read > 2: + reading_special_case = True + + if open_read == 2 and reading_text: + reading_text = False + reading_entity = True + reading_mention = False + + # we just finished reading an entity + if open_read == 0 and not reading_text: + if "#" in entity_buffer or entity_buffer.startswith(":"): + reading_special_case = True + # Ignore cases with nested structures like File: handles etc + if not reading_special_case: + if not mention_buffer: + mention_buffer = entity_buffer + start = len(final_text) + end = start + len(mention_buffer) + qid = wp_to_id.get(entity_buffer, None) + if qid: + entities.append((mention_buffer, qid, start, end)) + final_text += mention_buffer + + entity_buffer = "" + mention_buffer = "" + + reading_text = True + reading_entity = False + reading_mention = False + reading_special_case = False + return final_text, entities -def _write_training_entity(outputfile, article_id, alias, entity, start, end): - line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end) +def _write_training_description(outputfile, qid, description): + if description is not None: + line = str(qid) + "|" + description + "\n" + outputfile.write(line) + + +def _write_training_entities(outputfile, article_id, clean_text, entities): + entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities] + line = json.dumps( + { + "article_id": article_id, + "clean_text": clean_text, + "entities": entities_data + }, + ensure_ascii=False) + "\n" outputfile.write(line) +def read_training(nlp, entity_file_path, dev, limit, kb): + """ This method provides training examples that correspond to the entity annotations found by the nlp object. + For training,, it will include negative training examples by using the candidate generator, + and it will only keep positive training examples that can be found by using the candidate generator. + For testing, it will include all positive examples only.""" + + from tqdm import tqdm + data = [] + num_entities = 0 + get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb) + + logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit)) + with entity_file_path.open("r", encoding="utf8") as file: + with tqdm(total=limit, leave=False) as pbar: + for i, line in enumerate(file): + example = json.loads(line) + article_id = example["article_id"] + clean_text = example["clean_text"] + entities = example["entities"] + + if dev != is_dev(article_id) or len(clean_text) >= 30000: + continue + + doc = nlp(clean_text) + gold = get_gold_parse(doc, entities) + if gold and len(gold.links) > 0: + data.append((doc, gold)) + num_entities += len(gold.links) + pbar.update(len(gold.links)) + if limit and num_entities >= limit: + break + logger.info("Read {} entities in {} articles".format(num_entities, len(data))) + return data + + +def _get_gold_parse(doc, entities, dev, kb): + gold_entities = {} + tagged_ent_positions = set( + [(ent.start_char, ent.end_char) for ent in doc.ents] + ) + + for entity in entities: + entity_id = entity["entity"] + alias = entity["alias"] + start = entity["start"] + end = entity["end"] + + candidates = kb.get_candidates(alias) + candidate_ids = [ + c.entity_ for c in candidates + ] + + should_add_ent = ( + dev or + ( + (start, end) in tagged_ent_positions and + entity_id in candidate_ids and + len(candidates) > 1 + ) + ) + + if should_add_ent: + value_by_id = {entity_id: 1.0} + if not dev: + random.shuffle(candidate_ids) + value_by_id.update({ + kb_id: 0.0 + for kb_id in candidate_ids + if kb_id != entity_id + }) + gold_entities[(start, end)] = value_by_id + + return GoldParse(doc, links=gold_entities) + + def is_dev(article_id): return article_id.endswith("3") - - -def read_training(nlp, training_dir, dev, limit, kb=None): - """ This method provides training examples that correspond to the entity annotations found by the nlp object. - When kb is provided (for training), it will include negative training examples by using the candidate generator, - and it will only keep positive training examples that can be found in the KB. - When kb=None (for testing), it will include all positive examples only.""" - entityfile_loc = training_dir / ENTITY_FILE - data = [] - - # assume the data is written sequentially, so we can reuse the article docs - current_article_id = None - current_doc = None - ents_by_offset = dict() - skip_articles = set() - total_entities = 0 - - with entityfile_loc.open("r", encoding="utf8") as file: - for line in file: - if not limit or len(data) < limit: - fields = line.replace("\n", "").split(sep="|") - article_id = fields[0] - alias = fields[1] - wd_id = fields[2] - start = fields[3] - end = fields[4] - - if ( - dev == is_dev(article_id) - and article_id != "article_id" - and article_id not in skip_articles - ): - if not current_doc or (current_article_id != article_id): - # parse the new article text - file_name = article_id + ".txt" - try: - training_file = training_dir / file_name - with training_file.open("r", encoding="utf8") as f: - text = f.read() - # threshold for convenience / speed of processing - if len(text) < 30000: - current_doc = nlp(text) - current_article_id = article_id - ents_by_offset = dict() - for ent in current_doc.ents: - sent_length = len(ent.sent) - # custom filtering to avoid too long or too short sentences - if 5 < sent_length < 100: - offset = "{}_{}".format( - ent.start_char, ent.end_char - ) - ents_by_offset[offset] = ent - else: - skip_articles.add(article_id) - current_doc = None - except Exception as e: - print("Problem parsing article", article_id, e) - skip_articles.add(article_id) - - # repeat checking this condition in case an exception was thrown - if current_doc and (current_article_id == article_id): - offset = "{}_{}".format(start, end) - found_ent = ents_by_offset.get(offset, None) - if found_ent: - if found_ent.text != alias: - skip_articles.add(article_id) - current_doc = None - else: - sent = found_ent.sent.as_doc() - - gold_start = int(start) - found_ent.sent.start_char - gold_end = int(end) - found_ent.sent.start_char - - gold_entities = {} - found_useful = False - for ent in sent.ents: - entry = (ent.start_char, ent.end_char) - gold_entry = (gold_start, gold_end) - if entry == gold_entry: - # add both pos and neg examples (in random order) - # this will exclude examples not in the KB - if kb: - value_by_id = {} - candidates = kb.get_candidates(alias) - candidate_ids = [ - c.entity_ for c in candidates - ] - random.shuffle(candidate_ids) - for kb_id in candidate_ids: - found_useful = True - if kb_id != wd_id: - value_by_id[kb_id] = 0.0 - else: - value_by_id[kb_id] = 1.0 - gold_entities[entry] = value_by_id - # if no KB, keep all positive examples - else: - found_useful = True - value_by_id = {wd_id: 1.0} - - gold_entities[entry] = value_by_id - # currently feeding the gold data one entity per sentence at a time - # setting all other entities to empty gold dictionary - else: - gold_entities[entry] = {} - if found_useful: - gold = GoldParse(doc=sent, links=gold_entities) - data.append((sent, gold)) - total_entities += 1 - if len(data) % 2500 == 0: - print(" -read", total_entities, "entities") - - print(" -read", total_entities, "entities") - return data diff --git a/bin/wiki_entity_linking/wikidata_pretrain_kb.py b/bin/wiki_entity_linking/wikidata_pretrain_kb.py index c5261cada..56107f3a2 100644 --- a/bin/wiki_entity_linking/wikidata_pretrain_kb.py +++ b/bin/wiki_entity_linking/wikidata_pretrain_kb.py @@ -13,27 +13,25 @@ from https://dumps.wikimedia.org/enwiki/latest/ """ from __future__ import unicode_literals -import datetime +import logging from pathlib import Path import plac -from bin.wiki_entity_linking import wikipedia_processor as wp +from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd from bin.wiki_entity_linking import kb_creator - +from bin.wiki_entity_linking import training_set_creator +from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT +from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH import spacy -from spacy import Errors - - -def now(): - return datetime.datetime.now() +logger = logging.getLogger(__name__) @plac.annotations( wd_json=("Path to the downloaded WikiData JSON dump.", "positional", None, Path), wp_xml=("Path to the downloaded Wikipedia XML dump.", "positional", None, Path), output_dir=("Output directory", "positional", None, Path), - model=("Model name, should include pretrained vectors.", "positional", None, str), + model=("Model name or path, should include pretrained vectors.", "positional", None, str), max_per_alias=("Max. # entities per alias (default 10)", "option", "a", int), min_freq=("Min. count of an entity in the corpus (default 20)", "option", "f", int), min_pair=("Min. count of entity-alias pairs (default 5)", "option", "c", int), @@ -41,7 +39,9 @@ def now(): loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path), loc_entity_defs=("Location to file with entity definitions", "option", "d", Path), loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path), + descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"), limit=("Optional threshold to limit lines read from dumps", "option", "l", int), + lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str), ) def main( wd_json, @@ -55,20 +55,29 @@ def main( loc_prior_prob=None, loc_entity_defs=None, loc_entity_desc=None, + descriptions_from_wikipedia=False, limit=None, + lang="en", ): - print(now(), "Creating KB with Wikipedia and WikiData") - print() + + entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH + entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH + entity_freq_path = output_dir / ENTITY_FREQ_PATH + prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH + training_entities_path = output_dir / TRAINING_DATA_FILE + kb_path = output_dir / KB_FILE + + logger.info("Creating KB with Wikipedia and WikiData") if limit is not None: - print("Warning: reading only", limit, "lines of Wikipedia/Wikidata dumps.") + logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit)) # STEP 0: set up IO if not output_dir.exists(): - output_dir.mkdir() + output_dir.mkdir(parents=True) # STEP 1: create the NLP object - print(now(), "STEP 1: loaded model", model) + logger.info("STEP 1: Loading model {}".format(model)) nlp = spacy.load(model) # check the length of the nlp vectors @@ -79,64 +88,68 @@ def main( ) # STEP 2: create prior probabilities from WP - print() - if loc_prior_prob: - print(now(), "STEP 2: reading prior probabilities from", loc_prior_prob) - else: + if not prior_prob_path.exists(): # It takes about 2h to process 1000M lines of Wikipedia XML dump - loc_prior_prob = output_dir / "prior_prob.csv" - print(now(), "STEP 2: writing prior probabilities at", loc_prior_prob) - wp.read_prior_probs(wp_xml, loc_prior_prob, limit=limit) + logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path)) + wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit) + logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path)) # STEP 3: deduce entity frequencies from WP (takes only a few minutes) - print() - print(now(), "STEP 3: calculating entity frequencies") - loc_entity_freq = output_dir / "entity_freq.csv" - wp.write_entity_counts(loc_prior_prob, loc_entity_freq, to_print=False) + logger.info("STEP 3: calculating entity frequencies") + wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False) - loc_kb = output_dir / "kb" - - # STEP 4: reading entity descriptions and definitions from WikiData or from file - print() - if loc_entity_defs and loc_entity_desc: - read_raw = False - print(now(), "STEP 4a: reading entity definitions from", loc_entity_defs) - print(now(), "STEP 4b: reading entity descriptions from", loc_entity_desc) - else: + # STEP 4: reading definitions and (possibly) descriptions from WikiData or from file + message = " and descriptions" if not descriptions_from_wikipedia else "" + if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()): # It takes about 10h to process 55M lines of Wikidata JSON dump - read_raw = True - loc_entity_defs = output_dir / "entity_defs.csv" - loc_entity_desc = output_dir / "entity_descriptions.csv" - print(now(), "STEP 4: parsing wikidata for entity definitions and descriptions") + logger.info("STEP 4: parsing wikidata for entity definitions" + message) + title_to_id, id_to_descr = wd.read_wikidata_entities_json( + wd_json, + limit, + to_print=False, + lang=lang, + parse_descriptions=(not descriptions_from_wikipedia), + ) + wd.write_entity_files(entity_defs_path, title_to_id) + if not descriptions_from_wikipedia: + wd.write_entity_description_files(entity_descr_path, id_to_descr) + logger.info("STEP 4: read entity definitions" + message) - # STEP 5: creating the actual KB + # STEP 5: Getting gold entities from wikipedia + message = " and descriptions" if descriptions_from_wikipedia else "" + if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()): + logger.info("STEP 5: parsing wikipedia for gold entities" + message) + training_set_creator.create_training_examples_and_descriptions( + wp_xml, + entity_defs_path, + entity_descr_path, + training_entities_path, + parse_descriptions=descriptions_from_wikipedia, + limit=limit, + ) + logger.info("STEP 5: read gold entities" + message) + + # STEP 6: creating the actual KB # It takes ca. 30 minutes to pretrain the entity embeddings - print() - print(now(), "STEP 5: creating the KB at", loc_kb) + logger.info("STEP 6: creating the KB at {}".format(kb_path)) kb = kb_creator.create_kb( nlp=nlp, max_entities_per_alias=max_per_alias, min_entity_freq=min_freq, min_occ=min_pair, - entity_def_output=loc_entity_defs, - entity_descr_output=loc_entity_desc, - count_input=loc_entity_freq, - prior_prob_input=loc_prior_prob, - wikidata_input=wd_json, + entity_def_input=entity_defs_path, + entity_descr_path=entity_descr_path, + count_input=entity_freq_path, + prior_prob_input=prior_prob_path, entity_vector_length=entity_vector_length, - limit=limit, - read_raw_data=read_raw, ) - if read_raw: - print(" - wrote entity definitions to", loc_entity_defs) - print(" - wrote writing entity descriptions to", loc_entity_desc) - kb.dump(loc_kb) - nlp.to_disk(output_dir / "nlp") + kb.dump(kb_path) + nlp.to_disk(output_dir / KB_MODEL_DIR) - print() - print(now(), "Done!") + logger.info("Done!") if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) plac.call(main) diff --git a/bin/wiki_entity_linking/wikidata_processor.py b/bin/wiki_entity_linking/wikidata_processor.py index 660eab28e..b4034cb1a 100644 --- a/bin/wiki_entity_linking/wikidata_processor.py +++ b/bin/wiki_entity_linking/wikidata_processor.py @@ -1,17 +1,19 @@ # coding: utf-8 from __future__ import unicode_literals -import bz2 +import gzip import json +import logging import datetime +logger = logging.getLogger(__name__) -def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): + +def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True): # Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ - lang = "en" - site_filter = "enwiki" + site_filter = '{}wiki'.format(lang) # properties filter (currently disabled to get ALL data) prop_filter = dict() @@ -24,18 +26,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): parse_properties = False parse_sitelinks = True parse_labels = False - parse_descriptions = True parse_aliases = False parse_claims = False - with bz2.open(wikidata_file, mode="rb") as file: - line = file.readline() - cnt = 0 - while line and (not limit or cnt < limit): - if cnt % 1000000 == 0: - print( - datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump" - ) + with gzip.open(wikidata_file, mode='rb') as file: + for cnt, line in enumerate(file): + if limit and cnt >= limit: + break + if cnt % 500000 == 0: + logger.info("processed {} lines of WikiData dump".format(cnt)) clean_line = line.strip() if clean_line.endswith(b","): clean_line = clean_line[:-1] @@ -134,8 +133,19 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): if to_print: print() - line = file.readline() - cnt += 1 - print(datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump") return title_to_id, id_to_descr + + +def write_entity_files(entity_def_output, title_to_id): + with entity_def_output.open("w", encoding="utf8") as id_file: + id_file.write("WP_title" + "|" + "WD_id" + "\n") + for title, qid in title_to_id.items(): + id_file.write(title + "|" + str(qid) + "\n") + + +def write_entity_description_files(entity_descr_output, id_to_descr): + with entity_descr_output.open("w", encoding="utf8") as descr_file: + descr_file.write("WD_id" + "|" + "description" + "\n") + for qid, descr in id_to_descr.items(): + descr_file.write(str(qid) + "|" + descr + "\n") diff --git a/bin/wiki_entity_linking/wikidata_train_entity_linker.py b/bin/wiki_entity_linking/wikidata_train_entity_linker.py index 770919112..d9ed641d6 100644 --- a/bin/wiki_entity_linking/wikidata_train_entity_linker.py +++ b/bin/wiki_entity_linking/wikidata_train_entity_linker.py @@ -11,124 +11,84 @@ from https://dumps.wikimedia.org/enwiki/latest/ from __future__ import unicode_literals import random -import datetime +import logging from pathlib import Path import plac from bin.wiki_entity_linking import training_set_creator +from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR +from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines +from bin.wiki_entity_linking.kb_creator import read_nlp_kb -import spacy -from spacy.kb import KnowledgeBase from spacy.util import minibatch, compounding - -def now(): - return datetime.datetime.now() +logger = logging.getLogger(__name__) @plac.annotations( dir_kb=("Directory with KB, NLP and related files", "positional", None, Path), output_dir=("Output directory", "option", "o", Path), loc_training=("Location to training data", "option", "k", Path), - wp_xml=("Path to the downloaded Wikipedia XML dump.", "option", "w", Path), epochs=("Number of training iterations (default 10)", "option", "e", int), dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float), lr=("Learning rate (default 0.005)", "option", "n", float), l2=("L2 regularization", "option", "r", float), train_inst=("# training instances (default 90% of all)", "option", "t", int), dev_inst=("# test instances (default 10% of all)", "option", "d", int), - limit=("Optional threshold to limit lines read from WP dump", "option", "l", int), ) def main( dir_kb, output_dir=None, loc_training=None, - wp_xml=None, epochs=10, dropout=0.5, lr=0.005, l2=1e-6, train_inst=None, dev_inst=None, - limit=None, ): - print(now(), "Creating Entity Linker with Wikipedia and WikiData") - print() + logger.info("Creating Entity Linker with Wikipedia and WikiData") + + output_dir = Path(output_dir) if output_dir else dir_kb + training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE + nlp_dir = dir_kb / KB_MODEL_DIR + kb_path = output_dir / KB_FILE + nlp_output_dir = output_dir / OUTPUT_MODEL_DIR # STEP 0: set up IO - if output_dir and not output_dir.exists(): + if not output_dir.exists(): output_dir.mkdir() # STEP 1 : load the NLP object - nlp_dir = dir_kb / "nlp" - print(now(), "STEP 1: loading model from", nlp_dir) - nlp = spacy.load(nlp_dir) + logger.info("STEP 1: loading model from {}".format(nlp_dir)) + nlp, kb = read_nlp_kb(nlp_dir, kb_path) # check that there is a NER component in the pipeline if "ner" not in nlp.pipe_names: raise ValueError("The `nlp` object should have a pre-trained `ner` component.") - # STEP 2 : read the KB - print() - print(now(), "STEP 2: reading the KB from", dir_kb / "kb") - kb = KnowledgeBase(vocab=nlp.vocab) - kb.load_bulk(dir_kb / "kb") + # STEP 2: create a training dataset from WP + logger.info("STEP 2: reading training dataset from {}".format(training_path)) - # STEP 3: create a training dataset from WP - print() - if loc_training: - print(now(), "STEP 3: reading training dataset from", loc_training) - else: - if not wp_xml: - raise ValueError( - "Either provide a path to a preprocessed training directory, " - "or to the original Wikipedia XML dump." - ) - - if output_dir: - loc_training = output_dir / "training_data" - else: - loc_training = dir_kb / "training_data" - if not loc_training.exists(): - loc_training.mkdir() - print(now(), "STEP 3: creating training dataset at", loc_training) - - if limit is not None: - print("Warning: reading only", limit, "lines of Wikipedia dump.") - - loc_entity_defs = dir_kb / "entity_defs.csv" - training_set_creator.create_training( - wikipedia_input=wp_xml, - entity_def_input=loc_entity_defs, - training_output=loc_training, - limit=limit, - ) - - # STEP 4: parse the training data - print() - print(now(), "STEP 4: parse the training & evaluation data") - - # for training, get pos & neg instances that correspond to entries in the kb - print("Parsing training data, limit =", train_inst) train_data = training_set_creator.read_training( - nlp=nlp, training_dir=loc_training, dev=False, limit=train_inst, kb=kb + nlp=nlp, + entity_file_path=training_path, + dev=False, + limit=train_inst, + kb=kb, ) - print("Training on", len(train_data), "articles") - print() - - print("Parsing dev testing data, limit =", dev_inst) # for testing, get all pos instances, whether or not they are in the kb dev_data = training_set_creator.read_training( - nlp=nlp, training_dir=loc_training, dev=True, limit=dev_inst, kb=None + nlp=nlp, + entity_file_path=training_path, + dev=True, + limit=dev_inst, + kb=kb, ) - print("Dev testing on", len(dev_data), "articles") - print() - - # STEP 5: create and train the entity linking pipe - print() - print(now(), "STEP 5: training Entity Linking pipe") + # STEP 3: create and train the entity linking pipe + logger.info("STEP 3: training Entity Linking pipe") el_pipe = nlp.create_pipe( name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name} @@ -142,275 +102,70 @@ def main( optimizer.learn_rate = lr optimizer.L2 = l2 - if not train_data: - print("Did not find any training data") - else: - for itn in range(epochs): - random.shuffle(train_data) - losses = {} - batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001)) - batchnr = 0 + logger.info("Training on {} articles".format(len(train_data))) + logger.info("Dev testing on {} articles".format(len(dev_data))) - with nlp.disable_pipes(*other_pipes): - for batch in batches: - try: - docs, golds = zip(*batch) - nlp.update( - docs=docs, - golds=golds, - sgd=optimizer, - drop=dropout, - losses=losses, - ) - batchnr += 1 - except Exception as e: - print("Error updating batch:", e) - - if batchnr > 0: - el_pipe.cfg["incl_context"] = True - el_pipe.cfg["incl_prior"] = True - dev_acc_context, _ = _measure_acc(dev_data, el_pipe) - losses["entity_linker"] = losses["entity_linker"] / batchnr - print( - "Epoch, train loss", - itn, - round(losses["entity_linker"], 2), - " / dev accuracy avg", - round(dev_acc_context, 3), - ) - - # STEP 6: measure the performance of our trained pipe on an independent dev set - print() - if len(dev_data): - print() - print(now(), "STEP 6: performance measurement of Entity Linking pipe") - print() - - counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines( - dev_data, kb - ) - print("dev counts:", sorted(counts.items(), key=lambda x: x[0])) - - oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()] - print("dev accuracy oracle:", round(acc_o, 3), oracle_by_label) - - random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()] - print("dev accuracy random:", round(acc_r, 3), random_by_label) - - prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()] - print("dev accuracy prior:", round(acc_p, 3), prior_by_label) - - # using only context - el_pipe.cfg["incl_context"] = True - el_pipe.cfg["incl_prior"] = False - dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe) - context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()] - print("dev accuracy context:", round(dev_acc_context, 3), context_by_label) - - # measuring combined accuracy (prior + context) - el_pipe.cfg["incl_context"] = True - el_pipe.cfg["incl_prior"] = True - dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe) - combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()] - print("dev accuracy prior+context:", round(dev_acc_combo, 3), combo_by_label) - - # STEP 7: apply the EL pipe on a toy example - print() - print(now(), "STEP 7: applying Entity Linking to toy example") - print() - run_el_toy_example(nlp=nlp) - - # STEP 8: write the NLP pipeline (including entity linker) to file - if output_dir: - print() - nlp_loc = output_dir / "nlp" - print(now(), "STEP 8: Writing trained NLP to", nlp_loc) - nlp.to_disk(nlp_loc) - print() - - print() - print(now(), "Done!") - - -def _measure_acc(data, el_pipe=None, error_analysis=False): - # If the docs in the data require further processing with an entity linker, set el_pipe - correct_by_label = dict() - incorrect_by_label = dict() - - docs = [d for d, g in data if len(d) > 0] - if el_pipe is not None: - docs = list(el_pipe.pipe(docs)) - golds = [g for d, g in data if len(d) > 0] - - for doc, gold in zip(docs, golds): - try: - correct_entries_per_article = dict() - for entity, kb_dict in gold.links.items(): - start, end = entity - # only evaluating on positive examples - for gold_kb, value in kb_dict.items(): - if value: - offset = _offset(start, end) - correct_entries_per_article[offset] = gold_kb - - for ent in doc.ents: - ent_label = ent.label_ - pred_entity = ent.kb_id_ - start = ent.start_char - end = ent.end_char - offset = _offset(start, end) - gold_entity = correct_entries_per_article.get(offset, None) - # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' - if gold_entity is not None: - if gold_entity == pred_entity: - correct = correct_by_label.get(ent_label, 0) - correct_by_label[ent_label] = correct + 1 - else: - incorrect = incorrect_by_label.get(ent_label, 0) - incorrect_by_label[ent_label] = incorrect + 1 - if error_analysis: - print(ent.text, "in", doc) - print( - "Predicted", - pred_entity, - "should have been", - gold_entity, - ) - print() - - except Exception as e: - print("Error assessing accuracy", e) - - acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label) - return acc, acc_by_label - - -def _measure_baselines(data, kb): - # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound - counts_d = dict() - - random_correct_d = dict() - random_incorrect_d = dict() - - oracle_correct_d = dict() - oracle_incorrect_d = dict() - - prior_correct_d = dict() - prior_incorrect_d = dict() - - docs = [d for d, g in data if len(d) > 0] - golds = [g for d, g in data if len(d) > 0] - - for doc, gold in zip(docs, golds): - try: - correct_entries_per_article = dict() - for entity, kb_dict in gold.links.items(): - start, end = entity - for gold_kb, value in kb_dict.items(): - # only evaluating on positive examples - if value: - offset = _offset(start, end) - correct_entries_per_article[offset] = gold_kb - - for ent in doc.ents: - label = ent.label_ - start = ent.start_char - end = ent.end_char - offset = _offset(start, end) - gold_entity = correct_entries_per_article.get(offset, None) - - # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' - if gold_entity is not None: - counts_d[label] = counts_d.get(label, 0) + 1 - candidates = kb.get_candidates(ent.text) - oracle_candidate = "" - best_candidate = "" - random_candidate = "" - if candidates: - scores = [] - - for c in candidates: - scores.append(c.prior_prob) - if c.entity_ == gold_entity: - oracle_candidate = c.entity_ - - best_index = scores.index(max(scores)) - best_candidate = candidates[best_index].entity_ - random_candidate = random.choice(candidates).entity_ - - if gold_entity == best_candidate: - prior_correct_d[label] = prior_correct_d.get(label, 0) + 1 - else: - prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1 - - if gold_entity == random_candidate: - random_correct_d[label] = random_correct_d.get(label, 0) + 1 - else: - random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1 - - if gold_entity == oracle_candidate: - oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1 - else: - oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1 - - except Exception as e: - print("Error assessing accuracy", e) - - acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d) - acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d) - acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d) - - return ( - counts_d, - acc_rand, - acc_rand_d, - acc_prior, - acc_prior_d, - acc_oracle, - acc_oracle_d, + dev_baseline_accuracies = measure_baselines( + dev_data, kb ) + logger.info("Dev Baseline Accuracies:") + logger.info(dev_baseline_accuracies.report_accuracy("random")) + logger.info(dev_baseline_accuracies.report_accuracy("prior")) + logger.info(dev_baseline_accuracies.report_accuracy("oracle")) -def _offset(start, end): - return "{}_{}".format(start, end) + for itn in range(epochs): + random.shuffle(train_data) + losses = {} + batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001)) + batchnr = 0 + with nlp.disable_pipes(*other_pipes): + for batch in batches: + try: + docs, golds = zip(*batch) + nlp.update( + docs=docs, + golds=golds, + sgd=optimizer, + drop=dropout, + losses=losses, + ) + batchnr += 1 + except Exception as e: + logger.error("Error updating batch:" + str(e)) + if batchnr > 0: + logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2))) + measure_performance(dev_data, kb, el_pipe) -def calculate_acc(correct_by_label, incorrect_by_label): - acc_by_label = dict() - total_correct = 0 - total_incorrect = 0 - all_keys = set() - all_keys.update(correct_by_label.keys()) - all_keys.update(incorrect_by_label.keys()) - for label in sorted(all_keys): - correct = correct_by_label.get(label, 0) - incorrect = incorrect_by_label.get(label, 0) - total_correct += correct - total_incorrect += incorrect - if correct == incorrect == 0: - acc_by_label[label] = 0 - else: - acc_by_label[label] = correct / (correct + incorrect) - acc = 0 - if not (total_correct == total_incorrect == 0): - acc = total_correct / (total_correct + total_incorrect) - return acc, acc_by_label + # STEP 4: measure the performance of our trained pipe on an independent dev set + logger.info("STEP 4: performance measurement of Entity Linking pipe") + measure_performance(dev_data, kb, el_pipe) + + # STEP 5: apply the EL pipe on a toy example + logger.info("STEP 5: applying Entity Linking to toy example") + run_el_toy_example(nlp=nlp) + + if output_dir: + # STEP 6: write the NLP pipeline (including entity linker) to file + logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir)) + nlp.to_disk(nlp_output_dir) + + logger.info("Done!") def check_kb(kb): for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"): candidates = kb.get_candidates(mention) - print("generating candidates for " + mention + " :") + logger.info("generating candidates for " + mention + " :") for c in candidates: - print( - " ", - c.prior_prob, + logger.info(" ".join[ + str(c.prior_prob), c.alias_, "-->", - c.entity_ + " (freq=" + str(c.entity_freq) + ")", - ) - print() + c.entity_ + " (freq=" + str(c.entity_freq) + ")" + ]) def run_el_toy_example(nlp): @@ -421,11 +176,11 @@ def run_el_toy_example(nlp): "but Dougledydoug doesn't write about George Washington or Homer Simpson." ) doc = nlp(text) - print(text) + logger.info(text) for ent in doc.ents: - print(" ent", ent.text, ent.label_, ent.kb_id_) - print() + logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_])) if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) plac.call(main) diff --git a/bin/wiki_entity_linking/wikipedia_processor.py b/bin/wiki_entity_linking/wikipedia_processor.py index fca600368..8f928723e 100644 --- a/bin/wiki_entity_linking/wikipedia_processor.py +++ b/bin/wiki_entity_linking/wikipedia_processor.py @@ -5,6 +5,9 @@ import re import bz2 import csv import datetime +import logging + +from bin.wiki_entity_linking import LOG_FORMAT """ Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions. @@ -13,6 +16,9 @@ Write these results to file for downstream KB and training data generation. map_alias_to_link = dict() +logger = logging.getLogger(__name__) + + # these will/should be matched ignoring case wiki_namespaces = [ "b", @@ -116,10 +122,6 @@ for ns in wiki_namespaces: ns_regex = re.compile(ns_regex, re.IGNORECASE) -def now(): - return datetime.datetime.now() - - def read_prior_probs(wikipedia_input, prior_prob_output, limit=None): """ Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities. @@ -131,7 +133,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None): cnt = 0 while line and (not limit or cnt < limit): if cnt % 25000000 == 0: - print(now(), "processed", cnt, "lines of Wikipedia XML dump") + logger.info("processed {} lines of Wikipedia XML dump".format(cnt)) clean_line = line.strip().decode("utf-8") aliases, entities, normalizations = get_wp_links(clean_line) @@ -141,7 +143,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None): line = file.readline() cnt += 1 - print(now(), "processed", cnt, "lines of Wikipedia XML dump") + logger.info("processed {} lines of Wikipedia XML dump".format(cnt)) # write all aliases and their entities and count occurrences to file with prior_prob_output.open("w", encoding="utf8") as outputfile: