diff --git a/bin/wiki_entity_linking/README.md b/bin/wiki_entity_linking/README.md deleted file mode 100644 index 4e4af5c21..000000000 --- a/bin/wiki_entity_linking/README.md +++ /dev/null @@ -1,37 +0,0 @@ -## Entity Linking with Wikipedia and Wikidata - -### Step 1: Create a Knowledge Base (KB) and training data - -Run `wikidata_pretrain_kb.py` -* This takes as input the locations of a **Wikipedia and a Wikidata dump**, and produces a **KB directory** + **training file** - * WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/ - * Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language) -* You can set the filtering parameters for KB construction: - * `max_per_alias` (`-a`): (max) number of candidate entities in the KB per alias/synonym - * `min_freq` (`-f`): threshold of number of times an entity should occur in the corpus to be included in the KB - * `min_pair` (`-c`): threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB -* Further parameters to set: - * `descriptions_from_wikipedia` (`-wp`): whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`) - * `entity_vector_length` (`-v`): length of the pre-trained entity description vectors - * `lang` (`-la`): language for which to fetch Wikidata information (as the dump contains all languages) - -Quick testing and rerunning: -* When trying out the pipeline for a quick test, set `limit_prior` (`-lp`), `limit_train` (`-lt`) and/or `limit_wd` (`-lw`) to read only parts of the dumps instead of everything. - * e.g. set `-lt 20000 -lp 2000 -lw 3000 -f 1` -* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed. - - -### Step 2: Train an Entity Linking model - -Run `wikidata_train_entity_linker.py` -* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model** -* Specify the output directory (`-o`) in which the final, trained model will be saved -* You can set the learning parameters for the EL training: - * `epochs` (`-e`): number of training iterations - * `dropout` (`-p`): dropout rate - * `lr` (`-n`): learning rate - * `l2` (`-r`): L2 regularization -* Specify the number of training and dev testing articles with `train_articles` (`-t`) and `dev_articles` (`-d`) respectively - * If not specified, the full dataset will be processed - this may take a LONG time ! -* Further parameters to set: - * `labels_discard` (`-l`): NER label types to discard during training diff --git a/bin/wiki_entity_linking/__init__.py b/bin/wiki_entity_linking/__init__.py deleted file mode 100644 index de486bbcf..000000000 --- a/bin/wiki_entity_linking/__init__.py +++ /dev/null @@ -1,12 +0,0 @@ -TRAINING_DATA_FILE = "gold_entities.jsonl" -KB_FILE = "kb" -KB_MODEL_DIR = "nlp_kb" -OUTPUT_MODEL_DIR = "nlp" - -PRIOR_PROB_PATH = "prior_prob.csv" -ENTITY_DEFS_PATH = "entity_defs.csv" -ENTITY_FREQ_PATH = "entity_freq.csv" -ENTITY_ALIAS_PATH = "entity_alias.csv" -ENTITY_DESCR_PATH = "entity_descriptions.csv" - -LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' diff --git a/bin/wiki_entity_linking/entity_linker_evaluation.py b/bin/wiki_entity_linking/entity_linker_evaluation.py deleted file mode 100644 index 2aeffbfc2..000000000 --- a/bin/wiki_entity_linking/entity_linker_evaluation.py +++ /dev/null @@ -1,204 +0,0 @@ -# coding: utf-8 -from __future__ import unicode_literals - -import logging -import random -from tqdm import tqdm -from collections import defaultdict - -logger = logging.getLogger(__name__) - - -class Metrics(object): - true_pos = 0 - false_pos = 0 - false_neg = 0 - - def update_results(self, true_entity, candidate): - candidate_is_correct = true_entity == candidate - - # Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL") - # Therefore, if candidate_is_correct then we have a true positive and never a true negative. - self.true_pos += candidate_is_correct - self.false_neg += not candidate_is_correct - if candidate and candidate not in {"", "NIL"}: - # A wrong prediction (e.g. Q42 != Q3) counts both as a FP as well as a FN. - self.false_pos += not candidate_is_correct - - def calculate_precision(self): - if self.true_pos == 0: - return 0.0 - else: - return self.true_pos / (self.true_pos + self.false_pos) - - def calculate_recall(self): - if self.true_pos == 0: - return 0.0 - else: - return self.true_pos / (self.true_pos + self.false_neg) - - def calculate_fscore(self): - p = self.calculate_precision() - r = self.calculate_recall() - if p + r == 0: - return 0.0 - else: - return 2 * p * r / (p + r) - - -class EvaluationResults(object): - def __init__(self): - self.metrics = Metrics() - self.metrics_by_label = defaultdict(Metrics) - - def update_metrics(self, ent_label, true_entity, candidate): - self.metrics.update_results(true_entity, candidate) - self.metrics_by_label[ent_label].update_results(true_entity, candidate) - - def report_metrics(self, model_name): - model_str = model_name.title() - recall = self.metrics.calculate_recall() - precision = self.metrics.calculate_precision() - fscore = self.metrics.calculate_fscore() - return ( - "{}: ".format(model_str) - + "F-score = {} | ".format(round(fscore, 3)) - + "Recall = {} | ".format(round(recall, 3)) - + "Precision = {} | ".format(round(precision, 3)) - + "F-score by label = {}".format( - {k: v.calculate_fscore() for k, v in sorted(self.metrics_by_label.items())} - ) - ) - - -class BaselineResults(object): - def __init__(self): - self.random = EvaluationResults() - self.prior = EvaluationResults() - self.oracle = EvaluationResults() - - def report_performance(self, model): - results = getattr(self, model) - return results.report_metrics(model) - - def update_baselines( - self, - true_entity, - ent_label, - random_candidate, - prior_candidate, - oracle_candidate, - ): - self.oracle.update_metrics(ent_label, true_entity, oracle_candidate) - self.prior.update_metrics(ent_label, true_entity, prior_candidate) - self.random.update_metrics(ent_label, true_entity, random_candidate) - - -def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True, dev_limit=None): - counts = dict() - baseline_results = BaselineResults() - context_results = EvaluationResults() - combo_results = EvaluationResults() - - for doc, gold in tqdm(dev_data, total=dev_limit, leave=False, desc='Processing dev data'): - if len(doc) > 0: - correct_ents = dict() - for entity, kb_dict in gold.links.items(): - start, end = entity - for gold_kb, value in kb_dict.items(): - if value: - # only evaluating on positive examples - offset = _offset(start, end) - correct_ents[offset] = gold_kb - - if baseline: - _add_baseline(baseline_results, counts, doc, correct_ents, kb) - - if context: - # using only context - el_pipe.cfg["incl_context"] = True - el_pipe.cfg["incl_prior"] = False - _add_eval_result(context_results, doc, correct_ents, el_pipe) - - # measuring combined accuracy (prior + context) - el_pipe.cfg["incl_context"] = True - el_pipe.cfg["incl_prior"] = True - _add_eval_result(combo_results, doc, correct_ents, el_pipe) - - if baseline: - logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())})) - logger.info(baseline_results.report_performance("random")) - logger.info(baseline_results.report_performance("prior")) - logger.info(baseline_results.report_performance("oracle")) - - if context: - logger.info(context_results.report_metrics("context only")) - logger.info(combo_results.report_metrics("context and prior")) - - -def _add_eval_result(results, doc, correct_ents, el_pipe): - """ - Evaluate the ent.kb_id_ annotations against the gold standard. - Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL. - """ - try: - doc = el_pipe(doc) - for ent in doc.ents: - ent_label = ent.label_ - start = ent.start_char - end = ent.end_char - offset = _offset(start, end) - gold_entity = correct_ents.get(offset, None) - # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' - if gold_entity is not None: - pred_entity = ent.kb_id_ - results.update_metrics(ent_label, gold_entity, pred_entity) - - except Exception as e: - logging.error("Error assessing accuracy " + str(e)) - - -def _add_baseline(baseline_results, counts, doc, correct_ents, kb): - """ - Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound. - Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL. - """ - for ent in doc.ents: - ent_label = ent.label_ - start = ent.start_char - end = ent.end_char - offset = _offset(start, end) - gold_entity = correct_ents.get(offset, None) - - # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' - if gold_entity is not None: - candidates = kb.get_candidates(ent.text) - oracle_candidate = "" - prior_candidate = "" - random_candidate = "" - if candidates: - scores = [] - - for c in candidates: - scores.append(c.prior_prob) - if c.entity_ == gold_entity: - oracle_candidate = c.entity_ - - best_index = scores.index(max(scores)) - prior_candidate = candidates[best_index].entity_ - random_candidate = random.choice(candidates).entity_ - - current_count = counts.get(ent_label, 0) - counts[ent_label] = current_count+1 - - baseline_results.update_baselines( - gold_entity, - ent_label, - random_candidate, - prior_candidate, - oracle_candidate, - ) - - -def _offset(start, end): - return "{}_{}".format(start, end) diff --git a/bin/wiki_entity_linking/kb_creator.py b/bin/wiki_entity_linking/kb_creator.py deleted file mode 100644 index 8691308e0..000000000 --- a/bin/wiki_entity_linking/kb_creator.py +++ /dev/null @@ -1,161 +0,0 @@ -# coding: utf-8 -from __future__ import unicode_literals - -import logging - -from spacy.kb import KnowledgeBase - -from bin.wiki_entity_linking.train_descriptions import EntityEncoder -from bin.wiki_entity_linking import wiki_io as io - - -logger = logging.getLogger(__name__) - - -def create_kb( - nlp, - max_entities_per_alias, - min_entity_freq, - min_occ, - entity_def_path, - entity_descr_path, - entity_alias_path, - entity_freq_path, - prior_prob_path, - entity_vector_length, -): - # Create the knowledge base from Wikidata entries - kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length) - entity_list, filtered_title_to_id = _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length) - _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path) - return kb - - -def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length): - # read the mappings from file - title_to_id = io.read_title_to_id(entity_def_path) - id_to_descr = io.read_id_to_descr(entity_descr_path) - - # check the length of the nlp vectors - if "vectors" in nlp.meta and nlp.vocab.vectors.size: - input_dim = nlp.vocab.vectors_length - logger.info("Loaded pretrained vectors of size %s" % input_dim) - else: - raise ValueError( - "The `nlp` object should have access to pretrained word vectors, " - " cf. https://spacy.io/usage/models#languages." - ) - - logger.info("Filtering entities with fewer than {} mentions or no description".format(min_entity_freq)) - entity_frequencies = io.read_entity_to_count(entity_freq_path) - # filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise - filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities( - title_to_id, - id_to_descr, - entity_frequencies, - min_entity_freq - ) - logger.info("Kept {} entities from the set of {}".format(len(description_list), len(title_to_id.keys()))) - - logger.info("Training entity encoder") - encoder = EntityEncoder(nlp, input_dim, entity_vector_length) - encoder.train(description_list=description_list, to_print=True) - - logger.info("Getting entity embeddings") - embeddings = encoder.apply_encoder(description_list) - - logger.info("Adding {} entities".format(len(entity_list))) - kb.set_entities( - entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings - ) - return entity_list, filtered_title_to_id - - -def _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path): - logger.info("Adding aliases from Wikipedia and Wikidata") - _add_aliases( - kb, - entity_list=entity_list, - title_to_id=filtered_title_to_id, - max_entities_per_alias=max_entities_per_alias, - min_occ=min_occ, - prior_prob_path=prior_prob_path, - ) - - -def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies, - min_entity_freq: int = 10): - filtered_title_to_id = dict() - entity_list = [] - description_list = [] - frequency_list = [] - for title, entity in title_to_id.items(): - freq = entity_frequencies.get(title, 0) - desc = id_to_descr.get(entity, None) - if desc and freq > min_entity_freq: - entity_list.append(entity) - description_list.append(desc) - frequency_list.append(freq) - filtered_title_to_id[title] = entity - return filtered_title_to_id, entity_list, description_list, frequency_list - - -def _add_aliases(kb, entity_list, title_to_id, max_entities_per_alias, min_occ, prior_prob_path): - wp_titles = title_to_id.keys() - - # adding aliases with prior probabilities - # we can read this file sequentially, it's sorted by alias, and then by count - logger.info("Adding WP aliases") - with prior_prob_path.open("r", encoding="utf8") as prior_file: - # skip header - prior_file.readline() - line = prior_file.readline() - previous_alias = None - total_count = 0 - counts = [] - entities = [] - while line: - splits = line.replace("\n", "").split(sep="|") - new_alias = splits[0] - count = int(splits[1]) - entity = splits[2] - - if new_alias != previous_alias and previous_alias: - # done reading the previous alias --> output - if len(entities) > 0: - selected_entities = [] - prior_probs = [] - for ent_count, ent_string in zip(counts, entities): - if ent_string in wp_titles: - wd_id = title_to_id[ent_string] - p_entity_givenalias = ent_count / total_count - selected_entities.append(wd_id) - prior_probs.append(p_entity_givenalias) - - if selected_entities: - try: - kb.add_alias( - alias=previous_alias, - entities=selected_entities, - probabilities=prior_probs, - ) - except ValueError as e: - logger.error(e) - total_count = 0 - counts = [] - entities = [] - - total_count += count - - if len(entities) < max_entities_per_alias and count >= min_occ: - counts.append(count) - entities.append(entity) - previous_alias = new_alias - - line = prior_file.readline() - - -def read_kb(nlp, kb_file): - kb = KnowledgeBase(vocab=nlp.vocab) - kb.load_bulk(kb_file) - return kb diff --git a/bin/wiki_entity_linking/train_descriptions.py b/bin/wiki_entity_linking/train_descriptions.py deleted file mode 100644 index b0cfbb4c6..000000000 --- a/bin/wiki_entity_linking/train_descriptions.py +++ /dev/null @@ -1,145 +0,0 @@ -from random import shuffle - -import logging -import numpy as np - -from thinc.api import Model, chain, CosineDistance, Linear - -from spacy.util import create_default_optimizer - -logger = logging.getLogger(__name__) - - -class EntityEncoder: - """ - Train the embeddings of entity descriptions to fit a fixed-size entity vector (e.g. 64D). - This entity vector will be stored in the KB, for further downstream use in the entity model. - """ - - DROP = 0 - BATCH_SIZE = 1000 - - # Set min. acceptable loss to avoid a 'mean of empty slice' warning by numpy - MIN_LOSS = 0.01 - - # Reasonable default to stop training when things are not improving - MAX_NO_IMPROVEMENT = 20 - - def __init__(self, nlp, input_dim, desc_width, epochs=5): - self.nlp = nlp - self.input_dim = input_dim - self.desc_width = desc_width - self.epochs = epochs - self.distance = CosineDistance(ignore_zeros=True, normalize=False) - - def apply_encoder(self, description_list): - if self.encoder is None: - raise ValueError("Can not apply encoder before training it") - - batch_size = 100000 - - start = 0 - stop = min(batch_size, len(description_list)) - encodings = [] - - while start < len(description_list): - docs = list(self.nlp.pipe(description_list[start:stop])) - doc_embeddings = [self._get_doc_embedding(doc) for doc in docs] - enc = self.encoder(np.asarray(doc_embeddings)) - encodings.extend(enc.tolist()) - - start = start + batch_size - stop = min(stop + batch_size, len(description_list)) - logger.info("Encoded: {} entities".format(stop)) - - return encodings - - def train(self, description_list, to_print=False): - processed, loss = self._train_model(description_list) - if to_print: - logger.info( - "Trained entity descriptions on {} ".format(processed) + - "(non-unique) descriptions across {} ".format(self.epochs) + - "epochs" - ) - logger.info("Final loss: {}".format(loss)) - - def _train_model(self, description_list): - best_loss = 1.0 - iter_since_best = 0 - self._build_network(self.input_dim, self.desc_width) - - processed = 0 - loss = 1 - # copy this list so that shuffling does not affect other functions - descriptions = description_list.copy() - to_continue = True - - for i in range(self.epochs): - shuffle(descriptions) - - batch_nr = 0 - start = 0 - stop = min(self.BATCH_SIZE, len(descriptions)) - - while to_continue and start < len(descriptions): - batch = [] - for descr in descriptions[start:stop]: - doc = self.nlp(descr) - doc_vector = self._get_doc_embedding(doc) - batch.append(doc_vector) - - loss = self._update(batch) - if batch_nr % 25 == 0: - logger.info("loss: {} ".format(loss)) - processed += len(batch) - - # in general, continue training if we haven't reached our ideal min yet - to_continue = loss > self.MIN_LOSS - - # store the best loss and track how long it's been - if loss < best_loss: - best_loss = loss - iter_since_best = 0 - else: - iter_since_best += 1 - - # stop learning if we haven't seen improvement since the last few iterations - if iter_since_best > self.MAX_NO_IMPROVEMENT: - to_continue = False - - batch_nr += 1 - start = start + self.BATCH_SIZE - stop = min(stop + self.BATCH_SIZE, len(descriptions)) - - return processed, loss - - @staticmethod - def _get_doc_embedding(doc): - indices = np.zeros((len(doc),), dtype="i") - for i, word in enumerate(doc): - if word.orth in doc.vocab.vectors.key2row: - indices[i] = doc.vocab.vectors.key2row[word.orth] - else: - indices[i] = 0 - word_vectors = doc.vocab.vectors.data[indices] - doc_vector = np.mean(word_vectors, axis=0) - return doc_vector - - def _build_network(self, orig_width, hidden_with): - with Model.define_operators({">>": chain}): - # very simple encoder-decoder model - self.encoder = Linear(hidden_with, orig_width) - # TODO: removed the zero_init here - is oK? - self.model = self.encoder >> Linear(orig_width, hidden_with) - self.sgd = create_default_optimizer() - - def _update(self, vectors): - truths = self.model.ops.asarray(vectors) - predictions, bp_model = self.model.begin_update( - truths, drop=self.DROP - ) - d_scores, loss = self.distance(predictions, truths) - bp_model(d_scores, sgd=self.sgd) - return loss / len(vectors) - diff --git a/bin/wiki_entity_linking/wiki_io.py b/bin/wiki_entity_linking/wiki_io.py deleted file mode 100644 index 43ae87f0f..000000000 --- a/bin/wiki_entity_linking/wiki_io.py +++ /dev/null @@ -1,127 +0,0 @@ -# coding: utf-8 -from __future__ import unicode_literals - -import sys -import csv - -# min() needed to prevent error on windows, cf https://stackoverflow.com/questions/52404416/ -csv.field_size_limit(min(sys.maxsize, 2147483646)) - -""" This class provides reading/writing methods for temp files """ - - -# Entity definition: WP title -> WD ID # -def write_title_to_id(entity_def_output, title_to_id): - with entity_def_output.open("w", encoding="utf8") as id_file: - id_file.write("WP_title" + "|" + "WD_id" + "\n") - for title, qid in title_to_id.items(): - id_file.write(title + "|" + str(qid) + "\n") - - -def read_title_to_id(entity_def_output): - title_to_id = dict() - with entity_def_output.open("r", encoding="utf8") as id_file: - csvreader = csv.reader(id_file, delimiter="|") - # skip header - next(csvreader) - for row in csvreader: - title_to_id[row[0]] = row[1] - return title_to_id - - -# Entity aliases from WD: WD ID -> WD alias # -def write_id_to_alias(entity_alias_path, id_to_alias): - with entity_alias_path.open("w", encoding="utf8") as alias_file: - alias_file.write("WD_id" + "|" + "alias" + "\n") - for qid, alias_list in id_to_alias.items(): - for alias in alias_list: - alias_file.write(str(qid) + "|" + alias + "\n") - - -def read_id_to_alias(entity_alias_path): - id_to_alias = dict() - with entity_alias_path.open("r", encoding="utf8") as alias_file: - csvreader = csv.reader(alias_file, delimiter="|") - # skip header - next(csvreader) - for row in csvreader: - qid = row[0] - alias = row[1] - alias_list = id_to_alias.get(qid, []) - alias_list.append(alias) - id_to_alias[qid] = alias_list - return id_to_alias - - -def read_alias_to_id_generator(entity_alias_path): - """ Read (aliases, qid) tuples """ - - with entity_alias_path.open("r", encoding="utf8") as alias_file: - csvreader = csv.reader(alias_file, delimiter="|") - # skip header - next(csvreader) - for row in csvreader: - qid = row[0] - alias = row[1] - yield alias, qid - - -# Entity descriptions from WD: WD ID -> WD alias # -def write_id_to_descr(entity_descr_output, id_to_descr): - with entity_descr_output.open("w", encoding="utf8") as descr_file: - descr_file.write("WD_id" + "|" + "description" + "\n") - for qid, descr in id_to_descr.items(): - descr_file.write(str(qid) + "|" + descr + "\n") - - -def read_id_to_descr(entity_desc_path): - id_to_desc = dict() - with entity_desc_path.open("r", encoding="utf8") as descr_file: - csvreader = csv.reader(descr_file, delimiter="|") - # skip header - next(csvreader) - for row in csvreader: - id_to_desc[row[0]] = row[1] - return id_to_desc - - -# Entity counts from WP: WP title -> count # -def write_entity_to_count(prior_prob_input, count_output): - # Write entity counts for quick access later - entity_to_count = dict() - total_count = 0 - - with prior_prob_input.open("r", encoding="utf8") as prior_file: - # skip header - prior_file.readline() - line = prior_file.readline() - - while line: - splits = line.replace("\n", "").split(sep="|") - # alias = splits[0] - count = int(splits[1]) - entity = splits[2] - - current_count = entity_to_count.get(entity, 0) - entity_to_count[entity] = current_count + count - - total_count += count - - line = prior_file.readline() - - with count_output.open("w", encoding="utf8") as entity_file: - entity_file.write("entity" + "|" + "count" + "\n") - for entity, count in entity_to_count.items(): - entity_file.write(entity + "|" + str(count) + "\n") - - -def read_entity_to_count(count_input): - entity_to_count = dict() - with count_input.open("r", encoding="utf8") as csvfile: - csvreader = csv.reader(csvfile, delimiter="|") - # skip header - next(csvreader) - for row in csvreader: - entity_to_count[row[0]] = int(row[1]) - - return entity_to_count diff --git a/bin/wiki_entity_linking/wiki_namespaces.py b/bin/wiki_entity_linking/wiki_namespaces.py deleted file mode 100644 index e8f099ccd..000000000 --- a/bin/wiki_entity_linking/wiki_namespaces.py +++ /dev/null @@ -1,128 +0,0 @@ -# coding: utf8 -from __future__ import unicode_literals - -# List of meta pages in Wikidata, should be kept out of the Knowledge base -WD_META_ITEMS = [ - "Q163875", - "Q191780", - "Q224414", - "Q4167836", - "Q4167410", - "Q4663903", - "Q11266439", - "Q13406463", - "Q15407973", - "Q18616576", - "Q19887878", - "Q22808320", - "Q23894233", - "Q33120876", - "Q42104522", - "Q47460393", - "Q64875536", - "Q66480449", -] - - -# TODO: add more cases from non-English WP's - -# List of prefixes that refer to Wikipedia "file" pages -WP_FILE_NAMESPACE = ["Bestand", "File"] - -# List of prefixes that refer to Wikipedia "category" pages -WP_CATEGORY_NAMESPACE = ["Kategori", "Category", "Categorie"] - -# List of prefixes that refer to Wikipedia "meta" pages -# these will/should be matched ignoring case -WP_META_NAMESPACE = ( - WP_FILE_NAMESPACE - + WP_CATEGORY_NAMESPACE - + [ - "b", - "betawikiversity", - "Book", - "c", - "Commons", - "d", - "dbdump", - "download", - "Draft", - "Education", - "Foundation", - "Gadget", - "Gadget definition", - "Gebruiker", - "gerrit", - "Help", - "Image", - "Incubator", - "m", - "mail", - "mailarchive", - "media", - "MediaWiki", - "MediaWiki talk", - "Mediawikiwiki", - "MediaZilla", - "Meta", - "Metawikipedia", - "Module", - "mw", - "n", - "nost", - "oldwikisource", - "otrs", - "OTRSwiki", - "Overleg gebruiker", - "outreach", - "outreachwiki", - "Portal", - "phab", - "Phabricator", - "Project", - "q", - "quality", - "rev", - "s", - "spcom", - "Special", - "species", - "Strategy", - "sulutil", - "svn", - "Talk", - "Template", - "Template talk", - "Testwiki", - "ticket", - "TimedText", - "Toollabs", - "tools", - "tswiki", - "User", - "User talk", - "v", - "voy", - "w", - "Wikibooks", - "Wikidata", - "wikiHow", - "Wikinvest", - "wikilivres", - "Wikimedia", - "Wikinews", - "Wikipedia", - "Wikipedia talk", - "Wikiquote", - "Wikisource", - "Wikispecies", - "Wikitech", - "Wikiversity", - "Wikivoyage", - "wikt", - "wiktionary", - "wmf", - "wmania", - "WP", - ] -) diff --git a/bin/wiki_entity_linking/wikidata_pretrain_kb.py b/bin/wiki_entity_linking/wikidata_pretrain_kb.py deleted file mode 100644 index 003074feb..000000000 --- a/bin/wiki_entity_linking/wikidata_pretrain_kb.py +++ /dev/null @@ -1,179 +0,0 @@ -# coding: utf-8 -"""Script to process Wikipedia and Wikidata dumps and create a knowledge base (KB) -with specific parameters. Intermediate files are written to disk. - -Running the full pipeline on a standard laptop, may take up to 13 hours of processing. -Use the -p, -d and -s options to speed up processing using the intermediate files -from a previous run. - -For the Wikidata dump: get the latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ -For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2 -from https://dumps.wikimedia.org/enwiki/latest/ - -""" -from __future__ import unicode_literals - -import logging -from pathlib import Path -import plac - -from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd -from bin.wiki_entity_linking import wiki_io as io -from bin.wiki_entity_linking import kb_creator -from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT -from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH, ENTITY_ALIAS_PATH -import spacy -from bin.wiki_entity_linking.kb_creator import read_kb - -logger = logging.getLogger(__name__) - - -@plac.annotations( - wd_json=("Path to the downloaded WikiData JSON dump.", "positional", None, Path), - wp_xml=("Path to the downloaded Wikipedia XML dump.", "positional", None, Path), - output_dir=("Output directory", "positional", None, Path), - model=("Model name or path, should include pretrained vectors.", "positional", None, str), - max_per_alias=("Max. # entities per alias (default 10)", "option", "a", int), - min_freq=("Min. count of an entity in the corpus (default 20)", "option", "f", int), - min_pair=("Min. count of entity-alias pairs (default 5)", "option", "c", int), - entity_vector_length=("Length of entity vectors (default 64)", "option", "v", int), - loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path), - loc_entity_defs=("Location to file with entity definitions", "option", "d", Path), - loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path), - descr_from_wp=("Flag for using descriptions from WP instead of WD (default False)", "flag", "wp"), - limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int), - limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int), - limit_wd=("Threshold to limit lines read from WD", "option", "lw", int), - lang=("Optional language for which to get Wikidata titles. Defaults to 'en'", "option", "la", str), -) -def main( - wd_json, - wp_xml, - output_dir, - model, - max_per_alias=10, - min_freq=20, - min_pair=5, - entity_vector_length=64, - loc_prior_prob=None, - loc_entity_defs=None, - loc_entity_alias=None, - loc_entity_desc=None, - descr_from_wp=False, - limit_prior=None, - limit_train=None, - limit_wd=None, - lang="en", -): - entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH - entity_alias_path = loc_entity_alias if loc_entity_alias else output_dir / ENTITY_ALIAS_PATH - entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH - entity_freq_path = output_dir / ENTITY_FREQ_PATH - prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH - training_entities_path = output_dir / TRAINING_DATA_FILE - kb_path = output_dir / KB_FILE - - logger.info("Creating KB with Wikipedia and WikiData") - - # STEP 0: set up IO - if not output_dir.exists(): - output_dir.mkdir(parents=True) - - # STEP 1: Load the NLP object - logger.info("STEP 1: Loading NLP model {}".format(model)) - nlp = spacy.load(model) - - # check the length of the nlp vectors - if "vectors" not in nlp.meta or not nlp.vocab.vectors.size: - raise ValueError( - "The `nlp` object should have access to pretrained word vectors, " - " cf. https://spacy.io/usage/models#languages." - ) - - # STEP 2: create prior probabilities from WP - if not prior_prob_path.exists(): - # It takes about 2h to process 1000M lines of Wikipedia XML dump - logger.info("STEP 2: Writing prior probabilities to {}".format(prior_prob_path)) - if limit_prior is not None: - logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_prior)) - wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit_prior) - else: - logger.info("STEP 2: Reading prior probabilities from {}".format(prior_prob_path)) - - # STEP 3: calculate entity frequencies - if not entity_freq_path.exists(): - logger.info("STEP 3: Calculating and writing entity frequencies to {}".format(entity_freq_path)) - io.write_entity_to_count(prior_prob_path, entity_freq_path) - else: - logger.info("STEP 3: Reading entity frequencies from {}".format(entity_freq_path)) - - # STEP 4: reading definitions and (possibly) descriptions from WikiData or from file - if (not entity_defs_path.exists()) or (not descr_from_wp and not entity_descr_path.exists()): - # It takes about 10h to process 55M lines of Wikidata JSON dump - logger.info("STEP 4: Parsing and writing Wikidata entity definitions to {}".format(entity_defs_path)) - if limit_wd is not None: - logger.warning("Warning: reading only {} lines of Wikidata dump".format(limit_wd)) - title_to_id, id_to_descr, id_to_alias = wd.read_wikidata_entities_json( - wd_json, - limit_wd, - to_print=False, - lang=lang, - parse_descr=(not descr_from_wp), - ) - io.write_title_to_id(entity_defs_path, title_to_id) - - logger.info("STEP 4b: Writing Wikidata entity aliases to {}".format(entity_alias_path)) - io.write_id_to_alias(entity_alias_path, id_to_alias) - - if not descr_from_wp: - logger.info("STEP 4c: Writing Wikidata entity descriptions to {}".format(entity_descr_path)) - io.write_id_to_descr(entity_descr_path, id_to_descr) - else: - logger.info("STEP 4: Reading entity definitions from {}".format(entity_defs_path)) - logger.info("STEP 4b: Reading entity aliases from {}".format(entity_alias_path)) - if not descr_from_wp: - logger.info("STEP 4c: Reading entity descriptions from {}".format(entity_descr_path)) - - # STEP 5: Getting gold entities from Wikipedia - if (not training_entities_path.exists()) or (descr_from_wp and not entity_descr_path.exists()): - logger.info("STEP 5: Parsing and writing Wikipedia gold entities to {}".format(training_entities_path)) - if limit_train is not None: - logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_train)) - wp.create_training_and_desc(wp_xml, entity_defs_path, entity_descr_path, - training_entities_path, descr_from_wp, limit_train) - if descr_from_wp: - logger.info("STEP 5b: Parsing and writing Wikipedia descriptions to {}".format(entity_descr_path)) - else: - logger.info("STEP 5: Reading gold entities from {}".format(training_entities_path)) - if descr_from_wp: - logger.info("STEP 5b: Reading entity descriptions from {}".format(entity_descr_path)) - - # STEP 6: creating the actual KB - # It takes ca. 30 minutes to pretrain the entity embeddings - if not kb_path.exists(): - logger.info("STEP 6: Creating the KB at {}".format(kb_path)) - kb = kb_creator.create_kb( - nlp=nlp, - max_entities_per_alias=max_per_alias, - min_entity_freq=min_freq, - min_occ=min_pair, - entity_def_path=entity_defs_path, - entity_descr_path=entity_descr_path, - entity_alias_path=entity_alias_path, - entity_freq_path=entity_freq_path, - prior_prob_path=prior_prob_path, - entity_vector_length=entity_vector_length, - ) - kb.dump(kb_path) - logger.info("kb entities: {}".format(kb.get_size_entities())) - logger.info("kb aliases: {}".format(kb.get_size_aliases())) - nlp.to_disk(output_dir / KB_MODEL_DIR) - else: - logger.info("STEP 6: KB already exists at {}".format(kb_path)) - - logger.info("Done!") - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) - plac.call(main) diff --git a/bin/wiki_entity_linking/wikidata_processor.py b/bin/wiki_entity_linking/wikidata_processor.py deleted file mode 100644 index 8a070f567..000000000 --- a/bin/wiki_entity_linking/wikidata_processor.py +++ /dev/null @@ -1,154 +0,0 @@ -# coding: utf-8 -from __future__ import unicode_literals - -import bz2 -import json -import logging - -from bin.wiki_entity_linking.wiki_namespaces import WD_META_ITEMS - -logger = logging.getLogger(__name__) - - -def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descr=True): - # Read the JSON wiki data and parse out the entities. Takes about 7-10h to parse 55M lines. - # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ - - site_filter = '{}wiki'.format(lang) - - # filter: currently defined as OR: one hit suffices to be removed from further processing - exclude_list = WD_META_ITEMS - - # punctuation - exclude_list.extend(["Q1383557", "Q10617810"]) - - # letters etc - exclude_list.extend(["Q188725", "Q19776628", "Q3841820", "Q17907810", "Q9788", "Q9398093"]) - - neg_prop_filter = { - 'P31': exclude_list, # instance of - 'P279': exclude_list # subclass - } - - title_to_id = dict() - id_to_descr = dict() - id_to_alias = dict() - - # parse appropriate fields - depending on what we need in the KB - parse_properties = False - parse_sitelinks = True - parse_labels = False - parse_aliases = True - parse_claims = True - - with bz2.open(wikidata_file, mode='rb') as file: - for cnt, line in enumerate(file): - if limit and cnt >= limit: - break - if cnt % 500000 == 0 and cnt > 0: - logger.info("processed {} lines of WikiData JSON dump".format(cnt)) - clean_line = line.strip() - if clean_line.endswith(b","): - clean_line = clean_line[:-1] - if len(clean_line) > 1: - obj = json.loads(clean_line) - entry_type = obj["type"] - - if entry_type == "item": - keep = True - - claims = obj["claims"] - if parse_claims: - for prop, value_set in neg_prop_filter.items(): - claim_property = claims.get(prop, None) - if claim_property: - for cp in claim_property: - cp_id = ( - cp["mainsnak"] - .get("datavalue", {}) - .get("value", {}) - .get("id") - ) - cp_rank = cp["rank"] - if cp_rank != "deprecated" and cp_id in value_set: - keep = False - - if keep: - unique_id = obj["id"] - - if to_print: - print("ID:", unique_id) - print("type:", entry_type) - - # parsing all properties that refer to other entities - if parse_properties: - for prop, claim_property in claims.items(): - cp_dicts = [ - cp["mainsnak"]["datavalue"].get("value") - for cp in claim_property - if cp["mainsnak"].get("datavalue") - ] - cp_values = [ - cp_dict.get("id") - for cp_dict in cp_dicts - if isinstance(cp_dict, dict) - if cp_dict.get("id") is not None - ] - if cp_values: - if to_print: - print("prop:", prop, cp_values) - - found_link = False - if parse_sitelinks: - site_value = obj["sitelinks"].get(site_filter, None) - if site_value: - site = site_value["title"] - if to_print: - print(site_filter, ":", site) - title_to_id[site] = unique_id - found_link = True - - if parse_labels: - labels = obj["labels"] - if labels: - lang_label = labels.get(lang, None) - if lang_label: - if to_print: - print( - "label (" + lang + "):", lang_label["value"] - ) - - if found_link and parse_descr: - descriptions = obj["descriptions"] - if descriptions: - lang_descr = descriptions.get(lang, None) - if lang_descr: - if to_print: - print( - "description (" + lang + "):", - lang_descr["value"], - ) - id_to_descr[unique_id] = lang_descr["value"] - - if parse_aliases: - aliases = obj["aliases"] - if aliases: - lang_aliases = aliases.get(lang, None) - if lang_aliases: - for item in lang_aliases: - if to_print: - print( - "alias (" + lang + "):", item["value"] - ) - alias_list = id_to_alias.get(unique_id, []) - alias_list.append(item["value"]) - id_to_alias[unique_id] = alias_list - - if to_print: - print() - - # log final number of lines processed - logger.info("Finished. Processed {} lines of WikiData JSON dump".format(cnt)) - return title_to_id, id_to_descr, id_to_alias - - diff --git a/bin/wiki_entity_linking/wikidata_train_entity_linker.py b/bin/wiki_entity_linking/wikidata_train_entity_linker.py deleted file mode 100644 index af0e68768..000000000 --- a/bin/wiki_entity_linking/wikidata_train_entity_linker.py +++ /dev/null @@ -1,230 +0,0 @@ -# coding: utf-8 -"""Script that takes a previously created Knowledge Base and trains an entity linking -pipeline. The provided KB directory should hold the kb, the original nlp object and -its vocab used to create the KB, and a few auxiliary files such as the entity definitions, -as created by the script `wikidata_create_kb`. - -For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2 -from https://dumps.wikimedia.org/enwiki/latest/ -""" -from __future__ import unicode_literals - -import random -import logging -import spacy -from pathlib import Path -import plac -from tqdm import tqdm - -from bin.wiki_entity_linking import wikipedia_processor -from bin.wiki_entity_linking import ( - TRAINING_DATA_FILE, - KB_MODEL_DIR, - KB_FILE, - LOG_FORMAT, - OUTPUT_MODEL_DIR, -) -from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance -from bin.wiki_entity_linking.kb_creator import read_kb - -from spacy.util import minibatch, compounding - -logger = logging.getLogger(__name__) - - -@plac.annotations( - dir_kb=("Directory with KB, NLP and related files", "positional", None, Path), - output_dir=("Output directory", "option", "o", Path), - loc_training=("Location to training data", "option", "k", Path), - epochs=("Number of training iterations (default 10)", "option", "e", int), - dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float), - lr=("Learning rate (default 0.005)", "option", "n", float), - l2=("L2 regularization", "option", "r", float), - train_articles=("# training articles (default 90% of all)", "option", "t", int), - dev_articles=("# dev test articles (default 10% of all)", "option", "d", int), - labels_discard=("NER labels to discard (default None)", "option", "l", str), -) -def main( - dir_kb, - output_dir=None, - loc_training=None, - epochs=10, - dropout=0.5, - lr=0.005, - l2=1e-6, - train_articles=None, - dev_articles=None, - labels_discard=None, -): - if not output_dir: - logger.warning( - "No output dir specified so no results will be written, are you sure about this ?" - ) - - logger.info("Creating Entity Linker with Wikipedia and WikiData") - - output_dir = Path(output_dir) if output_dir else dir_kb - training_path = loc_training if loc_training else dir_kb / TRAINING_DATA_FILE - nlp_dir = dir_kb / KB_MODEL_DIR - kb_path = dir_kb / KB_FILE - nlp_output_dir = output_dir / OUTPUT_MODEL_DIR - - # STEP 0: set up IO - if not output_dir.exists(): - output_dir.mkdir() - - # STEP 1 : load the NLP object - logger.info("STEP 1a: Loading model from {}".format(nlp_dir)) - nlp = spacy.load(nlp_dir) - logger.info( - "Original NLP pipeline has following pipeline components: {}".format( - nlp.pipe_names - ) - ) - - # check that there is a NER component in the pipeline - if "ner" not in nlp.pipe_names: - raise ValueError("The `nlp` object should have a pretrained `ner` component.") - - logger.info("STEP 1b: Loading KB from {}".format(kb_path)) - kb = read_kb(nlp, kb_path) - - # STEP 2: read the training dataset previously created from WP - logger.info("STEP 2: Reading training & dev dataset from {}".format(training_path)) - train_indices, dev_indices = wikipedia_processor.read_training_indices( - training_path - ) - logger.info( - "Training set has {} articles, limit set to roughly {} articles per epoch".format( - len(train_indices), train_articles if train_articles else "all" - ) - ) - logger.info( - "Dev set has {} articles, limit set to rougly {} articles for evaluation".format( - len(dev_indices), dev_articles if dev_articles else "all" - ) - ) - if dev_articles: - dev_indices = dev_indices[0:dev_articles] - - # STEP 3: create and train an entity linking pipe - logger.info( - "STEP 3: Creating and training an Entity Linking pipe for {} epochs".format( - epochs - ) - ) - if labels_discard: - labels_discard = [x.strip() for x in labels_discard.split(",")] - logger.info( - "Discarding {} NER types: {}".format(len(labels_discard), labels_discard) - ) - else: - labels_discard = [] - - el_pipe = nlp.create_pipe( - name="entity_linker", - config={ - "pretrained_vectors": nlp.vocab.vectors, - "labels_discard": labels_discard, - }, - ) - el_pipe.set_kb(kb) - nlp.add_pipe(el_pipe, last=True) - - other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"] - with nlp.disable_pipes(*other_pipes): # only train Entity Linking - optimizer = nlp.begin_training() - optimizer.learn_rate = lr - optimizer.L2 = l2 - - logger.info("Dev Baseline Accuracies:") - dev_data = wikipedia_processor.read_el_docs_golds( - nlp=nlp, - entity_file_path=training_path, - dev=True, - line_ids=dev_indices, - kb=kb, - labels_discard=labels_discard, - ) - - measure_performance( - dev_data, kb, el_pipe, baseline=True, context=False, dev_limit=len(dev_indices) - ) - - for itn in range(epochs): - random.shuffle(train_indices) - losses = {} - batches = minibatch(train_indices, size=compounding(8.0, 128.0, 1.001)) - batchnr = 0 - articles_processed = 0 - - # we either process the whole training file, or just a part each epoch - bar_total = len(train_indices) - if train_articles: - bar_total = train_articles - - with tqdm(total=bar_total, leave=False, desc=f"Epoch {itn}") as pbar: - for batch in batches: - if not train_articles or articles_processed < train_articles: - with nlp.disable_pipes("entity_linker"): - train_batch = wikipedia_processor.read_el_docs_golds( - nlp=nlp, - entity_file_path=training_path, - dev=False, - line_ids=batch, - kb=kb, - labels_discard=labels_discard, - ) - try: - with nlp.disable_pipes(*other_pipes): - nlp.update( - examples=train_batch, - sgd=optimizer, - drop=dropout, - losses=losses, - ) - batchnr += 1 - articles_processed += len(docs) - pbar.update(len(docs)) - except Exception as e: - logger.error("Error updating batch:" + str(e)) - if batchnr > 0: - logging.info( - "Epoch {} trained on {} articles, train loss {}".format( - itn, articles_processed, round(losses["entity_linker"] / batchnr, 2) - ) - ) - # re-read the dev_data (data is returned as a generator) - dev_data = wikipedia_processor.read_el_docs_golds( - nlp=nlp, - entity_file_path=training_path, - dev=True, - line_ids=dev_indices, - kb=kb, - labels_discard=labels_discard, - ) - measure_performance( - dev_data, - kb, - el_pipe, - baseline=False, - context=True, - dev_limit=len(dev_indices), - ) - - if output_dir: - # STEP 4: write the NLP pipeline (now including an EL model) to file - logger.info( - "Final NLP pipeline has following pipeline components: {}".format( - nlp.pipe_names - ) - ) - logger.info("STEP 4: Writing trained NLP to {}".format(nlp_output_dir)) - nlp.to_disk(nlp_output_dir) - - logger.info("Done!") - - -if __name__ == "__main__": - logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) - plac.call(main) diff --git a/bin/wiki_entity_linking/wikipedia_processor.py b/bin/wiki_entity_linking/wikipedia_processor.py deleted file mode 100644 index 315b1e916..000000000 --- a/bin/wiki_entity_linking/wikipedia_processor.py +++ /dev/null @@ -1,565 +0,0 @@ -# coding: utf-8 -from __future__ import unicode_literals - -import re -import bz2 -import logging -import random -import json - -from spacy.gold import GoldParse -from bin.wiki_entity_linking import wiki_io as io -from bin.wiki_entity_linking.wiki_namespaces import ( - WP_META_NAMESPACE, - WP_FILE_NAMESPACE, - WP_CATEGORY_NAMESPACE, -) - -""" -Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions. -Write these results to file for downstream KB and training data generation. - -Process Wikipedia interlinks to generate a training dataset for the EL algorithm. -""" - -ENTITY_FILE = "gold_entities.csv" - -map_alias_to_link = dict() - -logger = logging.getLogger(__name__) - -title_regex = re.compile(r"(?<=).*(?=)") -id_regex = re.compile(r"(?<=)\d*(?=)") -text_regex = re.compile(r"(?<=).*(?= 0: - logger.info("processed {} lines of Wikipedia XML dump".format(cnt)) - clean_line = line.strip().decode("utf-8") - - # we attempt at reading the article's ID (but not the revision or contributor ID) - if "" in clean_line or "" in clean_line: - read_id = False - if "" in clean_line: - read_id = True - - if read_id: - ids = id_regex.search(clean_line) - if ids: - current_article_id = ids[0] - - # only processing prior probabilities from true training (non-dev) articles - if not is_dev(current_article_id): - aliases, entities, normalizations = get_wp_links(clean_line) - for alias, entity, norm in zip(aliases, entities, normalizations): - _store_alias( - alias, entity, normalize_alias=norm, normalize_entity=True - ) - - line = file.readline() - cnt += 1 - logger.info("processed {} lines of Wikipedia XML dump".format(cnt)) - logger.info("Finished. processed {} lines of Wikipedia XML dump".format(cnt)) - - # write all aliases and their entities and count occurrences to file - with prior_prob_output.open("w", encoding="utf8") as outputfile: - outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n") - for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]): - s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True) - for entity, count in s_dict: - outputfile.write(alias + "|" + str(count) + "|" + entity + "\n") - - -def _store_alias(alias, entity, normalize_alias=False, normalize_entity=True): - alias = alias.strip() - entity = entity.strip() - - # remove everything after # as this is not part of the title but refers to a specific paragraph - if normalize_entity: - # wikipedia titles are always capitalized - entity = _capitalize_first(entity.split("#")[0]) - if normalize_alias: - alias = alias.split("#")[0] - - if alias and entity: - alias_dict = map_alias_to_link.get(alias, dict()) - entity_count = alias_dict.get(entity, 0) - alias_dict[entity] = entity_count + 1 - map_alias_to_link[alias] = alias_dict - - -def get_wp_links(text): - aliases = [] - entities = [] - normalizations = [] - - matches = link_regex.findall(text) - for match in matches: - match = match[2:][:-2].replace("_", " ").strip() - - if ns_regex.match(match): - pass # ignore the entity if it points to a "meta" page - - # this is a simple [[link]], with the alias the same as the mention - elif "|" not in match: - aliases.append(match) - entities.append(match) - normalizations.append(True) - - # in wiki format, the link is written as [[entity|alias]] - else: - splits = match.split("|") - entity = splits[0].strip() - alias = splits[1].strip() - # specific wiki format [[alias (specification)|]] - if len(alias) == 0 and "(" in entity: - alias = entity.split("(")[0] - aliases.append(alias) - entities.append(entity) - normalizations.append(False) - else: - aliases.append(alias) - entities.append(entity) - normalizations.append(False) - - return aliases, entities, normalizations - - -def _capitalize_first(text): - if not text: - return None - result = text[0].capitalize() - if len(result) > 0: - result += text[1:] - return result - - -def create_training_and_desc( - wp_input, def_input, desc_output, training_output, parse_desc, limit=None -): - wp_to_id = io.read_title_to_id(def_input) - _process_wikipedia_texts( - wp_input, wp_to_id, desc_output, training_output, parse_desc, limit - ) - - -def _process_wikipedia_texts( - wikipedia_input, wp_to_id, output, training_output, parse_descriptions, limit=None -): - """ - Read the XML wikipedia data to parse out training data: - raw text data + positive instances - """ - - read_ids = set() - - with output.open("a", encoding="utf8") as descr_file, training_output.open( - "w", encoding="utf8" - ) as entity_file: - if parse_descriptions: - _write_training_description(descr_file, "WD_id", "description") - with bz2.open(wikipedia_input, mode="rb") as file: - article_count = 0 - article_text = "" - article_title = None - article_id = None - reading_text = False - reading_revision = False - - for line in file: - clean_line = line.strip().decode("utf-8") - - if clean_line == "": - reading_revision = True - elif clean_line == "": - reading_revision = False - - # Start reading new page - if clean_line == "": - article_text = "" - article_title = None - article_id = None - # finished reading this page - elif clean_line == "": - if article_id: - clean_text, entities = _process_wp_text( - article_title, article_text, wp_to_id - ) - if clean_text is not None and entities is not None: - _write_training_entities( - entity_file, article_id, clean_text, entities - ) - - if article_title in wp_to_id and parse_descriptions: - description = " ".join( - clean_text[:1000].split(" ")[:-1] - ) - _write_training_description( - descr_file, wp_to_id[article_title], description - ) - article_count += 1 - if article_count % 10000 == 0 and article_count > 0: - logger.info( - "Processed {} articles".format(article_count) - ) - if limit and article_count >= limit: - break - article_text = "" - article_title = None - article_id = None - reading_text = False - reading_revision = False - - # start reading text within a page - if "") - clean_text = clean_text.replace(r""", '"') - clean_text = clean_text.replace(r"&nbsp;", " ") - clean_text = clean_text.replace(r"&", "&") - - # remove multiple spaces - while " " in clean_text: - clean_text = clean_text.replace(" ", " ") - - return clean_text.strip() - - -def _remove_links(clean_text, wp_to_id): - # read the text char by char to get the right offsets for the interwiki links - entities = [] - final_text = "" - open_read = 0 - reading_text = True - reading_entity = False - reading_mention = False - reading_special_case = False - entity_buffer = "" - mention_buffer = "" - for index, letter in enumerate(clean_text): - if letter == "[": - open_read += 1 - elif letter == "]": - open_read -= 1 - elif letter == "|": - if reading_text: - final_text += letter - # switch from reading entity to mention in the [[entity|mention]] pattern - elif reading_entity: - reading_text = False - reading_entity = False - reading_mention = True - else: - reading_special_case = True - else: - if reading_entity: - entity_buffer += letter - elif reading_mention: - mention_buffer += letter - elif reading_text: - final_text += letter - else: - raise ValueError("Not sure at point", clean_text[index - 2 : index + 2]) - - if open_read > 2: - reading_special_case = True - - if open_read == 2 and reading_text: - reading_text = False - reading_entity = True - reading_mention = False - - # we just finished reading an entity - if open_read == 0 and not reading_text: - if "#" in entity_buffer or entity_buffer.startswith(":"): - reading_special_case = True - # Ignore cases with nested structures like File: handles etc - if not reading_special_case: - if not mention_buffer: - mention_buffer = entity_buffer - start = len(final_text) - end = start + len(mention_buffer) - qid = wp_to_id.get(entity_buffer, None) - if qid: - entities.append((mention_buffer, qid, start, end)) - final_text += mention_buffer - - entity_buffer = "" - mention_buffer = "" - - reading_text = True - reading_entity = False - reading_mention = False - reading_special_case = False - return final_text, entities - - -def _write_training_description(outputfile, qid, description): - if description is not None: - line = str(qid) + "|" + description + "\n" - outputfile.write(line) - - -def _write_training_entities(outputfile, article_id, clean_text, entities): - entities_data = [ - {"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} - for ent in entities - ] - line = ( - json.dumps( - { - "article_id": article_id, - "clean_text": clean_text, - "entities": entities_data, - }, - ensure_ascii=False, - ) - + "\n" - ) - outputfile.write(line) - - -def read_training_indices(entity_file_path): - """ This method creates two lists of indices into the training file: one with indices for the - training examples, and one for the dev examples.""" - train_indices = [] - dev_indices = [] - - with entity_file_path.open("r", encoding="utf8") as file: - for i, line in enumerate(file): - example = json.loads(line) - article_id = example["article_id"] - clean_text = example["clean_text"] - - if is_valid_article(clean_text): - if is_dev(article_id): - dev_indices.append(i) - else: - train_indices.append(i) - - return train_indices, dev_indices - - -def read_el_docs_golds(nlp, entity_file_path, dev, line_ids, kb, labels_discard=None): - """ This method provides training/dev examples that correspond to the entity annotations found by the nlp object. - For training, it will include both positive and negative examples by using the candidate generator from the kb. - For testing (kb=None), it will include all positive examples only.""" - if not labels_discard: - labels_discard = [] - - texts = [] - entities_list = [] - - with entity_file_path.open("r", encoding="utf8") as file: - for i, line in enumerate(file): - if i in line_ids: - example = json.loads(line) - article_id = example["article_id"] - clean_text = example["clean_text"] - entities = example["entities"] - - if dev != is_dev(article_id) or not is_valid_article(clean_text): - continue - - texts.append(clean_text) - entities_list.append(entities) - - docs = nlp.pipe(texts, batch_size=50) - - for doc, entities in zip(docs, entities_list): - gold = _get_gold_parse(doc, entities, dev=dev, kb=kb, labels_discard=labels_discard) - if gold and len(gold.links) > 0: - yield doc, gold - - -def _get_gold_parse(doc, entities, dev, kb, labels_discard): - gold_entities = {} - tagged_ent_positions = { - (ent.start_char, ent.end_char): ent - for ent in doc.ents - if ent.label_ not in labels_discard - } - - for entity in entities: - entity_id = entity["entity"] - alias = entity["alias"] - start = entity["start"] - end = entity["end"] - - candidate_ids = [] - if kb and not dev: - candidates = kb.get_candidates(alias) - candidate_ids = [cand.entity_ for cand in candidates] - - tagged_ent = tagged_ent_positions.get((start, end), None) - if tagged_ent: - # TODO: check that alias == doc.text[start:end] - should_add_ent = (dev or entity_id in candidate_ids) and is_valid_sentence( - tagged_ent.sent.text - ) - - if should_add_ent: - value_by_id = {entity_id: 1.0} - if not dev: - random.shuffle(candidate_ids) - value_by_id.update( - {kb_id: 0.0 for kb_id in candidate_ids if kb_id != entity_id} - ) - gold_entities[(start, end)] = value_by_id - - return GoldParse(doc, links=gold_entities) - - -def is_dev(article_id): - if not article_id: - return False - return article_id.endswith("3") - - -def is_valid_article(doc_text): - # custom length cut-off - return 10 < len(doc_text) < 30000 - - -def is_valid_sentence(sent_text): - if not 10 < len(sent_text) < 3000: - # custom length cut-off - return False - - if sent_text.strip().startswith("*") or sent_text.strip().startswith("#"): - # remove 'enumeration' sentences (occurs often on Wikipedia) - return False - - return True diff --git a/examples/training/pretrain_textcat.py b/examples/training/pretrain_textcat.py index 0aefec9ef..5c41c0e92 100644 --- a/examples/training/pretrain_textcat.py +++ b/examples/training/pretrain_textcat.py @@ -129,10 +129,7 @@ def train_textcat(nlp, n_texts, n_iter=10): ) train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats])) - # get names of other pipes to disable them during training - pipe_exceptions = ["textcat", "trf_wordpiecer", "trf_tok2vec"] - other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions] - with nlp.disable_pipes(*other_pipes): # only train textcat + with nlp.select_pipes(enable="textcat"): # only train textcat optimizer = nlp.begin_training() textcat.model.get_ref("tok2vec").from_bytes(tok2vec_weights) print("Training the model...") diff --git a/examples/training/rehearsal.py b/examples/training/rehearsal.py index a0455c0a9..24fc67ebb 100644 --- a/examples/training/rehearsal.py +++ b/examples/training/rehearsal.py @@ -62,11 +62,8 @@ def main(model_name, unlabelled_loc): optimizer.b1 = 0.0 optimizer.b2 = 0.0 - # get names of other pipes to disable them during training - pipe_exceptions = ["ner", "trf_wordpiecer", "trf_tok2vec"] - other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions] sizes = compounding(1.0, 4.0, 1.001) - with nlp.disable_pipes(*other_pipes): + with nlp.select_pipes(enable="ner"): for itn in range(n_iter): random.shuffle(TRAIN_DATA) random.shuffle(raw_docs) diff --git a/examples/training/textcat_example_data/textcatjsonl_to_trainjson.py b/examples/training/textcat_example_data/textcatjsonl_to_trainjson.py index 339ce39be..66d96ff68 100644 --- a/examples/training/textcat_example_data/textcatjsonl_to_trainjson.py +++ b/examples/training/textcat_example_data/textcatjsonl_to_trainjson.py @@ -5,16 +5,17 @@ from spacy.gold import docs_to_json import srsly import sys + @plac.annotations( model=("Model name. Defaults to 'en'.", "option", "m", str), input_file=("Input file (jsonl)", "positional", None, Path), output_dir=("Output directory", "positional", None, Path), n_texts=("Number of texts to convert", "option", "t", int), ) -def convert(model='en', input_file=None, output_dir=None, n_texts=0): +def convert(model="en", input_file=None, output_dir=None, n_texts=0): # Load model with tokenizer + sentencizer only nlp = spacy.load(model) - nlp.disable_pipes(*nlp.pipe_names) + nlp.select_pipes(disable=nlp.pipe_names) sentencizer = nlp.create_pipe("sentencizer") nlp.add_pipe(sentencizer, first=True) @@ -49,5 +50,6 @@ def convert(model='en', input_file=None, output_dir=None, n_texts=0): srsly.write_json(output_dir / input_file.with_suffix(".json"), [docs_to_json(docs)]) + if __name__ == "__main__": plac.call(convert) diff --git a/examples/training/train_entity_linker.py b/examples/training/train_entity_linker.py index 9776ad351..a22f255e7 100644 --- a/examples/training/train_entity_linker.py +++ b/examples/training/train_entity_linker.py @@ -97,7 +97,7 @@ def main(kb_path, vocab_path=None, output_dir=None, n_iter=50): kb_ids = nlp.get_pipe("entity_linker").kb.get_entity_strings() TRAIN_DOCS = [] for text, annotation in TRAIN_DATA: - with nlp.disable_pipes("entity_linker"): + with nlp.select_pipes(disable="entity_linker"): doc = nlp(text) annotation_clean = annotation for offset, kb_id_dict in annotation["links"].items(): @@ -112,10 +112,7 @@ def main(kb_path, vocab_path=None, output_dir=None, n_iter=50): annotation_clean["links"][offset] = new_dict TRAIN_DOCS.append((doc, annotation_clean)) - # get names of other pipes to disable them during training - pipe_exceptions = ["entity_linker", "trf_wordpiecer", "trf_tok2vec"] - other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions] - with nlp.disable_pipes(*other_pipes): # only train entity linker + with nlp.select_pipes(enable="entity_linker"): # only train entity linker # reset and initialize the weights randomly optimizer = nlp.begin_training() diff --git a/examples/training/train_intent_parser.py b/examples/training/train_intent_parser.py index bfec23d09..c3d5a279b 100644 --- a/examples/training/train_intent_parser.py +++ b/examples/training/train_intent_parser.py @@ -124,9 +124,7 @@ def main(model=None, output_dir=None, n_iter=15): for dep in annotations.get("deps", []): parser.add_label(dep) - pipe_exceptions = ["parser", "trf_wordpiecer", "trf_tok2vec"] - other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions] - with nlp.disable_pipes(*other_pipes): # only train parser + with nlp.select_pipes(enable="parser"): # only train parser optimizer = nlp.begin_training() for itn in range(n_iter): random.shuffle(TRAIN_DATA) diff --git a/examples/training/train_ner.py b/examples/training/train_ner.py index d4e0bf794..f0f3affe7 100644 --- a/examples/training/train_ner.py +++ b/examples/training/train_ner.py @@ -55,10 +55,7 @@ def main(model=None, output_dir=None, n_iter=100): print("Add label", ent[2]) ner.add_label(ent[2]) - # get names of other pipes to disable them during training - pipe_exceptions = ["simple_ner"] - other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions] - with nlp.disable_pipes(*other_pipes): # only train NER + with nlp.select_pipes(enable="ner"): # only train NER # reset and initialize the weights randomly – but only if we're # training a new model if model is None: diff --git a/examples/training/train_new_entity_type.py b/examples/training/train_new_entity_type.py index 47420e524..445c3fc27 100644 --- a/examples/training/train_new_entity_type.py +++ b/examples/training/train_new_entity_type.py @@ -94,10 +94,8 @@ def main(model=None, new_model_name="animal", output_dir=None, n_iter=30): else: optimizer = nlp.resume_training() move_names = list(ner.move_names) - # get names of other pipes to disable them during training - pipe_exceptions = ["ner", "trf_wordpiecer", "trf_tok2vec"] - other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions] - with nlp.disable_pipes(*other_pipes): # only train NER + + with nlp.select_pipes(enable="ner"): # only train NER sizes = compounding(1.0, 4.0, 1.001) # batch up the examples using spaCy's minibatch for itn in range(n_iter): diff --git a/examples/training/train_parser.py b/examples/training/train_parser.py index 7bb3e8586..4f4409e31 100644 --- a/examples/training/train_parser.py +++ b/examples/training/train_parser.py @@ -64,10 +64,7 @@ def main(model=None, output_dir=None, n_iter=15): for dep in annotations.get("deps", []): parser.add_label(dep) - # get names of other pipes to disable them during training - pipe_exceptions = ["parser", "trf_wordpiecer", "trf_tok2vec"] - other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions] - with nlp.disable_pipes(*other_pipes): # only train parser + with nlp.select_pipes(enable="parser"): # only train parser optimizer = nlp.begin_training() for itn in range(n_iter): random.shuffle(TRAIN_DATA) diff --git a/examples/training/train_textcat.py b/examples/training/train_textcat.py index dfb95b038..65acadb07 100644 --- a/examples/training/train_textcat.py +++ b/examples/training/train_textcat.py @@ -68,10 +68,7 @@ def main(config_path, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=Non ex = Example.from_gold(gold, doc=doc) train_examples.append(ex) - # get names of other pipes to disable them during training - pipe_exceptions = ["textcat", "trf_wordpiecer", "trf_tok2vec"] - other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions] - with nlp.disable_pipes(*other_pipes): # only train textcat + with nlp.select_pipes(enable="textcat"): # only train textcat optimizer = nlp.begin_training() if init_tok2vec is not None: with init_tok2vec.open("rb") as file_: diff --git a/spacy/cli/train.py b/spacy/cli/train.py index 5fa09da78..19e0a81e0 100644 --- a/spacy/cli/train.py +++ b/spacy/cli/train.py @@ -145,7 +145,7 @@ def train( msg.text(f"Loading vectors from model '{vectors}'") _load_vectors(nlp, vectors) - nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline]) + nlp.select_pipes(disable=[p for p in nlp.pipe_names if p not in pipeline]) for pipe in pipeline: # first, create the model. # Bit of a hack after the refactor to get the vectors into a default config @@ -201,8 +201,8 @@ def train( exits=1, ) msg.text(f"Extending component from base model '{pipe}'") - disabled_pipes = nlp.disable_pipes( - [p for p in nlp.pipe_names if p not in pipeline] + disabled_pipes = nlp.select_pipes( + disable=[p for p in nlp.pipe_names if p not in pipeline] ) else: msg.text(f"Starting with blank model '{lang}'") diff --git a/spacy/errors.py b/spacy/errors.py index 99a0081c0..7a7b44731 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -104,6 +104,8 @@ class Warnings(object): "string \"Field1=Value1,Value2|Field2=Value3\".") # TODO: fix numbering after merging develop into master + W096 = ("The method 'disable_pipes' has become deprecated - use 'select_pipes' " + "instead.") W097 = ("No Model config was provided to create the '{name}' component, " "and no default configuration could be found either.") W098 = ("No Model config was provided to create the '{name}' component, " @@ -132,7 +134,7 @@ class Errors(object): E007 = ("'{name}' already exists in pipeline. Existing names: {opts}") E008 = ("Some current components would be lost when restoring previous " "pipeline state. If you added components after calling " - "`nlp.disable_pipes()`, you should remove them explicitly with " + "`nlp.select_pipes()`, you should remove them explicitly with " "`nlp.remove_pipe()` before the pipeline is restored. Names of " "the new components: {names}") E009 = ("The `update` method expects same number of docs and golds, but " @@ -546,6 +548,13 @@ class Errors(object): "token itself.") # TODO: fix numbering after merging develop into master + E991 = ("The function 'select_pipes' should be called with either a " + "'disable' argument to list the names of the pipe components " + "that should be disabled, or with an 'enable' argument that " + "specifies which pipes should not be disabled.") + E992 = ("The function `select_pipes` was called with `enable`={enable} " + "and `disable`={disable} but that information is conflicting " + "for the `nlp` pipeline with components {names}.") E993 = ("The config for 'nlp' should include either a key 'name' to " "refer to an existing model by name or path, or a key 'lang' " "to create a new blank model.") diff --git a/spacy/language.py b/spacy/language.py index a7db5ef20..5f617b1f6 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -511,11 +511,37 @@ class Language(object): of the block. Otherwise, a DisabledPipes object is returned, that has a `.restore()` method you can use to undo your changes. - DOCS: https://spacy.io/api/language#disable_pipes + This method has been deprecated since 3.0 """ + warnings.warn(Warnings.W096, DeprecationWarning) if len(names) == 1 and isinstance(names[0], (list, tuple)): names = names[0] # support list of names instead of spread - return DisabledPipes(self, *names) + return DisabledPipes(self, names) + + def select_pipes(self, disable=None, enable=None): + """Disable one or more pipeline components. If used as a context + manager, the pipeline will be restored to the initial state at the end + of the block. Otherwise, a DisabledPipes object is returned, that has + a `.restore()` method you can use to undo your changes. + + disable (str or iterable): The name(s) of the pipes to disable + enable (str or iterable): The name(s) of the pipes to enable - all others will be disabled + + DOCS: https://spacy.io/api/language#select_pipes + """ + if enable is None and disable is None: + raise ValueError(Errors.E991) + if disable is not None and isinstance(disable, str): + disable = [disable] + if enable is not None: + if isinstance(enable, str): + enable = [enable] + to_disable = [pipe for pipe in self.pipe_names if pipe not in enable] + # raise an error if the enable and disable keywords are not consistent + if disable is not None and disable != to_disable: + raise ValueError(Errors.E992.format(enable=enable, disable=disable, names=self.pipe_names)) + disable = to_disable + return DisabledPipes(self, disable) def make_doc(self, text): return self.tokenizer(text) @@ -1117,7 +1143,7 @@ def _fix_pretrained_vectors_name(nlp): class DisabledPipes(list): """Manager for temporary pipeline disabling.""" - def __init__(self, nlp, *names): + def __init__(self, nlp, names): self.nlp = nlp self.names = names # Important! Not deep copy -- we just want the container (but we also diff --git a/spacy/pipeline/entityruler.py b/spacy/pipeline/entityruler.py index 06c568ac9..58160c2e9 100644 --- a/spacy/pipeline/entityruler.py +++ b/spacy/pipeline/entityruler.py @@ -200,7 +200,7 @@ class EntityRuler(object): ] except ValueError: subsequent_pipes = [] - with self.nlp.disable_pipes(subsequent_pipes): + with self.nlp.select_pipes(disable=subsequent_pipes): token_patterns = [] phrase_pattern_labels = [] phrase_pattern_texts = [] diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index e2fb02a2a..d42216655 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -88,7 +88,16 @@ def test_remove_pipe(nlp, name): def test_disable_pipes_method(nlp, name): nlp.add_pipe(new_pipe, name=name) assert nlp.has_pipe(name) - disabled = nlp.disable_pipes(name) + disabled = nlp.select_pipes(disable=name) + assert not nlp.has_pipe(name) + disabled.restore() + + +@pytest.mark.parametrize("name", ["my_component"]) +def test_enable_pipes_method(nlp, name): + nlp.add_pipe(new_pipe, name=name) + assert nlp.has_pipe(name) + disabled = nlp.select_pipes(enable=[]) assert not nlp.has_pipe(name) disabled.restore() @@ -97,19 +106,57 @@ def test_disable_pipes_method(nlp, name): def test_disable_pipes_context(nlp, name): nlp.add_pipe(new_pipe, name=name) assert nlp.has_pipe(name) - with nlp.disable_pipes(name): + with nlp.select_pipes(disable=name): assert not nlp.has_pipe(name) assert nlp.has_pipe(name) -def test_disable_pipes_list_arg(nlp): +def test_select_pipes_list_arg(nlp): for name in ["c1", "c2", "c3"]: nlp.add_pipe(new_pipe, name=name) assert nlp.has_pipe(name) - with nlp.disable_pipes(["c1", "c2"]): + with nlp.select_pipes(disable=["c1", "c2"]): assert not nlp.has_pipe("c1") assert not nlp.has_pipe("c2") assert nlp.has_pipe("c3") + with nlp.select_pipes(enable="c3"): + assert not nlp.has_pipe("c1") + assert not nlp.has_pipe("c2") + assert nlp.has_pipe("c3") + with nlp.select_pipes(enable=["c1", "c2"], disable="c3"): + assert nlp.has_pipe("c1") + assert nlp.has_pipe("c2") + assert not nlp.has_pipe("c3") + with nlp.select_pipes(enable=[]): + assert not nlp.has_pipe("c1") + assert not nlp.has_pipe("c2") + assert not nlp.has_pipe("c3") + with nlp.select_pipes(enable=["c1", "c2", "c3"], disable=[]): + assert nlp.has_pipe("c1") + assert nlp.has_pipe("c2") + assert nlp.has_pipe("c3") + with nlp.select_pipes(disable=["c1", "c2", "c3"], enable=[]): + assert not nlp.has_pipe("c1") + assert not nlp.has_pipe("c2") + assert not nlp.has_pipe("c3") + + +def test_select_pipes_errors(nlp): + for name in ["c1", "c2", "c3"]: + nlp.add_pipe(new_pipe, name=name) + assert nlp.has_pipe(name) + + with pytest.raises(ValueError): + nlp.select_pipes() + + with pytest.raises(ValueError): + nlp.select_pipes(enable=["c1", "c2"], disable=["c1"]) + + with pytest.raises(ValueError): + nlp.select_pipes(enable=["c1", "c2"], disable=[]) + + with pytest.raises(ValueError): + nlp.select_pipes(enable=[], disable=["c3"]) @pytest.mark.parametrize("n_pipes", [100]) diff --git a/spacy/tests/regression/test_issue3611.py b/spacy/tests/regression/test_issue3611.py index 120cea1d2..cab68793c 100644 --- a/spacy/tests/regression/test_issue3611.py +++ b/spacy/tests/regression/test_issue3611.py @@ -31,7 +31,7 @@ def test_issue3611(): nlp.add_pipe(textcat, last=True) # training the network - with nlp.disable_pipes([p for p in nlp.pipe_names if p != "textcat"]): + with nlp.select_pipes(enable="textcat"): optimizer = nlp.begin_training(X=x_train, Y=y_train) for i in range(3): losses = {} diff --git a/spacy/tests/regression/test_issue4030.py b/spacy/tests/regression/test_issue4030.py index 7158d9b21..b641213ad 100644 --- a/spacy/tests/regression/test_issue4030.py +++ b/spacy/tests/regression/test_issue4030.py @@ -31,7 +31,7 @@ def test_issue4030(): nlp.add_pipe(textcat, last=True) # training the network - with nlp.disable_pipes([p for p in nlp.pipe_names if p != "textcat"]): + with nlp.select_pipes(enable="textcat"): optimizer = nlp.begin_training() for i in range(3): losses = {} diff --git a/website/docs/api/language.md b/website/docs/api/language.md index d548a1f64..703a0f678 100644 --- a/website/docs/api/language.md +++ b/website/docs/api/language.md @@ -314,45 +314,47 @@ component function. | `name` | unicode | Name of the component to remove. | | **RETURNS** | tuple | A `(name, component)` tuple of the removed component. | -## Language.disable_pipes {#disable_pipes tag="contextmanager, method" new="2"} +## Language.select_pipes {#select_pipes tag="contextmanager, method" new="3"} Disable one or more pipeline components. If used as a context manager, the pipeline will be restored to the initial state at the end of the block. Otherwise, a `DisabledPipes` object is returned, that has a `.restore()` method you can use to undo your changes. +You can specify either `disable` (as a list or string), or `enable`. In the +latter case, all components not in the `enable` list, will be disabled. + > #### Example > > ```python -> # New API as of v2.2.2 -> with nlp.disable_pipes(["tagger", "parser"]): +> # New API as of v3.0 +> with nlp.select_pipes(disable=["tagger", "parser"]): > nlp.begin_training() > -> with nlp.disable_pipes("tagger", "parser"): +> with nlp.select_pipes(enable="ner"): > nlp.begin_training() > -> disabled = nlp.disable_pipes("tagger", "parser") +> disabled = nlp.select_pipes(disable=["tagger", "parser"]) > nlp.begin_training() > disabled.restore() > ``` -| Name | Type | Description | -| ----------------------------------------- | --------------- | ------------------------------------------------------------------------------------ | -| `disabled` 2.2.2 | list | Names of pipeline components to disable. | -| `*disabled` | unicode | Names of pipeline components to disable. | -| **RETURNS** | `DisabledPipes` | The disabled pipes that can be restored by calling the object's `.restore()` method. | +| Name | Type | Description | +| ----------- | --------------- | ------------------------------------------------------------------------------------ | +| `disable` | list | Names of pipeline components to disable. | +| `disable` | unicode | Name of pipeline component to disable. | +| `enable` | list | Names of pipeline components that will not be disabled. | +| `enable` | unicode | Name of pipeline component that will not be disabled. | +| **RETURNS** | `DisabledPipes` | The disabled pipes that can be restored by calling the object's `.restore()` method. | - -As of spaCy v2.2.2, the `Language.disable_pipes` method can also take a list of -component names as its first argument (instead of a variable number of -arguments). This is especially useful if you're generating the component names -to disable programmatically. The new syntax will become the default in the -future. + + +As of spaCy v3.0, the `disable_pipes` method has been renamed to `select_pipes`: ```diff -- disabled = nlp.disable_pipes("tagger", "parser") -+ disabled = nlp.disable_pipes(["tagger", "parser"]) +- nlp.disable_pipes(["tagger", "parser"]) ++ nlp.select_pipes(disable=["tagger", "parser"]) ``` diff --git a/website/docs/usage/processing-pipelines.md b/website/docs/usage/processing-pipelines.md index 7382f2b8c..696e11106 100644 --- a/website/docs/usage/processing-pipelines.md +++ b/website/docs/usage/processing-pipelines.md @@ -252,9 +252,9 @@ for doc in nlp.pipe(texts, disable=["tagger", "parser"]): If you need to **execute more code** with components disabled – e.g. to reset the weights or update only some components during training – you can use the -[`nlp.disable_pipes`](/api/language#disable_pipes) contextmanager. At the end of +[`nlp.select_pipes`](/api/language#select_pipes) contextmanager. At the end of the `with` block, the disabled pipeline components will be restored -automatically. Alternatively, `disable_pipes` returns an object that lets you +automatically. Alternatively, `select_pipes` returns an object that lets you call its `restore()` method to restore the disabled components when needed. This can be useful if you want to prevent unnecessary code indentation of large blocks. @@ -262,16 +262,26 @@ blocks. ```python ### Disable for block # 1. Use as a contextmanager -with nlp.disable_pipes("tagger", "parser"): +with nlp.select_pipes(disable=["tagger", "parser"]): doc = nlp("I won't be tagged and parsed") doc = nlp("I will be tagged and parsed") # 2. Restore manually -disabled = nlp.disable_pipes("ner") +disabled = nlp.select_pipes(disable="ner") doc = nlp("I won't have named entities") disabled.restore() ``` +If you want to disable all pipes except for one or a few, you can use the `enable` +keyword. Just like the `disable` keyword, it takes a list of pipe names, or a string +defining just one pipe. +```python +# Enable only the parser +with nlp.select_pipes(enable="parser"): + doc = nlp("I will only be parsed") +``` + + Finally, you can also use the [`remove_pipe`](/api/language#remove_pipe) method to remove pipeline components from an existing pipeline, the [`rename_pipe`](/api/language#rename_pipe) method to rename them, or the diff --git a/website/docs/usage/rule-based-matching.md b/website/docs/usage/rule-based-matching.md index 1db2405d1..5f47bd2e3 100644 --- a/website/docs/usage/rule-based-matching.md +++ b/website/docs/usage/rule-based-matching.md @@ -906,7 +906,7 @@ pipeline component, **make sure that the pipeline component runs** when you create the pattern. For example, to match on `POS` or `LEMMA`, the pattern `Doc` objects need to have part-of-speech tags set by the `tagger`. You can either call the `nlp` object on your pattern texts instead of `nlp.make_doc`, or use -[`nlp.disable_pipes`](/api/language#disable_pipes) to disable components +[`nlp.select_pipes`](/api/language#select_pipes) to disable components selectively. @@ -1121,8 +1121,7 @@ while adding the phrase patterns. entityruler = EntityRuler(nlp) patterns = [{"label": "TEST", "pattern": str(i)} for i in range(100000)] -other_pipes = [p for p in nlp.pipe_names if p != "tagger"] -with nlp.disable_pipes(*other_pipes): +with nlp.select_pipes(enable="tagger"): entityruler.add_patterns(patterns) ``` diff --git a/website/docs/usage/spacy-101.md b/website/docs/usage/spacy-101.md index 479bdd264..39d732724 100644 --- a/website/docs/usage/spacy-101.md +++ b/website/docs/usage/spacy-101.md @@ -647,8 +647,7 @@ import random nlp = spacy.load("en_core_web_sm") train_data = [("Uber blew through $1 million", {"entities": [(0, 4, "ORG")]})] -other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "ner"] -with nlp.disable_pipes(*other_pipes): +with nlp.select_pipes(enable="ner"): optimizer = nlp.begin_training() for i in range(10): random.shuffle(train_data) diff --git a/website/docs/usage/training.md b/website/docs/usage/training.md index 479441edf..a10c60357 100644 --- a/website/docs/usage/training.md +++ b/website/docs/usage/training.md @@ -362,7 +362,7 @@ https://github.com/explosion/spaCy/tree/master/examples/training/train_ner.py you're using a blank model, don't forget to add the entity recognizer to the pipeline. If you're using an existing model, make sure to disable all other pipeline components during training using - [`nlp.disable_pipes`](/api/language#disable_pipes). This way, you'll only be + [`nlp.select_pipes`](/api/language#select_pipes). This way, you'll only be training the entity recognizer. 2. **Shuffle and loop over** the examples. For each example, **update the model** by calling [`nlp.update`](/api/language#update), which steps through @@ -403,7 +403,7 @@ referred to as the "catastrophic forgetting" problem. you're using a blank model, don't forget to add the entity recognizer to the pipeline. If you're using an existing model, make sure to disable all other pipeline components during training using - [`nlp.disable_pipes`](/api/language#disable_pipes). This way, you'll only be + [`nlp.select_pipes`](/api/language#select_pipes). This way, you'll only be training the entity recognizer. 2. **Add the new entity label** to the entity recognizer using the [`add_label`](/api/entityrecognizer#add_label) method. You can access the @@ -436,7 +436,7 @@ https://github.com/explosion/spaCy/tree/master/examples/training/train_parser.py you're using a blank model, don't forget to add the parser to the pipeline. If you're using an existing model, make sure to disable all other pipeline components during training using - [`nlp.disable_pipes`](/api/language#disable_pipes). This way, you'll only be + [`nlp.select_pipes`](/api/language#select_pipes). This way, you'll only be training the parser. 2. **Add the dependency labels** to the parser using the [`add_label`](/api/dependencyparser#add_label) method. If you're starting off @@ -470,7 +470,7 @@ https://github.com/explosion/spaCy/tree/master/examples/training/train_tagger.py you're using a blank model, don't forget to add the tagger to the pipeline. If you're using an existing model, make sure to disable all other pipeline components during training using - [`nlp.disable_pipes`](/api/language#disable_pipes). This way, you'll only be + [`nlp.select_pipes`](/api/language#select_pipes). This way, you'll only be training the tagger. 2. **Add the tag map** to the tagger using the [`add_label`](/api/tagger#add_label) method. The first argument is the new @@ -544,7 +544,7 @@ https://github.com/explosion/spaCy/tree/master/examples/training/train_intent_pa you're using a blank model, don't forget to add the custom parser to the pipeline. If you're using an existing model, make sure to **remove the old parser** from the pipeline, and disable all other pipeline components during - training using [`nlp.disable_pipes`](/api/language#disable_pipes). This way, + training using [`nlp.select_pipes`](/api/language#select_pipes). This way, you'll only be training the parser. 3. **Add the dependency labels** to the parser using the [`add_label`](/api/dependencyparser#add_label) method. @@ -576,7 +576,7 @@ https://github.com/explosion/spaCy/tree/master/examples/training/train_textcat.p [`spacy.blank`](/api/top-level#spacy.blank) with the ID of your language. If you're using an existing model, make sure to disable all other pipeline components during training using - [`nlp.disable_pipes`](/api/language#disable_pipes). This way, you'll only be + [`nlp.select_pipes`](/api/language#select_pipes). This way, you'll only be training the text classifier. 2. **Add the text classifier** to the pipeline, and add the labels you want to train – for example, `POSITIVE`. @@ -653,7 +653,7 @@ https://github.com/explosion/spaCy/tree/master/examples/training/train_entity_li pipeline including also a component for [named entity recognition](/usage/training#ner). If you're using a model with additional components, make sure to disable all other pipeline components - during training using [`nlp.disable_pipes`](/api/language#disable_pipes). + during training using [`nlp.select_pipes`](/api/language#select_pipes). This way, you'll only be training the entity linker. 2. **Shuffle and loop over** the examples. For each example, **update the model** by calling [`nlp.update`](/api/language#update), which steps through