From 7b96a5e10f907c325f1d71dd23daafafe237d4d9 Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Mon, 6 Jan 2020 14:59:50 +0100 Subject: [PATCH] Reduce mem usage in training Entity Linker (#4811) * move nlp processing for el pipe to batch training instead of preprocessing * adding dev eval back in, and limit in articles instead of entities * use pipe whenever possible * few more small doc changes * access dev data through generator * tqdm description * small fixes * update documentation --- bin/wiki_entity_linking/README.md | 28 +-- .../entity_linker_evaluation.py | 181 ++++++++---------- .../wikidata_train_entity_linker.py | 158 +++++++-------- .../wikipedia_processor.py | 61 +++--- spacy/pipeline/pipes.pyx | 2 +- 5 files changed, 200 insertions(+), 230 deletions(-) diff --git a/bin/wiki_entity_linking/README.md b/bin/wiki_entity_linking/README.md index 540878592..7460a455e 100644 --- a/bin/wiki_entity_linking/README.md +++ b/bin/wiki_entity_linking/README.md @@ -7,16 +7,16 @@ Run `wikipedia_pretrain_kb.py` * WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/ * Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language) * You can set the filtering parameters for KB construction: - * `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym - * `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB - * `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB + * `max_per_alias` (`-a`): (max) number of candidate entities in the KB per alias/synonym + * `min_freq` (`-f`): threshold of number of times an entity should occur in the corpus to be included in the KB + * `min_pair` (`-c`): threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB * Further parameters to set: - * `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`) - * `entity_vector_length`: length of the pre-trained entity description vectors - * `lang`: language for which to fetch Wikidata information (as the dump contains all languages) + * `descriptions_from_wikipedia` (`-wp`): whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`) + * `entity_vector_length` (`-v`): length of the pre-trained entity description vectors + * `lang` (`-la`): language for which to fetch Wikidata information (as the dump contains all languages) Quick testing and rerunning: -* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything. +* When trying out the pipeline for a quick test, set `limit_prior` (`-lp`), `limit_train` (`-lt`) and/or `limit_wd` (`-lw`) to read only parts of the dumps instead of everything. * If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed. @@ -24,11 +24,13 @@ Quick testing and rerunning: Run `wikidata_train_entity_linker.py` * This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model** +* Specify the output directory (`-o`) in which the final, trained model will be saved * You can set the learning parameters for the EL training: - * `epochs`: number of training iterations - * `dropout`: dropout rate - * `lr`: learning rate - * `l2`: L2 regularization -* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively + * `epochs` (`-e`): number of training iterations + * `dropout` (`-p`): dropout rate + * `lr` (`-n`): learning rate + * `l2` (`-r`): L2 regularization +* Specify the number of training and dev testing articles with `train_articles` (`-t`) and `dev_articles` (`-d`) respectively + * If not specified, the full dataset will be processed - this may take a LONG time ! * Further parameters to set: - * `labels_discard`: NER label types to discard during training + * `labels_discard` (`-l`): NER label types to discard during training diff --git a/bin/wiki_entity_linking/entity_linker_evaluation.py b/bin/wiki_entity_linking/entity_linker_evaluation.py index 273ade0cd..2aeffbfc2 100644 --- a/bin/wiki_entity_linking/entity_linker_evaluation.py +++ b/bin/wiki_entity_linking/entity_linker_evaluation.py @@ -1,6 +1,8 @@ +# coding: utf-8 +from __future__ import unicode_literals + import logging import random - from tqdm import tqdm from collections import defaultdict @@ -92,133 +94,110 @@ class BaselineResults(object): self.random.update_metrics(ent_label, true_entity, random_candidate) -def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True): - if baseline: - baseline_accuracies, counts = measure_baselines(dev_data, kb) - logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())})) - logger.info(baseline_accuracies.report_performance("random")) - logger.info(baseline_accuracies.report_performance("prior")) - logger.info(baseline_accuracies.report_performance("oracle")) +def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True, dev_limit=None): + counts = dict() + baseline_results = BaselineResults() + context_results = EvaluationResults() + combo_results = EvaluationResults() - if context: - # using only context - el_pipe.cfg["incl_context"] = True - el_pipe.cfg["incl_prior"] = False - results = get_eval_results(dev_data, el_pipe) - logger.info(results.report_metrics("context only")) - - # measuring combined accuracy (prior + context) - el_pipe.cfg["incl_context"] = True - el_pipe.cfg["incl_prior"] = True - results = get_eval_results(dev_data, el_pipe) - logger.info(results.report_metrics("context and prior")) - - -def get_eval_results(data, el_pipe=None): - """ - Evaluate the ent.kb_id_ annotations against the gold standard. - Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL. - If the docs in the data require further processing with an entity linker, set el_pipe. - """ - docs = [] - golds = [] - for d, g in tqdm(data, leave=False): - if len(d) > 0: - golds.append(g) - if el_pipe is not None: - docs.append(el_pipe(d)) - else: - docs.append(d) - - results = EvaluationResults() - for doc, gold in zip(docs, golds): - try: - correct_entries_per_article = dict() + for doc, gold in tqdm(dev_data, total=dev_limit, leave=False, desc='Processing dev data'): + if len(doc) > 0: + correct_ents = dict() for entity, kb_dict in gold.links.items(): start, end = entity for gold_kb, value in kb_dict.items(): if value: # only evaluating on positive examples offset = _offset(start, end) - correct_entries_per_article[offset] = gold_kb + correct_ents[offset] = gold_kb - for ent in doc.ents: - ent_label = ent.label_ - pred_entity = ent.kb_id_ - start = ent.start_char - end = ent.end_char - offset = _offset(start, end) - gold_entity = correct_entries_per_article.get(offset, None) - # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' - if gold_entity is not None: - results.update_metrics(ent_label, gold_entity, pred_entity) + if baseline: + _add_baseline(baseline_results, counts, doc, correct_ents, kb) - except Exception as e: - logging.error("Error assessing accuracy " + str(e)) + if context: + # using only context + el_pipe.cfg["incl_context"] = True + el_pipe.cfg["incl_prior"] = False + _add_eval_result(context_results, doc, correct_ents, el_pipe) - return results + # measuring combined accuracy (prior + context) + el_pipe.cfg["incl_context"] = True + el_pipe.cfg["incl_prior"] = True + _add_eval_result(combo_results, doc, correct_ents, el_pipe) + + if baseline: + logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())})) + logger.info(baseline_results.report_performance("random")) + logger.info(baseline_results.report_performance("prior")) + logger.info(baseline_results.report_performance("oracle")) + + if context: + logger.info(context_results.report_metrics("context only")) + logger.info(combo_results.report_metrics("context and prior")) -def measure_baselines(data, kb): +def _add_eval_result(results, doc, correct_ents, el_pipe): """ - Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound. + Evaluate the ent.kb_id_ annotations against the gold standard. Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL. - Also return a dictionary of counts by entity label. """ - counts_d = dict() - - baseline_results = BaselineResults() - - docs = [d for d, g in data if len(d) > 0] - golds = [g for d, g in data if len(d) > 0] - - for doc, gold in zip(docs, golds): - correct_entries_per_article = dict() - for entity, kb_dict in gold.links.items(): - start, end = entity - for gold_kb, value in kb_dict.items(): - # only evaluating on positive examples - if value: - offset = _offset(start, end) - correct_entries_per_article[offset] = gold_kb - + try: + doc = el_pipe(doc) for ent in doc.ents: ent_label = ent.label_ start = ent.start_char end = ent.end_char offset = _offset(start, end) - gold_entity = correct_entries_per_article.get(offset, None) - + gold_entity = correct_ents.get(offset, None) # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' if gold_entity is not None: - candidates = kb.get_candidates(ent.text) - oracle_candidate = "" - prior_candidate = "" - random_candidate = "" - if candidates: - scores = [] + pred_entity = ent.kb_id_ + results.update_metrics(ent_label, gold_entity, pred_entity) - for c in candidates: - scores.append(c.prior_prob) - if c.entity_ == gold_entity: - oracle_candidate = c.entity_ + except Exception as e: + logging.error("Error assessing accuracy " + str(e)) - best_index = scores.index(max(scores)) - prior_candidate = candidates[best_index].entity_ - random_candidate = random.choice(candidates).entity_ - current_count = counts_d.get(ent_label, 0) - counts_d[ent_label] = current_count+1 +def _add_baseline(baseline_results, counts, doc, correct_ents, kb): + """ + Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound. + Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL. + """ + for ent in doc.ents: + ent_label = ent.label_ + start = ent.start_char + end = ent.end_char + offset = _offset(start, end) + gold_entity = correct_ents.get(offset, None) - baseline_results.update_baselines( - gold_entity, - ent_label, - random_candidate, - prior_candidate, - oracle_candidate, - ) + # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' + if gold_entity is not None: + candidates = kb.get_candidates(ent.text) + oracle_candidate = "" + prior_candidate = "" + random_candidate = "" + if candidates: + scores = [] - return baseline_results, counts_d + for c in candidates: + scores.append(c.prior_prob) + if c.entity_ == gold_entity: + oracle_candidate = c.entity_ + + best_index = scores.index(max(scores)) + prior_candidate = candidates[best_index].entity_ + random_candidate = random.choice(candidates).entity_ + + current_count = counts.get(ent_label, 0) + counts[ent_label] = current_count+1 + + baseline_results.update_baselines( + gold_entity, + ent_label, + random_candidate, + prior_candidate, + oracle_candidate, + ) def _offset(start, end): diff --git a/bin/wiki_entity_linking/wikidata_train_entity_linker.py b/bin/wiki_entity_linking/wikidata_train_entity_linker.py index 8635ae547..54f00fc6f 100644 --- a/bin/wiki_entity_linking/wikidata_train_entity_linker.py +++ b/bin/wiki_entity_linking/wikidata_train_entity_linker.py @@ -1,5 +1,5 @@ # coding: utf-8 -"""Script to take a previously created Knowledge Base and train an entity linking +"""Script that takes a previously created Knowledge Base and trains an entity linking pipeline. The provided KB directory should hold the kb, the original nlp object and its vocab used to create the KB, and a few auxiliary files such as the entity definitions, as created by the script `wikidata_create_kb`. @@ -14,6 +14,7 @@ import logging import spacy from pathlib import Path import plac +from tqdm import tqdm from bin.wiki_entity_linking import wikipedia_processor from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR @@ -33,8 +34,8 @@ logger = logging.getLogger(__name__) dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float), lr=("Learning rate (default 0.005)", "option", "n", float), l2=("L2 regularization", "option", "r", float), - train_inst=("# training instances (default 90% of all)", "option", "t", int), - dev_inst=("# test instances (default 10% of all)", "option", "d", int), + train_articles=("# training articles (default 90% of all)", "option", "t", int), + dev_articles=("# dev test articles (default 10% of all)", "option", "d", int), labels_discard=("NER labels to discard (default None)", "option", "l", str), ) def main( @@ -45,10 +46,13 @@ def main( dropout=0.5, lr=0.005, l2=1e-6, - train_inst=None, - dev_inst=None, + train_articles=None, + dev_articles=None, labels_discard=None ): + if not output_dir: + logger.warning("No output dir specified so no results will be written, are you sure about this ?") + logger.info("Creating Entity Linker with Wikipedia and WikiData") output_dir = Path(output_dir) if output_dir else dir_kb @@ -64,44 +68,33 @@ def main( # STEP 1 : load the NLP object logger.info("STEP 1a: Loading model from {}".format(nlp_dir)) nlp = spacy.load(nlp_dir) - logger.info("STEP 1b: Loading KB from {}".format(kb_path)) - kb = read_kb(nlp, kb_path) + logger.info("Original NLP pipeline has following pipeline components: {}".format(nlp.pipe_names)) # check that there is a NER component in the pipeline if "ner" not in nlp.pipe_names: raise ValueError("The `nlp` object should have a pretrained `ner` component.") - # STEP 2: read the training dataset previously created from WP - logger.info("STEP 2: Reading training dataset from {}".format(training_path)) + logger.info("STEP 1b: Loading KB from {}".format(kb_path)) + kb = read_kb(nlp, kb_path) + # STEP 2: read the training dataset previously created from WP + logger.info("STEP 2: Reading training & dev dataset from {}".format(training_path)) + train_indices, dev_indices = wikipedia_processor.read_training_indices(training_path) + logger.info("Training set has {} articles, limit set to roughly {} articles per epoch" + .format(len(train_indices), train_articles if train_articles else "all")) + logger.info("Dev set has {} articles, limit set to rougly {} articles for evaluation" + .format(len(dev_indices), dev_articles if dev_articles else "all")) + if dev_articles: + dev_indices = dev_indices[0:dev_articles] + + # STEP 3: create and train an entity linking pipe + logger.info("STEP 3: Creating and training an Entity Linking pipe for {} epochs".format(epochs)) if labels_discard: labels_discard = [x.strip() for x in labels_discard.split(",")] logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard)) else: labels_discard = [] - train_data = wikipedia_processor.read_training( - nlp=nlp, - entity_file_path=training_path, - dev=False, - limit=train_inst, - kb=kb, - labels_discard=labels_discard - ) - - # for testing, get all pos instances (independently of KB) - dev_data = wikipedia_processor.read_training( - nlp=nlp, - entity_file_path=training_path, - dev=True, - limit=dev_inst, - kb=None, - labels_discard=labels_discard - ) - - # STEP 3: create and train an entity linking pipe - logger.info("STEP 3: Creating and training an Entity Linking pipe") - el_pipe = nlp.create_pipe( name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name, "labels_discard": labels_discard} @@ -115,80 +108,65 @@ def main( optimizer.learn_rate = lr optimizer.L2 = l2 - logger.info("Training on {} articles".format(len(train_data))) - logger.info("Dev testing on {} articles".format(len(dev_data))) - - # baseline performance on dev data logger.info("Dev Baseline Accuracies:") - measure_performance(dev_data, kb, el_pipe, baseline=True, context=False) + dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path, + dev=True, line_ids=dev_indices, + kb=kb, labels_discard=labels_discard) + + measure_performance(dev_data, kb, el_pipe, baseline=True, context=False, dev_limit=len(dev_indices)) for itn in range(epochs): - random.shuffle(train_data) + random.shuffle(train_indices) losses = {} - batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001)) + batches = minibatch(train_indices, size=compounding(8.0, 128.0, 1.001)) batchnr = 0 + articles_processed = 0 - with nlp.disable_pipes(*other_pipes): + # we either process the whole training file, or just a part each epoch + bar_total = len(train_indices) + if train_articles: + bar_total = train_articles + + with tqdm(total=bar_total, leave=False, desc='Epoch ' + str(itn)) as pbar: for batch in batches: - try: - docs, golds = zip(*batch) - nlp.update( - docs=docs, - golds=golds, - sgd=optimizer, - drop=dropout, - losses=losses, - ) - batchnr += 1 - except Exception as e: - logger.error("Error updating batch:" + str(e)) + if not train_articles or articles_processed < train_articles: + with nlp.disable_pipes("entity_linker"): + train_batch = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path, + dev=False, line_ids=batch, + kb=kb, labels_discard=labels_discard) + docs, golds = zip(*train_batch) + try: + with nlp.disable_pipes(*other_pipes): + nlp.update( + docs=docs, + golds=golds, + sgd=optimizer, + drop=dropout, + losses=losses, + ) + batchnr += 1 + articles_processed += len(docs) + pbar.update(len(docs)) + except Exception as e: + logger.error("Error updating batch:" + str(e)) if batchnr > 0: - logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2))) - measure_performance(dev_data, kb, el_pipe, baseline=False, context=True) - - # STEP 4: measure the performance of our trained pipe on an independent dev set - logger.info("STEP 4: Final performance measurement of Entity Linking pipe") - measure_performance(dev_data, kb, el_pipe) - - # STEP 5: apply the EL pipe on a toy example - logger.info("STEP 5: Applying Entity Linking to toy example") - run_el_toy_example(nlp=nlp) + logging.info("Epoch {} trained on {} articles, train loss {}" + .format(itn, articles_processed, round(losses["entity_linker"] / batchnr, 2))) + # re-read the dev_data (data is returned as a generator) + dev_data = wikipedia_processor.read_el_docs_golds(nlp=nlp, entity_file_path=training_path, + dev=True, line_ids=dev_indices, + kb=kb, labels_discard=labels_discard) + measure_performance(dev_data, kb, el_pipe, baseline=False, context=True, dev_limit=len(dev_indices)) if output_dir: - # STEP 6: write the NLP pipeline (now including an EL model) to file - logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir)) + # STEP 4: write the NLP pipeline (now including an EL model) to file + logger.info("Final NLP pipeline has following pipeline components: {}".format(nlp.pipe_names)) + logger.info("STEP 4: Writing trained NLP to {}".format(nlp_output_dir)) nlp.to_disk(nlp_output_dir) logger.info("Done!") -def check_kb(kb): - for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"): - candidates = kb.get_candidates(mention) - - logger.info("generating candidates for " + mention + " :") - for c in candidates: - logger.info(" ".join[ - str(c.prior_prob), - c.alias_, - "-->", - c.entity_ + " (freq=" + str(c.entity_freq) + ")" - ]) - - -def run_el_toy_example(nlp): - text = ( - "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " - "Douglas reminds us to always bring our towel, even in China or Brazil. " - "The main character in Doug's novel is the man Arthur Dent, " - "but Dougledydoug doesn't write about George Washington or Homer Simpson." - ) - doc = nlp(text) - logger.info(text) - for ent in doc.ents: - logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_])) - - if __name__ == "__main__": logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) plac.call(main) diff --git a/bin/wiki_entity_linking/wikipedia_processor.py b/bin/wiki_entity_linking/wikipedia_processor.py index 19df0cf10..315b1e916 100644 --- a/bin/wiki_entity_linking/wikipedia_processor.py +++ b/bin/wiki_entity_linking/wikipedia_processor.py @@ -6,9 +6,6 @@ import bz2 import logging import random import json -from tqdm import tqdm - -from functools import partial from spacy.gold import GoldParse from bin.wiki_entity_linking import wiki_io as io @@ -454,25 +451,40 @@ def _write_training_entities(outputfile, article_id, clean_text, entities): outputfile.write(line) -def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None): - """ This method provides training examples that correspond to the entity annotations found by the nlp object. +def read_training_indices(entity_file_path): + """ This method creates two lists of indices into the training file: one with indices for the + training examples, and one for the dev examples.""" + train_indices = [] + dev_indices = [] + + with entity_file_path.open("r", encoding="utf8") as file: + for i, line in enumerate(file): + example = json.loads(line) + article_id = example["article_id"] + clean_text = example["clean_text"] + + if is_valid_article(clean_text): + if is_dev(article_id): + dev_indices.append(i) + else: + train_indices.append(i) + + return train_indices, dev_indices + + +def read_el_docs_golds(nlp, entity_file_path, dev, line_ids, kb, labels_discard=None): + """ This method provides training/dev examples that correspond to the entity annotations found by the nlp object. For training, it will include both positive and negative examples by using the candidate generator from the kb. For testing (kb=None), it will include all positive examples only.""" if not labels_discard: labels_discard = [] - data = [] - num_entities = 0 - get_gold_parse = partial( - _get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard - ) + texts = [] + entities_list = [] - logger.info( - "Reading {} data with limit {}".format("dev" if dev else "train", limit) - ) with entity_file_path.open("r", encoding="utf8") as file: - with tqdm(total=limit, leave=False) as pbar: - for i, line in enumerate(file): + for i, line in enumerate(file): + if i in line_ids: example = json.loads(line) article_id = example["article_id"] clean_text = example["clean_text"] @@ -481,16 +493,15 @@ def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None): if dev != is_dev(article_id) or not is_valid_article(clean_text): continue - doc = nlp(clean_text) - gold = get_gold_parse(doc, entities) - if gold and len(gold.links) > 0: - data.append((doc, gold)) - num_entities += len(gold.links) - pbar.update(len(gold.links)) - if limit and num_entities >= limit: - break - logger.info("Read {} entities in {} articles".format(num_entities, len(data))) - return data + texts.append(clean_text) + entities_list.append(entities) + + docs = nlp.pipe(texts, batch_size=50) + + for doc, entities in zip(docs, entities_list): + gold = _get_gold_parse(doc, entities, dev=dev, kb=kb, labels_discard=labels_discard) + if gold and len(gold.links) > 0: + yield doc, gold def _get_gold_parse(doc, entities, dev, kb, labels_discard): diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index f57ea59d2..b51520777 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1308,7 +1308,7 @@ class EntityLinker(Pipe): for i, doc in enumerate(docs): if len(doc) > 0: # Looping through each sentence and each entity - # This may go wrong if there are entities across sentences - because they might not get a KB ID + # This may go wrong if there are entities across sentences - which shouldn't happen normally. for sent in doc.sents: sent_doc = sent.as_doc() # currently, the context is the same for each entity in a sentence (should be refined)