mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	KB extensions and better parsing of WikiData (#4375)
* fix overflow error on windows * more documentation & logging fixes * md fix * 3 different limit parameters to play with execution time * bug fixes directory locations * small fixes * exclude dev test articles from prior probabilities stats * small fixes * filtering wikidata entities, removing numeric and meta items * adding aliases from wikidata also to the KB * fix adding WD aliases * adding also new aliases to previously added entities * fixing comma's * small doc fixes * adding subclassof filtering * append alias functionality in KB * prevent appending the same entity-alias pair * fix for appending WD aliases * remove date filter * remove unnecessary import * small corrections and reformatting * remove WD aliases for now (too slow) * removing numeric entities from training and evaluation * small fixes * shortcut during prediction if there is only one candidate * add counts and fscore logging, remove FP NER from evaluation * fix entity_linker.predict to take docs instead of single sentences * remove enumeration sentences from the WP dataset * entity_linker.update to process full doc instead of single sentence * spelling corrections and dump locations in readme * NLP IO fix * reading KB is unnecessary at the end of the pipeline * small logging fix * remove empty files
This commit is contained in:
		
							parent
							
								
									428887b8f2
								
							
						
					
					
						commit
						2d249a9502
					
				
							
								
								
									
										34
									
								
								bin/wiki_entity_linking/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								bin/wiki_entity_linking/README.md
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,34 @@
 | 
				
			||||||
 | 
					## Entity Linking with Wikipedia and Wikidata
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Step 1: Create a Knowledge Base (KB) and training data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Run  `wikipedia_pretrain_kb.py` 
 | 
				
			||||||
 | 
					* This takes as input the locations of a **Wikipedia and a Wikidata dump**, and produces a **KB directory** + **training file**
 | 
				
			||||||
 | 
					  * WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/
 | 
				
			||||||
 | 
					  * Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language)
 | 
				
			||||||
 | 
					* You can set the filtering parameters for KB construction:
 | 
				
			||||||
 | 
					  * `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym
 | 
				
			||||||
 | 
					  * `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB
 | 
				
			||||||
 | 
					  * `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB
 | 
				
			||||||
 | 
					* Further parameters to set:
 | 
				
			||||||
 | 
					  * `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`)
 | 
				
			||||||
 | 
					  * `entity_vector_length`: length of the pre-trained entity description vectors
 | 
				
			||||||
 | 
					  * `lang`: language for which to fetch Wikidata information (as the dump contains all languages)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Quick testing and rerunning: 
 | 
				
			||||||
 | 
					* When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything. 
 | 
				
			||||||
 | 
					* If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					### Step 2: Train an Entity Linking model
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Run  `wikidata_train_entity_linker.py` 
 | 
				
			||||||
 | 
					* This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model**
 | 
				
			||||||
 | 
					* You can set the learning parameters for the EL training:
 | 
				
			||||||
 | 
					  * `epochs`: number of training iterations
 | 
				
			||||||
 | 
					  * `dropout`: dropout rate
 | 
				
			||||||
 | 
					  * `lr`: learning rate
 | 
				
			||||||
 | 
					  * `l2`: L2 regularization
 | 
				
			||||||
 | 
					* Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively
 | 
				
			||||||
 | 
					* Further parameters to set:
 | 
				
			||||||
 | 
					  * `labels_discard`: NER label types to discard during training
 | 
				
			||||||
| 
						 | 
					@ -6,6 +6,7 @@ OUTPUT_MODEL_DIR = "nlp"
 | 
				
			||||||
PRIOR_PROB_PATH = "prior_prob.csv"
 | 
					PRIOR_PROB_PATH = "prior_prob.csv"
 | 
				
			||||||
ENTITY_DEFS_PATH = "entity_defs.csv"
 | 
					ENTITY_DEFS_PATH = "entity_defs.csv"
 | 
				
			||||||
ENTITY_FREQ_PATH = "entity_freq.csv"
 | 
					ENTITY_FREQ_PATH = "entity_freq.csv"
 | 
				
			||||||
 | 
					ENTITY_ALIAS_PATH = "entity_alias.csv"
 | 
				
			||||||
ENTITY_DESCR_PATH = "entity_descriptions.csv"
 | 
					ENTITY_DESCR_PATH = "entity_descriptions.csv"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
 | 
					LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -15,10 +15,11 @@ class Metrics(object):
 | 
				
			||||||
        candidate_is_correct = true_entity == candidate
 | 
					        candidate_is_correct = true_entity == candidate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
 | 
					        # Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
 | 
				
			||||||
        # Therefore, if candidate_is_correct then we have a true positive and never a true negative
 | 
					        # Therefore, if candidate_is_correct then we have a true positive and never a true negative.
 | 
				
			||||||
        self.true_pos += candidate_is_correct
 | 
					        self.true_pos += candidate_is_correct
 | 
				
			||||||
        self.false_neg += not candidate_is_correct
 | 
					        self.false_neg += not candidate_is_correct
 | 
				
			||||||
        if candidate not in {"", "NIL"}:
 | 
					        if candidate and candidate not in {"", "NIL"}:
 | 
				
			||||||
 | 
					            # A wrong prediction (e.g. Q42 != Q3) counts both as a FP as well as a FN.
 | 
				
			||||||
            self.false_pos += not candidate_is_correct
 | 
					            self.false_pos += not candidate_is_correct
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def calculate_precision(self):
 | 
					    def calculate_precision(self):
 | 
				
			||||||
| 
						 | 
					@ -33,6 +34,14 @@ class Metrics(object):
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            return self.true_pos / (self.true_pos + self.false_neg)
 | 
					            return self.true_pos / (self.true_pos + self.false_neg)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def calculate_fscore(self):
 | 
				
			||||||
 | 
					        p = self.calculate_precision()
 | 
				
			||||||
 | 
					        r = self.calculate_recall()
 | 
				
			||||||
 | 
					        if p + r == 0:
 | 
				
			||||||
 | 
					            return 0.0
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return 2 * p * r / (p + r)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class EvaluationResults(object):
 | 
					class EvaluationResults(object):
 | 
				
			||||||
    def __init__(self):
 | 
					    def __init__(self):
 | 
				
			||||||
| 
						 | 
					@ -43,18 +52,20 @@ class EvaluationResults(object):
 | 
				
			||||||
        self.metrics.update_results(true_entity, candidate)
 | 
					        self.metrics.update_results(true_entity, candidate)
 | 
				
			||||||
        self.metrics_by_label[ent_label].update_results(true_entity, candidate)
 | 
					        self.metrics_by_label[ent_label].update_results(true_entity, candidate)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def increment_false_negatives(self):
 | 
					 | 
				
			||||||
        self.metrics.false_neg += 1
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    def report_metrics(self, model_name):
 | 
					    def report_metrics(self, model_name):
 | 
				
			||||||
        model_str = model_name.title()
 | 
					        model_str = model_name.title()
 | 
				
			||||||
        recall = self.metrics.calculate_recall()
 | 
					        recall = self.metrics.calculate_recall()
 | 
				
			||||||
        precision = self.metrics.calculate_precision()
 | 
					        precision = self.metrics.calculate_precision()
 | 
				
			||||||
        return ("{}: ".format(model_str) +
 | 
					        fscore = self.metrics.calculate_fscore()
 | 
				
			||||||
                "Recall = {} | ".format(round(recall, 3)) +
 | 
					        return (
 | 
				
			||||||
                "Precision = {} | ".format(round(precision, 3)) +
 | 
					            "{}: ".format(model_str)
 | 
				
			||||||
                "Precision by label = {}".format({k: v.calculate_precision()
 | 
					            + "F-score = {} | ".format(round(fscore, 3))
 | 
				
			||||||
                                                  for k, v in self.metrics_by_label.items()}))
 | 
					            + "Recall = {} | ".format(round(recall, 3))
 | 
				
			||||||
 | 
					            + "Precision = {} | ".format(round(precision, 3))
 | 
				
			||||||
 | 
					            + "F-score by label = {}".format(
 | 
				
			||||||
 | 
					                {k: v.calculate_fscore() for k, v in sorted(self.metrics_by_label.items())}
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class BaselineResults(object):
 | 
					class BaselineResults(object):
 | 
				
			||||||
| 
						 | 
					@ -63,40 +74,51 @@ class BaselineResults(object):
 | 
				
			||||||
        self.prior = EvaluationResults()
 | 
					        self.prior = EvaluationResults()
 | 
				
			||||||
        self.oracle = EvaluationResults()
 | 
					        self.oracle = EvaluationResults()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def report_accuracy(self, model):
 | 
					    def report_performance(self, model):
 | 
				
			||||||
        results = getattr(self, model)
 | 
					        results = getattr(self, model)
 | 
				
			||||||
        return results.report_metrics(model)
 | 
					        return results.report_metrics(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate):
 | 
					    def update_baselines(
 | 
				
			||||||
 | 
					        self,
 | 
				
			||||||
 | 
					        true_entity,
 | 
				
			||||||
 | 
					        ent_label,
 | 
				
			||||||
 | 
					        random_candidate,
 | 
				
			||||||
 | 
					        prior_candidate,
 | 
				
			||||||
 | 
					        oracle_candidate,
 | 
				
			||||||
 | 
					    ):
 | 
				
			||||||
        self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
 | 
					        self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
 | 
				
			||||||
        self.prior.update_metrics(ent_label, true_entity, prior_candidate)
 | 
					        self.prior.update_metrics(ent_label, true_entity, prior_candidate)
 | 
				
			||||||
        self.random.update_metrics(ent_label, true_entity, random_candidate)
 | 
					        self.random.update_metrics(ent_label, true_entity, random_candidate)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def measure_performance(dev_data, kb, el_pipe):
 | 
					def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True):
 | 
				
			||||||
    baseline_accuracies = measure_baselines(
 | 
					    if baseline:
 | 
				
			||||||
        dev_data, kb
 | 
					        baseline_accuracies, counts = measure_baselines(dev_data, kb)
 | 
				
			||||||
    )
 | 
					        logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())}))
 | 
				
			||||||
 | 
					        logger.info(baseline_accuracies.report_performance("random"))
 | 
				
			||||||
 | 
					        logger.info(baseline_accuracies.report_performance("prior"))
 | 
				
			||||||
 | 
					        logger.info(baseline_accuracies.report_performance("oracle"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info(baseline_accuracies.report_accuracy("random"))
 | 
					    if context:
 | 
				
			||||||
    logger.info(baseline_accuracies.report_accuracy("prior"))
 | 
					        # using only context
 | 
				
			||||||
    logger.info(baseline_accuracies.report_accuracy("oracle"))
 | 
					        el_pipe.cfg["incl_context"] = True
 | 
				
			||||||
 | 
					        el_pipe.cfg["incl_prior"] = False
 | 
				
			||||||
 | 
					        results = get_eval_results(dev_data, el_pipe)
 | 
				
			||||||
 | 
					        logger.info(results.report_metrics("context only"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # using only context
 | 
					        # measuring combined accuracy (prior + context)
 | 
				
			||||||
    el_pipe.cfg["incl_context"] = True
 | 
					        el_pipe.cfg["incl_context"] = True
 | 
				
			||||||
    el_pipe.cfg["incl_prior"] = False
 | 
					        el_pipe.cfg["incl_prior"] = True
 | 
				
			||||||
    results = get_eval_results(dev_data, el_pipe)
 | 
					        results = get_eval_results(dev_data, el_pipe)
 | 
				
			||||||
    logger.info(results.report_metrics("context only"))
 | 
					        logger.info(results.report_metrics("context and prior"))
 | 
				
			||||||
 | 
					 | 
				
			||||||
    # measuring combined accuracy (prior + context)
 | 
					 | 
				
			||||||
    el_pipe.cfg["incl_context"] = True
 | 
					 | 
				
			||||||
    el_pipe.cfg["incl_prior"] = True
 | 
					 | 
				
			||||||
    results = get_eval_results(dev_data, el_pipe)
 | 
					 | 
				
			||||||
    logger.info(results.report_metrics("context and prior"))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_eval_results(data, el_pipe=None):
 | 
					def get_eval_results(data, el_pipe=None):
 | 
				
			||||||
    # If the docs in the data require further processing with an entity linker, set el_pipe
 | 
					    """
 | 
				
			||||||
 | 
					    Evaluate the ent.kb_id_ annotations against the gold standard.
 | 
				
			||||||
 | 
					    Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
 | 
				
			||||||
 | 
					    If the docs in the data require further processing with an entity linker, set el_pipe.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    from tqdm import tqdm
 | 
					    from tqdm import tqdm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    docs = []
 | 
					    docs = []
 | 
				
			||||||
| 
						 | 
					@ -111,18 +133,15 @@ def get_eval_results(data, el_pipe=None):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    results = EvaluationResults()
 | 
					    results = EvaluationResults()
 | 
				
			||||||
    for doc, gold in zip(docs, golds):
 | 
					    for doc, gold in zip(docs, golds):
 | 
				
			||||||
        tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
 | 
					 | 
				
			||||||
        try:
 | 
					        try:
 | 
				
			||||||
            correct_entries_per_article = dict()
 | 
					            correct_entries_per_article = dict()
 | 
				
			||||||
            for entity, kb_dict in gold.links.items():
 | 
					            for entity, kb_dict in gold.links.items():
 | 
				
			||||||
                start, end = entity
 | 
					                start, end = entity
 | 
				
			||||||
                # only evaluating on positive examples
 | 
					 | 
				
			||||||
                for gold_kb, value in kb_dict.items():
 | 
					                for gold_kb, value in kb_dict.items():
 | 
				
			||||||
                    if value:
 | 
					                    if value:
 | 
				
			||||||
 | 
					                        # only evaluating on positive examples
 | 
				
			||||||
                        offset = _offset(start, end)
 | 
					                        offset = _offset(start, end)
 | 
				
			||||||
                        correct_entries_per_article[offset] = gold_kb
 | 
					                        correct_entries_per_article[offset] = gold_kb
 | 
				
			||||||
                        if offset not in tagged_entries_per_article:
 | 
					 | 
				
			||||||
                            results.increment_false_negatives()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for ent in doc.ents:
 | 
					            for ent in doc.ents:
 | 
				
			||||||
                ent_label = ent.label_
 | 
					                ent_label = ent.label_
 | 
				
			||||||
| 
						 | 
					@ -142,7 +161,11 @@ def get_eval_results(data, el_pipe=None):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def measure_baselines(data, kb):
 | 
					def measure_baselines(data, kb):
 | 
				
			||||||
    # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
 | 
					    """
 | 
				
			||||||
 | 
					    Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound.
 | 
				
			||||||
 | 
					    Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL.
 | 
				
			||||||
 | 
					    Also return a dictionary of counts by entity label.
 | 
				
			||||||
 | 
					    """
 | 
				
			||||||
    counts_d = dict()
 | 
					    counts_d = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    baseline_results = BaselineResults()
 | 
					    baseline_results = BaselineResults()
 | 
				
			||||||
| 
						 | 
					@ -152,7 +175,6 @@ def measure_baselines(data, kb):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for doc, gold in zip(docs, golds):
 | 
					    for doc, gold in zip(docs, golds):
 | 
				
			||||||
        correct_entries_per_article = dict()
 | 
					        correct_entries_per_article = dict()
 | 
				
			||||||
        tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
 | 
					 | 
				
			||||||
        for entity, kb_dict in gold.links.items():
 | 
					        for entity, kb_dict in gold.links.items():
 | 
				
			||||||
            start, end = entity
 | 
					            start, end = entity
 | 
				
			||||||
            for gold_kb, value in kb_dict.items():
 | 
					            for gold_kb, value in kb_dict.items():
 | 
				
			||||||
| 
						 | 
					@ -160,10 +182,6 @@ def measure_baselines(data, kb):
 | 
				
			||||||
                if value:
 | 
					                if value:
 | 
				
			||||||
                    offset = _offset(start, end)
 | 
					                    offset = _offset(start, end)
 | 
				
			||||||
                    correct_entries_per_article[offset] = gold_kb
 | 
					                    correct_entries_per_article[offset] = gold_kb
 | 
				
			||||||
                    if offset not in tagged_entries_per_article:
 | 
					 | 
				
			||||||
                        baseline_results.random.increment_false_negatives()
 | 
					 | 
				
			||||||
                        baseline_results.oracle.increment_false_negatives()
 | 
					 | 
				
			||||||
                        baseline_results.prior.increment_false_negatives()
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for ent in doc.ents:
 | 
					        for ent in doc.ents:
 | 
				
			||||||
            ent_label = ent.label_
 | 
					            ent_label = ent.label_
 | 
				
			||||||
| 
						 | 
					@ -176,7 +194,7 @@ def measure_baselines(data, kb):
 | 
				
			||||||
            if gold_entity is not None:
 | 
					            if gold_entity is not None:
 | 
				
			||||||
                candidates = kb.get_candidates(ent.text)
 | 
					                candidates = kb.get_candidates(ent.text)
 | 
				
			||||||
                oracle_candidate = ""
 | 
					                oracle_candidate = ""
 | 
				
			||||||
                best_candidate = ""
 | 
					                prior_candidate = ""
 | 
				
			||||||
                random_candidate = ""
 | 
					                random_candidate = ""
 | 
				
			||||||
                if candidates:
 | 
					                if candidates:
 | 
				
			||||||
                    scores = []
 | 
					                    scores = []
 | 
				
			||||||
| 
						 | 
					@ -187,13 +205,21 @@ def measure_baselines(data, kb):
 | 
				
			||||||
                            oracle_candidate = c.entity_
 | 
					                            oracle_candidate = c.entity_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    best_index = scores.index(max(scores))
 | 
					                    best_index = scores.index(max(scores))
 | 
				
			||||||
                    best_candidate = candidates[best_index].entity_
 | 
					                    prior_candidate = candidates[best_index].entity_
 | 
				
			||||||
                    random_candidate = random.choice(candidates).entity_
 | 
					                    random_candidate = random.choice(candidates).entity_
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                baseline_results.update_baselines(gold_entity, ent_label,
 | 
					                current_count = counts_d.get(ent_label, 0)
 | 
				
			||||||
                                                  random_candidate, best_candidate, oracle_candidate)
 | 
					                counts_d[ent_label] = current_count+1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return baseline_results
 | 
					                baseline_results.update_baselines(
 | 
				
			||||||
 | 
					                    gold_entity,
 | 
				
			||||||
 | 
					                    ent_label,
 | 
				
			||||||
 | 
					                    random_candidate,
 | 
				
			||||||
 | 
					                    prior_candidate,
 | 
				
			||||||
 | 
					                    oracle_candidate,
 | 
				
			||||||
 | 
					                )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return baseline_results, counts_d
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def _offset(start, end):
 | 
					def _offset(start, end):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,17 +1,12 @@
 | 
				
			||||||
# coding: utf-8
 | 
					# coding: utf-8
 | 
				
			||||||
from __future__ import unicode_literals
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import csv
 | 
					 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import spacy
 | 
					 | 
				
			||||||
import sys
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from spacy.kb import KnowledgeBase
 | 
					from spacy.kb import KnowledgeBase
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from bin.wiki_entity_linking import wikipedia_processor as wp
 | 
					 | 
				
			||||||
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
 | 
					from bin.wiki_entity_linking.train_descriptions import EntityEncoder
 | 
				
			||||||
 | 
					from bin.wiki_entity_linking import wiki_io as io
 | 
				
			||||||
csv.field_size_limit(sys.maxsize)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
| 
						 | 
					@ -22,18 +17,24 @@ def create_kb(
 | 
				
			||||||
    max_entities_per_alias,
 | 
					    max_entities_per_alias,
 | 
				
			||||||
    min_entity_freq,
 | 
					    min_entity_freq,
 | 
				
			||||||
    min_occ,
 | 
					    min_occ,
 | 
				
			||||||
    entity_def_input,
 | 
					    entity_def_path,
 | 
				
			||||||
    entity_descr_path,
 | 
					    entity_descr_path,
 | 
				
			||||||
    count_input,
 | 
					    entity_alias_path,
 | 
				
			||||||
    prior_prob_input,
 | 
					    entity_freq_path,
 | 
				
			||||||
 | 
					    prior_prob_path,
 | 
				
			||||||
    entity_vector_length,
 | 
					    entity_vector_length,
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    # Create the knowledge base from Wikidata entries
 | 
					    # Create the knowledge base from Wikidata entries
 | 
				
			||||||
    kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
 | 
					    kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
 | 
				
			||||||
 | 
					    entity_list, filtered_title_to_id = _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length)
 | 
				
			||||||
 | 
					    _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path)
 | 
				
			||||||
 | 
					    return kb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length):
 | 
				
			||||||
    # read the mappings from file
 | 
					    # read the mappings from file
 | 
				
			||||||
    title_to_id = get_entity_to_id(entity_def_input)
 | 
					    title_to_id = io.read_title_to_id(entity_def_path)
 | 
				
			||||||
    id_to_descr = get_id_to_description(entity_descr_path)
 | 
					    id_to_descr = io.read_id_to_descr(entity_descr_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # check the length of the nlp vectors
 | 
					    # check the length of the nlp vectors
 | 
				
			||||||
    if "vectors" in nlp.meta and nlp.vocab.vectors.size:
 | 
					    if "vectors" in nlp.meta and nlp.vocab.vectors.size:
 | 
				
			||||||
| 
						 | 
					@ -45,10 +46,8 @@ def create_kb(
 | 
				
			||||||
            " cf. https://spacy.io/usage/models#languages."
 | 
					            " cf. https://spacy.io/usage/models#languages."
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info("Get entity frequencies")
 | 
					 | 
				
			||||||
    entity_frequencies = wp.get_all_frequencies(count_input=count_input)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
 | 
					    logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
 | 
				
			||||||
 | 
					    entity_frequencies = io.read_entity_to_count(entity_freq_path)
 | 
				
			||||||
    # filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
 | 
					    # filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
 | 
				
			||||||
    filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
 | 
					    filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
 | 
				
			||||||
        title_to_id,
 | 
					        title_to_id,
 | 
				
			||||||
| 
						 | 
					@ -56,36 +55,33 @@ def create_kb(
 | 
				
			||||||
        entity_frequencies,
 | 
					        entity_frequencies,
 | 
				
			||||||
        min_entity_freq
 | 
					        min_entity_freq
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    logger.info("Left with {} entities".format(len(description_list)))
 | 
					    logger.info("Kept {} entities from the set of {}".format(len(description_list), len(title_to_id.keys())))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info("Train entity encoder")
 | 
					    logger.info("Training entity encoder")
 | 
				
			||||||
    encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
 | 
					    encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
 | 
				
			||||||
    encoder.train(description_list=description_list, to_print=True)
 | 
					    encoder.train(description_list=description_list, to_print=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info("Get entity embeddings:")
 | 
					    logger.info("Getting entity embeddings")
 | 
				
			||||||
    embeddings = encoder.apply_encoder(description_list)
 | 
					    embeddings = encoder.apply_encoder(description_list)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info("Adding {} entities".format(len(entity_list)))
 | 
					    logger.info("Adding {} entities".format(len(entity_list)))
 | 
				
			||||||
    kb.set_entities(
 | 
					    kb.set_entities(
 | 
				
			||||||
        entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
 | 
					        entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					    return entity_list, filtered_title_to_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info("Adding aliases")
 | 
					
 | 
				
			||||||
 | 
					def _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
 | 
				
			||||||
 | 
					    logger.info("Adding aliases from Wikipedia and Wikidata")
 | 
				
			||||||
    _add_aliases(
 | 
					    _add_aliases(
 | 
				
			||||||
        kb,
 | 
					        kb,
 | 
				
			||||||
 | 
					        entity_list=entity_list,
 | 
				
			||||||
        title_to_id=filtered_title_to_id,
 | 
					        title_to_id=filtered_title_to_id,
 | 
				
			||||||
        max_entities_per_alias=max_entities_per_alias,
 | 
					        max_entities_per_alias=max_entities_per_alias,
 | 
				
			||||||
        min_occ=min_occ,
 | 
					        min_occ=min_occ,
 | 
				
			||||||
        prior_prob_input=prior_prob_input,
 | 
					        prior_prob_path=prior_prob_path,
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info("KB size: {} entities, {} aliases".format(
 | 
					 | 
				
			||||||
        kb.get_size_entities(),
 | 
					 | 
				
			||||||
        kb.get_size_aliases()))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    logger.info("Done with kb")
 | 
					 | 
				
			||||||
    return kb
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
 | 
					def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
 | 
				
			||||||
                          min_entity_freq: int = 10):
 | 
					                          min_entity_freq: int = 10):
 | 
				
			||||||
| 
						 | 
					@ -104,34 +100,13 @@ def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
 | 
				
			||||||
    return filtered_title_to_id, entity_list, description_list, frequency_list
 | 
					    return filtered_title_to_id, entity_list, description_list, frequency_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_entity_to_id(entity_def_output):
 | 
					def _add_aliases(kb, entity_list, title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
 | 
				
			||||||
    entity_to_id = dict()
 | 
					 | 
				
			||||||
    with entity_def_output.open("r", encoding="utf8") as csvfile:
 | 
					 | 
				
			||||||
        csvreader = csv.reader(csvfile, delimiter="|")
 | 
					 | 
				
			||||||
        # skip header
 | 
					 | 
				
			||||||
        next(csvreader)
 | 
					 | 
				
			||||||
        for row in csvreader:
 | 
					 | 
				
			||||||
            entity_to_id[row[0]] = row[1]
 | 
					 | 
				
			||||||
    return entity_to_id
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def get_id_to_description(entity_descr_path):
 | 
					 | 
				
			||||||
    id_to_desc = dict()
 | 
					 | 
				
			||||||
    with entity_descr_path.open("r", encoding="utf8") as csvfile:
 | 
					 | 
				
			||||||
        csvreader = csv.reader(csvfile, delimiter="|")
 | 
					 | 
				
			||||||
        # skip header
 | 
					 | 
				
			||||||
        next(csvreader)
 | 
					 | 
				
			||||||
        for row in csvreader:
 | 
					 | 
				
			||||||
            id_to_desc[row[0]] = row[1]
 | 
					 | 
				
			||||||
    return id_to_desc
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
 | 
					 | 
				
			||||||
    wp_titles = title_to_id.keys()
 | 
					    wp_titles = title_to_id.keys()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # adding aliases with prior probabilities
 | 
					    # adding aliases with prior probabilities
 | 
				
			||||||
    # we can read this file sequentially, it's sorted by alias, and then by count
 | 
					    # we can read this file sequentially, it's sorted by alias, and then by count
 | 
				
			||||||
    with prior_prob_input.open("r", encoding="utf8") as prior_file:
 | 
					    logger.info("Adding WP aliases")
 | 
				
			||||||
 | 
					    with prior_prob_path.open("r", encoding="utf8") as prior_file:
 | 
				
			||||||
        # skip header
 | 
					        # skip header
 | 
				
			||||||
        prior_file.readline()
 | 
					        prior_file.readline()
 | 
				
			||||||
        line = prior_file.readline()
 | 
					        line = prior_file.readline()
 | 
				
			||||||
| 
						 | 
					@ -180,10 +155,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
 | 
				
			||||||
            line = prior_file.readline()
 | 
					            line = prior_file.readline()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def read_nlp_kb(model_dir, kb_file):
 | 
					def read_kb(nlp, kb_file):
 | 
				
			||||||
    nlp = spacy.load(model_dir)
 | 
					 | 
				
			||||||
    kb = KnowledgeBase(vocab=nlp.vocab)
 | 
					    kb = KnowledgeBase(vocab=nlp.vocab)
 | 
				
			||||||
    kb.load_bulk(kb_file)
 | 
					    kb.load_bulk(kb_file)
 | 
				
			||||||
    logger.info("kb entities: {}".format(kb.get_size_entities()))
 | 
					    return kb
 | 
				
			||||||
    logger.info("kb aliases: {}".format(kb.get_size_aliases()))
 | 
					 | 
				
			||||||
    return nlp, kb
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -53,7 +53,7 @@ class EntityEncoder:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            start = start + batch_size
 | 
					            start = start + batch_size
 | 
				
			||||||
            stop = min(stop + batch_size, len(description_list))
 | 
					            stop = min(stop + batch_size, len(description_list))
 | 
				
			||||||
            logger.info("encoded: {} entities".format(stop))
 | 
					            logger.info("Encoded: {} entities".format(stop))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return encodings
 | 
					        return encodings
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -62,7 +62,7 @@ class EntityEncoder:
 | 
				
			||||||
        if to_print:
 | 
					        if to_print:
 | 
				
			||||||
            logger.info(
 | 
					            logger.info(
 | 
				
			||||||
                "Trained entity descriptions on {} ".format(processed) +
 | 
					                "Trained entity descriptions on {} ".format(processed) +
 | 
				
			||||||
                "(non-unique) entities across {} ".format(self.epochs) +
 | 
					                "(non-unique) descriptions across {} ".format(self.epochs) +
 | 
				
			||||||
                "epochs"
 | 
					                "epochs"
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
            logger.info("Final loss: {}".format(loss))
 | 
					            logger.info("Final loss: {}".format(loss))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,395 +0,0 @@
 | 
				
			||||||
# coding: utf-8
 | 
					 | 
				
			||||||
from __future__ import unicode_literals
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
import logging
 | 
					 | 
				
			||||||
import random
 | 
					 | 
				
			||||||
import re
 | 
					 | 
				
			||||||
import bz2
 | 
					 | 
				
			||||||
import json
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from functools import partial
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
from spacy.gold import GoldParse
 | 
					 | 
				
			||||||
from bin.wiki_entity_linking import kb_creator
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
 | 
					 | 
				
			||||||
Gold-standard entities are stored in one file in standoff format (by character offset).
 | 
					 | 
				
			||||||
"""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
ENTITY_FILE = "gold_entities.csv"
 | 
					 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def create_training_examples_and_descriptions(wikipedia_input,
 | 
					 | 
				
			||||||
                                              entity_def_input,
 | 
					 | 
				
			||||||
                                              description_output,
 | 
					 | 
				
			||||||
                                              training_output,
 | 
					 | 
				
			||||||
                                              parse_descriptions,
 | 
					 | 
				
			||||||
                                              limit=None):
 | 
					 | 
				
			||||||
    wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
 | 
					 | 
				
			||||||
    _process_wikipedia_texts(wikipedia_input,
 | 
					 | 
				
			||||||
                             wp_to_id,
 | 
					 | 
				
			||||||
                             description_output,
 | 
					 | 
				
			||||||
                             training_output,
 | 
					 | 
				
			||||||
                             parse_descriptions,
 | 
					 | 
				
			||||||
                             limit)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _process_wikipedia_texts(wikipedia_input,
 | 
					 | 
				
			||||||
                             wp_to_id,
 | 
					 | 
				
			||||||
                             output,
 | 
					 | 
				
			||||||
                             training_output,
 | 
					 | 
				
			||||||
                             parse_descriptions,
 | 
					 | 
				
			||||||
                             limit=None):
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    Read the XML wikipedia data to parse out training data:
 | 
					 | 
				
			||||||
    raw text data + positive instances
 | 
					 | 
				
			||||||
    """
 | 
					 | 
				
			||||||
    title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
 | 
					 | 
				
			||||||
    id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    read_ids = set()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file:
 | 
					 | 
				
			||||||
        if parse_descriptions:
 | 
					 | 
				
			||||||
            _write_training_description(descr_file, "WD_id", "description")
 | 
					 | 
				
			||||||
        with bz2.open(wikipedia_input, mode="rb") as file:
 | 
					 | 
				
			||||||
            article_count = 0
 | 
					 | 
				
			||||||
            article_text = ""
 | 
					 | 
				
			||||||
            article_title = None
 | 
					 | 
				
			||||||
            article_id = None
 | 
					 | 
				
			||||||
            reading_text = False
 | 
					 | 
				
			||||||
            reading_revision = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            logger.info("Processed {} articles".format(article_count))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            for line in file:
 | 
					 | 
				
			||||||
                clean_line = line.strip().decode("utf-8")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if clean_line == "<revision>":
 | 
					 | 
				
			||||||
                    reading_revision = True
 | 
					 | 
				
			||||||
                elif clean_line == "</revision>":
 | 
					 | 
				
			||||||
                    reading_revision = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # Start reading new page
 | 
					 | 
				
			||||||
                if clean_line == "<page>":
 | 
					 | 
				
			||||||
                    article_text = ""
 | 
					 | 
				
			||||||
                    article_title = None
 | 
					 | 
				
			||||||
                    article_id = None
 | 
					 | 
				
			||||||
                # finished reading this page
 | 
					 | 
				
			||||||
                elif clean_line == "</page>":
 | 
					 | 
				
			||||||
                    if article_id:
 | 
					 | 
				
			||||||
                        clean_text, entities = _process_wp_text(
 | 
					 | 
				
			||||||
                            article_title,
 | 
					 | 
				
			||||||
                            article_text,
 | 
					 | 
				
			||||||
                            wp_to_id
 | 
					 | 
				
			||||||
                        )
 | 
					 | 
				
			||||||
                        if clean_text is not None and entities is not None:
 | 
					 | 
				
			||||||
                            _write_training_entities(entity_file,
 | 
					 | 
				
			||||||
                                                     article_id,
 | 
					 | 
				
			||||||
                                                     clean_text,
 | 
					 | 
				
			||||||
                                                     entities)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                            if article_title in wp_to_id and parse_descriptions:
 | 
					 | 
				
			||||||
                                description = " ".join(clean_text[:1000].split(" ")[:-1])
 | 
					 | 
				
			||||||
                                _write_training_description(
 | 
					 | 
				
			||||||
                                    descr_file,
 | 
					 | 
				
			||||||
                                    wp_to_id[article_title],
 | 
					 | 
				
			||||||
                                    description
 | 
					 | 
				
			||||||
                                )
 | 
					 | 
				
			||||||
                            article_count += 1
 | 
					 | 
				
			||||||
                            if article_count % 10000 == 0:
 | 
					 | 
				
			||||||
                                logger.info("Processed {} articles".format(article_count))
 | 
					 | 
				
			||||||
                            if limit and article_count >= limit:
 | 
					 | 
				
			||||||
                                break
 | 
					 | 
				
			||||||
                    article_text = ""
 | 
					 | 
				
			||||||
                    article_title = None
 | 
					 | 
				
			||||||
                    article_id = None
 | 
					 | 
				
			||||||
                    reading_text = False
 | 
					 | 
				
			||||||
                    reading_revision = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # start reading text within a page
 | 
					 | 
				
			||||||
                if "<text" in clean_line:
 | 
					 | 
				
			||||||
                    reading_text = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if reading_text:
 | 
					 | 
				
			||||||
                    article_text += " " + clean_line
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # stop reading text within a page (we assume a new page doesn't start on the same line)
 | 
					 | 
				
			||||||
                if "</text" in clean_line:
 | 
					 | 
				
			||||||
                    reading_text = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # read the ID of this article (outside the revision portion of the document)
 | 
					 | 
				
			||||||
                if not reading_revision:
 | 
					 | 
				
			||||||
                    ids = id_regex.search(clean_line)
 | 
					 | 
				
			||||||
                    if ids:
 | 
					 | 
				
			||||||
                        article_id = ids[0]
 | 
					 | 
				
			||||||
                        if article_id in read_ids:
 | 
					 | 
				
			||||||
                            logger.info(
 | 
					 | 
				
			||||||
                                "Found duplicate article ID", article_id, clean_line
 | 
					 | 
				
			||||||
                            )  # This should never happen ...
 | 
					 | 
				
			||||||
                        read_ids.add(article_id)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                # read the title of this article (outside the revision portion of the document)
 | 
					 | 
				
			||||||
                if not reading_revision:
 | 
					 | 
				
			||||||
                    titles = title_regex.search(clean_line)
 | 
					 | 
				
			||||||
                    if titles:
 | 
					 | 
				
			||||||
                        article_title = titles[0].strip()
 | 
					 | 
				
			||||||
    logger.info("Finished. Processed {} articles".format(article_count))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
 | 
					 | 
				
			||||||
info_regex = re.compile(r"{[^{]*?}")
 | 
					 | 
				
			||||||
htlm_regex = re.compile(r"<!--[^-]*-->")
 | 
					 | 
				
			||||||
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
 | 
					 | 
				
			||||||
file_regex = re.compile(r"\[\[File:[^[\]]+]]")
 | 
					 | 
				
			||||||
ref_regex = re.compile(r"<ref.*?>")  # non-greedy
 | 
					 | 
				
			||||||
ref_2_regex = re.compile(r"</ref.*?>")  # non-greedy
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _process_wp_text(article_title, article_text, wp_to_id):
 | 
					 | 
				
			||||||
    # ignore meta Wikipedia pages
 | 
					 | 
				
			||||||
    if (
 | 
					 | 
				
			||||||
        article_title.startswith("Wikipedia:") or
 | 
					 | 
				
			||||||
        article_title.startswith("Kategori:")
 | 
					 | 
				
			||||||
    ):
 | 
					 | 
				
			||||||
        return None, None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # remove the text tags
 | 
					 | 
				
			||||||
    text_search = text_regex.search(article_text)
 | 
					 | 
				
			||||||
    if text_search is None:
 | 
					 | 
				
			||||||
        return None, None
 | 
					 | 
				
			||||||
    text = text_search.group(0)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # stop processing if this is a redirect page
 | 
					 | 
				
			||||||
    if text.startswith("#REDIRECT"):
 | 
					 | 
				
			||||||
        return None, None
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # get the raw text without markup etc, keeping only interwiki links
 | 
					 | 
				
			||||||
    clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
 | 
					 | 
				
			||||||
    return clean_text, entities
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _get_clean_wp_text(article_text):
 | 
					 | 
				
			||||||
    clean_text = article_text.strip()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # remove bolding & italic markup
 | 
					 | 
				
			||||||
    clean_text = clean_text.replace("'''", "")
 | 
					 | 
				
			||||||
    clean_text = clean_text.replace("''", "")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # remove nested {{info}} statements by removing the inner/smallest ones first and iterating
 | 
					 | 
				
			||||||
    try_again = True
 | 
					 | 
				
			||||||
    previous_length = len(clean_text)
 | 
					 | 
				
			||||||
    while try_again:
 | 
					 | 
				
			||||||
        clean_text = info_regex.sub(
 | 
					 | 
				
			||||||
            "", clean_text
 | 
					 | 
				
			||||||
        )  # non-greedy match excluding a nested {
 | 
					 | 
				
			||||||
        if len(clean_text) < previous_length:
 | 
					 | 
				
			||||||
            try_again = True
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            try_again = False
 | 
					 | 
				
			||||||
        previous_length = len(clean_text)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # remove HTML comments
 | 
					 | 
				
			||||||
    clean_text = htlm_regex.sub("", clean_text)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # remove Category and File statements
 | 
					 | 
				
			||||||
    clean_text = category_regex.sub("", clean_text)
 | 
					 | 
				
			||||||
    clean_text = file_regex.sub("", clean_text)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # remove multiple =
 | 
					 | 
				
			||||||
    while "==" in clean_text:
 | 
					 | 
				
			||||||
        clean_text = clean_text.replace("==", "=")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    clean_text = clean_text.replace(". =", ".")
 | 
					 | 
				
			||||||
    clean_text = clean_text.replace(" = ", ". ")
 | 
					 | 
				
			||||||
    clean_text = clean_text.replace("= ", ".")
 | 
					 | 
				
			||||||
    clean_text = clean_text.replace(" =", "")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # remove refs (non-greedy match)
 | 
					 | 
				
			||||||
    clean_text = ref_regex.sub("", clean_text)
 | 
					 | 
				
			||||||
    clean_text = ref_2_regex.sub("", clean_text)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # remove additional wikiformatting
 | 
					 | 
				
			||||||
    clean_text = re.sub(r"<blockquote>", "", clean_text)
 | 
					 | 
				
			||||||
    clean_text = re.sub(r"</blockquote>", "", clean_text)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # change special characters back to normal ones
 | 
					 | 
				
			||||||
    clean_text = clean_text.replace(r"<", "<")
 | 
					 | 
				
			||||||
    clean_text = clean_text.replace(r">", ">")
 | 
					 | 
				
			||||||
    clean_text = clean_text.replace(r""", '"')
 | 
					 | 
				
			||||||
    clean_text = clean_text.replace(r"&nbsp;", " ")
 | 
					 | 
				
			||||||
    clean_text = clean_text.replace(r"&", "&")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # remove multiple spaces
 | 
					 | 
				
			||||||
    while "  " in clean_text:
 | 
					 | 
				
			||||||
        clean_text = clean_text.replace("  ", " ")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return clean_text.strip()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _remove_links(clean_text, wp_to_id):
 | 
					 | 
				
			||||||
    # read the text char by char to get the right offsets for the interwiki links
 | 
					 | 
				
			||||||
    entities = []
 | 
					 | 
				
			||||||
    final_text = ""
 | 
					 | 
				
			||||||
    open_read = 0
 | 
					 | 
				
			||||||
    reading_text = True
 | 
					 | 
				
			||||||
    reading_entity = False
 | 
					 | 
				
			||||||
    reading_mention = False
 | 
					 | 
				
			||||||
    reading_special_case = False
 | 
					 | 
				
			||||||
    entity_buffer = ""
 | 
					 | 
				
			||||||
    mention_buffer = ""
 | 
					 | 
				
			||||||
    for index, letter in enumerate(clean_text):
 | 
					 | 
				
			||||||
        if letter == "[":
 | 
					 | 
				
			||||||
            open_read += 1
 | 
					 | 
				
			||||||
        elif letter == "]":
 | 
					 | 
				
			||||||
            open_read -= 1
 | 
					 | 
				
			||||||
        elif letter == "|":
 | 
					 | 
				
			||||||
            if reading_text:
 | 
					 | 
				
			||||||
                final_text += letter
 | 
					 | 
				
			||||||
            # switch from reading entity to mention in the [[entity|mention]] pattern
 | 
					 | 
				
			||||||
            elif reading_entity:
 | 
					 | 
				
			||||||
                reading_text = False
 | 
					 | 
				
			||||||
                reading_entity = False
 | 
					 | 
				
			||||||
                reading_mention = True
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                reading_special_case = True
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            if reading_entity:
 | 
					 | 
				
			||||||
                entity_buffer += letter
 | 
					 | 
				
			||||||
            elif reading_mention:
 | 
					 | 
				
			||||||
                mention_buffer += letter
 | 
					 | 
				
			||||||
            elif reading_text:
 | 
					 | 
				
			||||||
                final_text += letter
 | 
					 | 
				
			||||||
            else:
 | 
					 | 
				
			||||||
                raise ValueError("Not sure at point", clean_text[index - 2: index + 2])
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if open_read > 2:
 | 
					 | 
				
			||||||
            reading_special_case = True
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if open_read == 2 and reading_text:
 | 
					 | 
				
			||||||
            reading_text = False
 | 
					 | 
				
			||||||
            reading_entity = True
 | 
					 | 
				
			||||||
            reading_mention = False
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        # we just finished reading an entity
 | 
					 | 
				
			||||||
        if open_read == 0 and not reading_text:
 | 
					 | 
				
			||||||
            if "#" in entity_buffer or entity_buffer.startswith(":"):
 | 
					 | 
				
			||||||
                reading_special_case = True
 | 
					 | 
				
			||||||
            # Ignore cases with nested structures like File: handles etc
 | 
					 | 
				
			||||||
            if not reading_special_case:
 | 
					 | 
				
			||||||
                if not mention_buffer:
 | 
					 | 
				
			||||||
                    mention_buffer = entity_buffer
 | 
					 | 
				
			||||||
                start = len(final_text)
 | 
					 | 
				
			||||||
                end = start + len(mention_buffer)
 | 
					 | 
				
			||||||
                qid = wp_to_id.get(entity_buffer, None)
 | 
					 | 
				
			||||||
                if qid:
 | 
					 | 
				
			||||||
                    entities.append((mention_buffer, qid, start, end))
 | 
					 | 
				
			||||||
                final_text += mention_buffer
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            entity_buffer = ""
 | 
					 | 
				
			||||||
            mention_buffer = ""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            reading_text = True
 | 
					 | 
				
			||||||
            reading_entity = False
 | 
					 | 
				
			||||||
            reading_mention = False
 | 
					 | 
				
			||||||
            reading_special_case = False
 | 
					 | 
				
			||||||
    return final_text, entities
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _write_training_description(outputfile, qid, description):
 | 
					 | 
				
			||||||
    if description is not None:
 | 
					 | 
				
			||||||
        line = str(qid) + "|" + description + "\n"
 | 
					 | 
				
			||||||
        outputfile.write(line)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _write_training_entities(outputfile, article_id, clean_text, entities):
 | 
					 | 
				
			||||||
    entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities]
 | 
					 | 
				
			||||||
    line = json.dumps(
 | 
					 | 
				
			||||||
        {
 | 
					 | 
				
			||||||
            "article_id": article_id,
 | 
					 | 
				
			||||||
            "clean_text": clean_text,
 | 
					 | 
				
			||||||
            "entities": entities_data
 | 
					 | 
				
			||||||
        },
 | 
					 | 
				
			||||||
        ensure_ascii=False) + "\n"
 | 
					 | 
				
			||||||
    outputfile.write(line)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def read_training(nlp, entity_file_path, dev, limit, kb):
 | 
					 | 
				
			||||||
    """ This method provides training examples that correspond to the entity annotations found by the nlp object.
 | 
					 | 
				
			||||||
     For training,, it will include negative training examples by using the candidate generator,
 | 
					 | 
				
			||||||
     and it will only keep positive training examples that can be found by using the candidate generator.
 | 
					 | 
				
			||||||
     For testing, it will include all positive examples only."""
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    from tqdm import tqdm
 | 
					 | 
				
			||||||
    data = []
 | 
					 | 
				
			||||||
    num_entities = 0
 | 
					 | 
				
			||||||
    get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit))
 | 
					 | 
				
			||||||
    with entity_file_path.open("r", encoding="utf8") as file:
 | 
					 | 
				
			||||||
        with tqdm(total=limit, leave=False) as pbar:
 | 
					 | 
				
			||||||
            for i, line in enumerate(file):
 | 
					 | 
				
			||||||
                example = json.loads(line)
 | 
					 | 
				
			||||||
                article_id = example["article_id"]
 | 
					 | 
				
			||||||
                clean_text = example["clean_text"]
 | 
					 | 
				
			||||||
                entities = example["entities"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                if dev != is_dev(article_id) or len(clean_text) >= 30000:
 | 
					 | 
				
			||||||
                    continue
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
                doc = nlp(clean_text)
 | 
					 | 
				
			||||||
                gold = get_gold_parse(doc, entities)
 | 
					 | 
				
			||||||
                if gold and len(gold.links) > 0:
 | 
					 | 
				
			||||||
                    data.append((doc, gold))
 | 
					 | 
				
			||||||
                    num_entities += len(gold.links)
 | 
					 | 
				
			||||||
                    pbar.update(len(gold.links))
 | 
					 | 
				
			||||||
                if limit and num_entities >= limit:
 | 
					 | 
				
			||||||
                    break
 | 
					 | 
				
			||||||
    logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
 | 
					 | 
				
			||||||
    return data
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def _get_gold_parse(doc, entities, dev, kb):
 | 
					 | 
				
			||||||
    gold_entities = {}
 | 
					 | 
				
			||||||
    tagged_ent_positions = set(
 | 
					 | 
				
			||||||
        [(ent.start_char, ent.end_char) for ent in doc.ents]
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    for entity in entities:
 | 
					 | 
				
			||||||
        entity_id = entity["entity"]
 | 
					 | 
				
			||||||
        alias = entity["alias"]
 | 
					 | 
				
			||||||
        start = entity["start"]
 | 
					 | 
				
			||||||
        end = entity["end"]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        candidates = kb.get_candidates(alias)
 | 
					 | 
				
			||||||
        candidate_ids = [
 | 
					 | 
				
			||||||
            c.entity_ for c in candidates
 | 
					 | 
				
			||||||
        ]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        should_add_ent = (
 | 
					 | 
				
			||||||
            dev or
 | 
					 | 
				
			||||||
            (
 | 
					 | 
				
			||||||
                (start, end) in tagged_ent_positions and
 | 
					 | 
				
			||||||
                entity_id in candidate_ids and
 | 
					 | 
				
			||||||
                len(candidates) > 1
 | 
					 | 
				
			||||||
            )
 | 
					 | 
				
			||||||
        )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        if should_add_ent:
 | 
					 | 
				
			||||||
            value_by_id = {entity_id: 1.0}
 | 
					 | 
				
			||||||
            if not dev:
 | 
					 | 
				
			||||||
                random.shuffle(candidate_ids)
 | 
					 | 
				
			||||||
                value_by_id.update({
 | 
					 | 
				
			||||||
                    kb_id: 0.0
 | 
					 | 
				
			||||||
                    for kb_id in candidate_ids
 | 
					 | 
				
			||||||
                    if kb_id != entity_id
 | 
					 | 
				
			||||||
                })
 | 
					 | 
				
			||||||
            gold_entities[(start, end)] = value_by_id
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    return GoldParse(doc, links=gold_entities)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def is_dev(article_id):
 | 
					 | 
				
			||||||
    return article_id.endswith("3")
 | 
					 | 
				
			||||||
							
								
								
									
										127
									
								
								bin/wiki_entity_linking/wiki_io.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								bin/wiki_entity_linking/wiki_io.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,127 @@
 | 
				
			||||||
 | 
					# coding: utf-8
 | 
				
			||||||
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					import sys
 | 
				
			||||||
 | 
					import csv
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# min() needed to prevent error on windows, cf https://stackoverflow.com/questions/52404416/
 | 
				
			||||||
 | 
					csv.field_size_limit(min(sys.maxsize, 2147483646))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					""" This class provides reading/writing methods for temp files """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Entity definition: WP title -> WD ID #
 | 
				
			||||||
 | 
					def write_title_to_id(entity_def_output, title_to_id):
 | 
				
			||||||
 | 
					    with entity_def_output.open("w", encoding="utf8") as id_file:
 | 
				
			||||||
 | 
					        id_file.write("WP_title" + "|" + "WD_id" + "\n")
 | 
				
			||||||
 | 
					        for title, qid in title_to_id.items():
 | 
				
			||||||
 | 
					            id_file.write(title + "|" + str(qid) + "\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def read_title_to_id(entity_def_output):
 | 
				
			||||||
 | 
					    title_to_id = dict()
 | 
				
			||||||
 | 
					    with entity_def_output.open("r", encoding="utf8") as id_file:
 | 
				
			||||||
 | 
					        csvreader = csv.reader(id_file, delimiter="|")
 | 
				
			||||||
 | 
					        # skip header
 | 
				
			||||||
 | 
					        next(csvreader)
 | 
				
			||||||
 | 
					        for row in csvreader:
 | 
				
			||||||
 | 
					            title_to_id[row[0]] = row[1]
 | 
				
			||||||
 | 
					    return title_to_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Entity aliases from WD: WD ID -> WD alias #
 | 
				
			||||||
 | 
					def write_id_to_alias(entity_alias_path, id_to_alias):
 | 
				
			||||||
 | 
					    with entity_alias_path.open("w", encoding="utf8") as alias_file:
 | 
				
			||||||
 | 
					        alias_file.write("WD_id" + "|" + "alias" + "\n")
 | 
				
			||||||
 | 
					        for qid, alias_list in id_to_alias.items():
 | 
				
			||||||
 | 
					            for alias in alias_list:
 | 
				
			||||||
 | 
					                alias_file.write(str(qid) + "|" + alias + "\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def read_id_to_alias(entity_alias_path):
 | 
				
			||||||
 | 
					    id_to_alias = dict()
 | 
				
			||||||
 | 
					    with entity_alias_path.open("r", encoding="utf8") as alias_file:
 | 
				
			||||||
 | 
					        csvreader = csv.reader(alias_file, delimiter="|")
 | 
				
			||||||
 | 
					        # skip header
 | 
				
			||||||
 | 
					        next(csvreader)
 | 
				
			||||||
 | 
					        for row in csvreader:
 | 
				
			||||||
 | 
					            qid = row[0]
 | 
				
			||||||
 | 
					            alias = row[1]
 | 
				
			||||||
 | 
					            alias_list = id_to_alias.get(qid, [])
 | 
				
			||||||
 | 
					            alias_list.append(alias)
 | 
				
			||||||
 | 
					            id_to_alias[qid] = alias_list
 | 
				
			||||||
 | 
					    return id_to_alias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def read_alias_to_id_generator(entity_alias_path):
 | 
				
			||||||
 | 
					    """ Read (aliases, qid) tuples """
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with entity_alias_path.open("r", encoding="utf8") as alias_file:
 | 
				
			||||||
 | 
					        csvreader = csv.reader(alias_file, delimiter="|")
 | 
				
			||||||
 | 
					        # skip header
 | 
				
			||||||
 | 
					        next(csvreader)
 | 
				
			||||||
 | 
					        for row in csvreader:
 | 
				
			||||||
 | 
					            qid = row[0]
 | 
				
			||||||
 | 
					            alias = row[1]
 | 
				
			||||||
 | 
					            yield alias, qid
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Entity descriptions from WD: WD ID -> WD alias #
 | 
				
			||||||
 | 
					def write_id_to_descr(entity_descr_output, id_to_descr):
 | 
				
			||||||
 | 
					    with entity_descr_output.open("w", encoding="utf8") as descr_file:
 | 
				
			||||||
 | 
					        descr_file.write("WD_id" + "|" + "description" + "\n")
 | 
				
			||||||
 | 
					        for qid, descr in id_to_descr.items():
 | 
				
			||||||
 | 
					            descr_file.write(str(qid) + "|" + descr + "\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def read_id_to_descr(entity_desc_path):
 | 
				
			||||||
 | 
					    id_to_desc = dict()
 | 
				
			||||||
 | 
					    with entity_desc_path.open("r", encoding="utf8") as descr_file:
 | 
				
			||||||
 | 
					        csvreader = csv.reader(descr_file, delimiter="|")
 | 
				
			||||||
 | 
					        # skip header
 | 
				
			||||||
 | 
					        next(csvreader)
 | 
				
			||||||
 | 
					        for row in csvreader:
 | 
				
			||||||
 | 
					            id_to_desc[row[0]] = row[1]
 | 
				
			||||||
 | 
					    return id_to_desc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# Entity counts from WP: WP title -> count #
 | 
				
			||||||
 | 
					def write_entity_to_count(prior_prob_input, count_output):
 | 
				
			||||||
 | 
					    # Write entity counts for quick access later
 | 
				
			||||||
 | 
					    entity_to_count = dict()
 | 
				
			||||||
 | 
					    total_count = 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with prior_prob_input.open("r", encoding="utf8") as prior_file:
 | 
				
			||||||
 | 
					        # skip header
 | 
				
			||||||
 | 
					        prior_file.readline()
 | 
				
			||||||
 | 
					        line = prior_file.readline()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        while line:
 | 
				
			||||||
 | 
					            splits = line.replace("\n", "").split(sep="|")
 | 
				
			||||||
 | 
					            # alias = splits[0]
 | 
				
			||||||
 | 
					            count = int(splits[1])
 | 
				
			||||||
 | 
					            entity = splits[2]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            current_count = entity_to_count.get(entity, 0)
 | 
				
			||||||
 | 
					            entity_to_count[entity] = current_count + count
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            total_count += count
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            line = prior_file.readline()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with count_output.open("w", encoding="utf8") as entity_file:
 | 
				
			||||||
 | 
					        entity_file.write("entity" + "|" + "count" + "\n")
 | 
				
			||||||
 | 
					        for entity, count in entity_to_count.items():
 | 
				
			||||||
 | 
					            entity_file.write(entity + "|" + str(count) + "\n")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def read_entity_to_count(count_input):
 | 
				
			||||||
 | 
					    entity_to_count = dict()
 | 
				
			||||||
 | 
					    with count_input.open("r", encoding="utf8") as csvfile:
 | 
				
			||||||
 | 
					        csvreader = csv.reader(csvfile, delimiter="|")
 | 
				
			||||||
 | 
					        # skip header
 | 
				
			||||||
 | 
					        next(csvreader)
 | 
				
			||||||
 | 
					        for row in csvreader:
 | 
				
			||||||
 | 
					            entity_to_count[row[0]] = int(row[1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return entity_to_count
 | 
				
			||||||
							
								
								
									
										128
									
								
								bin/wiki_entity_linking/wiki_namespaces.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								bin/wiki_entity_linking/wiki_namespaces.py
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,128 @@
 | 
				
			||||||
 | 
					# coding: utf8
 | 
				
			||||||
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# List of meta pages in Wikidata, should be kept out of the Knowledge base
 | 
				
			||||||
 | 
					WD_META_ITEMS = [
 | 
				
			||||||
 | 
					    "Q163875",
 | 
				
			||||||
 | 
					    "Q191780",
 | 
				
			||||||
 | 
					    "Q224414",
 | 
				
			||||||
 | 
					    "Q4167836",
 | 
				
			||||||
 | 
					    "Q4167410",
 | 
				
			||||||
 | 
					    "Q4663903",
 | 
				
			||||||
 | 
					    "Q11266439",
 | 
				
			||||||
 | 
					    "Q13406463",
 | 
				
			||||||
 | 
					    "Q15407973",
 | 
				
			||||||
 | 
					    "Q18616576",
 | 
				
			||||||
 | 
					    "Q19887878",
 | 
				
			||||||
 | 
					    "Q22808320",
 | 
				
			||||||
 | 
					    "Q23894233",
 | 
				
			||||||
 | 
					    "Q33120876",
 | 
				
			||||||
 | 
					    "Q42104522",
 | 
				
			||||||
 | 
					    "Q47460393",
 | 
				
			||||||
 | 
					    "Q64875536",
 | 
				
			||||||
 | 
					    "Q66480449",
 | 
				
			||||||
 | 
					]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# TODO: add more cases from non-English WP's
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# List of prefixes that refer to Wikipedia "file" pages
 | 
				
			||||||
 | 
					WP_FILE_NAMESPACE = ["Bestand", "File"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# List of prefixes that refer to Wikipedia "category" pages
 | 
				
			||||||
 | 
					WP_CATEGORY_NAMESPACE = ["Kategori", "Category", "Categorie"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					# List of prefixes that refer to Wikipedia "meta" pages
 | 
				
			||||||
 | 
					# these will/should be matched ignoring case
 | 
				
			||||||
 | 
					WP_META_NAMESPACE = (
 | 
				
			||||||
 | 
					    WP_FILE_NAMESPACE
 | 
				
			||||||
 | 
					    + WP_CATEGORY_NAMESPACE
 | 
				
			||||||
 | 
					    + [
 | 
				
			||||||
 | 
					        "b",
 | 
				
			||||||
 | 
					        "betawikiversity",
 | 
				
			||||||
 | 
					        "Book",
 | 
				
			||||||
 | 
					        "c",
 | 
				
			||||||
 | 
					        "Commons",
 | 
				
			||||||
 | 
					        "d",
 | 
				
			||||||
 | 
					        "dbdump",
 | 
				
			||||||
 | 
					        "download",
 | 
				
			||||||
 | 
					        "Draft",
 | 
				
			||||||
 | 
					        "Education",
 | 
				
			||||||
 | 
					        "Foundation",
 | 
				
			||||||
 | 
					        "Gadget",
 | 
				
			||||||
 | 
					        "Gadget definition",
 | 
				
			||||||
 | 
					        "Gebruiker",
 | 
				
			||||||
 | 
					        "gerrit",
 | 
				
			||||||
 | 
					        "Help",
 | 
				
			||||||
 | 
					        "Image",
 | 
				
			||||||
 | 
					        "Incubator",
 | 
				
			||||||
 | 
					        "m",
 | 
				
			||||||
 | 
					        "mail",
 | 
				
			||||||
 | 
					        "mailarchive",
 | 
				
			||||||
 | 
					        "media",
 | 
				
			||||||
 | 
					        "MediaWiki",
 | 
				
			||||||
 | 
					        "MediaWiki talk",
 | 
				
			||||||
 | 
					        "Mediawikiwiki",
 | 
				
			||||||
 | 
					        "MediaZilla",
 | 
				
			||||||
 | 
					        "Meta",
 | 
				
			||||||
 | 
					        "Metawikipedia",
 | 
				
			||||||
 | 
					        "Module",
 | 
				
			||||||
 | 
					        "mw",
 | 
				
			||||||
 | 
					        "n",
 | 
				
			||||||
 | 
					        "nost",
 | 
				
			||||||
 | 
					        "oldwikisource",
 | 
				
			||||||
 | 
					        "otrs",
 | 
				
			||||||
 | 
					        "OTRSwiki",
 | 
				
			||||||
 | 
					        "Overleg gebruiker",
 | 
				
			||||||
 | 
					        "outreach",
 | 
				
			||||||
 | 
					        "outreachwiki",
 | 
				
			||||||
 | 
					        "Portal",
 | 
				
			||||||
 | 
					        "phab",
 | 
				
			||||||
 | 
					        "Phabricator",
 | 
				
			||||||
 | 
					        "Project",
 | 
				
			||||||
 | 
					        "q",
 | 
				
			||||||
 | 
					        "quality",
 | 
				
			||||||
 | 
					        "rev",
 | 
				
			||||||
 | 
					        "s",
 | 
				
			||||||
 | 
					        "spcom",
 | 
				
			||||||
 | 
					        "Special",
 | 
				
			||||||
 | 
					        "species",
 | 
				
			||||||
 | 
					        "Strategy",
 | 
				
			||||||
 | 
					        "sulutil",
 | 
				
			||||||
 | 
					        "svn",
 | 
				
			||||||
 | 
					        "Talk",
 | 
				
			||||||
 | 
					        "Template",
 | 
				
			||||||
 | 
					        "Template talk",
 | 
				
			||||||
 | 
					        "Testwiki",
 | 
				
			||||||
 | 
					        "ticket",
 | 
				
			||||||
 | 
					        "TimedText",
 | 
				
			||||||
 | 
					        "Toollabs",
 | 
				
			||||||
 | 
					        "tools",
 | 
				
			||||||
 | 
					        "tswiki",
 | 
				
			||||||
 | 
					        "User",
 | 
				
			||||||
 | 
					        "User talk",
 | 
				
			||||||
 | 
					        "v",
 | 
				
			||||||
 | 
					        "voy",
 | 
				
			||||||
 | 
					        "w",
 | 
				
			||||||
 | 
					        "Wikibooks",
 | 
				
			||||||
 | 
					        "Wikidata",
 | 
				
			||||||
 | 
					        "wikiHow",
 | 
				
			||||||
 | 
					        "Wikinvest",
 | 
				
			||||||
 | 
					        "wikilivres",
 | 
				
			||||||
 | 
					        "Wikimedia",
 | 
				
			||||||
 | 
					        "Wikinews",
 | 
				
			||||||
 | 
					        "Wikipedia",
 | 
				
			||||||
 | 
					        "Wikipedia talk",
 | 
				
			||||||
 | 
					        "Wikiquote",
 | 
				
			||||||
 | 
					        "Wikisource",
 | 
				
			||||||
 | 
					        "Wikispecies",
 | 
				
			||||||
 | 
					        "Wikitech",
 | 
				
			||||||
 | 
					        "Wikiversity",
 | 
				
			||||||
 | 
					        "Wikivoyage",
 | 
				
			||||||
 | 
					        "wikt",
 | 
				
			||||||
 | 
					        "wiktionary",
 | 
				
			||||||
 | 
					        "wmf",
 | 
				
			||||||
 | 
					        "wmania",
 | 
				
			||||||
 | 
					        "WP",
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
| 
						 | 
					@ -18,11 +18,12 @@ from pathlib import Path
 | 
				
			||||||
import plac
 | 
					import plac
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
 | 
					from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
 | 
				
			||||||
 | 
					from bin.wiki_entity_linking import wiki_io as io
 | 
				
			||||||
from bin.wiki_entity_linking import kb_creator
 | 
					from bin.wiki_entity_linking import kb_creator
 | 
				
			||||||
from bin.wiki_entity_linking import training_set_creator
 | 
					 | 
				
			||||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
 | 
					from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
 | 
				
			||||||
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH
 | 
					from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH, ENTITY_ALIAS_PATH
 | 
				
			||||||
import spacy
 | 
					import spacy
 | 
				
			||||||
 | 
					from bin.wiki_entity_linking.kb_creator import read_kb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -39,9 +40,11 @@ logger = logging.getLogger(__name__)
 | 
				
			||||||
    loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
 | 
					    loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
 | 
				
			||||||
    loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
 | 
					    loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
 | 
				
			||||||
    loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
 | 
					    loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
 | 
				
			||||||
    descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"),
 | 
					    descr_from_wp=("Flag for using wp descriptions not wd", "flag", "wp"),
 | 
				
			||||||
    limit=("Optional threshold to limit lines read from dumps", "option", "l", int),
 | 
					    limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int),
 | 
				
			||||||
    lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str),
 | 
					    limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int),
 | 
				
			||||||
 | 
					    limit_wd=("Threshold to limit lines read from WD", "option", "lw", int),
 | 
				
			||||||
 | 
					    lang=("Optional language for which to get Wikidata titles. Defaults to 'en'", "option", "la", str),
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
def main(
 | 
					def main(
 | 
				
			||||||
    wd_json,
 | 
					    wd_json,
 | 
				
			||||||
| 
						 | 
					@ -54,13 +57,16 @@ def main(
 | 
				
			||||||
    entity_vector_length=64,
 | 
					    entity_vector_length=64,
 | 
				
			||||||
    loc_prior_prob=None,
 | 
					    loc_prior_prob=None,
 | 
				
			||||||
    loc_entity_defs=None,
 | 
					    loc_entity_defs=None,
 | 
				
			||||||
 | 
					    loc_entity_alias=None,
 | 
				
			||||||
    loc_entity_desc=None,
 | 
					    loc_entity_desc=None,
 | 
				
			||||||
    descriptions_from_wikipedia=False,
 | 
					    descr_from_wp=False,
 | 
				
			||||||
    limit=None,
 | 
					    limit_prior=None,
 | 
				
			||||||
 | 
					    limit_train=None,
 | 
				
			||||||
 | 
					    limit_wd=None,
 | 
				
			||||||
    lang="en",
 | 
					    lang="en",
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
 | 
					 | 
				
			||||||
    entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
 | 
					    entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
 | 
				
			||||||
 | 
					    entity_alias_path = loc_entity_alias if loc_entity_alias else output_dir / ENTITY_ALIAS_PATH
 | 
				
			||||||
    entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
 | 
					    entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
 | 
				
			||||||
    entity_freq_path = output_dir / ENTITY_FREQ_PATH
 | 
					    entity_freq_path = output_dir / ENTITY_FREQ_PATH
 | 
				
			||||||
    prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
 | 
					    prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
 | 
				
			||||||
| 
						 | 
					@ -69,15 +75,12 @@ def main(
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info("Creating KB with Wikipedia and WikiData")
 | 
					    logger.info("Creating KB with Wikipedia and WikiData")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if limit is not None:
 | 
					 | 
				
			||||||
        logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit))
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    # STEP 0: set up IO
 | 
					    # STEP 0: set up IO
 | 
				
			||||||
    if not output_dir.exists():
 | 
					    if not output_dir.exists():
 | 
				
			||||||
        output_dir.mkdir(parents=True)
 | 
					        output_dir.mkdir(parents=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 1: create the NLP object
 | 
					    # STEP 1: Load the NLP object
 | 
				
			||||||
    logger.info("STEP 1: Loading model {}".format(model))
 | 
					    logger.info("STEP 1: Loading NLP model {}".format(model))
 | 
				
			||||||
    nlp = spacy.load(model)
 | 
					    nlp = spacy.load(model)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # check the length of the nlp vectors
 | 
					    # check the length of the nlp vectors
 | 
				
			||||||
| 
						 | 
					@ -90,62 +93,83 @@ def main(
 | 
				
			||||||
    # STEP 2: create prior probabilities from WP
 | 
					    # STEP 2: create prior probabilities from WP
 | 
				
			||||||
    if not prior_prob_path.exists():
 | 
					    if not prior_prob_path.exists():
 | 
				
			||||||
        # It takes about 2h to process 1000M lines of Wikipedia XML dump
 | 
					        # It takes about 2h to process 1000M lines of Wikipedia XML dump
 | 
				
			||||||
        logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path))
 | 
					        logger.info("STEP 2: Writing prior probabilities to {}".format(prior_prob_path))
 | 
				
			||||||
        wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit)
 | 
					        if limit_prior is not None:
 | 
				
			||||||
    logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path))
 | 
					            logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_prior))
 | 
				
			||||||
 | 
					        wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit_prior)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        logger.info("STEP 2: Reading prior probabilities from {}".format(prior_prob_path))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 3: deduce entity frequencies from WP (takes only a few minutes)
 | 
					    # STEP 3: calculate entity frequencies
 | 
				
			||||||
    logger.info("STEP 3: calculating entity frequencies")
 | 
					    if not entity_freq_path.exists():
 | 
				
			||||||
    wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False)
 | 
					        logger.info("STEP 3: Calculating and writing entity frequencies to {}".format(entity_freq_path))
 | 
				
			||||||
 | 
					        io.write_entity_to_count(prior_prob_path, entity_freq_path)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        logger.info("STEP 3: Reading entity frequencies from {}".format(entity_freq_path))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
 | 
					    # STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
 | 
				
			||||||
    message = " and descriptions" if not descriptions_from_wikipedia else ""
 | 
					    if (not entity_defs_path.exists()) or (not descr_from_wp and not entity_descr_path.exists()):
 | 
				
			||||||
    if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()):
 | 
					 | 
				
			||||||
        # It takes about 10h to process 55M lines of Wikidata JSON dump
 | 
					        # It takes about 10h to process 55M lines of Wikidata JSON dump
 | 
				
			||||||
        logger.info("STEP 4: parsing wikidata for entity definitions" + message)
 | 
					        logger.info("STEP 4: Parsing and writing Wikidata entity definitions to {}".format(entity_defs_path))
 | 
				
			||||||
        title_to_id, id_to_descr = wd.read_wikidata_entities_json(
 | 
					        if limit_wd is not None:
 | 
				
			||||||
 | 
					            logger.warning("Warning: reading only {} lines of Wikidata dump".format(limit_wd))
 | 
				
			||||||
 | 
					        title_to_id, id_to_descr, id_to_alias = wd.read_wikidata_entities_json(
 | 
				
			||||||
            wd_json,
 | 
					            wd_json,
 | 
				
			||||||
            limit,
 | 
					            limit_wd,
 | 
				
			||||||
            to_print=False,
 | 
					            to_print=False,
 | 
				
			||||||
            lang=lang,
 | 
					            lang=lang,
 | 
				
			||||||
            parse_descriptions=(not descriptions_from_wikipedia),
 | 
					            parse_descr=(not descr_from_wp),
 | 
				
			||||||
        )
 | 
					        )
 | 
				
			||||||
        wd.write_entity_files(entity_defs_path, title_to_id)
 | 
					        io.write_title_to_id(entity_defs_path, title_to_id)
 | 
				
			||||||
        if not descriptions_from_wikipedia:
 | 
					 | 
				
			||||||
            wd.write_entity_description_files(entity_descr_path, id_to_descr)
 | 
					 | 
				
			||||||
    logger.info("STEP 4: read entity definitions" + message)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 5: Getting gold entities from wikipedia
 | 
					        logger.info("STEP 4b: Writing Wikidata entity aliases to {}".format(entity_alias_path))
 | 
				
			||||||
    message = " and descriptions" if descriptions_from_wikipedia else ""
 | 
					        io.write_id_to_alias(entity_alias_path, id_to_alias)
 | 
				
			||||||
    if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()):
 | 
					
 | 
				
			||||||
        logger.info("STEP 5: parsing wikipedia for gold entities" + message)
 | 
					        if not descr_from_wp:
 | 
				
			||||||
        training_set_creator.create_training_examples_and_descriptions(
 | 
					            logger.info("STEP 4c: Writing Wikidata entity descriptions to {}".format(entity_descr_path))
 | 
				
			||||||
            wp_xml,
 | 
					            io.write_id_to_descr(entity_descr_path, id_to_descr)
 | 
				
			||||||
            entity_defs_path,
 | 
					    else:
 | 
				
			||||||
            entity_descr_path,
 | 
					        logger.info("STEP 4: Reading entity definitions from {}".format(entity_defs_path))
 | 
				
			||||||
            training_entities_path,
 | 
					        logger.info("STEP 4b: Reading entity aliases from {}".format(entity_alias_path))
 | 
				
			||||||
            parse_descriptions=descriptions_from_wikipedia,
 | 
					        if not descr_from_wp:
 | 
				
			||||||
            limit=limit,
 | 
					            logger.info("STEP 4c: Reading entity descriptions from {}".format(entity_descr_path))
 | 
				
			||||||
        )
 | 
					
 | 
				
			||||||
    logger.info("STEP 5: read gold entities" + message)
 | 
					    # STEP 5: Getting gold entities from Wikipedia
 | 
				
			||||||
 | 
					    if (not training_entities_path.exists()) or (descr_from_wp and not entity_descr_path.exists()):
 | 
				
			||||||
 | 
					        logger.info("STEP 5: Parsing and writing Wikipedia gold entities to {}".format(training_entities_path))
 | 
				
			||||||
 | 
					        if limit_train is not None:
 | 
				
			||||||
 | 
					            logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_train))
 | 
				
			||||||
 | 
					        wp.create_training_and_desc(wp_xml, entity_defs_path, entity_descr_path,
 | 
				
			||||||
 | 
					                                    training_entities_path, descr_from_wp, limit_train)
 | 
				
			||||||
 | 
					        if descr_from_wp:
 | 
				
			||||||
 | 
					            logger.info("STEP 5b: Parsing and writing Wikipedia descriptions to {}".format(entity_descr_path))
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        logger.info("STEP 5: Reading gold entities from {}".format(training_entities_path))
 | 
				
			||||||
 | 
					        if descr_from_wp:
 | 
				
			||||||
 | 
					            logger.info("STEP 5b: Reading entity descriptions from {}".format(entity_descr_path))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 6: creating the actual KB
 | 
					    # STEP 6: creating the actual KB
 | 
				
			||||||
    # It takes ca. 30 minutes to pretrain the entity embeddings
 | 
					    # It takes ca. 30 minutes to pretrain the entity embeddings
 | 
				
			||||||
    logger.info("STEP 6: creating the KB at {}".format(kb_path))
 | 
					    if not kb_path.exists():
 | 
				
			||||||
    kb = kb_creator.create_kb(
 | 
					        logger.info("STEP 6: Creating the KB at {}".format(kb_path))
 | 
				
			||||||
        nlp=nlp,
 | 
					        kb = kb_creator.create_kb(
 | 
				
			||||||
        max_entities_per_alias=max_per_alias,
 | 
					            nlp=nlp,
 | 
				
			||||||
        min_entity_freq=min_freq,
 | 
					            max_entities_per_alias=max_per_alias,
 | 
				
			||||||
        min_occ=min_pair,
 | 
					            min_entity_freq=min_freq,
 | 
				
			||||||
        entity_def_input=entity_defs_path,
 | 
					            min_occ=min_pair,
 | 
				
			||||||
        entity_descr_path=entity_descr_path,
 | 
					            entity_def_path=entity_defs_path,
 | 
				
			||||||
        count_input=entity_freq_path,
 | 
					            entity_descr_path=entity_descr_path,
 | 
				
			||||||
        prior_prob_input=prior_prob_path,
 | 
					            entity_alias_path=entity_alias_path,
 | 
				
			||||||
        entity_vector_length=entity_vector_length,
 | 
					            entity_freq_path=entity_freq_path,
 | 
				
			||||||
    )
 | 
					            prior_prob_path=prior_prob_path,
 | 
				
			||||||
 | 
					            entity_vector_length=entity_vector_length,
 | 
				
			||||||
    kb.dump(kb_path)
 | 
					        )
 | 
				
			||||||
    nlp.to_disk(output_dir / KB_MODEL_DIR)
 | 
					        kb.dump(kb_path)
 | 
				
			||||||
 | 
					        logger.info("kb entities: {}".format(kb.get_size_entities()))
 | 
				
			||||||
 | 
					        logger.info("kb aliases: {}".format(kb.get_size_aliases()))
 | 
				
			||||||
 | 
					        nlp.to_disk(output_dir / KB_MODEL_DIR)
 | 
				
			||||||
 | 
					    else:
 | 
				
			||||||
 | 
					        logger.info("STEP 6: KB already exists at {}".format(kb_path))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    logger.info("Done!")
 | 
					    logger.info("Done!")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1,40 +1,52 @@
 | 
				
			||||||
# coding: utf-8
 | 
					# coding: utf-8
 | 
				
			||||||
from __future__ import unicode_literals
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import gzip
 | 
					import bz2
 | 
				
			||||||
import json
 | 
					import json
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
import datetime
 | 
					
 | 
				
			||||||
 | 
					from bin.wiki_entity_linking.wiki_namespaces import WD_META_ITEMS
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True):
 | 
					def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descr=True):
 | 
				
			||||||
    # Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
 | 
					    # Read the JSON wiki data and parse out the entities. Takes about 7-10h to parse 55M lines.
 | 
				
			||||||
    # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
 | 
					    # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    site_filter = '{}wiki'.format(lang)
 | 
					    site_filter = '{}wiki'.format(lang)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # properties filter (currently disabled to get ALL data)
 | 
					    # filter: currently defined as OR: one hit suffices to be removed from further processing
 | 
				
			||||||
    prop_filter = dict()
 | 
					    exclude_list = WD_META_ITEMS
 | 
				
			||||||
    # prop_filter = {'P31': {'Q5', 'Q15632617'}}     # currently defined as OR: one property suffices to be selected
 | 
					
 | 
				
			||||||
 | 
					    # punctuation
 | 
				
			||||||
 | 
					    exclude_list.extend(["Q1383557", "Q10617810"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # letters etc
 | 
				
			||||||
 | 
					    exclude_list.extend(["Q188725", "Q19776628", "Q3841820", "Q17907810", "Q9788", "Q9398093"])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    neg_prop_filter = {
 | 
				
			||||||
 | 
					        'P31': exclude_list,    # instance of
 | 
				
			||||||
 | 
					        'P279': exclude_list    # subclass
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    title_to_id = dict()
 | 
					    title_to_id = dict()
 | 
				
			||||||
    id_to_descr = dict()
 | 
					    id_to_descr = dict()
 | 
				
			||||||
 | 
					    id_to_alias = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # parse appropriate fields - depending on what we need in the KB
 | 
					    # parse appropriate fields - depending on what we need in the KB
 | 
				
			||||||
    parse_properties = False
 | 
					    parse_properties = False
 | 
				
			||||||
    parse_sitelinks = True
 | 
					    parse_sitelinks = True
 | 
				
			||||||
    parse_labels = False
 | 
					    parse_labels = False
 | 
				
			||||||
    parse_aliases = False
 | 
					    parse_aliases = True
 | 
				
			||||||
    parse_claims = False
 | 
					    parse_claims = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    with gzip.open(wikidata_file, mode='rb') as file:
 | 
					    with bz2.open(wikidata_file, mode='rb') as file:
 | 
				
			||||||
        for cnt, line in enumerate(file):
 | 
					        for cnt, line in enumerate(file):
 | 
				
			||||||
            if limit and cnt >= limit:
 | 
					            if limit and cnt >= limit:
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
            if cnt % 500000 == 0:
 | 
					            if cnt % 500000 == 0 and cnt > 0:
 | 
				
			||||||
                logger.info("processed {} lines of WikiData dump".format(cnt))
 | 
					                logger.info("processed {} lines of WikiData JSON dump".format(cnt))
 | 
				
			||||||
            clean_line = line.strip()
 | 
					            clean_line = line.strip()
 | 
				
			||||||
            if clean_line.endswith(b","):
 | 
					            if clean_line.endswith(b","):
 | 
				
			||||||
                clean_line = clean_line[:-1]
 | 
					                clean_line = clean_line[:-1]
 | 
				
			||||||
| 
						 | 
					@ -43,13 +55,11 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
 | 
				
			||||||
                entry_type = obj["type"]
 | 
					                entry_type = obj["type"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                if entry_type == "item":
 | 
					                if entry_type == "item":
 | 
				
			||||||
                    # filtering records on their properties (currently disabled to get ALL data)
 | 
					 | 
				
			||||||
                    # keep = False
 | 
					 | 
				
			||||||
                    keep = True
 | 
					                    keep = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    claims = obj["claims"]
 | 
					                    claims = obj["claims"]
 | 
				
			||||||
                    if parse_claims:
 | 
					                    if parse_claims:
 | 
				
			||||||
                        for prop, value_set in prop_filter.items():
 | 
					                        for prop, value_set in neg_prop_filter.items():
 | 
				
			||||||
                            claim_property = claims.get(prop, None)
 | 
					                            claim_property = claims.get(prop, None)
 | 
				
			||||||
                            if claim_property:
 | 
					                            if claim_property:
 | 
				
			||||||
                                for cp in claim_property:
 | 
					                                for cp in claim_property:
 | 
				
			||||||
| 
						 | 
					@ -61,7 +71,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
 | 
				
			||||||
                                    )
 | 
					                                    )
 | 
				
			||||||
                                    cp_rank = cp["rank"]
 | 
					                                    cp_rank = cp["rank"]
 | 
				
			||||||
                                    if cp_rank != "deprecated" and cp_id in value_set:
 | 
					                                    if cp_rank != "deprecated" and cp_id in value_set:
 | 
				
			||||||
                                        keep = True
 | 
					                                        keep = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    if keep:
 | 
					                    if keep:
 | 
				
			||||||
                        unique_id = obj["id"]
 | 
					                        unique_id = obj["id"]
 | 
				
			||||||
| 
						 | 
					@ -108,7 +118,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
 | 
				
			||||||
                                            "label (" + lang + "):", lang_label["value"]
 | 
					                                            "label (" + lang + "):", lang_label["value"]
 | 
				
			||||||
                                        )
 | 
					                                        )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        if found_link and parse_descriptions:
 | 
					                        if found_link and parse_descr:
 | 
				
			||||||
                            descriptions = obj["descriptions"]
 | 
					                            descriptions = obj["descriptions"]
 | 
				
			||||||
                            if descriptions:
 | 
					                            if descriptions:
 | 
				
			||||||
                                lang_descr = descriptions.get(lang, None)
 | 
					                                lang_descr = descriptions.get(lang, None)
 | 
				
			||||||
| 
						 | 
					@ -130,22 +140,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang=
 | 
				
			||||||
                                            print(
 | 
					                                            print(
 | 
				
			||||||
                                                "alias (" + lang + "):", item["value"]
 | 
					                                                "alias (" + lang + "):", item["value"]
 | 
				
			||||||
                                            )
 | 
					                                            )
 | 
				
			||||||
 | 
					                                        alias_list = id_to_alias.get(unique_id, [])
 | 
				
			||||||
 | 
					                                        alias_list.append(item["value"])
 | 
				
			||||||
 | 
					                                        id_to_alias[unique_id] = alias_list
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        if to_print:
 | 
					                        if to_print:
 | 
				
			||||||
                            print()
 | 
					                            print()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return title_to_id, id_to_descr
 | 
					    # log final number of lines processed
 | 
				
			||||||
 | 
					    logger.info("Finished. Processed {} lines of WikiData JSON dump".format(cnt))
 | 
				
			||||||
 | 
					    return title_to_id, id_to_descr, id_to_alias
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def write_entity_files(entity_def_output, title_to_id):
 | 
					 | 
				
			||||||
    with entity_def_output.open("w", encoding="utf8") as id_file:
 | 
					 | 
				
			||||||
        id_file.write("WP_title" + "|" + "WD_id" + "\n")
 | 
					 | 
				
			||||||
        for title, qid in title_to_id.items():
 | 
					 | 
				
			||||||
            id_file.write(title + "|" + str(qid) + "\n")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
def write_entity_description_files(entity_descr_output, id_to_descr):
 | 
					 | 
				
			||||||
    with entity_descr_output.open("w", encoding="utf8") as descr_file:
 | 
					 | 
				
			||||||
        descr_file.write("WD_id" + "|" + "description" + "\n")
 | 
					 | 
				
			||||||
        for qid, descr in id_to_descr.items():
 | 
					 | 
				
			||||||
            descr_file.write(str(qid) + "|" + descr + "\n")
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,19 +6,19 @@ as created by the script `wikidata_create_kb`.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2
 | 
					For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2
 | 
				
			||||||
from https://dumps.wikimedia.org/enwiki/latest/
 | 
					from https://dumps.wikimedia.org/enwiki/latest/
 | 
				
			||||||
 | 
					 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
from __future__ import unicode_literals
 | 
					from __future__ import unicode_literals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import random
 | 
					import random
 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					import spacy
 | 
				
			||||||
from pathlib import Path
 | 
					from pathlib import Path
 | 
				
			||||||
import plac
 | 
					import plac
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from bin.wiki_entity_linking import training_set_creator
 | 
					from bin.wiki_entity_linking import wikipedia_processor
 | 
				
			||||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
 | 
					from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
 | 
				
			||||||
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines
 | 
					from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance
 | 
				
			||||||
from bin.wiki_entity_linking.kb_creator import read_nlp_kb
 | 
					from bin.wiki_entity_linking.kb_creator import read_kb
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from spacy.util import minibatch, compounding
 | 
					from spacy.util import minibatch, compounding
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -35,6 +35,7 @@ logger = logging.getLogger(__name__)
 | 
				
			||||||
    l2=("L2 regularization", "option", "r", float),
 | 
					    l2=("L2 regularization", "option", "r", float),
 | 
				
			||||||
    train_inst=("# training instances (default 90% of all)", "option", "t", int),
 | 
					    train_inst=("# training instances (default 90% of all)", "option", "t", int),
 | 
				
			||||||
    dev_inst=("# test instances (default 10% of all)", "option", "d", int),
 | 
					    dev_inst=("# test instances (default 10% of all)", "option", "d", int),
 | 
				
			||||||
 | 
					    labels_discard=("NER labels to discard (default None)", "option", "l", str),
 | 
				
			||||||
)
 | 
					)
 | 
				
			||||||
def main(
 | 
					def main(
 | 
				
			||||||
    dir_kb,
 | 
					    dir_kb,
 | 
				
			||||||
| 
						 | 
					@ -46,13 +47,14 @@ def main(
 | 
				
			||||||
    l2=1e-6,
 | 
					    l2=1e-6,
 | 
				
			||||||
    train_inst=None,
 | 
					    train_inst=None,
 | 
				
			||||||
    dev_inst=None,
 | 
					    dev_inst=None,
 | 
				
			||||||
 | 
					    labels_discard=None
 | 
				
			||||||
):
 | 
					):
 | 
				
			||||||
    logger.info("Creating Entity Linker with Wikipedia and WikiData")
 | 
					    logger.info("Creating Entity Linker with Wikipedia and WikiData")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    output_dir = Path(output_dir) if output_dir else dir_kb
 | 
					    output_dir = Path(output_dir) if output_dir else dir_kb
 | 
				
			||||||
    training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE
 | 
					    training_path = loc_training if loc_training else dir_kb / TRAINING_DATA_FILE
 | 
				
			||||||
    nlp_dir = dir_kb / KB_MODEL_DIR
 | 
					    nlp_dir = dir_kb / KB_MODEL_DIR
 | 
				
			||||||
    kb_path = output_dir / KB_FILE
 | 
					    kb_path = dir_kb / KB_FILE
 | 
				
			||||||
    nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
 | 
					    nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 0: set up IO
 | 
					    # STEP 0: set up IO
 | 
				
			||||||
| 
						 | 
					@ -60,38 +62,47 @@ def main(
 | 
				
			||||||
        output_dir.mkdir()
 | 
					        output_dir.mkdir()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 1 : load the NLP object
 | 
					    # STEP 1 : load the NLP object
 | 
				
			||||||
    logger.info("STEP 1: loading model from {}".format(nlp_dir))
 | 
					    logger.info("STEP 1a: Loading model from {}".format(nlp_dir))
 | 
				
			||||||
    nlp, kb = read_nlp_kb(nlp_dir, kb_path)
 | 
					    nlp = spacy.load(nlp_dir)
 | 
				
			||||||
 | 
					    logger.info("STEP 1b: Loading KB from {}".format(kb_path))
 | 
				
			||||||
 | 
					    kb = read_kb(nlp, kb_path)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # check that there is a NER component in the pipeline
 | 
					    # check that there is a NER component in the pipeline
 | 
				
			||||||
    if "ner" not in nlp.pipe_names:
 | 
					    if "ner" not in nlp.pipe_names:
 | 
				
			||||||
        raise ValueError("The `nlp` object should have a pretrained `ner` component.")
 | 
					        raise ValueError("The `nlp` object should have a pretrained `ner` component.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 2: create a training dataset from WP
 | 
					    # STEP 2: read the training dataset previously created from WP
 | 
				
			||||||
    logger.info("STEP 2: reading training dataset from {}".format(training_path))
 | 
					    logger.info("STEP 2: Reading training dataset from {}".format(training_path))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    train_data = training_set_creator.read_training(
 | 
					    if labels_discard:
 | 
				
			||||||
 | 
					        labels_discard = [x.strip() for x in labels_discard.split(",")]
 | 
				
			||||||
 | 
					        logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    train_data = wikipedia_processor.read_training(
 | 
				
			||||||
        nlp=nlp,
 | 
					        nlp=nlp,
 | 
				
			||||||
        entity_file_path=training_path,
 | 
					        entity_file_path=training_path,
 | 
				
			||||||
        dev=False,
 | 
					        dev=False,
 | 
				
			||||||
        limit=train_inst,
 | 
					        limit=train_inst,
 | 
				
			||||||
        kb=kb,
 | 
					        kb=kb,
 | 
				
			||||||
 | 
					        labels_discard=labels_discard
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # for testing, get all pos instances, whether or not they are in the kb
 | 
					    # for testing, get all pos instances (independently of KB)
 | 
				
			||||||
    dev_data = training_set_creator.read_training(
 | 
					    dev_data = wikipedia_processor.read_training(
 | 
				
			||||||
        nlp=nlp,
 | 
					        nlp=nlp,
 | 
				
			||||||
        entity_file_path=training_path,
 | 
					        entity_file_path=training_path,
 | 
				
			||||||
        dev=True,
 | 
					        dev=True,
 | 
				
			||||||
        limit=dev_inst,
 | 
					        limit=dev_inst,
 | 
				
			||||||
        kb=kb,
 | 
					        kb=None,
 | 
				
			||||||
 | 
					        labels_discard=labels_discard
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 3: create and train the entity linking pipe
 | 
					    # STEP 3: create and train an entity linking pipe
 | 
				
			||||||
    logger.info("STEP 3: training Entity Linking pipe")
 | 
					    logger.info("STEP 3: Creating and training an Entity Linking pipe")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    el_pipe = nlp.create_pipe(
 | 
					    el_pipe = nlp.create_pipe(
 | 
				
			||||||
        name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
 | 
					        name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name,
 | 
				
			||||||
 | 
					                                      "labels_discard": labels_discard}
 | 
				
			||||||
    )
 | 
					    )
 | 
				
			||||||
    el_pipe.set_kb(kb)
 | 
					    el_pipe.set_kb(kb)
 | 
				
			||||||
    nlp.add_pipe(el_pipe, last=True)
 | 
					    nlp.add_pipe(el_pipe, last=True)
 | 
				
			||||||
| 
						 | 
					@ -105,14 +116,9 @@ def main(
 | 
				
			||||||
    logger.info("Training on {} articles".format(len(train_data)))
 | 
					    logger.info("Training on {} articles".format(len(train_data)))
 | 
				
			||||||
    logger.info("Dev testing on {} articles".format(len(dev_data)))
 | 
					    logger.info("Dev testing on {} articles".format(len(dev_data)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    dev_baseline_accuracies = measure_baselines(
 | 
					    # baseline performance on dev data
 | 
				
			||||||
        dev_data, kb
 | 
					 | 
				
			||||||
    )
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    logger.info("Dev Baseline Accuracies:")
 | 
					    logger.info("Dev Baseline Accuracies:")
 | 
				
			||||||
    logger.info(dev_baseline_accuracies.report_accuracy("random"))
 | 
					    measure_performance(dev_data, kb, el_pipe, baseline=True, context=False)
 | 
				
			||||||
    logger.info(dev_baseline_accuracies.report_accuracy("prior"))
 | 
					 | 
				
			||||||
    logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    for itn in range(epochs):
 | 
					    for itn in range(epochs):
 | 
				
			||||||
        random.shuffle(train_data)
 | 
					        random.shuffle(train_data)
 | 
				
			||||||
| 
						 | 
					@ -136,18 +142,18 @@ def main(
 | 
				
			||||||
                    logger.error("Error updating batch:" + str(e))
 | 
					                    logger.error("Error updating batch:" + str(e))
 | 
				
			||||||
        if batchnr > 0:
 | 
					        if batchnr > 0:
 | 
				
			||||||
            logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
 | 
					            logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
 | 
				
			||||||
        measure_performance(dev_data, kb, el_pipe)
 | 
					            measure_performance(dev_data, kb, el_pipe, baseline=False, context=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 4: measure the performance of our trained pipe on an independent dev set
 | 
					    # STEP 4: measure the performance of our trained pipe on an independent dev set
 | 
				
			||||||
    logger.info("STEP 4: performance measurement of Entity Linking pipe")
 | 
					    logger.info("STEP 4: Final performance measurement of Entity Linking pipe")
 | 
				
			||||||
    measure_performance(dev_data, kb, el_pipe)
 | 
					    measure_performance(dev_data, kb, el_pipe)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # STEP 5: apply the EL pipe on a toy example
 | 
					    # STEP 5: apply the EL pipe on a toy example
 | 
				
			||||||
    logger.info("STEP 5: applying Entity Linking to toy example")
 | 
					    logger.info("STEP 5: Applying Entity Linking to toy example")
 | 
				
			||||||
    run_el_toy_example(nlp=nlp)
 | 
					    run_el_toy_example(nlp=nlp)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    if output_dir:
 | 
					    if output_dir:
 | 
				
			||||||
        # STEP 6: write the NLP pipeline (including entity linker) to file
 | 
					        # STEP 6: write the NLP pipeline (now including an EL model) to file
 | 
				
			||||||
        logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
 | 
					        logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
 | 
				
			||||||
        nlp.to_disk(nlp_output_dir)
 | 
					        nlp.to_disk(nlp_output_dir)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -3,147 +3,104 @@ from __future__ import unicode_literals
 | 
				
			||||||
 | 
					
 | 
				
			||||||
import re
 | 
					import re
 | 
				
			||||||
import bz2
 | 
					import bz2
 | 
				
			||||||
import csv
 | 
					 | 
				
			||||||
import datetime
 | 
					 | 
				
			||||||
import logging
 | 
					import logging
 | 
				
			||||||
 | 
					import random
 | 
				
			||||||
 | 
					import json
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from bin.wiki_entity_linking import LOG_FORMAT
 | 
					from functools import partial
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					from spacy.gold import GoldParse
 | 
				
			||||||
 | 
					from bin.wiki_entity_linking import wiki_io as io
 | 
				
			||||||
 | 
					from bin.wiki_entity_linking.wiki_namespaces import (
 | 
				
			||||||
 | 
					    WP_META_NAMESPACE,
 | 
				
			||||||
 | 
					    WP_FILE_NAMESPACE,
 | 
				
			||||||
 | 
					    WP_CATEGORY_NAMESPACE,
 | 
				
			||||||
 | 
					)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
 | 
					Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
 | 
				
			||||||
Write these results to file for downstream KB and training data generation.
 | 
					Write these results to file for downstream KB and training data generation.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					Process Wikipedia interlinks to generate a training dataset for the EL algorithm.
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					ENTITY_FILE = "gold_entities.csv"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
map_alias_to_link = dict()
 | 
					map_alias_to_link = dict()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
logger = logging.getLogger(__name__)
 | 
					logger = logging.getLogger(__name__)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					title_regex = re.compile(r"(?<=<title>).*(?=</title>)")
 | 
				
			||||||
# these will/should be matched ignoring case
 | 
					id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
 | 
				
			||||||
wiki_namespaces = [
 | 
					text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
 | 
				
			||||||
    "b",
 | 
					info_regex = re.compile(r"{[^{]*?}")
 | 
				
			||||||
    "betawikiversity",
 | 
					html_regex = re.compile(r"<!--[^-]*-->")
 | 
				
			||||||
    "Book",
 | 
					ref_regex = re.compile(r"<ref.*?>")  # non-greedy
 | 
				
			||||||
    "c",
 | 
					ref_2_regex = re.compile(r"</ref.*?>")  # non-greedy
 | 
				
			||||||
    "Category",
 | 
					 | 
				
			||||||
    "Commons",
 | 
					 | 
				
			||||||
    "d",
 | 
					 | 
				
			||||||
    "dbdump",
 | 
					 | 
				
			||||||
    "download",
 | 
					 | 
				
			||||||
    "Draft",
 | 
					 | 
				
			||||||
    "Education",
 | 
					 | 
				
			||||||
    "Foundation",
 | 
					 | 
				
			||||||
    "Gadget",
 | 
					 | 
				
			||||||
    "Gadget definition",
 | 
					 | 
				
			||||||
    "gerrit",
 | 
					 | 
				
			||||||
    "File",
 | 
					 | 
				
			||||||
    "Help",
 | 
					 | 
				
			||||||
    "Image",
 | 
					 | 
				
			||||||
    "Incubator",
 | 
					 | 
				
			||||||
    "m",
 | 
					 | 
				
			||||||
    "mail",
 | 
					 | 
				
			||||||
    "mailarchive",
 | 
					 | 
				
			||||||
    "media",
 | 
					 | 
				
			||||||
    "MediaWiki",
 | 
					 | 
				
			||||||
    "MediaWiki talk",
 | 
					 | 
				
			||||||
    "Mediawikiwiki",
 | 
					 | 
				
			||||||
    "MediaZilla",
 | 
					 | 
				
			||||||
    "Meta",
 | 
					 | 
				
			||||||
    "Metawikipedia",
 | 
					 | 
				
			||||||
    "Module",
 | 
					 | 
				
			||||||
    "mw",
 | 
					 | 
				
			||||||
    "n",
 | 
					 | 
				
			||||||
    "nost",
 | 
					 | 
				
			||||||
    "oldwikisource",
 | 
					 | 
				
			||||||
    "outreach",
 | 
					 | 
				
			||||||
    "outreachwiki",
 | 
					 | 
				
			||||||
    "otrs",
 | 
					 | 
				
			||||||
    "OTRSwiki",
 | 
					 | 
				
			||||||
    "Portal",
 | 
					 | 
				
			||||||
    "phab",
 | 
					 | 
				
			||||||
    "Phabricator",
 | 
					 | 
				
			||||||
    "Project",
 | 
					 | 
				
			||||||
    "q",
 | 
					 | 
				
			||||||
    "quality",
 | 
					 | 
				
			||||||
    "rev",
 | 
					 | 
				
			||||||
    "s",
 | 
					 | 
				
			||||||
    "spcom",
 | 
					 | 
				
			||||||
    "Special",
 | 
					 | 
				
			||||||
    "species",
 | 
					 | 
				
			||||||
    "Strategy",
 | 
					 | 
				
			||||||
    "sulutil",
 | 
					 | 
				
			||||||
    "svn",
 | 
					 | 
				
			||||||
    "Talk",
 | 
					 | 
				
			||||||
    "Template",
 | 
					 | 
				
			||||||
    "Template talk",
 | 
					 | 
				
			||||||
    "Testwiki",
 | 
					 | 
				
			||||||
    "ticket",
 | 
					 | 
				
			||||||
    "TimedText",
 | 
					 | 
				
			||||||
    "Toollabs",
 | 
					 | 
				
			||||||
    "tools",
 | 
					 | 
				
			||||||
    "tswiki",
 | 
					 | 
				
			||||||
    "User",
 | 
					 | 
				
			||||||
    "User talk",
 | 
					 | 
				
			||||||
    "v",
 | 
					 | 
				
			||||||
    "voy",
 | 
					 | 
				
			||||||
    "w",
 | 
					 | 
				
			||||||
    "Wikibooks",
 | 
					 | 
				
			||||||
    "Wikidata",
 | 
					 | 
				
			||||||
    "wikiHow",
 | 
					 | 
				
			||||||
    "Wikinvest",
 | 
					 | 
				
			||||||
    "wikilivres",
 | 
					 | 
				
			||||||
    "Wikimedia",
 | 
					 | 
				
			||||||
    "Wikinews",
 | 
					 | 
				
			||||||
    "Wikipedia",
 | 
					 | 
				
			||||||
    "Wikipedia talk",
 | 
					 | 
				
			||||||
    "Wikiquote",
 | 
					 | 
				
			||||||
    "Wikisource",
 | 
					 | 
				
			||||||
    "Wikispecies",
 | 
					 | 
				
			||||||
    "Wikitech",
 | 
					 | 
				
			||||||
    "Wikiversity",
 | 
					 | 
				
			||||||
    "Wikivoyage",
 | 
					 | 
				
			||||||
    "wikt",
 | 
					 | 
				
			||||||
    "wiktionary",
 | 
					 | 
				
			||||||
    "wmf",
 | 
					 | 
				
			||||||
    "wmania",
 | 
					 | 
				
			||||||
    "WP",
 | 
					 | 
				
			||||||
]
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
# find the links
 | 
					# find the links
 | 
				
			||||||
link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
 | 
					link_regex = re.compile(r"\[\[[^\[\]]*\]\]")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
# match on interwiki links, e.g. `en:` or `:fr:`
 | 
					# match on interwiki links, e.g. `en:` or `:fr:`
 | 
				
			||||||
ns_regex = r":?" + "[a-z][a-z]" + ":"
 | 
					ns_regex = r":?" + "[a-z][a-z]" + ":"
 | 
				
			||||||
 | 
					 | 
				
			||||||
# match on Namespace: optionally preceded by a :
 | 
					# match on Namespace: optionally preceded by a :
 | 
				
			||||||
for ns in wiki_namespaces:
 | 
					for ns in WP_META_NAMESPACE:
 | 
				
			||||||
    ns_regex += "|" + ":?" + ns + ":"
 | 
					    ns_regex += "|" + ":?" + ns + ":"
 | 
				
			||||||
 | 
					 | 
				
			||||||
ns_regex = re.compile(ns_regex, re.IGNORECASE)
 | 
					ns_regex = re.compile(ns_regex, re.IGNORECASE)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					files = r""
 | 
				
			||||||
 | 
					for f in WP_FILE_NAMESPACE:
 | 
				
			||||||
 | 
					    files += "\[\[" + f + ":[^[\]]+]]" + "|"
 | 
				
			||||||
 | 
					files = files[0 : len(files) - 1]
 | 
				
			||||||
 | 
					file_regex = re.compile(files)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cats = r""
 | 
				
			||||||
 | 
					for c in WP_CATEGORY_NAMESPACE:
 | 
				
			||||||
 | 
					    cats += "\[\[" + c + ":[^\[]*]]" + "|"
 | 
				
			||||||
 | 
					cats = cats[0 : len(cats) - 1]
 | 
				
			||||||
 | 
					category_regex = re.compile(cats)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
 | 
					def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
 | 
					    Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
 | 
				
			||||||
    The full file takes about 2h to parse 1100M lines.
 | 
					    The full file takes about 2-3h to parse 1100M lines.
 | 
				
			||||||
    It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from.
 | 
					    It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from,
 | 
				
			||||||
 | 
					    though dev test articles are excluded in order not to get an artificially strong baseline.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
 | 
					    cnt = 0
 | 
				
			||||||
 | 
					    read_id = False
 | 
				
			||||||
 | 
					    current_article_id = None
 | 
				
			||||||
    with bz2.open(wikipedia_input, mode="rb") as file:
 | 
					    with bz2.open(wikipedia_input, mode="rb") as file:
 | 
				
			||||||
        line = file.readline()
 | 
					        line = file.readline()
 | 
				
			||||||
        cnt = 0
 | 
					 | 
				
			||||||
        while line and (not limit or cnt < limit):
 | 
					        while line and (not limit or cnt < limit):
 | 
				
			||||||
            if cnt % 25000000 == 0:
 | 
					            if cnt % 25000000 == 0 and cnt > 0:
 | 
				
			||||||
                logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
 | 
					                logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
 | 
				
			||||||
            clean_line = line.strip().decode("utf-8")
 | 
					            clean_line = line.strip().decode("utf-8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            aliases, entities, normalizations = get_wp_links(clean_line)
 | 
					            # we attempt at reading the article's ID (but not the revision or contributor ID)
 | 
				
			||||||
            for alias, entity, norm in zip(aliases, entities, normalizations):
 | 
					            if "<revision>" in clean_line or "<contributor>" in clean_line:
 | 
				
			||||||
                _store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
 | 
					                read_id = False
 | 
				
			||||||
                _store_alias(alias, entity, normalize_alias=norm, normalize_entity=True)
 | 
					            if "<page>" in clean_line:
 | 
				
			||||||
 | 
					                read_id = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if read_id:
 | 
				
			||||||
 | 
					                ids = id_regex.search(clean_line)
 | 
				
			||||||
 | 
					                if ids:
 | 
				
			||||||
 | 
					                    current_article_id = ids[0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            # only processing prior probabilities from true training (non-dev) articles
 | 
				
			||||||
 | 
					            if not is_dev(current_article_id):
 | 
				
			||||||
 | 
					                aliases, entities, normalizations = get_wp_links(clean_line)
 | 
				
			||||||
 | 
					                for alias, entity, norm in zip(aliases, entities, normalizations):
 | 
				
			||||||
 | 
					                    _store_alias(
 | 
				
			||||||
 | 
					                        alias, entity, normalize_alias=norm, normalize_entity=True
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            line = file.readline()
 | 
					            line = file.readline()
 | 
				
			||||||
            cnt += 1
 | 
					            cnt += 1
 | 
				
			||||||
        logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
 | 
					        logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
 | 
				
			||||||
 | 
					    logger.info("Finished. processed {} lines of Wikipedia XML dump".format(cnt))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    # write all aliases and their entities and count occurrences to file
 | 
					    # write all aliases and their entities and count occurrences to file
 | 
				
			||||||
    with prior_prob_output.open("w", encoding="utf8") as outputfile:
 | 
					    with prior_prob_output.open("w", encoding="utf8") as outputfile:
 | 
				
			||||||
| 
						 | 
					@ -182,7 +139,7 @@ def get_wp_links(text):
 | 
				
			||||||
        match = match[2:][:-2].replace("_", " ").strip()
 | 
					        match = match[2:][:-2].replace("_", " ").strip()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if ns_regex.match(match):
 | 
					        if ns_regex.match(match):
 | 
				
			||||||
            pass  # ignore namespaces at the beginning of the string
 | 
					            pass  # ignore the entity if it points to a "meta" page
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # this is a simple [[link]], with the alias the same as the mention
 | 
					        # this is a simple [[link]], with the alias the same as the mention
 | 
				
			||||||
        elif "|" not in match:
 | 
					        elif "|" not in match:
 | 
				
			||||||
| 
						 | 
					@ -218,47 +175,382 @@ def _capitalize_first(text):
 | 
				
			||||||
    return result
 | 
					    return result
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def write_entity_counts(prior_prob_input, count_output, to_print=False):
 | 
					def create_training_and_desc(
 | 
				
			||||||
    # Write entity counts for quick access later
 | 
					    wp_input, def_input, desc_output, training_output, parse_desc, limit=None
 | 
				
			||||||
    entity_to_count = dict()
 | 
					):
 | 
				
			||||||
    total_count = 0
 | 
					    wp_to_id = io.read_title_to_id(def_input)
 | 
				
			||||||
 | 
					    _process_wikipedia_texts(
 | 
				
			||||||
    with prior_prob_input.open("r", encoding="utf8") as prior_file:
 | 
					        wp_input, wp_to_id, desc_output, training_output, parse_desc, limit
 | 
				
			||||||
        # skip header
 | 
					    )
 | 
				
			||||||
        prior_file.readline()
 | 
					 | 
				
			||||||
        line = prior_file.readline()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        while line:
 | 
					 | 
				
			||||||
            splits = line.replace("\n", "").split(sep="|")
 | 
					 | 
				
			||||||
            # alias = splits[0]
 | 
					 | 
				
			||||||
            count = int(splits[1])
 | 
					 | 
				
			||||||
            entity = splits[2]
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            current_count = entity_to_count.get(entity, 0)
 | 
					 | 
				
			||||||
            entity_to_count[entity] = current_count + count
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            total_count += count
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
            line = prior_file.readline()
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    with count_output.open("w", encoding="utf8") as entity_file:
 | 
					 | 
				
			||||||
        entity_file.write("entity" + "|" + "count" + "\n")
 | 
					 | 
				
			||||||
        for entity, count in entity_to_count.items():
 | 
					 | 
				
			||||||
            entity_file.write(entity + "|" + str(count) + "\n")
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
    if to_print:
 | 
					 | 
				
			||||||
        for entity, count in entity_to_count.items():
 | 
					 | 
				
			||||||
            print("Entity count:", entity, count)
 | 
					 | 
				
			||||||
        print("Total count:", total_count)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def get_all_frequencies(count_input):
 | 
					def _process_wikipedia_texts(
 | 
				
			||||||
    entity_to_count = dict()
 | 
					    wikipedia_input, wp_to_id, output, training_output, parse_descriptions, limit=None
 | 
				
			||||||
    with count_input.open("r", encoding="utf8") as csvfile:
 | 
					):
 | 
				
			||||||
        csvreader = csv.reader(csvfile, delimiter="|")
 | 
					    """
 | 
				
			||||||
        # skip header
 | 
					    Read the XML wikipedia data to parse out training data:
 | 
				
			||||||
        next(csvreader)
 | 
					    raw text data + positive instances
 | 
				
			||||||
        for row in csvreader:
 | 
					    """
 | 
				
			||||||
            entity_to_count[row[0]] = int(row[1])
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return entity_to_count
 | 
					    read_ids = set()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with output.open("a", encoding="utf8") as descr_file, training_output.open(
 | 
				
			||||||
 | 
					        "w", encoding="utf8"
 | 
				
			||||||
 | 
					    ) as entity_file:
 | 
				
			||||||
 | 
					        if parse_descriptions:
 | 
				
			||||||
 | 
					            _write_training_description(descr_file, "WD_id", "description")
 | 
				
			||||||
 | 
					        with bz2.open(wikipedia_input, mode="rb") as file:
 | 
				
			||||||
 | 
					            article_count = 0
 | 
				
			||||||
 | 
					            article_text = ""
 | 
				
			||||||
 | 
					            article_title = None
 | 
				
			||||||
 | 
					            article_id = None
 | 
				
			||||||
 | 
					            reading_text = False
 | 
				
			||||||
 | 
					            reading_revision = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            for line in file:
 | 
				
			||||||
 | 
					                clean_line = line.strip().decode("utf-8")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if clean_line == "<revision>":
 | 
				
			||||||
 | 
					                    reading_revision = True
 | 
				
			||||||
 | 
					                elif clean_line == "</revision>":
 | 
				
			||||||
 | 
					                    reading_revision = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # Start reading new page
 | 
				
			||||||
 | 
					                if clean_line == "<page>":
 | 
				
			||||||
 | 
					                    article_text = ""
 | 
				
			||||||
 | 
					                    article_title = None
 | 
				
			||||||
 | 
					                    article_id = None
 | 
				
			||||||
 | 
					                # finished reading this page
 | 
				
			||||||
 | 
					                elif clean_line == "</page>":
 | 
				
			||||||
 | 
					                    if article_id:
 | 
				
			||||||
 | 
					                        clean_text, entities = _process_wp_text(
 | 
				
			||||||
 | 
					                            article_title, article_text, wp_to_id
 | 
				
			||||||
 | 
					                        )
 | 
				
			||||||
 | 
					                        if clean_text is not None and entities is not None:
 | 
				
			||||||
 | 
					                            _write_training_entities(
 | 
				
			||||||
 | 
					                                entity_file, article_id, clean_text, entities
 | 
				
			||||||
 | 
					                            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                            if article_title in wp_to_id and parse_descriptions:
 | 
				
			||||||
 | 
					                                description = " ".join(
 | 
				
			||||||
 | 
					                                    clean_text[:1000].split(" ")[:-1]
 | 
				
			||||||
 | 
					                                )
 | 
				
			||||||
 | 
					                                _write_training_description(
 | 
				
			||||||
 | 
					                                    descr_file, wp_to_id[article_title], description
 | 
				
			||||||
 | 
					                                )
 | 
				
			||||||
 | 
					                            article_count += 1
 | 
				
			||||||
 | 
					                            if article_count % 10000 == 0 and article_count > 0:
 | 
				
			||||||
 | 
					                                logger.info(
 | 
				
			||||||
 | 
					                                    "Processed {} articles".format(article_count)
 | 
				
			||||||
 | 
					                                )
 | 
				
			||||||
 | 
					                            if limit and article_count >= limit:
 | 
				
			||||||
 | 
					                                break
 | 
				
			||||||
 | 
					                    article_text = ""
 | 
				
			||||||
 | 
					                    article_title = None
 | 
				
			||||||
 | 
					                    article_id = None
 | 
				
			||||||
 | 
					                    reading_text = False
 | 
				
			||||||
 | 
					                    reading_revision = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # start reading text within a page
 | 
				
			||||||
 | 
					                if "<text" in clean_line:
 | 
				
			||||||
 | 
					                    reading_text = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if reading_text:
 | 
				
			||||||
 | 
					                    article_text += " " + clean_line
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # stop reading text within a page (we assume a new page doesn't start on the same line)
 | 
				
			||||||
 | 
					                if "</text" in clean_line:
 | 
				
			||||||
 | 
					                    reading_text = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # read the ID of this article (outside the revision portion of the document)
 | 
				
			||||||
 | 
					                if not reading_revision:
 | 
				
			||||||
 | 
					                    ids = id_regex.search(clean_line)
 | 
				
			||||||
 | 
					                    if ids:
 | 
				
			||||||
 | 
					                        article_id = ids[0]
 | 
				
			||||||
 | 
					                        if article_id in read_ids:
 | 
				
			||||||
 | 
					                            logger.info(
 | 
				
			||||||
 | 
					                                "Found duplicate article ID", article_id, clean_line
 | 
				
			||||||
 | 
					                            )  # This should never happen ...
 | 
				
			||||||
 | 
					                        read_ids.add(article_id)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                # read the title of this article (outside the revision portion of the document)
 | 
				
			||||||
 | 
					                if not reading_revision:
 | 
				
			||||||
 | 
					                    titles = title_regex.search(clean_line)
 | 
				
			||||||
 | 
					                    if titles:
 | 
				
			||||||
 | 
					                        article_title = titles[0].strip()
 | 
				
			||||||
 | 
					    logger.info("Finished. Processed {} articles".format(article_count))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _process_wp_text(article_title, article_text, wp_to_id):
 | 
				
			||||||
 | 
					    # ignore meta Wikipedia pages
 | 
				
			||||||
 | 
					    if ns_regex.match(article_title):
 | 
				
			||||||
 | 
					        return None, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # remove the text tags
 | 
				
			||||||
 | 
					    text_search = text_regex.search(article_text)
 | 
				
			||||||
 | 
					    if text_search is None:
 | 
				
			||||||
 | 
					        return None, None
 | 
				
			||||||
 | 
					    text = text_search.group(0)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # stop processing if this is a redirect page
 | 
				
			||||||
 | 
					    if text.startswith("#REDIRECT"):
 | 
				
			||||||
 | 
					        return None, None
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # get the raw text without markup etc, keeping only interwiki links
 | 
				
			||||||
 | 
					    clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
 | 
				
			||||||
 | 
					    return clean_text, entities
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_clean_wp_text(article_text):
 | 
				
			||||||
 | 
					    clean_text = article_text.strip()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # remove bolding & italic markup
 | 
				
			||||||
 | 
					    clean_text = clean_text.replace("'''", "")
 | 
				
			||||||
 | 
					    clean_text = clean_text.replace("''", "")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # remove nested {{info}} statements by removing the inner/smallest ones first and iterating
 | 
				
			||||||
 | 
					    try_again = True
 | 
				
			||||||
 | 
					    previous_length = len(clean_text)
 | 
				
			||||||
 | 
					    while try_again:
 | 
				
			||||||
 | 
					        clean_text = info_regex.sub(
 | 
				
			||||||
 | 
					            "", clean_text
 | 
				
			||||||
 | 
					        )  # non-greedy match excluding a nested {
 | 
				
			||||||
 | 
					        if len(clean_text) < previous_length:
 | 
				
			||||||
 | 
					            try_again = True
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            try_again = False
 | 
				
			||||||
 | 
					        previous_length = len(clean_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # remove HTML comments
 | 
				
			||||||
 | 
					    clean_text = html_regex.sub("", clean_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # remove Category and File statements
 | 
				
			||||||
 | 
					    clean_text = category_regex.sub("", clean_text)
 | 
				
			||||||
 | 
					    clean_text = file_regex.sub("", clean_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # remove multiple =
 | 
				
			||||||
 | 
					    while "==" in clean_text:
 | 
				
			||||||
 | 
					        clean_text = clean_text.replace("==", "=")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    clean_text = clean_text.replace(". =", ".")
 | 
				
			||||||
 | 
					    clean_text = clean_text.replace(" = ", ". ")
 | 
				
			||||||
 | 
					    clean_text = clean_text.replace("= ", ".")
 | 
				
			||||||
 | 
					    clean_text = clean_text.replace(" =", "")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # remove refs (non-greedy match)
 | 
				
			||||||
 | 
					    clean_text = ref_regex.sub("", clean_text)
 | 
				
			||||||
 | 
					    clean_text = ref_2_regex.sub("", clean_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # remove additional wikiformatting
 | 
				
			||||||
 | 
					    clean_text = re.sub(r"<blockquote>", "", clean_text)
 | 
				
			||||||
 | 
					    clean_text = re.sub(r"</blockquote>", "", clean_text)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # change special characters back to normal ones
 | 
				
			||||||
 | 
					    clean_text = clean_text.replace(r"<", "<")
 | 
				
			||||||
 | 
					    clean_text = clean_text.replace(r">", ">")
 | 
				
			||||||
 | 
					    clean_text = clean_text.replace(r""", '"')
 | 
				
			||||||
 | 
					    clean_text = clean_text.replace(r"&nbsp;", " ")
 | 
				
			||||||
 | 
					    clean_text = clean_text.replace(r"&", "&")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # remove multiple spaces
 | 
				
			||||||
 | 
					    while "  " in clean_text:
 | 
				
			||||||
 | 
					        clean_text = clean_text.replace("  ", " ")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return clean_text.strip()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _remove_links(clean_text, wp_to_id):
 | 
				
			||||||
 | 
					    # read the text char by char to get the right offsets for the interwiki links
 | 
				
			||||||
 | 
					    entities = []
 | 
				
			||||||
 | 
					    final_text = ""
 | 
				
			||||||
 | 
					    open_read = 0
 | 
				
			||||||
 | 
					    reading_text = True
 | 
				
			||||||
 | 
					    reading_entity = False
 | 
				
			||||||
 | 
					    reading_mention = False
 | 
				
			||||||
 | 
					    reading_special_case = False
 | 
				
			||||||
 | 
					    entity_buffer = ""
 | 
				
			||||||
 | 
					    mention_buffer = ""
 | 
				
			||||||
 | 
					    for index, letter in enumerate(clean_text):
 | 
				
			||||||
 | 
					        if letter == "[":
 | 
				
			||||||
 | 
					            open_read += 1
 | 
				
			||||||
 | 
					        elif letter == "]":
 | 
				
			||||||
 | 
					            open_read -= 1
 | 
				
			||||||
 | 
					        elif letter == "|":
 | 
				
			||||||
 | 
					            if reading_text:
 | 
				
			||||||
 | 
					                final_text += letter
 | 
				
			||||||
 | 
					            # switch from reading entity to mention in the [[entity|mention]] pattern
 | 
				
			||||||
 | 
					            elif reading_entity:
 | 
				
			||||||
 | 
					                reading_text = False
 | 
				
			||||||
 | 
					                reading_entity = False
 | 
				
			||||||
 | 
					                reading_mention = True
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                reading_special_case = True
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            if reading_entity:
 | 
				
			||||||
 | 
					                entity_buffer += letter
 | 
				
			||||||
 | 
					            elif reading_mention:
 | 
				
			||||||
 | 
					                mention_buffer += letter
 | 
				
			||||||
 | 
					            elif reading_text:
 | 
				
			||||||
 | 
					                final_text += letter
 | 
				
			||||||
 | 
					            else:
 | 
				
			||||||
 | 
					                raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if open_read > 2:
 | 
				
			||||||
 | 
					            reading_special_case = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if open_read == 2 and reading_text:
 | 
				
			||||||
 | 
					            reading_text = False
 | 
				
			||||||
 | 
					            reading_entity = True
 | 
				
			||||||
 | 
					            reading_mention = False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # we just finished reading an entity
 | 
				
			||||||
 | 
					        if open_read == 0 and not reading_text:
 | 
				
			||||||
 | 
					            if "#" in entity_buffer or entity_buffer.startswith(":"):
 | 
				
			||||||
 | 
					                reading_special_case = True
 | 
				
			||||||
 | 
					            # Ignore cases with nested structures like File: handles etc
 | 
				
			||||||
 | 
					            if not reading_special_case:
 | 
				
			||||||
 | 
					                if not mention_buffer:
 | 
				
			||||||
 | 
					                    mention_buffer = entity_buffer
 | 
				
			||||||
 | 
					                start = len(final_text)
 | 
				
			||||||
 | 
					                end = start + len(mention_buffer)
 | 
				
			||||||
 | 
					                qid = wp_to_id.get(entity_buffer, None)
 | 
				
			||||||
 | 
					                if qid:
 | 
				
			||||||
 | 
					                    entities.append((mention_buffer, qid, start, end))
 | 
				
			||||||
 | 
					                final_text += mention_buffer
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            entity_buffer = ""
 | 
				
			||||||
 | 
					            mention_buffer = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            reading_text = True
 | 
				
			||||||
 | 
					            reading_entity = False
 | 
				
			||||||
 | 
					            reading_mention = False
 | 
				
			||||||
 | 
					            reading_special_case = False
 | 
				
			||||||
 | 
					    return final_text, entities
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _write_training_description(outputfile, qid, description):
 | 
				
			||||||
 | 
					    if description is not None:
 | 
				
			||||||
 | 
					        line = str(qid) + "|" + description + "\n"
 | 
				
			||||||
 | 
					        outputfile.write(line)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _write_training_entities(outputfile, article_id, clean_text, entities):
 | 
				
			||||||
 | 
					    entities_data = [
 | 
				
			||||||
 | 
					        {"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]}
 | 
				
			||||||
 | 
					        for ent in entities
 | 
				
			||||||
 | 
					    ]
 | 
				
			||||||
 | 
					    line = (
 | 
				
			||||||
 | 
					        json.dumps(
 | 
				
			||||||
 | 
					            {
 | 
				
			||||||
 | 
					                "article_id": article_id,
 | 
				
			||||||
 | 
					                "clean_text": clean_text,
 | 
				
			||||||
 | 
					                "entities": entities_data,
 | 
				
			||||||
 | 
					            },
 | 
				
			||||||
 | 
					            ensure_ascii=False,
 | 
				
			||||||
 | 
					        )
 | 
				
			||||||
 | 
					        + "\n"
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    outputfile.write(line)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None):
 | 
				
			||||||
 | 
					    """ This method provides training examples that correspond to the entity annotations found by the nlp object.
 | 
				
			||||||
 | 
					     For training, it will include both positive and negative examples by using the candidate generator from the kb.
 | 
				
			||||||
 | 
					     For testing (kb=None), it will include all positive examples only."""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    from tqdm import tqdm
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if not labels_discard:
 | 
				
			||||||
 | 
					        labels_discard = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    data = []
 | 
				
			||||||
 | 
					    num_entities = 0
 | 
				
			||||||
 | 
					    get_gold_parse = partial(
 | 
				
			||||||
 | 
					        _get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    logger.info(
 | 
				
			||||||
 | 
					        "Reading {} data with limit {}".format("dev" if dev else "train", limit)
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					    with entity_file_path.open("r", encoding="utf8") as file:
 | 
				
			||||||
 | 
					        with tqdm(total=limit, leave=False) as pbar:
 | 
				
			||||||
 | 
					            for i, line in enumerate(file):
 | 
				
			||||||
 | 
					                example = json.loads(line)
 | 
				
			||||||
 | 
					                article_id = example["article_id"]
 | 
				
			||||||
 | 
					                clean_text = example["clean_text"]
 | 
				
			||||||
 | 
					                entities = example["entities"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                if dev != is_dev(article_id) or not is_valid_article(clean_text):
 | 
				
			||||||
 | 
					                    continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                doc = nlp(clean_text)
 | 
				
			||||||
 | 
					                gold = get_gold_parse(doc, entities)
 | 
				
			||||||
 | 
					                if gold and len(gold.links) > 0:
 | 
				
			||||||
 | 
					                    data.append((doc, gold))
 | 
				
			||||||
 | 
					                    num_entities += len(gold.links)
 | 
				
			||||||
 | 
					                    pbar.update(len(gold.links))
 | 
				
			||||||
 | 
					                if limit and num_entities >= limit:
 | 
				
			||||||
 | 
					                    break
 | 
				
			||||||
 | 
					    logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
 | 
				
			||||||
 | 
					    return data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def _get_gold_parse(doc, entities, dev, kb, labels_discard):
 | 
				
			||||||
 | 
					    gold_entities = {}
 | 
				
			||||||
 | 
					    tagged_ent_positions = {
 | 
				
			||||||
 | 
					        (ent.start_char, ent.end_char): ent
 | 
				
			||||||
 | 
					        for ent in doc.ents
 | 
				
			||||||
 | 
					        if ent.label_ not in labels_discard
 | 
				
			||||||
 | 
					    }
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    for entity in entities:
 | 
				
			||||||
 | 
					        entity_id = entity["entity"]
 | 
				
			||||||
 | 
					        alias = entity["alias"]
 | 
				
			||||||
 | 
					        start = entity["start"]
 | 
				
			||||||
 | 
					        end = entity["end"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        candidate_ids = []
 | 
				
			||||||
 | 
					        if kb and not dev:
 | 
				
			||||||
 | 
					            candidates = kb.get_candidates(alias)
 | 
				
			||||||
 | 
					            candidate_ids = [cand.entity_ for cand in candidates]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        tagged_ent = tagged_ent_positions.get((start, end), None)
 | 
				
			||||||
 | 
					        if tagged_ent:
 | 
				
			||||||
 | 
					            # TODO: check that alias == doc.text[start:end]
 | 
				
			||||||
 | 
					            should_add_ent = (dev or entity_id in candidate_ids) and is_valid_sentence(
 | 
				
			||||||
 | 
					                tagged_ent.sent.text
 | 
				
			||||||
 | 
					            )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            if should_add_ent:
 | 
				
			||||||
 | 
					                value_by_id = {entity_id: 1.0}
 | 
				
			||||||
 | 
					                if not dev:
 | 
				
			||||||
 | 
					                    random.shuffle(candidate_ids)
 | 
				
			||||||
 | 
					                    value_by_id.update(
 | 
				
			||||||
 | 
					                        {kb_id: 0.0 for kb_id in candidate_ids if kb_id != entity_id}
 | 
				
			||||||
 | 
					                    )
 | 
				
			||||||
 | 
					                gold_entities[(start, end)] = value_by_id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return GoldParse(doc, links=gold_entities)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_dev(article_id):
 | 
				
			||||||
 | 
					    if not article_id:
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					    return article_id.endswith("3")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_valid_article(doc_text):
 | 
				
			||||||
 | 
					    # custom length cut-off
 | 
				
			||||||
 | 
					    return 10 < len(doc_text) < 30000
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def is_valid_sentence(sent_text):
 | 
				
			||||||
 | 
					    if not 10 < len(sent_text) < 3000:
 | 
				
			||||||
 | 
					        # custom length cut-off
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    if sent_text.strip().startswith("*") or sent_text.strip().startswith("#"):
 | 
				
			||||||
 | 
					        # remove 'enumeration' sentences (occurs often on Wikipedia)
 | 
				
			||||||
 | 
					        return False
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    return True
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -246,7 +246,7 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"):
 | 
				
			||||||
    """Perform an update over a single batch of documents.
 | 
					    """Perform an update over a single batch of documents.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    docs (iterable): A batch of `Doc` objects.
 | 
					    docs (iterable): A batch of `Doc` objects.
 | 
				
			||||||
    drop (float): The droput rate.
 | 
					    drop (float): The dropout rate.
 | 
				
			||||||
    optimizer (callable): An optimizer.
 | 
					    optimizer (callable): An optimizer.
 | 
				
			||||||
    RETURNS loss: A float for the loss.
 | 
					    RETURNS loss: A float for the loss.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -80,8 +80,8 @@ class Warnings(object):
 | 
				
			||||||
            "the v2.x models cannot release the global interpreter lock. "
 | 
					            "the v2.x models cannot release the global interpreter lock. "
 | 
				
			||||||
            "Future versions may introduce a `n_process` argument for "
 | 
					            "Future versions may introduce a `n_process` argument for "
 | 
				
			||||||
            "parallel inference via multiprocessing.")
 | 
					            "parallel inference via multiprocessing.")
 | 
				
			||||||
    W017 = ("Alias '{alias}' already exists in the Knowledge base.")
 | 
					    W017 = ("Alias '{alias}' already exists in the Knowledge Base.")
 | 
				
			||||||
    W018 = ("Entity '{entity}' already exists in the Knowledge base.")
 | 
					    W018 = ("Entity '{entity}' already exists in the Knowledge Base.")
 | 
				
			||||||
    W019 = ("Changing vectors name from {old} to {new}, to avoid clash with "
 | 
					    W019 = ("Changing vectors name from {old} to {new}, to avoid clash with "
 | 
				
			||||||
            "previously loaded vectors. See Issue #3853.")
 | 
					            "previously loaded vectors. See Issue #3853.")
 | 
				
			||||||
    W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
 | 
					    W020 = ("Unnamed vectors. This won't allow multiple vectors models to be "
 | 
				
			||||||
| 
						 | 
					@ -96,6 +96,8 @@ class Warnings(object):
 | 
				
			||||||
            "If this is surprising, make sure you have the spacy-lookups-data "
 | 
					            "If this is surprising, make sure you have the spacy-lookups-data "
 | 
				
			||||||
            "package installed.")
 | 
					            "package installed.")
 | 
				
			||||||
    W023 = ("Multiprocessing of Language.pipe is not supported in Python2. 'n_process' will be set to 1.")
 | 
					    W023 = ("Multiprocessing of Language.pipe is not supported in Python2. 'n_process' will be set to 1.")
 | 
				
			||||||
 | 
					    W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in "
 | 
				
			||||||
 | 
					            "the Knowledge Base.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@add_codes
 | 
					@add_codes
 | 
				
			||||||
| 
						 | 
					@ -408,7 +410,7 @@ class Errors(object):
 | 
				
			||||||
            "{probabilities_length} respectively.")
 | 
					            "{probabilities_length} respectively.")
 | 
				
			||||||
    E133 = ("The sum of prior probabilities for alias '{alias}' should not "
 | 
					    E133 = ("The sum of prior probabilities for alias '{alias}' should not "
 | 
				
			||||||
            "exceed 1, but found {sum}.")
 | 
					            "exceed 1, but found {sum}.")
 | 
				
			||||||
    E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.")
 | 
					    E134 = ("Entity '{entity}' is not defined in the Knowledge Base.")
 | 
				
			||||||
    E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
 | 
					    E135 = ("If you meant to replace a built-in component, use `create_pipe`: "
 | 
				
			||||||
            "`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
 | 
					            "`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`")
 | 
				
			||||||
    E136 = ("This additional feature requires the jsonschema library to be "
 | 
					    E136 = ("This additional feature requires the jsonschema library to be "
 | 
				
			||||||
| 
						 | 
					@ -420,7 +422,7 @@ class Errors(object):
 | 
				
			||||||
    E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
 | 
					    E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input "
 | 
				
			||||||
            "includes either the `text` or `tokens` key. For more info, see "
 | 
					            "includes either the `text` or `tokens` key. For more info, see "
 | 
				
			||||||
            "the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
 | 
					            "the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl")
 | 
				
			||||||
    E139 = ("Knowledge base for component '{name}' not initialized. Did you "
 | 
					    E139 = ("Knowledge Base for component '{name}' not initialized. Did you "
 | 
				
			||||||
            "forget to call set_kb()?")
 | 
					            "forget to call set_kb()?")
 | 
				
			||||||
    E140 = ("The list of entities, prior probabilities and entity vectors "
 | 
					    E140 = ("The list of entities, prior probabilities and entity vectors "
 | 
				
			||||||
            "should be of equal length.")
 | 
					            "should be of equal length.")
 | 
				
			||||||
| 
						 | 
					@ -499,6 +501,7 @@ class Errors(object):
 | 
				
			||||||
    E174 = ("Architecture '{name}' not found in registry. Available "
 | 
					    E174 = ("Architecture '{name}' not found in registry. Available "
 | 
				
			||||||
            "names: {names}")
 | 
					            "names: {names}")
 | 
				
			||||||
    E175 = ("Can't remove rule for unknown match pattern ID: {key}")
 | 
					    E175 = ("Can't remove rule for unknown match pattern ID: {key}")
 | 
				
			||||||
 | 
					    E176 = ("Alias '{alias}' is not defined in the Knowledge Base.")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@add_codes
 | 
					@add_codes
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										69
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							
							
						
						
									
										69
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							| 
						 | 
					@ -142,6 +142,7 @@ cdef class KnowledgeBase:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        i = 0
 | 
					        i = 0
 | 
				
			||||||
        cdef KBEntryC entry
 | 
					        cdef KBEntryC entry
 | 
				
			||||||
 | 
					        cdef hash_t entity_hash
 | 
				
			||||||
        while i < nr_entities:
 | 
					        while i < nr_entities:
 | 
				
			||||||
            entity_vector = vector_list[i]
 | 
					            entity_vector = vector_list[i]
 | 
				
			||||||
            if len(entity_vector) != self.entity_vector_length:
 | 
					            if len(entity_vector) != self.entity_vector_length:
 | 
				
			||||||
| 
						 | 
					@ -161,6 +162,14 @@ cdef class KnowledgeBase:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            i += 1
 | 
					            i += 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def contains_entity(self, unicode entity):
 | 
				
			||||||
 | 
					        cdef hash_t entity_hash = self.vocab.strings.add(entity)
 | 
				
			||||||
 | 
					        return entity_hash in self._entry_index
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def contains_alias(self, unicode alias):
 | 
				
			||||||
 | 
					        cdef hash_t alias_hash = self.vocab.strings.add(alias)
 | 
				
			||||||
 | 
					        return alias_hash in self._alias_index
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def add_alias(self, unicode alias, entities, probabilities):
 | 
					    def add_alias(self, unicode alias, entities, probabilities):
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        For a given alias, add its potential entities and prior probabilies to the KB.
 | 
					        For a given alias, add its potential entities and prior probabilies to the KB.
 | 
				
			||||||
| 
						 | 
					@ -190,7 +199,7 @@ cdef class KnowledgeBase:
 | 
				
			||||||
        for entity, prob in zip(entities, probabilities):
 | 
					        for entity, prob in zip(entities, probabilities):
 | 
				
			||||||
            entity_hash = self.vocab.strings[entity]
 | 
					            entity_hash = self.vocab.strings[entity]
 | 
				
			||||||
            if not entity_hash in self._entry_index:
 | 
					            if not entity_hash in self._entry_index:
 | 
				
			||||||
                raise ValueError(Errors.E134.format(alias=alias, entity=entity))
 | 
					                raise ValueError(Errors.E134.format(entity=entity))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            entry_index = <int64_t>self._entry_index.get(entity_hash)
 | 
					            entry_index = <int64_t>self._entry_index.get(entity_hash)
 | 
				
			||||||
            entry_indices.push_back(int(entry_index))
 | 
					            entry_indices.push_back(int(entry_index))
 | 
				
			||||||
| 
						 | 
					@ -201,8 +210,63 @@ cdef class KnowledgeBase:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        return alias_hash
 | 
					        return alias_hash
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_candidates(self, unicode alias):
 | 
					    def append_alias(self, unicode alias, unicode entity, float prior_prob, ignore_warnings=False):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        For an alias already existing in the KB, extend its potential entities with one more.
 | 
				
			||||||
 | 
					        Throw a warning if either the alias or the entity is unknown,
 | 
				
			||||||
 | 
					        or when the combination is already previously recorded.
 | 
				
			||||||
 | 
					        Throw an error if this entity+prior prob would exceed the sum of 1.
 | 
				
			||||||
 | 
					        For efficiency, it's best to use the method `add_alias` as much as possible instead of this one.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        # Check if the alias exists in the KB
 | 
				
			||||||
        cdef hash_t alias_hash = self.vocab.strings[alias]
 | 
					        cdef hash_t alias_hash = self.vocab.strings[alias]
 | 
				
			||||||
 | 
					        if not alias_hash in self._alias_index:
 | 
				
			||||||
 | 
					            raise ValueError(Errors.E176.format(alias=alias))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Check if the entity exists in the KB
 | 
				
			||||||
 | 
					        cdef hash_t entity_hash = self.vocab.strings[entity]
 | 
				
			||||||
 | 
					        if not entity_hash in self._entry_index:
 | 
				
			||||||
 | 
					            raise ValueError(Errors.E134.format(entity=entity))
 | 
				
			||||||
 | 
					        entry_index = <int64_t>self._entry_index.get(entity_hash)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        # Throw an error if the prior probabilities (including the new one) sum up to more than 1
 | 
				
			||||||
 | 
					        alias_index = <int64_t>self._alias_index.get(alias_hash)
 | 
				
			||||||
 | 
					        alias_entry = self._aliases_table[alias_index]
 | 
				
			||||||
 | 
					        current_sum = sum([p for p in alias_entry.probs])
 | 
				
			||||||
 | 
					        new_sum = current_sum + prior_prob
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if new_sum > 1.00001:
 | 
				
			||||||
 | 
					            raise ValueError(Errors.E133.format(alias=alias, sum=new_sum))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        entry_indices = alias_entry.entry_indices
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        is_present = False
 | 
				
			||||||
 | 
					        for i in range(entry_indices.size()):
 | 
				
			||||||
 | 
					            if entry_indices[i] == int(entry_index):
 | 
				
			||||||
 | 
					                is_present = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					        if is_present:
 | 
				
			||||||
 | 
					            if not ignore_warnings:
 | 
				
			||||||
 | 
					                user_warning(Warnings.W024.format(entity=entity, alias=alias))
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            entry_indices.push_back(int(entry_index))
 | 
				
			||||||
 | 
					            alias_entry.entry_indices = entry_indices
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					            probs = alias_entry.probs
 | 
				
			||||||
 | 
					            probs.push_back(float(prior_prob))
 | 
				
			||||||
 | 
					            alias_entry.probs = probs
 | 
				
			||||||
 | 
					            self._aliases_table[alias_index] = alias_entry
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_candidates(self, unicode alias):
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        Return candidate entities for an alias. Each candidate defines the entity, the original alias,
 | 
				
			||||||
 | 
					        and the prior probability of that alias resolving to that entity.
 | 
				
			||||||
 | 
					        If the alias is not known in the KB, and empty list is returned.
 | 
				
			||||||
 | 
					        """
 | 
				
			||||||
 | 
					        cdef hash_t alias_hash = self.vocab.strings[alias]
 | 
				
			||||||
 | 
					        if not alias_hash in self._alias_index:
 | 
				
			||||||
 | 
					            return []
 | 
				
			||||||
        alias_index = <int64_t>self._alias_index.get(alias_hash)
 | 
					        alias_index = <int64_t>self._alias_index.get(alias_hash)
 | 
				
			||||||
        alias_entry = self._aliases_table[alias_index]
 | 
					        alias_entry = self._aliases_table[alias_index]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -341,7 +405,6 @@ cdef class KnowledgeBase:
 | 
				
			||||||
        assert nr_entities == self.get_size_entities()
 | 
					        assert nr_entities == self.get_size_entities()
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        # STEP 3: load aliases
 | 
					        # STEP 3: load aliases
 | 
				
			||||||
 | 
					 | 
				
			||||||
        cdef int64_t nr_aliases
 | 
					        cdef int64_t nr_aliases
 | 
				
			||||||
        reader.read_alias_length(&nr_aliases)
 | 
					        reader.read_alias_length(&nr_aliases)
 | 
				
			||||||
        self._alias_index = PreshMap(nr_aliases+1)
 | 
					        self._alias_index = PreshMap(nr_aliases+1)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -483,7 +483,7 @@ class Language(object):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        docs (iterable): A batch of `Doc` objects.
 | 
					        docs (iterable): A batch of `Doc` objects.
 | 
				
			||||||
        golds (iterable): A batch of `GoldParse` objects.
 | 
					        golds (iterable): A batch of `GoldParse` objects.
 | 
				
			||||||
        drop (float): The droput rate.
 | 
					        drop (float): The dropout rate.
 | 
				
			||||||
        sgd (callable): An optimizer.
 | 
					        sgd (callable): An optimizer.
 | 
				
			||||||
        losses (dict): Dictionary to update with the loss, keyed by component.
 | 
					        losses (dict): Dictionary to update with the loss, keyed by component.
 | 
				
			||||||
        component_cfg (dict): Config parameters for specific pipeline
 | 
					        component_cfg (dict): Config parameters for specific pipeline
 | 
				
			||||||
| 
						 | 
					@ -531,7 +531,7 @@ class Language(object):
 | 
				
			||||||
        even if you're updating it with a smaller set of examples.
 | 
					        even if you're updating it with a smaller set of examples.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        docs (iterable): A batch of `Doc` objects.
 | 
					        docs (iterable): A batch of `Doc` objects.
 | 
				
			||||||
        drop (float): The droput rate.
 | 
					        drop (float): The dropout rate.
 | 
				
			||||||
        sgd (callable): An optimizer.
 | 
					        sgd (callable): An optimizer.
 | 
				
			||||||
        RETURNS (dict): Results from the update.
 | 
					        RETURNS (dict): Results from the update.
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -1195,23 +1195,26 @@ class EntityLinker(Pipe):
 | 
				
			||||||
            docs = [docs]
 | 
					            docs = [docs]
 | 
				
			||||||
            golds = [golds]
 | 
					            golds = [golds]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        context_docs = []
 | 
					        sentence_docs = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        for doc, gold in zip(docs, golds):
 | 
					        for doc, gold in zip(docs, golds):
 | 
				
			||||||
            ents_by_offset = dict()
 | 
					            ents_by_offset = dict()
 | 
				
			||||||
            for ent in doc.ents:
 | 
					            for ent in doc.ents:
 | 
				
			||||||
                ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent
 | 
					                ents_by_offset[(ent.start_char, ent.end_char)] = ent
 | 
				
			||||||
 | 
					
 | 
				
			||||||
            for entity, kb_dict in gold.links.items():
 | 
					            for entity, kb_dict in gold.links.items():
 | 
				
			||||||
                start, end = entity
 | 
					                start, end = entity
 | 
				
			||||||
                mention = doc.text[start:end]
 | 
					                mention = doc.text[start:end]
 | 
				
			||||||
 | 
					                # the gold annotations should link to proper entities - if this fails, the dataset is likely corrupt
 | 
				
			||||||
 | 
					                ent = ents_by_offset[(start, end)]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                for kb_id, value in kb_dict.items():
 | 
					                for kb_id, value in kb_dict.items():
 | 
				
			||||||
                    # Currently only training on the positive instances
 | 
					                    # Currently only training on the positive instances
 | 
				
			||||||
                    if value:
 | 
					                    if value:
 | 
				
			||||||
                        context_docs.append(doc)
 | 
					                        sentence_docs.append(ent.sent.as_doc())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop)
 | 
					        sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop)
 | 
				
			||||||
        loss, d_scores = self.get_similarity_loss(scores=context_encodings, golds=golds, docs=None)
 | 
					        loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None)
 | 
				
			||||||
        bp_context(d_scores, sgd=sgd)
 | 
					        bp_context(d_scores, sgd=sgd)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if losses is not None:
 | 
					        if losses is not None:
 | 
				
			||||||
| 
						 | 
					@ -1280,50 +1283,68 @@ class EntityLinker(Pipe):
 | 
				
			||||||
        if isinstance(docs, Doc):
 | 
					        if isinstance(docs, Doc):
 | 
				
			||||||
            docs = [docs]
 | 
					            docs = [docs]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        context_encodings = self.model(docs)
 | 
					 | 
				
			||||||
        xp = get_array_module(context_encodings)
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        for i, doc in enumerate(docs):
 | 
					        for i, doc in enumerate(docs):
 | 
				
			||||||
            if len(doc) > 0:
 | 
					            if len(doc) > 0:
 | 
				
			||||||
                # currently, the context is the same for each entity in a sentence (should be refined)
 | 
					                # Looping through each sentence and each entity
 | 
				
			||||||
                context_encoding = context_encodings[i]
 | 
					                # This may go wrong if there are entities across sentences - because they might not get a KB ID
 | 
				
			||||||
                context_enc_t = context_encoding.T
 | 
					                for sent in doc.ents:
 | 
				
			||||||
                norm_1 = xp.linalg.norm(context_enc_t)
 | 
					                    sent_doc = sent.as_doc()
 | 
				
			||||||
                for ent in doc.ents:
 | 
					                    # currently, the context is the same for each entity in a sentence (should be refined)
 | 
				
			||||||
                    entity_count += 1
 | 
					                    sentence_encoding = self.model([sent_doc])[0]
 | 
				
			||||||
 | 
					                    xp = get_array_module(sentence_encoding)
 | 
				
			||||||
 | 
					                    sentence_encoding_t = sentence_encoding.T
 | 
				
			||||||
 | 
					                    sentence_norm = xp.linalg.norm(sentence_encoding_t)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                    candidates = self.kb.get_candidates(ent.text)
 | 
					                    for ent in sent_doc.ents:
 | 
				
			||||||
                    if not candidates:
 | 
					                        entity_count += 1
 | 
				
			||||||
                        final_kb_ids.append(self.NIL)  # no prediction possible for this entity
 | 
					 | 
				
			||||||
                        final_tensors.append(context_encoding)
 | 
					 | 
				
			||||||
                    else:
 | 
					 | 
				
			||||||
                        random.shuffle(candidates)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        # this will set all prior probabilities to 0 if they should be excluded from the model
 | 
					                        if ent.label_ in self.cfg.get("labels_discard", []):
 | 
				
			||||||
                        prior_probs = xp.asarray([c.prior_prob for c in candidates])
 | 
					                            # ignoring this entity - setting to NIL
 | 
				
			||||||
                        if not self.cfg.get("incl_prior", True):
 | 
					                            final_kb_ids.append(self.NIL)
 | 
				
			||||||
                            prior_probs = xp.asarray([0.0 for c in candidates])
 | 
					                            final_tensors.append(sentence_encoding)
 | 
				
			||||||
                        scores = prior_probs
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        # add in similarity from the context
 | 
					                        else:
 | 
				
			||||||
                        if self.cfg.get("incl_context", True):
 | 
					                            candidates = self.kb.get_candidates(ent.text)
 | 
				
			||||||
                            entity_encodings = xp.asarray([c.entity_vector for c in candidates])
 | 
					                            if not candidates:
 | 
				
			||||||
                            norm_2 = xp.linalg.norm(entity_encodings, axis=1)
 | 
					                                # no prediction possible for this entity - setting to NIL
 | 
				
			||||||
 | 
					                                final_kb_ids.append(self.NIL)
 | 
				
			||||||
 | 
					                                final_tensors.append(sentence_encoding)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                            if len(entity_encodings) != len(prior_probs):
 | 
					                            elif len(candidates) == 1:
 | 
				
			||||||
                                raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
 | 
					                                # shortcut for efficiency reasons: take the 1 candidate
 | 
				
			||||||
 | 
					
 | 
				
			||||||
                             # cosine similarity
 | 
					                                # TODO: thresholding
 | 
				
			||||||
                            sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
 | 
					                                final_kb_ids.append(candidates[0].entity_)
 | 
				
			||||||
                            if sims.shape != prior_probs.shape:
 | 
					                                final_tensors.append(sentence_encoding)
 | 
				
			||||||
                                raise ValueError(Errors.E161)
 | 
					 | 
				
			||||||
                            scores = prior_probs + sims - (prior_probs*sims)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
                        # TODO: thresholding
 | 
					                            else:
 | 
				
			||||||
                        best_index = scores.argmax()
 | 
					                                random.shuffle(candidates)
 | 
				
			||||||
                        best_candidate = candidates[best_index]
 | 
					
 | 
				
			||||||
                        final_kb_ids.append(best_candidate.entity_)
 | 
					                                # this will set all prior probabilities to 0 if they should be excluded from the model
 | 
				
			||||||
                        final_tensors.append(context_encoding)
 | 
					                                prior_probs = xp.asarray([c.prior_prob for c in candidates])
 | 
				
			||||||
 | 
					                                if not self.cfg.get("incl_prior", True):
 | 
				
			||||||
 | 
					                                    prior_probs = xp.asarray([0.0 for c in candidates])
 | 
				
			||||||
 | 
					                                scores = prior_probs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                                # add in similarity from the context
 | 
				
			||||||
 | 
					                                if self.cfg.get("incl_context", True):
 | 
				
			||||||
 | 
					                                    entity_encodings = xp.asarray([c.entity_vector for c in candidates])
 | 
				
			||||||
 | 
					                                    entity_norm = xp.linalg.norm(entity_encodings, axis=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                                    if len(entity_encodings) != len(prior_probs):
 | 
				
			||||||
 | 
					                                        raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length"))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                                    # cosine similarity
 | 
				
			||||||
 | 
					                                    sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm)
 | 
				
			||||||
 | 
					                                    if sims.shape != prior_probs.shape:
 | 
				
			||||||
 | 
					                                        raise ValueError(Errors.E161)
 | 
				
			||||||
 | 
					                                    scores = prior_probs + sims - (prior_probs*sims)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					                                # TODO: thresholding
 | 
				
			||||||
 | 
					                                best_index = scores.argmax()
 | 
				
			||||||
 | 
					                                best_candidate = candidates[best_index]
 | 
				
			||||||
 | 
					                                final_kb_ids.append(best_candidate.entity_)
 | 
				
			||||||
 | 
					                                final_tensors.append(sentence_encoding)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        if not (len(final_tensors) == len(final_kb_ids) == entity_count):
 | 
					        if not (len(final_tensors) == len(final_kb_ids) == entity_count):
 | 
				
			||||||
            raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
 | 
					            raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length"))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -131,6 +131,53 @@ def test_candidate_generation(nlp):
 | 
				
			||||||
    assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
 | 
					    assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_append_alias(nlp):
 | 
				
			||||||
 | 
					    """Test that we can append additional alias-entity pairs"""
 | 
				
			||||||
 | 
					    mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # adding entities
 | 
				
			||||||
 | 
					    mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
 | 
				
			||||||
 | 
					    mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
 | 
				
			||||||
 | 
					    mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # adding aliases
 | 
				
			||||||
 | 
					    mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.4, 0.1])
 | 
				
			||||||
 | 
					    mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # test the size of the relevant candidates
 | 
				
			||||||
 | 
					    assert len(mykb.get_candidates("douglas")) == 2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # append an alias
 | 
				
			||||||
 | 
					    mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # test the size of the relevant candidates has been incremented
 | 
				
			||||||
 | 
					    assert len(mykb.get_candidates("douglas")) == 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # append the same alias-entity pair again should not work (will throw a warning)
 | 
				
			||||||
 | 
					    mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # test the size of the relevant candidates remained unchanged
 | 
				
			||||||
 | 
					    assert len(mykb.get_candidates("douglas")) == 3
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_append_invalid_alias(nlp):
 | 
				
			||||||
 | 
					    """Test that append an alias will throw an error if prior probs are exceeding 1"""
 | 
				
			||||||
 | 
					    mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # adding entities
 | 
				
			||||||
 | 
					    mykb.add_entity(entity="Q1", freq=27, entity_vector=[1])
 | 
				
			||||||
 | 
					    mykb.add_entity(entity="Q2", freq=12, entity_vector=[2])
 | 
				
			||||||
 | 
					    mykb.add_entity(entity="Q3", freq=5, entity_vector=[3])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # adding aliases
 | 
				
			||||||
 | 
					    mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1])
 | 
				
			||||||
 | 
					    mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # append an alias - should fail because the entities and probabilities vectors are not of equal length
 | 
				
			||||||
 | 
					    with pytest.raises(ValueError):
 | 
				
			||||||
 | 
					        mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_preserving_links_asdoc(nlp):
 | 
					def test_preserving_links_asdoc(nlp):
 | 
				
			||||||
    """Test that Span.as_doc preserves the existing entity links"""
 | 
					    """Test that Span.as_doc preserves the existing entity links"""
 | 
				
			||||||
    mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
 | 
					    mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -430,7 +430,7 @@ def test_issue957(en_tokenizer):
 | 
				
			||||||
def test_issue999(train_data):
 | 
					def test_issue999(train_data):
 | 
				
			||||||
    """Test that adding entities and resuming training works passably OK.
 | 
					    """Test that adding entities and resuming training works passably OK.
 | 
				
			||||||
    There are two issues here:
 | 
					    There are two issues here:
 | 
				
			||||||
    1) We have to read labels. This isn't very nice.
 | 
					    1) We have to re-add labels. This isn't very nice.
 | 
				
			||||||
    2) There's no way to set the learning rate for the weight update, so we
 | 
					    2) There's no way to set the learning rate for the weight update, so we
 | 
				
			||||||
        end up out-of-scale, causing it to learn too fast.
 | 
					        end up out-of-scale, causing it to learn too fast.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user