mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-01 00:17:44 +03:00 
			
		
		
		
	Merge branch 'master' into spacy.io
This commit is contained in:
		
						commit
						0b2df3b879
					
				
							
								
								
									
										106
									
								
								.github/contributors/PeterGilles.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							
							
						
						
									
										106
									
								
								.github/contributors/PeterGilles.md
									
									
									
									
										vendored
									
									
										Normal file
									
								
							|  | @ -0,0 +1,106 @@ | |||
| # spaCy contributor agreement | ||||
| 
 | ||||
| This spaCy Contributor Agreement (**"SCA"**) is based on the | ||||
| [Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf). | ||||
| The SCA applies to any contribution that you make to any product or project | ||||
| managed by us (the **"project"**), and sets out the intellectual property rights | ||||
| you grant to us in the contributed materials. The term **"us"** shall mean | ||||
| [ExplosionAI GmbH](https://explosion.ai/legal). The term | ||||
| **"you"** shall mean the person or entity identified below. | ||||
| 
 | ||||
| If you agree to be bound by these terms, fill in the information requested | ||||
| below and include the filled-in version with your first pull request, under the | ||||
| folder [`.github/contributors/`](/.github/contributors/). The name of the file | ||||
| should be your GitHub username, with the extension `.md`. For example, the user | ||||
| example_user would create the file `.github/contributors/example_user.md`. | ||||
| 
 | ||||
| Read this agreement carefully before signing. These terms and conditions | ||||
| constitute a binding legal agreement. | ||||
| 
 | ||||
| ## Contributor Agreement | ||||
| 
 | ||||
| 1. The term "contribution" or "contributed materials" means any source code, | ||||
| object code, patch, tool, sample, graphic, specification, manual, | ||||
| documentation, or any other material posted or submitted by you to the project. | ||||
| 
 | ||||
| 2. With respect to any worldwide copyrights, or copyright applications and | ||||
| registrations, in your contribution: | ||||
| 
 | ||||
|     * you hereby assign to us joint ownership, and to the extent that such | ||||
|     assignment is or becomes invalid, ineffective or unenforceable, you hereby | ||||
|     grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, | ||||
|     royalty-free, unrestricted license to exercise all rights under those | ||||
|     copyrights. This includes, at our option, the right to sublicense these same | ||||
|     rights to third parties through multiple levels of sublicensees or other | ||||
|     licensing arrangements; | ||||
| 
 | ||||
|     * you agree that each of us can do all things in relation to your | ||||
|     contribution as if each of us were the sole owners, and if one of us makes | ||||
|     a derivative work of your contribution, the one who makes the derivative | ||||
|     work (or has it made will be the sole owner of that derivative work; | ||||
| 
 | ||||
|     * you agree that you will not assert any moral rights in your contribution | ||||
|     against us, our licensees or transferees; | ||||
| 
 | ||||
|     * you agree that we may register a copyright in your contribution and | ||||
|     exercise all ownership rights associated with it; and | ||||
| 
 | ||||
|     * you agree that neither of us has any duty to consult with, obtain the | ||||
|     consent of, pay or render an accounting to the other for any use or | ||||
|     distribution of your contribution. | ||||
| 
 | ||||
| 3. With respect to any patents you own, or that you can license without payment | ||||
| to any third party, you hereby grant to us a perpetual, irrevocable, | ||||
| non-exclusive, worldwide, no-charge, royalty-free license to: | ||||
| 
 | ||||
|     * make, have made, use, sell, offer to sell, import, and otherwise transfer | ||||
|     your contribution in whole or in part, alone or in combination with or | ||||
|     included in any product, work or materials arising out of the project to | ||||
|     which your contribution was submitted, and | ||||
| 
 | ||||
|     * at our option, to sublicense these same rights to third parties through | ||||
|     multiple levels of sublicensees or other licensing arrangements. | ||||
| 
 | ||||
| 4. Except as set out above, you keep all right, title, and interest in your | ||||
| contribution. The rights that you grant to us under these terms are effective | ||||
| on the date you first submitted a contribution to us, even if your submission | ||||
| took place before the date you sign these terms. | ||||
| 
 | ||||
| 5. You covenant, represent, warrant and agree that: | ||||
| 
 | ||||
|     * Each contribution that you submit is and shall be an original work of | ||||
|     authorship and you can legally grant the rights set out in this SCA; | ||||
| 
 | ||||
|     * to the best of your knowledge, each contribution will not violate any | ||||
|     third party's copyrights, trademarks, patents, or other intellectual | ||||
|     property rights; and | ||||
| 
 | ||||
|     * each contribution shall be in compliance with U.S. export control laws and | ||||
|     other applicable export and import laws. You agree to notify us if you | ||||
|     become aware of any circumstance which would make any of the foregoing | ||||
|     representations inaccurate in any respect. We may publicly disclose your | ||||
|     participation in the project, including the fact that you have signed the SCA. | ||||
| 
 | ||||
| 6. This SCA is governed by the laws of the State of California and applicable | ||||
| U.S. Federal law. Any choice of law rules will not apply. | ||||
| 
 | ||||
| 7. Please place an “x” on one of the applicable statement below. Please do NOT | ||||
| mark both statements: | ||||
| 
 | ||||
|     * [X] I am signing on behalf of myself as an individual and no other person | ||||
|     or entity, including my employer, has or will have rights with respect to my | ||||
|     contributions. | ||||
| 
 | ||||
|     * [ ] I am signing on behalf of my employer or a legal entity and I have the | ||||
|     actual authority to contractually bind that entity. | ||||
| 
 | ||||
| ## Contributor Details | ||||
| 
 | ||||
| | Field                          | Entry                | | ||||
| |------------------------------- | -------------------- | | ||||
| | Name                           |  Peter Gilles        | | ||||
| | Company name (if applicable)   |                      | | ||||
| | Title or role (if applicable)  |                      | | ||||
| | Date                           |  10.10.              | | ||||
| | GitHub username                |  Peter Gilles        | | ||||
| | Website (optional)             |                      | | ||||
|  | @ -197,7 +197,7 @@ path to the model data directory. | |||
| ```python | ||||
| import spacy | ||||
| nlp = spacy.load("en_core_web_sm") | ||||
| doc = nlp(u"This is a sentence.") | ||||
| doc = nlp("This is a sentence.") | ||||
| ``` | ||||
| 
 | ||||
| You can also `import` a model directly via its full name and then call its | ||||
|  | @ -208,7 +208,7 @@ import spacy | |||
| import en_core_web_sm | ||||
| 
 | ||||
| nlp = en_core_web_sm.load() | ||||
| doc = nlp(u"This is a sentence.") | ||||
| doc = nlp("This is a sentence.") | ||||
| ``` | ||||
| 
 | ||||
| 📖 **For more info and examples, check out the | ||||
|  |  | |||
							
								
								
									
										34
									
								
								bin/wiki_entity_linking/README.md
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								bin/wiki_entity_linking/README.md
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,34 @@ | |||
| ## Entity Linking with Wikipedia and Wikidata | ||||
| 
 | ||||
| ### Step 1: Create a Knowledge Base (KB) and training data | ||||
| 
 | ||||
| Run  `wikipedia_pretrain_kb.py`  | ||||
| * This takes as input the locations of a **Wikipedia and a Wikidata dump**, and produces a **KB directory** + **training file** | ||||
|   * WikiData: get `latest-all.json.bz2` from https://dumps.wikimedia.org/wikidatawiki/entities/ | ||||
|   * Wikipedia: get `enwiki-latest-pages-articles-multistream.xml.bz2` from https://dumps.wikimedia.org/enwiki/latest/ (or for any other language) | ||||
| * You can set the filtering parameters for KB construction: | ||||
|   * `max_per_alias`: (max) number of candidate entities in the KB per alias/synonym | ||||
|   * `min_freq`: threshold of number of times an entity should occur in the corpus to be included in the KB | ||||
|   * `min_pair`: threshold of number of times an entity+alias combination should occur in the corpus to be included in the KB | ||||
| * Further parameters to set: | ||||
|   * `descriptions_from_wikipedia`: whether to parse descriptions from Wikipedia (`True`) or Wikidata (`False`) | ||||
|   * `entity_vector_length`: length of the pre-trained entity description vectors | ||||
|   * `lang`: language for which to fetch Wikidata information (as the dump contains all languages) | ||||
| 
 | ||||
| Quick testing and rerunning:  | ||||
| * When trying out the pipeline for a quick test, set `limit_prior`, `limit_train` and/or `limit_wd` to read only parts of the dumps instead of everything.  | ||||
| * If you only want to (re)run certain parts of the pipeline, just remove the corresponding files and they will be recalculated or reparsed. | ||||
| 
 | ||||
| 
 | ||||
| ### Step 2: Train an Entity Linking model | ||||
| 
 | ||||
| Run  `wikidata_train_entity_linker.py`  | ||||
| * This takes the **KB directory** produced by Step 1, and trains an **Entity Linking model** | ||||
| * You can set the learning parameters for the EL training: | ||||
|   * `epochs`: number of training iterations | ||||
|   * `dropout`: dropout rate | ||||
|   * `lr`: learning rate | ||||
|   * `l2`: L2 regularization | ||||
| * Specify the number of training and dev testing entities with `train_inst` and `dev_inst` respectively | ||||
| * Further parameters to set: | ||||
|   * `labels_discard`: NER label types to discard during training | ||||
|  | @ -6,6 +6,7 @@ OUTPUT_MODEL_DIR = "nlp" | |||
| PRIOR_PROB_PATH = "prior_prob.csv" | ||||
| ENTITY_DEFS_PATH = "entity_defs.csv" | ||||
| ENTITY_FREQ_PATH = "entity_freq.csv" | ||||
| ENTITY_ALIAS_PATH = "entity_alias.csv" | ||||
| ENTITY_DESCR_PATH = "entity_descriptions.csv" | ||||
| 
 | ||||
| LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' | ||||
|  |  | |||
|  | @ -15,10 +15,11 @@ class Metrics(object): | |||
|         candidate_is_correct = true_entity == candidate | ||||
| 
 | ||||
|         # Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL") | ||||
|         # Therefore, if candidate_is_correct then we have a true positive and never a true negative | ||||
|         # Therefore, if candidate_is_correct then we have a true positive and never a true negative. | ||||
|         self.true_pos += candidate_is_correct | ||||
|         self.false_neg += not candidate_is_correct | ||||
|         if candidate not in {"", "NIL"}: | ||||
|         if candidate and candidate not in {"", "NIL"}: | ||||
|             # A wrong prediction (e.g. Q42 != Q3) counts both as a FP as well as a FN. | ||||
|             self.false_pos += not candidate_is_correct | ||||
| 
 | ||||
|     def calculate_precision(self): | ||||
|  | @ -33,6 +34,14 @@ class Metrics(object): | |||
|         else: | ||||
|             return self.true_pos / (self.true_pos + self.false_neg) | ||||
| 
 | ||||
|     def calculate_fscore(self): | ||||
|         p = self.calculate_precision() | ||||
|         r = self.calculate_recall() | ||||
|         if p + r == 0: | ||||
|             return 0.0 | ||||
|         else: | ||||
|             return 2 * p * r / (p + r) | ||||
| 
 | ||||
| 
 | ||||
| class EvaluationResults(object): | ||||
|     def __init__(self): | ||||
|  | @ -43,18 +52,20 @@ class EvaluationResults(object): | |||
|         self.metrics.update_results(true_entity, candidate) | ||||
|         self.metrics_by_label[ent_label].update_results(true_entity, candidate) | ||||
| 
 | ||||
|     def increment_false_negatives(self): | ||||
|         self.metrics.false_neg += 1 | ||||
| 
 | ||||
|     def report_metrics(self, model_name): | ||||
|         model_str = model_name.title() | ||||
|         recall = self.metrics.calculate_recall() | ||||
|         precision = self.metrics.calculate_precision() | ||||
|         return ("{}: ".format(model_str) + | ||||
|                 "Recall = {} | ".format(round(recall, 3)) + | ||||
|                 "Precision = {} | ".format(round(precision, 3)) + | ||||
|                 "Precision by label = {}".format({k: v.calculate_precision() | ||||
|                                                   for k, v in self.metrics_by_label.items()})) | ||||
|         fscore = self.metrics.calculate_fscore() | ||||
|         return ( | ||||
|             "{}: ".format(model_str) | ||||
|             + "F-score = {} | ".format(round(fscore, 3)) | ||||
|             + "Recall = {} | ".format(round(recall, 3)) | ||||
|             + "Precision = {} | ".format(round(precision, 3)) | ||||
|             + "F-score by label = {}".format( | ||||
|                 {k: v.calculate_fscore() for k, v in sorted(self.metrics_by_label.items())} | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
| 
 | ||||
| class BaselineResults(object): | ||||
|  | @ -63,40 +74,51 @@ class BaselineResults(object): | |||
|         self.prior = EvaluationResults() | ||||
|         self.oracle = EvaluationResults() | ||||
| 
 | ||||
|     def report_accuracy(self, model): | ||||
|     def report_performance(self, model): | ||||
|         results = getattr(self, model) | ||||
|         return results.report_metrics(model) | ||||
| 
 | ||||
|     def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate): | ||||
|     def update_baselines( | ||||
|         self, | ||||
|         true_entity, | ||||
|         ent_label, | ||||
|         random_candidate, | ||||
|         prior_candidate, | ||||
|         oracle_candidate, | ||||
|     ): | ||||
|         self.oracle.update_metrics(ent_label, true_entity, oracle_candidate) | ||||
|         self.prior.update_metrics(ent_label, true_entity, prior_candidate) | ||||
|         self.random.update_metrics(ent_label, true_entity, random_candidate) | ||||
| 
 | ||||
| 
 | ||||
| def measure_performance(dev_data, kb, el_pipe): | ||||
|     baseline_accuracies = measure_baselines( | ||||
|         dev_data, kb | ||||
|     ) | ||||
| def measure_performance(dev_data, kb, el_pipe, baseline=True, context=True): | ||||
|     if baseline: | ||||
|         baseline_accuracies, counts = measure_baselines(dev_data, kb) | ||||
|         logger.info("Counts: {}".format({k: v for k, v in sorted(counts.items())})) | ||||
|         logger.info(baseline_accuracies.report_performance("random")) | ||||
|         logger.info(baseline_accuracies.report_performance("prior")) | ||||
|         logger.info(baseline_accuracies.report_performance("oracle")) | ||||
| 
 | ||||
|     logger.info(baseline_accuracies.report_accuracy("random")) | ||||
|     logger.info(baseline_accuracies.report_accuracy("prior")) | ||||
|     logger.info(baseline_accuracies.report_accuracy("oracle")) | ||||
|     if context: | ||||
|         # using only context | ||||
|         el_pipe.cfg["incl_context"] = True | ||||
|         el_pipe.cfg["incl_prior"] = False | ||||
|         results = get_eval_results(dev_data, el_pipe) | ||||
|         logger.info(results.report_metrics("context only")) | ||||
| 
 | ||||
|     # using only context | ||||
|     el_pipe.cfg["incl_context"] = True | ||||
|     el_pipe.cfg["incl_prior"] = False | ||||
|     results = get_eval_results(dev_data, el_pipe) | ||||
|     logger.info(results.report_metrics("context only")) | ||||
| 
 | ||||
|     # measuring combined accuracy (prior + context) | ||||
|     el_pipe.cfg["incl_context"] = True | ||||
|     el_pipe.cfg["incl_prior"] = True | ||||
|     results = get_eval_results(dev_data, el_pipe) | ||||
|     logger.info(results.report_metrics("context and prior")) | ||||
|         # measuring combined accuracy (prior + context) | ||||
|         el_pipe.cfg["incl_context"] = True | ||||
|         el_pipe.cfg["incl_prior"] = True | ||||
|         results = get_eval_results(dev_data, el_pipe) | ||||
|         logger.info(results.report_metrics("context and prior")) | ||||
| 
 | ||||
| 
 | ||||
| def get_eval_results(data, el_pipe=None): | ||||
|     # If the docs in the data require further processing with an entity linker, set el_pipe | ||||
|     """ | ||||
|     Evaluate the ent.kb_id_ annotations against the gold standard. | ||||
|     Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL. | ||||
|     If the docs in the data require further processing with an entity linker, set el_pipe. | ||||
|     """ | ||||
|     from tqdm import tqdm | ||||
| 
 | ||||
|     docs = [] | ||||
|  | @ -111,18 +133,15 @@ def get_eval_results(data, el_pipe=None): | |||
| 
 | ||||
|     results = EvaluationResults() | ||||
|     for doc, gold in zip(docs, golds): | ||||
|         tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents} | ||||
|         try: | ||||
|             correct_entries_per_article = dict() | ||||
|             for entity, kb_dict in gold.links.items(): | ||||
|                 start, end = entity | ||||
|                 # only evaluating on positive examples | ||||
|                 for gold_kb, value in kb_dict.items(): | ||||
|                     if value: | ||||
|                         # only evaluating on positive examples | ||||
|                         offset = _offset(start, end) | ||||
|                         correct_entries_per_article[offset] = gold_kb | ||||
|                         if offset not in tagged_entries_per_article: | ||||
|                             results.increment_false_negatives() | ||||
| 
 | ||||
|             for ent in doc.ents: | ||||
|                 ent_label = ent.label_ | ||||
|  | @ -142,7 +161,11 @@ def get_eval_results(data, el_pipe=None): | |||
| 
 | ||||
| 
 | ||||
| def measure_baselines(data, kb): | ||||
|     # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound | ||||
|     """ | ||||
|     Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound. | ||||
|     Only evaluate entities that overlap between gold and NER, to isolate the performance of the NEL. | ||||
|     Also return a dictionary of counts by entity label. | ||||
|     """ | ||||
|     counts_d = dict() | ||||
| 
 | ||||
|     baseline_results = BaselineResults() | ||||
|  | @ -152,7 +175,6 @@ def measure_baselines(data, kb): | |||
| 
 | ||||
|     for doc, gold in zip(docs, golds): | ||||
|         correct_entries_per_article = dict() | ||||
|         tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents} | ||||
|         for entity, kb_dict in gold.links.items(): | ||||
|             start, end = entity | ||||
|             for gold_kb, value in kb_dict.items(): | ||||
|  | @ -160,10 +182,6 @@ def measure_baselines(data, kb): | |||
|                 if value: | ||||
|                     offset = _offset(start, end) | ||||
|                     correct_entries_per_article[offset] = gold_kb | ||||
|                     if offset not in tagged_entries_per_article: | ||||
|                         baseline_results.random.increment_false_negatives() | ||||
|                         baseline_results.oracle.increment_false_negatives() | ||||
|                         baseline_results.prior.increment_false_negatives() | ||||
| 
 | ||||
|         for ent in doc.ents: | ||||
|             ent_label = ent.label_ | ||||
|  | @ -176,7 +194,7 @@ def measure_baselines(data, kb): | |||
|             if gold_entity is not None: | ||||
|                 candidates = kb.get_candidates(ent.text) | ||||
|                 oracle_candidate = "" | ||||
|                 best_candidate = "" | ||||
|                 prior_candidate = "" | ||||
|                 random_candidate = "" | ||||
|                 if candidates: | ||||
|                     scores = [] | ||||
|  | @ -187,13 +205,21 @@ def measure_baselines(data, kb): | |||
|                             oracle_candidate = c.entity_ | ||||
| 
 | ||||
|                     best_index = scores.index(max(scores)) | ||||
|                     best_candidate = candidates[best_index].entity_ | ||||
|                     prior_candidate = candidates[best_index].entity_ | ||||
|                     random_candidate = random.choice(candidates).entity_ | ||||
| 
 | ||||
|                 baseline_results.update_baselines(gold_entity, ent_label, | ||||
|                                                   random_candidate, best_candidate, oracle_candidate) | ||||
|                 current_count = counts_d.get(ent_label, 0) | ||||
|                 counts_d[ent_label] = current_count+1 | ||||
| 
 | ||||
|     return baseline_results | ||||
|                 baseline_results.update_baselines( | ||||
|                     gold_entity, | ||||
|                     ent_label, | ||||
|                     random_candidate, | ||||
|                     prior_candidate, | ||||
|                     oracle_candidate, | ||||
|                 ) | ||||
| 
 | ||||
|     return baseline_results, counts_d | ||||
| 
 | ||||
| 
 | ||||
| def _offset(start, end): | ||||
|  |  | |||
|  | @ -1,17 +1,12 @@ | |||
| # coding: utf-8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import csv | ||||
| import logging | ||||
| import spacy | ||||
| import sys | ||||
| 
 | ||||
| from spacy.kb import KnowledgeBase | ||||
| 
 | ||||
| from bin.wiki_entity_linking import wikipedia_processor as wp | ||||
| from bin.wiki_entity_linking.train_descriptions import EntityEncoder | ||||
| 
 | ||||
| csv.field_size_limit(sys.maxsize) | ||||
| from bin.wiki_entity_linking import wiki_io as io | ||||
| 
 | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
|  | @ -22,18 +17,24 @@ def create_kb( | |||
|     max_entities_per_alias, | ||||
|     min_entity_freq, | ||||
|     min_occ, | ||||
|     entity_def_input, | ||||
|     entity_def_path, | ||||
|     entity_descr_path, | ||||
|     count_input, | ||||
|     prior_prob_input, | ||||
|     entity_alias_path, | ||||
|     entity_freq_path, | ||||
|     prior_prob_path, | ||||
|     entity_vector_length, | ||||
| ): | ||||
|     # Create the knowledge base from Wikidata entries | ||||
|     kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length) | ||||
|     entity_list, filtered_title_to_id = _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length) | ||||
|     _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path) | ||||
|     return kb | ||||
| 
 | ||||
| 
 | ||||
| def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length): | ||||
|     # read the mappings from file | ||||
|     title_to_id = get_entity_to_id(entity_def_input) | ||||
|     id_to_descr = get_id_to_description(entity_descr_path) | ||||
|     title_to_id = io.read_title_to_id(entity_def_path) | ||||
|     id_to_descr = io.read_id_to_descr(entity_descr_path) | ||||
| 
 | ||||
|     # check the length of the nlp vectors | ||||
|     if "vectors" in nlp.meta and nlp.vocab.vectors.size: | ||||
|  | @ -45,10 +46,8 @@ def create_kb( | |||
|             " cf. https://spacy.io/usage/models#languages." | ||||
|         ) | ||||
| 
 | ||||
|     logger.info("Get entity frequencies") | ||||
|     entity_frequencies = wp.get_all_frequencies(count_input=count_input) | ||||
| 
 | ||||
|     logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq)) | ||||
|     entity_frequencies = io.read_entity_to_count(entity_freq_path) | ||||
|     # filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise | ||||
|     filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities( | ||||
|         title_to_id, | ||||
|  | @ -56,36 +55,33 @@ def create_kb( | |||
|         entity_frequencies, | ||||
|         min_entity_freq | ||||
|     ) | ||||
|     logger.info("Left with {} entities".format(len(description_list))) | ||||
|     logger.info("Kept {} entities from the set of {}".format(len(description_list), len(title_to_id.keys()))) | ||||
| 
 | ||||
|     logger.info("Train entity encoder") | ||||
|     logger.info("Training entity encoder") | ||||
|     encoder = EntityEncoder(nlp, input_dim, entity_vector_length) | ||||
|     encoder.train(description_list=description_list, to_print=True) | ||||
| 
 | ||||
|     logger.info("Get entity embeddings:") | ||||
|     logger.info("Getting entity embeddings") | ||||
|     embeddings = encoder.apply_encoder(description_list) | ||||
| 
 | ||||
|     logger.info("Adding {} entities".format(len(entity_list))) | ||||
|     kb.set_entities( | ||||
|         entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings | ||||
|     ) | ||||
|     return entity_list, filtered_title_to_id | ||||
| 
 | ||||
|     logger.info("Adding aliases") | ||||
| 
 | ||||
| def _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path): | ||||
|     logger.info("Adding aliases from Wikipedia and Wikidata") | ||||
|     _add_aliases( | ||||
|         kb, | ||||
|         entity_list=entity_list, | ||||
|         title_to_id=filtered_title_to_id, | ||||
|         max_entities_per_alias=max_entities_per_alias, | ||||
|         min_occ=min_occ, | ||||
|         prior_prob_input=prior_prob_input, | ||||
|         prior_prob_path=prior_prob_path, | ||||
|     ) | ||||
| 
 | ||||
|     logger.info("KB size: {} entities, {} aliases".format( | ||||
|         kb.get_size_entities(), | ||||
|         kb.get_size_aliases())) | ||||
| 
 | ||||
|     logger.info("Done with kb") | ||||
|     return kb | ||||
| 
 | ||||
| 
 | ||||
| def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies, | ||||
|                           min_entity_freq: int = 10): | ||||
|  | @ -104,34 +100,13 @@ def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies, | |||
|     return filtered_title_to_id, entity_list, description_list, frequency_list | ||||
| 
 | ||||
| 
 | ||||
| def get_entity_to_id(entity_def_output): | ||||
|     entity_to_id = dict() | ||||
|     with entity_def_output.open("r", encoding="utf8") as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|         for row in csvreader: | ||||
|             entity_to_id[row[0]] = row[1] | ||||
|     return entity_to_id | ||||
| 
 | ||||
| 
 | ||||
| def get_id_to_description(entity_descr_path): | ||||
|     id_to_desc = dict() | ||||
|     with entity_descr_path.open("r", encoding="utf8") as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|         for row in csvreader: | ||||
|             id_to_desc[row[0]] = row[1] | ||||
|     return id_to_desc | ||||
| 
 | ||||
| 
 | ||||
| def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input): | ||||
| def _add_aliases(kb, entity_list, title_to_id, max_entities_per_alias, min_occ, prior_prob_path): | ||||
|     wp_titles = title_to_id.keys() | ||||
| 
 | ||||
|     # adding aliases with prior probabilities | ||||
|     # we can read this file sequentially, it's sorted by alias, and then by count | ||||
|     with prior_prob_input.open("r", encoding="utf8") as prior_file: | ||||
|     logger.info("Adding WP aliases") | ||||
|     with prior_prob_path.open("r", encoding="utf8") as prior_file: | ||||
|         # skip header | ||||
|         prior_file.readline() | ||||
|         line = prior_file.readline() | ||||
|  | @ -180,10 +155,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in | |||
|             line = prior_file.readline() | ||||
| 
 | ||||
| 
 | ||||
| def read_nlp_kb(model_dir, kb_file): | ||||
|     nlp = spacy.load(model_dir) | ||||
| def read_kb(nlp, kb_file): | ||||
|     kb = KnowledgeBase(vocab=nlp.vocab) | ||||
|     kb.load_bulk(kb_file) | ||||
|     logger.info("kb entities: {}".format(kb.get_size_entities())) | ||||
|     logger.info("kb aliases: {}".format(kb.get_size_aliases())) | ||||
|     return nlp, kb | ||||
|     return kb | ||||
|  |  | |||
|  | @ -53,7 +53,7 @@ class EntityEncoder: | |||
| 
 | ||||
|             start = start + batch_size | ||||
|             stop = min(stop + batch_size, len(description_list)) | ||||
|             logger.info("encoded: {} entities".format(stop)) | ||||
|             logger.info("Encoded: {} entities".format(stop)) | ||||
| 
 | ||||
|         return encodings | ||||
| 
 | ||||
|  | @ -62,7 +62,7 @@ class EntityEncoder: | |||
|         if to_print: | ||||
|             logger.info( | ||||
|                 "Trained entity descriptions on {} ".format(processed) + | ||||
|                 "(non-unique) entities across {} ".format(self.epochs) + | ||||
|                 "(non-unique) descriptions across {} ".format(self.epochs) + | ||||
|                 "epochs" | ||||
|             ) | ||||
|             logger.info("Final loss: {}".format(loss)) | ||||
|  |  | |||
|  | @ -1,395 +0,0 @@ | |||
| # coding: utf-8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import logging | ||||
| import random | ||||
| import re | ||||
| import bz2 | ||||
| import json | ||||
| 
 | ||||
| from functools import partial | ||||
| 
 | ||||
| from spacy.gold import GoldParse | ||||
| from bin.wiki_entity_linking import kb_creator | ||||
| 
 | ||||
| """ | ||||
| Process Wikipedia interlinks to generate a training dataset for the EL algorithm. | ||||
| Gold-standard entities are stored in one file in standoff format (by character offset). | ||||
| """ | ||||
| 
 | ||||
| ENTITY_FILE = "gold_entities.csv" | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| def create_training_examples_and_descriptions(wikipedia_input, | ||||
|                                               entity_def_input, | ||||
|                                               description_output, | ||||
|                                               training_output, | ||||
|                                               parse_descriptions, | ||||
|                                               limit=None): | ||||
|     wp_to_id = kb_creator.get_entity_to_id(entity_def_input) | ||||
|     _process_wikipedia_texts(wikipedia_input, | ||||
|                              wp_to_id, | ||||
|                              description_output, | ||||
|                              training_output, | ||||
|                              parse_descriptions, | ||||
|                              limit) | ||||
| 
 | ||||
| 
 | ||||
| def _process_wikipedia_texts(wikipedia_input, | ||||
|                              wp_to_id, | ||||
|                              output, | ||||
|                              training_output, | ||||
|                              parse_descriptions, | ||||
|                              limit=None): | ||||
|     """ | ||||
|     Read the XML wikipedia data to parse out training data: | ||||
|     raw text data + positive instances | ||||
|     """ | ||||
|     title_regex = re.compile(r"(?<=<title>).*(?=</title>)") | ||||
|     id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)") | ||||
| 
 | ||||
|     read_ids = set() | ||||
| 
 | ||||
|     with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file: | ||||
|         if parse_descriptions: | ||||
|             _write_training_description(descr_file, "WD_id", "description") | ||||
|         with bz2.open(wikipedia_input, mode="rb") as file: | ||||
|             article_count = 0 | ||||
|             article_text = "" | ||||
|             article_title = None | ||||
|             article_id = None | ||||
|             reading_text = False | ||||
|             reading_revision = False | ||||
| 
 | ||||
|             logger.info("Processed {} articles".format(article_count)) | ||||
| 
 | ||||
|             for line in file: | ||||
|                 clean_line = line.strip().decode("utf-8") | ||||
| 
 | ||||
|                 if clean_line == "<revision>": | ||||
|                     reading_revision = True | ||||
|                 elif clean_line == "</revision>": | ||||
|                     reading_revision = False | ||||
| 
 | ||||
|                 # Start reading new page | ||||
|                 if clean_line == "<page>": | ||||
|                     article_text = "" | ||||
|                     article_title = None | ||||
|                     article_id = None | ||||
|                 # finished reading this page | ||||
|                 elif clean_line == "</page>": | ||||
|                     if article_id: | ||||
|                         clean_text, entities = _process_wp_text( | ||||
|                             article_title, | ||||
|                             article_text, | ||||
|                             wp_to_id | ||||
|                         ) | ||||
|                         if clean_text is not None and entities is not None: | ||||
|                             _write_training_entities(entity_file, | ||||
|                                                      article_id, | ||||
|                                                      clean_text, | ||||
|                                                      entities) | ||||
| 
 | ||||
|                             if article_title in wp_to_id and parse_descriptions: | ||||
|                                 description = " ".join(clean_text[:1000].split(" ")[:-1]) | ||||
|                                 _write_training_description( | ||||
|                                     descr_file, | ||||
|                                     wp_to_id[article_title], | ||||
|                                     description | ||||
|                                 ) | ||||
|                             article_count += 1 | ||||
|                             if article_count % 10000 == 0: | ||||
|                                 logger.info("Processed {} articles".format(article_count)) | ||||
|                             if limit and article_count >= limit: | ||||
|                                 break | ||||
|                     article_text = "" | ||||
|                     article_title = None | ||||
|                     article_id = None | ||||
|                     reading_text = False | ||||
|                     reading_revision = False | ||||
| 
 | ||||
|                 # start reading text within a page | ||||
|                 if "<text" in clean_line: | ||||
|                     reading_text = True | ||||
| 
 | ||||
|                 if reading_text: | ||||
|                     article_text += " " + clean_line | ||||
| 
 | ||||
|                 # stop reading text within a page (we assume a new page doesn't start on the same line) | ||||
|                 if "</text" in clean_line: | ||||
|                     reading_text = False | ||||
| 
 | ||||
|                 # read the ID of this article (outside the revision portion of the document) | ||||
|                 if not reading_revision: | ||||
|                     ids = id_regex.search(clean_line) | ||||
|                     if ids: | ||||
|                         article_id = ids[0] | ||||
|                         if article_id in read_ids: | ||||
|                             logger.info( | ||||
|                                 "Found duplicate article ID", article_id, clean_line | ||||
|                             )  # This should never happen ... | ||||
|                         read_ids.add(article_id) | ||||
| 
 | ||||
|                 # read the title of this article (outside the revision portion of the document) | ||||
|                 if not reading_revision: | ||||
|                     titles = title_regex.search(clean_line) | ||||
|                     if titles: | ||||
|                         article_title = titles[0].strip() | ||||
|     logger.info("Finished. Processed {} articles".format(article_count)) | ||||
| 
 | ||||
| 
 | ||||
| text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)") | ||||
| info_regex = re.compile(r"{[^{]*?}") | ||||
| htlm_regex = re.compile(r"<!--[^-]*-->") | ||||
| category_regex = re.compile(r"\[\[Category:[^\[]*]]") | ||||
| file_regex = re.compile(r"\[\[File:[^[\]]+]]") | ||||
| ref_regex = re.compile(r"<ref.*?>")  # non-greedy | ||||
| ref_2_regex = re.compile(r"</ref.*?>")  # non-greedy | ||||
| 
 | ||||
| 
 | ||||
| def _process_wp_text(article_title, article_text, wp_to_id): | ||||
|     # ignore meta Wikipedia pages | ||||
|     if ( | ||||
|         article_title.startswith("Wikipedia:") or | ||||
|         article_title.startswith("Kategori:") | ||||
|     ): | ||||
|         return None, None | ||||
| 
 | ||||
|     # remove the text tags | ||||
|     text_search = text_regex.search(article_text) | ||||
|     if text_search is None: | ||||
|         return None, None | ||||
|     text = text_search.group(0) | ||||
| 
 | ||||
|     # stop processing if this is a redirect page | ||||
|     if text.startswith("#REDIRECT"): | ||||
|         return None, None | ||||
| 
 | ||||
|     # get the raw text without markup etc, keeping only interwiki links | ||||
|     clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id) | ||||
|     return clean_text, entities | ||||
| 
 | ||||
| 
 | ||||
| def _get_clean_wp_text(article_text): | ||||
|     clean_text = article_text.strip() | ||||
| 
 | ||||
|     # remove bolding & italic markup | ||||
|     clean_text = clean_text.replace("'''", "") | ||||
|     clean_text = clean_text.replace("''", "") | ||||
| 
 | ||||
|     # remove nested {{info}} statements by removing the inner/smallest ones first and iterating | ||||
|     try_again = True | ||||
|     previous_length = len(clean_text) | ||||
|     while try_again: | ||||
|         clean_text = info_regex.sub( | ||||
|             "", clean_text | ||||
|         )  # non-greedy match excluding a nested { | ||||
|         if len(clean_text) < previous_length: | ||||
|             try_again = True | ||||
|         else: | ||||
|             try_again = False | ||||
|         previous_length = len(clean_text) | ||||
| 
 | ||||
|     # remove HTML comments | ||||
|     clean_text = htlm_regex.sub("", clean_text) | ||||
| 
 | ||||
|     # remove Category and File statements | ||||
|     clean_text = category_regex.sub("", clean_text) | ||||
|     clean_text = file_regex.sub("", clean_text) | ||||
| 
 | ||||
|     # remove multiple = | ||||
|     while "==" in clean_text: | ||||
|         clean_text = clean_text.replace("==", "=") | ||||
| 
 | ||||
|     clean_text = clean_text.replace(". =", ".") | ||||
|     clean_text = clean_text.replace(" = ", ". ") | ||||
|     clean_text = clean_text.replace("= ", ".") | ||||
|     clean_text = clean_text.replace(" =", "") | ||||
| 
 | ||||
|     # remove refs (non-greedy match) | ||||
|     clean_text = ref_regex.sub("", clean_text) | ||||
|     clean_text = ref_2_regex.sub("", clean_text) | ||||
| 
 | ||||
|     # remove additional wikiformatting | ||||
|     clean_text = re.sub(r"<blockquote>", "", clean_text) | ||||
|     clean_text = re.sub(r"</blockquote>", "", clean_text) | ||||
| 
 | ||||
|     # change special characters back to normal ones | ||||
|     clean_text = clean_text.replace(r"<", "<") | ||||
|     clean_text = clean_text.replace(r">", ">") | ||||
|     clean_text = clean_text.replace(r""", '"') | ||||
|     clean_text = clean_text.replace(r"&nbsp;", " ") | ||||
|     clean_text = clean_text.replace(r"&", "&") | ||||
| 
 | ||||
|     # remove multiple spaces | ||||
|     while "  " in clean_text: | ||||
|         clean_text = clean_text.replace("  ", " ") | ||||
| 
 | ||||
|     return clean_text.strip() | ||||
| 
 | ||||
| 
 | ||||
| def _remove_links(clean_text, wp_to_id): | ||||
|     # read the text char by char to get the right offsets for the interwiki links | ||||
|     entities = [] | ||||
|     final_text = "" | ||||
|     open_read = 0 | ||||
|     reading_text = True | ||||
|     reading_entity = False | ||||
|     reading_mention = False | ||||
|     reading_special_case = False | ||||
|     entity_buffer = "" | ||||
|     mention_buffer = "" | ||||
|     for index, letter in enumerate(clean_text): | ||||
|         if letter == "[": | ||||
|             open_read += 1 | ||||
|         elif letter == "]": | ||||
|             open_read -= 1 | ||||
|         elif letter == "|": | ||||
|             if reading_text: | ||||
|                 final_text += letter | ||||
|             # switch from reading entity to mention in the [[entity|mention]] pattern | ||||
|             elif reading_entity: | ||||
|                 reading_text = False | ||||
|                 reading_entity = False | ||||
|                 reading_mention = True | ||||
|             else: | ||||
|                 reading_special_case = True | ||||
|         else: | ||||
|             if reading_entity: | ||||
|                 entity_buffer += letter | ||||
|             elif reading_mention: | ||||
|                 mention_buffer += letter | ||||
|             elif reading_text: | ||||
|                 final_text += letter | ||||
|             else: | ||||
|                 raise ValueError("Not sure at point", clean_text[index - 2: index + 2]) | ||||
| 
 | ||||
|         if open_read > 2: | ||||
|             reading_special_case = True | ||||
| 
 | ||||
|         if open_read == 2 and reading_text: | ||||
|             reading_text = False | ||||
|             reading_entity = True | ||||
|             reading_mention = False | ||||
| 
 | ||||
|         # we just finished reading an entity | ||||
|         if open_read == 0 and not reading_text: | ||||
|             if "#" in entity_buffer or entity_buffer.startswith(":"): | ||||
|                 reading_special_case = True | ||||
|             # Ignore cases with nested structures like File: handles etc | ||||
|             if not reading_special_case: | ||||
|                 if not mention_buffer: | ||||
|                     mention_buffer = entity_buffer | ||||
|                 start = len(final_text) | ||||
|                 end = start + len(mention_buffer) | ||||
|                 qid = wp_to_id.get(entity_buffer, None) | ||||
|                 if qid: | ||||
|                     entities.append((mention_buffer, qid, start, end)) | ||||
|                 final_text += mention_buffer | ||||
| 
 | ||||
|             entity_buffer = "" | ||||
|             mention_buffer = "" | ||||
| 
 | ||||
|             reading_text = True | ||||
|             reading_entity = False | ||||
|             reading_mention = False | ||||
|             reading_special_case = False | ||||
|     return final_text, entities | ||||
| 
 | ||||
| 
 | ||||
| def _write_training_description(outputfile, qid, description): | ||||
|     if description is not None: | ||||
|         line = str(qid) + "|" + description + "\n" | ||||
|         outputfile.write(line) | ||||
| 
 | ||||
| 
 | ||||
| def _write_training_entities(outputfile, article_id, clean_text, entities): | ||||
|     entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities] | ||||
|     line = json.dumps( | ||||
|         { | ||||
|             "article_id": article_id, | ||||
|             "clean_text": clean_text, | ||||
|             "entities": entities_data | ||||
|         }, | ||||
|         ensure_ascii=False) + "\n" | ||||
|     outputfile.write(line) | ||||
| 
 | ||||
| 
 | ||||
| def read_training(nlp, entity_file_path, dev, limit, kb): | ||||
|     """ This method provides training examples that correspond to the entity annotations found by the nlp object. | ||||
|      For training,, it will include negative training examples by using the candidate generator, | ||||
|      and it will only keep positive training examples that can be found by using the candidate generator. | ||||
|      For testing, it will include all positive examples only.""" | ||||
| 
 | ||||
|     from tqdm import tqdm | ||||
|     data = [] | ||||
|     num_entities = 0 | ||||
|     get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb) | ||||
| 
 | ||||
|     logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit)) | ||||
|     with entity_file_path.open("r", encoding="utf8") as file: | ||||
|         with tqdm(total=limit, leave=False) as pbar: | ||||
|             for i, line in enumerate(file): | ||||
|                 example = json.loads(line) | ||||
|                 article_id = example["article_id"] | ||||
|                 clean_text = example["clean_text"] | ||||
|                 entities = example["entities"] | ||||
| 
 | ||||
|                 if dev != is_dev(article_id) or len(clean_text) >= 30000: | ||||
|                     continue | ||||
| 
 | ||||
|                 doc = nlp(clean_text) | ||||
|                 gold = get_gold_parse(doc, entities) | ||||
|                 if gold and len(gold.links) > 0: | ||||
|                     data.append((doc, gold)) | ||||
|                     num_entities += len(gold.links) | ||||
|                     pbar.update(len(gold.links)) | ||||
|                 if limit and num_entities >= limit: | ||||
|                     break | ||||
|     logger.info("Read {} entities in {} articles".format(num_entities, len(data))) | ||||
|     return data | ||||
| 
 | ||||
| 
 | ||||
| def _get_gold_parse(doc, entities, dev, kb): | ||||
|     gold_entities = {} | ||||
|     tagged_ent_positions = set( | ||||
|         [(ent.start_char, ent.end_char) for ent in doc.ents] | ||||
|     ) | ||||
| 
 | ||||
|     for entity in entities: | ||||
|         entity_id = entity["entity"] | ||||
|         alias = entity["alias"] | ||||
|         start = entity["start"] | ||||
|         end = entity["end"] | ||||
| 
 | ||||
|         candidates = kb.get_candidates(alias) | ||||
|         candidate_ids = [ | ||||
|             c.entity_ for c in candidates | ||||
|         ] | ||||
| 
 | ||||
|         should_add_ent = ( | ||||
|             dev or | ||||
|             ( | ||||
|                 (start, end) in tagged_ent_positions and | ||||
|                 entity_id in candidate_ids and | ||||
|                 len(candidates) > 1 | ||||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         if should_add_ent: | ||||
|             value_by_id = {entity_id: 1.0} | ||||
|             if not dev: | ||||
|                 random.shuffle(candidate_ids) | ||||
|                 value_by_id.update({ | ||||
|                     kb_id: 0.0 | ||||
|                     for kb_id in candidate_ids | ||||
|                     if kb_id != entity_id | ||||
|                 }) | ||||
|             gold_entities[(start, end)] = value_by_id | ||||
| 
 | ||||
|     return GoldParse(doc, links=gold_entities) | ||||
| 
 | ||||
| 
 | ||||
| def is_dev(article_id): | ||||
|     return article_id.endswith("3") | ||||
							
								
								
									
										127
									
								
								bin/wiki_entity_linking/wiki_io.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										127
									
								
								bin/wiki_entity_linking/wiki_io.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,127 @@ | |||
| # coding: utf-8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import sys | ||||
| import csv | ||||
| 
 | ||||
| # min() needed to prevent error on windows, cf https://stackoverflow.com/questions/52404416/ | ||||
| csv.field_size_limit(min(sys.maxsize, 2147483646)) | ||||
| 
 | ||||
| """ This class provides reading/writing methods for temp files """ | ||||
| 
 | ||||
| 
 | ||||
| # Entity definition: WP title -> WD ID # | ||||
| def write_title_to_id(entity_def_output, title_to_id): | ||||
|     with entity_def_output.open("w", encoding="utf8") as id_file: | ||||
|         id_file.write("WP_title" + "|" + "WD_id" + "\n") | ||||
|         for title, qid in title_to_id.items(): | ||||
|             id_file.write(title + "|" + str(qid) + "\n") | ||||
| 
 | ||||
| 
 | ||||
| def read_title_to_id(entity_def_output): | ||||
|     title_to_id = dict() | ||||
|     with entity_def_output.open("r", encoding="utf8") as id_file: | ||||
|         csvreader = csv.reader(id_file, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|         for row in csvreader: | ||||
|             title_to_id[row[0]] = row[1] | ||||
|     return title_to_id | ||||
| 
 | ||||
| 
 | ||||
| # Entity aliases from WD: WD ID -> WD alias # | ||||
| def write_id_to_alias(entity_alias_path, id_to_alias): | ||||
|     with entity_alias_path.open("w", encoding="utf8") as alias_file: | ||||
|         alias_file.write("WD_id" + "|" + "alias" + "\n") | ||||
|         for qid, alias_list in id_to_alias.items(): | ||||
|             for alias in alias_list: | ||||
|                 alias_file.write(str(qid) + "|" + alias + "\n") | ||||
| 
 | ||||
| 
 | ||||
| def read_id_to_alias(entity_alias_path): | ||||
|     id_to_alias = dict() | ||||
|     with entity_alias_path.open("r", encoding="utf8") as alias_file: | ||||
|         csvreader = csv.reader(alias_file, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|         for row in csvreader: | ||||
|             qid = row[0] | ||||
|             alias = row[1] | ||||
|             alias_list = id_to_alias.get(qid, []) | ||||
|             alias_list.append(alias) | ||||
|             id_to_alias[qid] = alias_list | ||||
|     return id_to_alias | ||||
| 
 | ||||
| 
 | ||||
| def read_alias_to_id_generator(entity_alias_path): | ||||
|     """ Read (aliases, qid) tuples """ | ||||
| 
 | ||||
|     with entity_alias_path.open("r", encoding="utf8") as alias_file: | ||||
|         csvreader = csv.reader(alias_file, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|         for row in csvreader: | ||||
|             qid = row[0] | ||||
|             alias = row[1] | ||||
|             yield alias, qid | ||||
| 
 | ||||
| 
 | ||||
| # Entity descriptions from WD: WD ID -> WD alias # | ||||
| def write_id_to_descr(entity_descr_output, id_to_descr): | ||||
|     with entity_descr_output.open("w", encoding="utf8") as descr_file: | ||||
|         descr_file.write("WD_id" + "|" + "description" + "\n") | ||||
|         for qid, descr in id_to_descr.items(): | ||||
|             descr_file.write(str(qid) + "|" + descr + "\n") | ||||
| 
 | ||||
| 
 | ||||
| def read_id_to_descr(entity_desc_path): | ||||
|     id_to_desc = dict() | ||||
|     with entity_desc_path.open("r", encoding="utf8") as descr_file: | ||||
|         csvreader = csv.reader(descr_file, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|         for row in csvreader: | ||||
|             id_to_desc[row[0]] = row[1] | ||||
|     return id_to_desc | ||||
| 
 | ||||
| 
 | ||||
| # Entity counts from WP: WP title -> count # | ||||
| def write_entity_to_count(prior_prob_input, count_output): | ||||
|     # Write entity counts for quick access later | ||||
|     entity_to_count = dict() | ||||
|     total_count = 0 | ||||
| 
 | ||||
|     with prior_prob_input.open("r", encoding="utf8") as prior_file: | ||||
|         # skip header | ||||
|         prior_file.readline() | ||||
|         line = prior_file.readline() | ||||
| 
 | ||||
|         while line: | ||||
|             splits = line.replace("\n", "").split(sep="|") | ||||
|             # alias = splits[0] | ||||
|             count = int(splits[1]) | ||||
|             entity = splits[2] | ||||
| 
 | ||||
|             current_count = entity_to_count.get(entity, 0) | ||||
|             entity_to_count[entity] = current_count + count | ||||
| 
 | ||||
|             total_count += count | ||||
| 
 | ||||
|             line = prior_file.readline() | ||||
| 
 | ||||
|     with count_output.open("w", encoding="utf8") as entity_file: | ||||
|         entity_file.write("entity" + "|" + "count" + "\n") | ||||
|         for entity, count in entity_to_count.items(): | ||||
|             entity_file.write(entity + "|" + str(count) + "\n") | ||||
| 
 | ||||
| 
 | ||||
| def read_entity_to_count(count_input): | ||||
|     entity_to_count = dict() | ||||
|     with count_input.open("r", encoding="utf8") as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|         for row in csvreader: | ||||
|             entity_to_count[row[0]] = int(row[1]) | ||||
| 
 | ||||
|     return entity_to_count | ||||
							
								
								
									
										128
									
								
								bin/wiki_entity_linking/wiki_namespaces.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										128
									
								
								bin/wiki_entity_linking/wiki_namespaces.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,128 @@ | |||
| # coding: utf8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| # List of meta pages in Wikidata, should be kept out of the Knowledge base | ||||
| WD_META_ITEMS = [ | ||||
|     "Q163875", | ||||
|     "Q191780", | ||||
|     "Q224414", | ||||
|     "Q4167836", | ||||
|     "Q4167410", | ||||
|     "Q4663903", | ||||
|     "Q11266439", | ||||
|     "Q13406463", | ||||
|     "Q15407973", | ||||
|     "Q18616576", | ||||
|     "Q19887878", | ||||
|     "Q22808320", | ||||
|     "Q23894233", | ||||
|     "Q33120876", | ||||
|     "Q42104522", | ||||
|     "Q47460393", | ||||
|     "Q64875536", | ||||
|     "Q66480449", | ||||
| ] | ||||
| 
 | ||||
| 
 | ||||
| # TODO: add more cases from non-English WP's | ||||
| 
 | ||||
| # List of prefixes that refer to Wikipedia "file" pages | ||||
| WP_FILE_NAMESPACE = ["Bestand", "File"] | ||||
| 
 | ||||
| # List of prefixes that refer to Wikipedia "category" pages | ||||
| WP_CATEGORY_NAMESPACE = ["Kategori", "Category", "Categorie"] | ||||
| 
 | ||||
| # List of prefixes that refer to Wikipedia "meta" pages | ||||
| # these will/should be matched ignoring case | ||||
| WP_META_NAMESPACE = ( | ||||
|     WP_FILE_NAMESPACE | ||||
|     + WP_CATEGORY_NAMESPACE | ||||
|     + [ | ||||
|         "b", | ||||
|         "betawikiversity", | ||||
|         "Book", | ||||
|         "c", | ||||
|         "Commons", | ||||
|         "d", | ||||
|         "dbdump", | ||||
|         "download", | ||||
|         "Draft", | ||||
|         "Education", | ||||
|         "Foundation", | ||||
|         "Gadget", | ||||
|         "Gadget definition", | ||||
|         "Gebruiker", | ||||
|         "gerrit", | ||||
|         "Help", | ||||
|         "Image", | ||||
|         "Incubator", | ||||
|         "m", | ||||
|         "mail", | ||||
|         "mailarchive", | ||||
|         "media", | ||||
|         "MediaWiki", | ||||
|         "MediaWiki talk", | ||||
|         "Mediawikiwiki", | ||||
|         "MediaZilla", | ||||
|         "Meta", | ||||
|         "Metawikipedia", | ||||
|         "Module", | ||||
|         "mw", | ||||
|         "n", | ||||
|         "nost", | ||||
|         "oldwikisource", | ||||
|         "otrs", | ||||
|         "OTRSwiki", | ||||
|         "Overleg gebruiker", | ||||
|         "outreach", | ||||
|         "outreachwiki", | ||||
|         "Portal", | ||||
|         "phab", | ||||
|         "Phabricator", | ||||
|         "Project", | ||||
|         "q", | ||||
|         "quality", | ||||
|         "rev", | ||||
|         "s", | ||||
|         "spcom", | ||||
|         "Special", | ||||
|         "species", | ||||
|         "Strategy", | ||||
|         "sulutil", | ||||
|         "svn", | ||||
|         "Talk", | ||||
|         "Template", | ||||
|         "Template talk", | ||||
|         "Testwiki", | ||||
|         "ticket", | ||||
|         "TimedText", | ||||
|         "Toollabs", | ||||
|         "tools", | ||||
|         "tswiki", | ||||
|         "User", | ||||
|         "User talk", | ||||
|         "v", | ||||
|         "voy", | ||||
|         "w", | ||||
|         "Wikibooks", | ||||
|         "Wikidata", | ||||
|         "wikiHow", | ||||
|         "Wikinvest", | ||||
|         "wikilivres", | ||||
|         "Wikimedia", | ||||
|         "Wikinews", | ||||
|         "Wikipedia", | ||||
|         "Wikipedia talk", | ||||
|         "Wikiquote", | ||||
|         "Wikisource", | ||||
|         "Wikispecies", | ||||
|         "Wikitech", | ||||
|         "Wikiversity", | ||||
|         "Wikivoyage", | ||||
|         "wikt", | ||||
|         "wiktionary", | ||||
|         "wmf", | ||||
|         "wmania", | ||||
|         "WP", | ||||
|     ] | ||||
| ) | ||||
|  | @ -18,11 +18,12 @@ from pathlib import Path | |||
| import plac | ||||
| 
 | ||||
| from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd | ||||
| from bin.wiki_entity_linking import wiki_io as io | ||||
| from bin.wiki_entity_linking import kb_creator | ||||
| from bin.wiki_entity_linking import training_set_creator | ||||
| from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT | ||||
| from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH | ||||
| from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH, ENTITY_ALIAS_PATH | ||||
| import spacy | ||||
| from bin.wiki_entity_linking.kb_creator import read_kb | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
|  | @ -39,9 +40,11 @@ logger = logging.getLogger(__name__) | |||
|     loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path), | ||||
|     loc_entity_defs=("Location to file with entity definitions", "option", "d", Path), | ||||
|     loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path), | ||||
|     descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"), | ||||
|     limit=("Optional threshold to limit lines read from dumps", "option", "l", int), | ||||
|     lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str), | ||||
|     descr_from_wp=("Flag for using wp descriptions not wd", "flag", "wp"), | ||||
|     limit_prior=("Threshold to limit lines read from WP for prior probabilities", "option", "lp", int), | ||||
|     limit_train=("Threshold to limit lines read from WP for training set", "option", "lt", int), | ||||
|     limit_wd=("Threshold to limit lines read from WD", "option", "lw", int), | ||||
|     lang=("Optional language for which to get Wikidata titles. Defaults to 'en'", "option", "la", str), | ||||
| ) | ||||
| def main( | ||||
|     wd_json, | ||||
|  | @ -54,13 +57,16 @@ def main( | |||
|     entity_vector_length=64, | ||||
|     loc_prior_prob=None, | ||||
|     loc_entity_defs=None, | ||||
|     loc_entity_alias=None, | ||||
|     loc_entity_desc=None, | ||||
|     descriptions_from_wikipedia=False, | ||||
|     limit=None, | ||||
|     descr_from_wp=False, | ||||
|     limit_prior=None, | ||||
|     limit_train=None, | ||||
|     limit_wd=None, | ||||
|     lang="en", | ||||
| ): | ||||
| 
 | ||||
|     entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH | ||||
|     entity_alias_path = loc_entity_alias if loc_entity_alias else output_dir / ENTITY_ALIAS_PATH | ||||
|     entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH | ||||
|     entity_freq_path = output_dir / ENTITY_FREQ_PATH | ||||
|     prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH | ||||
|  | @ -69,15 +75,12 @@ def main( | |||
| 
 | ||||
|     logger.info("Creating KB with Wikipedia and WikiData") | ||||
| 
 | ||||
|     if limit is not None: | ||||
|         logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit)) | ||||
| 
 | ||||
|     # STEP 0: set up IO | ||||
|     if not output_dir.exists(): | ||||
|         output_dir.mkdir(parents=True) | ||||
| 
 | ||||
|     # STEP 1: create the NLP object | ||||
|     logger.info("STEP 1: Loading model {}".format(model)) | ||||
|     # STEP 1: Load the NLP object | ||||
|     logger.info("STEP 1: Loading NLP model {}".format(model)) | ||||
|     nlp = spacy.load(model) | ||||
| 
 | ||||
|     # check the length of the nlp vectors | ||||
|  | @ -90,62 +93,83 @@ def main( | |||
|     # STEP 2: create prior probabilities from WP | ||||
|     if not prior_prob_path.exists(): | ||||
|         # It takes about 2h to process 1000M lines of Wikipedia XML dump | ||||
|         logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path)) | ||||
|         wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit) | ||||
|     logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path)) | ||||
|         logger.info("STEP 2: Writing prior probabilities to {}".format(prior_prob_path)) | ||||
|         if limit_prior is not None: | ||||
|             logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_prior)) | ||||
|         wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit_prior) | ||||
|     else: | ||||
|         logger.info("STEP 2: Reading prior probabilities from {}".format(prior_prob_path)) | ||||
| 
 | ||||
|     # STEP 3: deduce entity frequencies from WP (takes only a few minutes) | ||||
|     logger.info("STEP 3: calculating entity frequencies") | ||||
|     wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False) | ||||
|     # STEP 3: calculate entity frequencies | ||||
|     if not entity_freq_path.exists(): | ||||
|         logger.info("STEP 3: Calculating and writing entity frequencies to {}".format(entity_freq_path)) | ||||
|         io.write_entity_to_count(prior_prob_path, entity_freq_path) | ||||
|     else: | ||||
|         logger.info("STEP 3: Reading entity frequencies from {}".format(entity_freq_path)) | ||||
| 
 | ||||
|     # STEP 4: reading definitions and (possibly) descriptions from WikiData or from file | ||||
|     message = " and descriptions" if not descriptions_from_wikipedia else "" | ||||
|     if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()): | ||||
|     if (not entity_defs_path.exists()) or (not descr_from_wp and not entity_descr_path.exists()): | ||||
|         # It takes about 10h to process 55M lines of Wikidata JSON dump | ||||
|         logger.info("STEP 4: parsing wikidata for entity definitions" + message) | ||||
|         title_to_id, id_to_descr = wd.read_wikidata_entities_json( | ||||
|         logger.info("STEP 4: Parsing and writing Wikidata entity definitions to {}".format(entity_defs_path)) | ||||
|         if limit_wd is not None: | ||||
|             logger.warning("Warning: reading only {} lines of Wikidata dump".format(limit_wd)) | ||||
|         title_to_id, id_to_descr, id_to_alias = wd.read_wikidata_entities_json( | ||||
|             wd_json, | ||||
|             limit, | ||||
|             limit_wd, | ||||
|             to_print=False, | ||||
|             lang=lang, | ||||
|             parse_descriptions=(not descriptions_from_wikipedia), | ||||
|             parse_descr=(not descr_from_wp), | ||||
|         ) | ||||
|         wd.write_entity_files(entity_defs_path, title_to_id) | ||||
|         if not descriptions_from_wikipedia: | ||||
|             wd.write_entity_description_files(entity_descr_path, id_to_descr) | ||||
|     logger.info("STEP 4: read entity definitions" + message) | ||||
|         io.write_title_to_id(entity_defs_path, title_to_id) | ||||
| 
 | ||||
|     # STEP 5: Getting gold entities from wikipedia | ||||
|     message = " and descriptions" if descriptions_from_wikipedia else "" | ||||
|     if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()): | ||||
|         logger.info("STEP 5: parsing wikipedia for gold entities" + message) | ||||
|         training_set_creator.create_training_examples_and_descriptions( | ||||
|             wp_xml, | ||||
|             entity_defs_path, | ||||
|             entity_descr_path, | ||||
|             training_entities_path, | ||||
|             parse_descriptions=descriptions_from_wikipedia, | ||||
|             limit=limit, | ||||
|         ) | ||||
|     logger.info("STEP 5: read gold entities" + message) | ||||
|         logger.info("STEP 4b: Writing Wikidata entity aliases to {}".format(entity_alias_path)) | ||||
|         io.write_id_to_alias(entity_alias_path, id_to_alias) | ||||
| 
 | ||||
|         if not descr_from_wp: | ||||
|             logger.info("STEP 4c: Writing Wikidata entity descriptions to {}".format(entity_descr_path)) | ||||
|             io.write_id_to_descr(entity_descr_path, id_to_descr) | ||||
|     else: | ||||
|         logger.info("STEP 4: Reading entity definitions from {}".format(entity_defs_path)) | ||||
|         logger.info("STEP 4b: Reading entity aliases from {}".format(entity_alias_path)) | ||||
|         if not descr_from_wp: | ||||
|             logger.info("STEP 4c: Reading entity descriptions from {}".format(entity_descr_path)) | ||||
| 
 | ||||
|     # STEP 5: Getting gold entities from Wikipedia | ||||
|     if (not training_entities_path.exists()) or (descr_from_wp and not entity_descr_path.exists()): | ||||
|         logger.info("STEP 5: Parsing and writing Wikipedia gold entities to {}".format(training_entities_path)) | ||||
|         if limit_train is not None: | ||||
|             logger.warning("Warning: reading only {} lines of Wikipedia dump".format(limit_train)) | ||||
|         wp.create_training_and_desc(wp_xml, entity_defs_path, entity_descr_path, | ||||
|                                     training_entities_path, descr_from_wp, limit_train) | ||||
|         if descr_from_wp: | ||||
|             logger.info("STEP 5b: Parsing and writing Wikipedia descriptions to {}".format(entity_descr_path)) | ||||
|     else: | ||||
|         logger.info("STEP 5: Reading gold entities from {}".format(training_entities_path)) | ||||
|         if descr_from_wp: | ||||
|             logger.info("STEP 5b: Reading entity descriptions from {}".format(entity_descr_path)) | ||||
| 
 | ||||
|     # STEP 6: creating the actual KB | ||||
|     # It takes ca. 30 minutes to pretrain the entity embeddings | ||||
|     logger.info("STEP 6: creating the KB at {}".format(kb_path)) | ||||
|     kb = kb_creator.create_kb( | ||||
|         nlp=nlp, | ||||
|         max_entities_per_alias=max_per_alias, | ||||
|         min_entity_freq=min_freq, | ||||
|         min_occ=min_pair, | ||||
|         entity_def_input=entity_defs_path, | ||||
|         entity_descr_path=entity_descr_path, | ||||
|         count_input=entity_freq_path, | ||||
|         prior_prob_input=prior_prob_path, | ||||
|         entity_vector_length=entity_vector_length, | ||||
|     ) | ||||
| 
 | ||||
|     kb.dump(kb_path) | ||||
|     nlp.to_disk(output_dir / KB_MODEL_DIR) | ||||
|     if not kb_path.exists(): | ||||
|         logger.info("STEP 6: Creating the KB at {}".format(kb_path)) | ||||
|         kb = kb_creator.create_kb( | ||||
|             nlp=nlp, | ||||
|             max_entities_per_alias=max_per_alias, | ||||
|             min_entity_freq=min_freq, | ||||
|             min_occ=min_pair, | ||||
|             entity_def_path=entity_defs_path, | ||||
|             entity_descr_path=entity_descr_path, | ||||
|             entity_alias_path=entity_alias_path, | ||||
|             entity_freq_path=entity_freq_path, | ||||
|             prior_prob_path=prior_prob_path, | ||||
|             entity_vector_length=entity_vector_length, | ||||
|         ) | ||||
|         kb.dump(kb_path) | ||||
|         logger.info("kb entities: {}".format(kb.get_size_entities())) | ||||
|         logger.info("kb aliases: {}".format(kb.get_size_aliases())) | ||||
|         nlp.to_disk(output_dir / KB_MODEL_DIR) | ||||
|     else: | ||||
|         logger.info("STEP 6: KB already exists at {}".format(kb_path)) | ||||
| 
 | ||||
|     logger.info("Done!") | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,40 +1,52 @@ | |||
| # coding: utf-8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import gzip | ||||
| import bz2 | ||||
| import json | ||||
| import logging | ||||
| import datetime | ||||
| 
 | ||||
| from bin.wiki_entity_linking.wiki_namespaces import WD_META_ITEMS | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True): | ||||
|     # Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. | ||||
| def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descr=True): | ||||
|     # Read the JSON wiki data and parse out the entities. Takes about 7-10h to parse 55M lines. | ||||
|     # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ | ||||
| 
 | ||||
|     site_filter = '{}wiki'.format(lang) | ||||
| 
 | ||||
|     # properties filter (currently disabled to get ALL data) | ||||
|     prop_filter = dict() | ||||
|     # prop_filter = {'P31': {'Q5', 'Q15632617'}}     # currently defined as OR: one property suffices to be selected | ||||
|     # filter: currently defined as OR: one hit suffices to be removed from further processing | ||||
|     exclude_list = WD_META_ITEMS | ||||
| 
 | ||||
|     # punctuation | ||||
|     exclude_list.extend(["Q1383557", "Q10617810"]) | ||||
| 
 | ||||
|     # letters etc | ||||
|     exclude_list.extend(["Q188725", "Q19776628", "Q3841820", "Q17907810", "Q9788", "Q9398093"]) | ||||
| 
 | ||||
|     neg_prop_filter = { | ||||
|         'P31': exclude_list,    # instance of | ||||
|         'P279': exclude_list    # subclass | ||||
|     } | ||||
| 
 | ||||
|     title_to_id = dict() | ||||
|     id_to_descr = dict() | ||||
|     id_to_alias = dict() | ||||
| 
 | ||||
|     # parse appropriate fields - depending on what we need in the KB | ||||
|     parse_properties = False | ||||
|     parse_sitelinks = True | ||||
|     parse_labels = False | ||||
|     parse_aliases = False | ||||
|     parse_claims = False | ||||
|     parse_aliases = True | ||||
|     parse_claims = True | ||||
| 
 | ||||
|     with gzip.open(wikidata_file, mode='rb') as file: | ||||
|     with bz2.open(wikidata_file, mode='rb') as file: | ||||
|         for cnt, line in enumerate(file): | ||||
|             if limit and cnt >= limit: | ||||
|                 break | ||||
|             if cnt % 500000 == 0: | ||||
|                 logger.info("processed {} lines of WikiData dump".format(cnt)) | ||||
|             if cnt % 500000 == 0 and cnt > 0: | ||||
|                 logger.info("processed {} lines of WikiData JSON dump".format(cnt)) | ||||
|             clean_line = line.strip() | ||||
|             if clean_line.endswith(b","): | ||||
|                 clean_line = clean_line[:-1] | ||||
|  | @ -43,13 +55,11 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang= | |||
|                 entry_type = obj["type"] | ||||
| 
 | ||||
|                 if entry_type == "item": | ||||
|                     # filtering records on their properties (currently disabled to get ALL data) | ||||
|                     # keep = False | ||||
|                     keep = True | ||||
| 
 | ||||
|                     claims = obj["claims"] | ||||
|                     if parse_claims: | ||||
|                         for prop, value_set in prop_filter.items(): | ||||
|                         for prop, value_set in neg_prop_filter.items(): | ||||
|                             claim_property = claims.get(prop, None) | ||||
|                             if claim_property: | ||||
|                                 for cp in claim_property: | ||||
|  | @ -61,7 +71,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang= | |||
|                                     ) | ||||
|                                     cp_rank = cp["rank"] | ||||
|                                     if cp_rank != "deprecated" and cp_id in value_set: | ||||
|                                         keep = True | ||||
|                                         keep = False | ||||
| 
 | ||||
|                     if keep: | ||||
|                         unique_id = obj["id"] | ||||
|  | @ -108,7 +118,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang= | |||
|                                             "label (" + lang + "):", lang_label["value"] | ||||
|                                         ) | ||||
| 
 | ||||
|                         if found_link and parse_descriptions: | ||||
|                         if found_link and parse_descr: | ||||
|                             descriptions = obj["descriptions"] | ||||
|                             if descriptions: | ||||
|                                 lang_descr = descriptions.get(lang, None) | ||||
|  | @ -130,22 +140,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang= | |||
|                                             print( | ||||
|                                                 "alias (" + lang + "):", item["value"] | ||||
|                                             ) | ||||
|                                         alias_list = id_to_alias.get(unique_id, []) | ||||
|                                         alias_list.append(item["value"]) | ||||
|                                         id_to_alias[unique_id] = alias_list | ||||
| 
 | ||||
|                         if to_print: | ||||
|                             print() | ||||
| 
 | ||||
|     return title_to_id, id_to_descr | ||||
|     # log final number of lines processed | ||||
|     logger.info("Finished. Processed {} lines of WikiData JSON dump".format(cnt)) | ||||
|     return title_to_id, id_to_descr, id_to_alias | ||||
| 
 | ||||
| 
 | ||||
| def write_entity_files(entity_def_output, title_to_id): | ||||
|     with entity_def_output.open("w", encoding="utf8") as id_file: | ||||
|         id_file.write("WP_title" + "|" + "WD_id" + "\n") | ||||
|         for title, qid in title_to_id.items(): | ||||
|             id_file.write(title + "|" + str(qid) + "\n") | ||||
| 
 | ||||
| 
 | ||||
| def write_entity_description_files(entity_descr_output, id_to_descr): | ||||
|     with entity_descr_output.open("w", encoding="utf8") as descr_file: | ||||
|         descr_file.write("WD_id" + "|" + "description" + "\n") | ||||
|         for qid, descr in id_to_descr.items(): | ||||
|             descr_file.write(str(qid) + "|" + descr + "\n") | ||||
|  |  | |||
|  | @ -6,19 +6,19 @@ as created by the script `wikidata_create_kb`. | |||
| 
 | ||||
| For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2 | ||||
| from https://dumps.wikimedia.org/enwiki/latest/ | ||||
| 
 | ||||
| """ | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import random | ||||
| import logging | ||||
| import spacy | ||||
| from pathlib import Path | ||||
| import plac | ||||
| 
 | ||||
| from bin.wiki_entity_linking import training_set_creator | ||||
| from bin.wiki_entity_linking import wikipedia_processor | ||||
| from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR | ||||
| from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines | ||||
| from bin.wiki_entity_linking.kb_creator import read_nlp_kb | ||||
| from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance | ||||
| from bin.wiki_entity_linking.kb_creator import read_kb | ||||
| 
 | ||||
| from spacy.util import minibatch, compounding | ||||
| 
 | ||||
|  | @ -35,6 +35,7 @@ logger = logging.getLogger(__name__) | |||
|     l2=("L2 regularization", "option", "r", float), | ||||
|     train_inst=("# training instances (default 90% of all)", "option", "t", int), | ||||
|     dev_inst=("# test instances (default 10% of all)", "option", "d", int), | ||||
|     labels_discard=("NER labels to discard (default None)", "option", "l", str), | ||||
| ) | ||||
| def main( | ||||
|     dir_kb, | ||||
|  | @ -46,13 +47,14 @@ def main( | |||
|     l2=1e-6, | ||||
|     train_inst=None, | ||||
|     dev_inst=None, | ||||
|     labels_discard=None | ||||
| ): | ||||
|     logger.info("Creating Entity Linker with Wikipedia and WikiData") | ||||
| 
 | ||||
|     output_dir = Path(output_dir) if output_dir else dir_kb | ||||
|     training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE | ||||
|     training_path = loc_training if loc_training else dir_kb / TRAINING_DATA_FILE | ||||
|     nlp_dir = dir_kb / KB_MODEL_DIR | ||||
|     kb_path = output_dir / KB_FILE | ||||
|     kb_path = dir_kb / KB_FILE | ||||
|     nlp_output_dir = output_dir / OUTPUT_MODEL_DIR | ||||
| 
 | ||||
|     # STEP 0: set up IO | ||||
|  | @ -60,38 +62,47 @@ def main( | |||
|         output_dir.mkdir() | ||||
| 
 | ||||
|     # STEP 1 : load the NLP object | ||||
|     logger.info("STEP 1: loading model from {}".format(nlp_dir)) | ||||
|     nlp, kb = read_nlp_kb(nlp_dir, kb_path) | ||||
|     logger.info("STEP 1a: Loading model from {}".format(nlp_dir)) | ||||
|     nlp = spacy.load(nlp_dir) | ||||
|     logger.info("STEP 1b: Loading KB from {}".format(kb_path)) | ||||
|     kb = read_kb(nlp, kb_path) | ||||
| 
 | ||||
|     # check that there is a NER component in the pipeline | ||||
|     if "ner" not in nlp.pipe_names: | ||||
|         raise ValueError("The `nlp` object should have a pretrained `ner` component.") | ||||
| 
 | ||||
|     # STEP 2: create a training dataset from WP | ||||
|     logger.info("STEP 2: reading training dataset from {}".format(training_path)) | ||||
|     # STEP 2: read the training dataset previously created from WP | ||||
|     logger.info("STEP 2: Reading training dataset from {}".format(training_path)) | ||||
| 
 | ||||
|     train_data = training_set_creator.read_training( | ||||
|     if labels_discard: | ||||
|         labels_discard = [x.strip() for x in labels_discard.split(",")] | ||||
|         logger.info("Discarding {} NER types: {}".format(len(labels_discard), labels_discard)) | ||||
| 
 | ||||
|     train_data = wikipedia_processor.read_training( | ||||
|         nlp=nlp, | ||||
|         entity_file_path=training_path, | ||||
|         dev=False, | ||||
|         limit=train_inst, | ||||
|         kb=kb, | ||||
|         labels_discard=labels_discard | ||||
|     ) | ||||
| 
 | ||||
|     # for testing, get all pos instances, whether or not they are in the kb | ||||
|     dev_data = training_set_creator.read_training( | ||||
|     # for testing, get all pos instances (independently of KB) | ||||
|     dev_data = wikipedia_processor.read_training( | ||||
|         nlp=nlp, | ||||
|         entity_file_path=training_path, | ||||
|         dev=True, | ||||
|         limit=dev_inst, | ||||
|         kb=kb, | ||||
|         kb=None, | ||||
|         labels_discard=labels_discard | ||||
|     ) | ||||
| 
 | ||||
|     # STEP 3: create and train the entity linking pipe | ||||
|     logger.info("STEP 3: training Entity Linking pipe") | ||||
|     # STEP 3: create and train an entity linking pipe | ||||
|     logger.info("STEP 3: Creating and training an Entity Linking pipe") | ||||
| 
 | ||||
|     el_pipe = nlp.create_pipe( | ||||
|         name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name} | ||||
|         name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name, | ||||
|                                       "labels_discard": labels_discard} | ||||
|     ) | ||||
|     el_pipe.set_kb(kb) | ||||
|     nlp.add_pipe(el_pipe, last=True) | ||||
|  | @ -105,14 +116,9 @@ def main( | |||
|     logger.info("Training on {} articles".format(len(train_data))) | ||||
|     logger.info("Dev testing on {} articles".format(len(dev_data))) | ||||
| 
 | ||||
|     dev_baseline_accuracies = measure_baselines( | ||||
|         dev_data, kb | ||||
|     ) | ||||
| 
 | ||||
|     # baseline performance on dev data | ||||
|     logger.info("Dev Baseline Accuracies:") | ||||
|     logger.info(dev_baseline_accuracies.report_accuracy("random")) | ||||
|     logger.info(dev_baseline_accuracies.report_accuracy("prior")) | ||||
|     logger.info(dev_baseline_accuracies.report_accuracy("oracle")) | ||||
|     measure_performance(dev_data, kb, el_pipe, baseline=True, context=False) | ||||
| 
 | ||||
|     for itn in range(epochs): | ||||
|         random.shuffle(train_data) | ||||
|  | @ -136,18 +142,18 @@ def main( | |||
|                     logger.error("Error updating batch:" + str(e)) | ||||
|         if batchnr > 0: | ||||
|             logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2))) | ||||
|         measure_performance(dev_data, kb, el_pipe) | ||||
|             measure_performance(dev_data, kb, el_pipe, baseline=False, context=True) | ||||
| 
 | ||||
|     # STEP 4: measure the performance of our trained pipe on an independent dev set | ||||
|     logger.info("STEP 4: performance measurement of Entity Linking pipe") | ||||
|     logger.info("STEP 4: Final performance measurement of Entity Linking pipe") | ||||
|     measure_performance(dev_data, kb, el_pipe) | ||||
| 
 | ||||
|     # STEP 5: apply the EL pipe on a toy example | ||||
|     logger.info("STEP 5: applying Entity Linking to toy example") | ||||
|     logger.info("STEP 5: Applying Entity Linking to toy example") | ||||
|     run_el_toy_example(nlp=nlp) | ||||
| 
 | ||||
|     if output_dir: | ||||
|         # STEP 6: write the NLP pipeline (including entity linker) to file | ||||
|         # STEP 6: write the NLP pipeline (now including an EL model) to file | ||||
|         logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir)) | ||||
|         nlp.to_disk(nlp_output_dir) | ||||
| 
 | ||||
|  |  | |||
|  | @ -3,147 +3,104 @@ from __future__ import unicode_literals | |||
| 
 | ||||
| import re | ||||
| import bz2 | ||||
| import csv | ||||
| import datetime | ||||
| import logging | ||||
| import random | ||||
| import json | ||||
| 
 | ||||
| from bin.wiki_entity_linking import LOG_FORMAT | ||||
| from functools import partial | ||||
| 
 | ||||
| from spacy.gold import GoldParse | ||||
| from bin.wiki_entity_linking import wiki_io as io | ||||
| from bin.wiki_entity_linking.wiki_namespaces import ( | ||||
|     WP_META_NAMESPACE, | ||||
|     WP_FILE_NAMESPACE, | ||||
|     WP_CATEGORY_NAMESPACE, | ||||
| ) | ||||
| 
 | ||||
| """ | ||||
| Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions. | ||||
| Write these results to file for downstream KB and training data generation. | ||||
| 
 | ||||
| Process Wikipedia interlinks to generate a training dataset for the EL algorithm. | ||||
| """ | ||||
| 
 | ||||
| ENTITY_FILE = "gold_entities.csv" | ||||
| 
 | ||||
| map_alias_to_link = dict() | ||||
| 
 | ||||
| logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| # these will/should be matched ignoring case | ||||
| wiki_namespaces = [ | ||||
|     "b", | ||||
|     "betawikiversity", | ||||
|     "Book", | ||||
|     "c", | ||||
|     "Category", | ||||
|     "Commons", | ||||
|     "d", | ||||
|     "dbdump", | ||||
|     "download", | ||||
|     "Draft", | ||||
|     "Education", | ||||
|     "Foundation", | ||||
|     "Gadget", | ||||
|     "Gadget definition", | ||||
|     "gerrit", | ||||
|     "File", | ||||
|     "Help", | ||||
|     "Image", | ||||
|     "Incubator", | ||||
|     "m", | ||||
|     "mail", | ||||
|     "mailarchive", | ||||
|     "media", | ||||
|     "MediaWiki", | ||||
|     "MediaWiki talk", | ||||
|     "Mediawikiwiki", | ||||
|     "MediaZilla", | ||||
|     "Meta", | ||||
|     "Metawikipedia", | ||||
|     "Module", | ||||
|     "mw", | ||||
|     "n", | ||||
|     "nost", | ||||
|     "oldwikisource", | ||||
|     "outreach", | ||||
|     "outreachwiki", | ||||
|     "otrs", | ||||
|     "OTRSwiki", | ||||
|     "Portal", | ||||
|     "phab", | ||||
|     "Phabricator", | ||||
|     "Project", | ||||
|     "q", | ||||
|     "quality", | ||||
|     "rev", | ||||
|     "s", | ||||
|     "spcom", | ||||
|     "Special", | ||||
|     "species", | ||||
|     "Strategy", | ||||
|     "sulutil", | ||||
|     "svn", | ||||
|     "Talk", | ||||
|     "Template", | ||||
|     "Template talk", | ||||
|     "Testwiki", | ||||
|     "ticket", | ||||
|     "TimedText", | ||||
|     "Toollabs", | ||||
|     "tools", | ||||
|     "tswiki", | ||||
|     "User", | ||||
|     "User talk", | ||||
|     "v", | ||||
|     "voy", | ||||
|     "w", | ||||
|     "Wikibooks", | ||||
|     "Wikidata", | ||||
|     "wikiHow", | ||||
|     "Wikinvest", | ||||
|     "wikilivres", | ||||
|     "Wikimedia", | ||||
|     "Wikinews", | ||||
|     "Wikipedia", | ||||
|     "Wikipedia talk", | ||||
|     "Wikiquote", | ||||
|     "Wikisource", | ||||
|     "Wikispecies", | ||||
|     "Wikitech", | ||||
|     "Wikiversity", | ||||
|     "Wikivoyage", | ||||
|     "wikt", | ||||
|     "wiktionary", | ||||
|     "wmf", | ||||
|     "wmania", | ||||
|     "WP", | ||||
| ] | ||||
| title_regex = re.compile(r"(?<=<title>).*(?=</title>)") | ||||
| id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)") | ||||
| text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)") | ||||
| info_regex = re.compile(r"{[^{]*?}") | ||||
| html_regex = re.compile(r"<!--[^-]*-->") | ||||
| ref_regex = re.compile(r"<ref.*?>")  # non-greedy | ||||
| ref_2_regex = re.compile(r"</ref.*?>")  # non-greedy | ||||
| 
 | ||||
| # find the links | ||||
| link_regex = re.compile(r"\[\[[^\[\]]*\]\]") | ||||
| 
 | ||||
| # match on interwiki links, e.g. `en:` or `:fr:` | ||||
| ns_regex = r":?" + "[a-z][a-z]" + ":" | ||||
| 
 | ||||
| # match on Namespace: optionally preceded by a : | ||||
| for ns in wiki_namespaces: | ||||
| for ns in WP_META_NAMESPACE: | ||||
|     ns_regex += "|" + ":?" + ns + ":" | ||||
| 
 | ||||
| ns_regex = re.compile(ns_regex, re.IGNORECASE) | ||||
| 
 | ||||
| files = r"" | ||||
| for f in WP_FILE_NAMESPACE: | ||||
|     files += "\[\[" + f + ":[^[\]]+]]" + "|" | ||||
| files = files[0 : len(files) - 1] | ||||
| file_regex = re.compile(files) | ||||
| 
 | ||||
| cats = r"" | ||||
| for c in WP_CATEGORY_NAMESPACE: | ||||
|     cats += "\[\[" + c + ":[^\[]*]]" + "|" | ||||
| cats = cats[0 : len(cats) - 1] | ||||
| category_regex = re.compile(cats) | ||||
| 
 | ||||
| 
 | ||||
| def read_prior_probs(wikipedia_input, prior_prob_output, limit=None): | ||||
|     """ | ||||
|     Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities. | ||||
|     The full file takes about 2h to parse 1100M lines. | ||||
|     It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from. | ||||
|     The full file takes about 2-3h to parse 1100M lines. | ||||
|     It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from, | ||||
|     though dev test articles are excluded in order not to get an artificially strong baseline. | ||||
|     """ | ||||
|     cnt = 0 | ||||
|     read_id = False | ||||
|     current_article_id = None | ||||
|     with bz2.open(wikipedia_input, mode="rb") as file: | ||||
|         line = file.readline() | ||||
|         cnt = 0 | ||||
|         while line and (not limit or cnt < limit): | ||||
|             if cnt % 25000000 == 0: | ||||
|             if cnt % 25000000 == 0 and cnt > 0: | ||||
|                 logger.info("processed {} lines of Wikipedia XML dump".format(cnt)) | ||||
|             clean_line = line.strip().decode("utf-8") | ||||
| 
 | ||||
|             aliases, entities, normalizations = get_wp_links(clean_line) | ||||
|             for alias, entity, norm in zip(aliases, entities, normalizations): | ||||
|                 _store_alias(alias, entity, normalize_alias=norm, normalize_entity=True) | ||||
|                 _store_alias(alias, entity, normalize_alias=norm, normalize_entity=True) | ||||
|             # we attempt at reading the article's ID (but not the revision or contributor ID) | ||||
|             if "<revision>" in clean_line or "<contributor>" in clean_line: | ||||
|                 read_id = False | ||||
|             if "<page>" in clean_line: | ||||
|                 read_id = True | ||||
| 
 | ||||
|             if read_id: | ||||
|                 ids = id_regex.search(clean_line) | ||||
|                 if ids: | ||||
|                     current_article_id = ids[0] | ||||
| 
 | ||||
|             # only processing prior probabilities from true training (non-dev) articles | ||||
|             if not is_dev(current_article_id): | ||||
|                 aliases, entities, normalizations = get_wp_links(clean_line) | ||||
|                 for alias, entity, norm in zip(aliases, entities, normalizations): | ||||
|                     _store_alias( | ||||
|                         alias, entity, normalize_alias=norm, normalize_entity=True | ||||
|                     ) | ||||
| 
 | ||||
|             line = file.readline() | ||||
|             cnt += 1 | ||||
|         logger.info("processed {} lines of Wikipedia XML dump".format(cnt)) | ||||
|     logger.info("Finished. processed {} lines of Wikipedia XML dump".format(cnt)) | ||||
| 
 | ||||
|     # write all aliases and their entities and count occurrences to file | ||||
|     with prior_prob_output.open("w", encoding="utf8") as outputfile: | ||||
|  | @ -182,7 +139,7 @@ def get_wp_links(text): | |||
|         match = match[2:][:-2].replace("_", " ").strip() | ||||
| 
 | ||||
|         if ns_regex.match(match): | ||||
|             pass  # ignore namespaces at the beginning of the string | ||||
|             pass  # ignore the entity if it points to a "meta" page | ||||
| 
 | ||||
|         # this is a simple [[link]], with the alias the same as the mention | ||||
|         elif "|" not in match: | ||||
|  | @ -218,47 +175,382 @@ def _capitalize_first(text): | |||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| def write_entity_counts(prior_prob_input, count_output, to_print=False): | ||||
|     # Write entity counts for quick access later | ||||
|     entity_to_count = dict() | ||||
|     total_count = 0 | ||||
| 
 | ||||
|     with prior_prob_input.open("r", encoding="utf8") as prior_file: | ||||
|         # skip header | ||||
|         prior_file.readline() | ||||
|         line = prior_file.readline() | ||||
| 
 | ||||
|         while line: | ||||
|             splits = line.replace("\n", "").split(sep="|") | ||||
|             # alias = splits[0] | ||||
|             count = int(splits[1]) | ||||
|             entity = splits[2] | ||||
| 
 | ||||
|             current_count = entity_to_count.get(entity, 0) | ||||
|             entity_to_count[entity] = current_count + count | ||||
| 
 | ||||
|             total_count += count | ||||
| 
 | ||||
|             line = prior_file.readline() | ||||
| 
 | ||||
|     with count_output.open("w", encoding="utf8") as entity_file: | ||||
|         entity_file.write("entity" + "|" + "count" + "\n") | ||||
|         for entity, count in entity_to_count.items(): | ||||
|             entity_file.write(entity + "|" + str(count) + "\n") | ||||
| 
 | ||||
|     if to_print: | ||||
|         for entity, count in entity_to_count.items(): | ||||
|             print("Entity count:", entity, count) | ||||
|         print("Total count:", total_count) | ||||
| def create_training_and_desc( | ||||
|     wp_input, def_input, desc_output, training_output, parse_desc, limit=None | ||||
| ): | ||||
|     wp_to_id = io.read_title_to_id(def_input) | ||||
|     _process_wikipedia_texts( | ||||
|         wp_input, wp_to_id, desc_output, training_output, parse_desc, limit | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def get_all_frequencies(count_input): | ||||
|     entity_to_count = dict() | ||||
|     with count_input.open("r", encoding="utf8") as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|         for row in csvreader: | ||||
|             entity_to_count[row[0]] = int(row[1]) | ||||
| def _process_wikipedia_texts( | ||||
|     wikipedia_input, wp_to_id, output, training_output, parse_descriptions, limit=None | ||||
| ): | ||||
|     """ | ||||
|     Read the XML wikipedia data to parse out training data: | ||||
|     raw text data + positive instances | ||||
|     """ | ||||
| 
 | ||||
|     return entity_to_count | ||||
|     read_ids = set() | ||||
| 
 | ||||
|     with output.open("a", encoding="utf8") as descr_file, training_output.open( | ||||
|         "w", encoding="utf8" | ||||
|     ) as entity_file: | ||||
|         if parse_descriptions: | ||||
|             _write_training_description(descr_file, "WD_id", "description") | ||||
|         with bz2.open(wikipedia_input, mode="rb") as file: | ||||
|             article_count = 0 | ||||
|             article_text = "" | ||||
|             article_title = None | ||||
|             article_id = None | ||||
|             reading_text = False | ||||
|             reading_revision = False | ||||
| 
 | ||||
|             for line in file: | ||||
|                 clean_line = line.strip().decode("utf-8") | ||||
| 
 | ||||
|                 if clean_line == "<revision>": | ||||
|                     reading_revision = True | ||||
|                 elif clean_line == "</revision>": | ||||
|                     reading_revision = False | ||||
| 
 | ||||
|                 # Start reading new page | ||||
|                 if clean_line == "<page>": | ||||
|                     article_text = "" | ||||
|                     article_title = None | ||||
|                     article_id = None | ||||
|                 # finished reading this page | ||||
|                 elif clean_line == "</page>": | ||||
|                     if article_id: | ||||
|                         clean_text, entities = _process_wp_text( | ||||
|                             article_title, article_text, wp_to_id | ||||
|                         ) | ||||
|                         if clean_text is not None and entities is not None: | ||||
|                             _write_training_entities( | ||||
|                                 entity_file, article_id, clean_text, entities | ||||
|                             ) | ||||
| 
 | ||||
|                             if article_title in wp_to_id and parse_descriptions: | ||||
|                                 description = " ".join( | ||||
|                                     clean_text[:1000].split(" ")[:-1] | ||||
|                                 ) | ||||
|                                 _write_training_description( | ||||
|                                     descr_file, wp_to_id[article_title], description | ||||
|                                 ) | ||||
|                             article_count += 1 | ||||
|                             if article_count % 10000 == 0 and article_count > 0: | ||||
|                                 logger.info( | ||||
|                                     "Processed {} articles".format(article_count) | ||||
|                                 ) | ||||
|                             if limit and article_count >= limit: | ||||
|                                 break | ||||
|                     article_text = "" | ||||
|                     article_title = None | ||||
|                     article_id = None | ||||
|                     reading_text = False | ||||
|                     reading_revision = False | ||||
| 
 | ||||
|                 # start reading text within a page | ||||
|                 if "<text" in clean_line: | ||||
|                     reading_text = True | ||||
| 
 | ||||
|                 if reading_text: | ||||
|                     article_text += " " + clean_line | ||||
| 
 | ||||
|                 # stop reading text within a page (we assume a new page doesn't start on the same line) | ||||
|                 if "</text" in clean_line: | ||||
|                     reading_text = False | ||||
| 
 | ||||
|                 # read the ID of this article (outside the revision portion of the document) | ||||
|                 if not reading_revision: | ||||
|                     ids = id_regex.search(clean_line) | ||||
|                     if ids: | ||||
|                         article_id = ids[0] | ||||
|                         if article_id in read_ids: | ||||
|                             logger.info( | ||||
|                                 "Found duplicate article ID", article_id, clean_line | ||||
|                             )  # This should never happen ... | ||||
|                         read_ids.add(article_id) | ||||
| 
 | ||||
|                 # read the title of this article (outside the revision portion of the document) | ||||
|                 if not reading_revision: | ||||
|                     titles = title_regex.search(clean_line) | ||||
|                     if titles: | ||||
|                         article_title = titles[0].strip() | ||||
|     logger.info("Finished. Processed {} articles".format(article_count)) | ||||
| 
 | ||||
| 
 | ||||
| def _process_wp_text(article_title, article_text, wp_to_id): | ||||
|     # ignore meta Wikipedia pages | ||||
|     if ns_regex.match(article_title): | ||||
|         return None, None | ||||
| 
 | ||||
|     # remove the text tags | ||||
|     text_search = text_regex.search(article_text) | ||||
|     if text_search is None: | ||||
|         return None, None | ||||
|     text = text_search.group(0) | ||||
| 
 | ||||
|     # stop processing if this is a redirect page | ||||
|     if text.startswith("#REDIRECT"): | ||||
|         return None, None | ||||
| 
 | ||||
|     # get the raw text without markup etc, keeping only interwiki links | ||||
|     clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id) | ||||
|     return clean_text, entities | ||||
| 
 | ||||
| 
 | ||||
| def _get_clean_wp_text(article_text): | ||||
|     clean_text = article_text.strip() | ||||
| 
 | ||||
|     # remove bolding & italic markup | ||||
|     clean_text = clean_text.replace("'''", "") | ||||
|     clean_text = clean_text.replace("''", "") | ||||
| 
 | ||||
|     # remove nested {{info}} statements by removing the inner/smallest ones first and iterating | ||||
|     try_again = True | ||||
|     previous_length = len(clean_text) | ||||
|     while try_again: | ||||
|         clean_text = info_regex.sub( | ||||
|             "", clean_text | ||||
|         )  # non-greedy match excluding a nested { | ||||
|         if len(clean_text) < previous_length: | ||||
|             try_again = True | ||||
|         else: | ||||
|             try_again = False | ||||
|         previous_length = len(clean_text) | ||||
| 
 | ||||
|     # remove HTML comments | ||||
|     clean_text = html_regex.sub("", clean_text) | ||||
| 
 | ||||
|     # remove Category and File statements | ||||
|     clean_text = category_regex.sub("", clean_text) | ||||
|     clean_text = file_regex.sub("", clean_text) | ||||
| 
 | ||||
|     # remove multiple = | ||||
|     while "==" in clean_text: | ||||
|         clean_text = clean_text.replace("==", "=") | ||||
| 
 | ||||
|     clean_text = clean_text.replace(". =", ".") | ||||
|     clean_text = clean_text.replace(" = ", ". ") | ||||
|     clean_text = clean_text.replace("= ", ".") | ||||
|     clean_text = clean_text.replace(" =", "") | ||||
| 
 | ||||
|     # remove refs (non-greedy match) | ||||
|     clean_text = ref_regex.sub("", clean_text) | ||||
|     clean_text = ref_2_regex.sub("", clean_text) | ||||
| 
 | ||||
|     # remove additional wikiformatting | ||||
|     clean_text = re.sub(r"<blockquote>", "", clean_text) | ||||
|     clean_text = re.sub(r"</blockquote>", "", clean_text) | ||||
| 
 | ||||
|     # change special characters back to normal ones | ||||
|     clean_text = clean_text.replace(r"<", "<") | ||||
|     clean_text = clean_text.replace(r">", ">") | ||||
|     clean_text = clean_text.replace(r""", '"') | ||||
|     clean_text = clean_text.replace(r"&nbsp;", " ") | ||||
|     clean_text = clean_text.replace(r"&", "&") | ||||
| 
 | ||||
|     # remove multiple spaces | ||||
|     while "  " in clean_text: | ||||
|         clean_text = clean_text.replace("  ", " ") | ||||
| 
 | ||||
|     return clean_text.strip() | ||||
| 
 | ||||
| 
 | ||||
| def _remove_links(clean_text, wp_to_id): | ||||
|     # read the text char by char to get the right offsets for the interwiki links | ||||
|     entities = [] | ||||
|     final_text = "" | ||||
|     open_read = 0 | ||||
|     reading_text = True | ||||
|     reading_entity = False | ||||
|     reading_mention = False | ||||
|     reading_special_case = False | ||||
|     entity_buffer = "" | ||||
|     mention_buffer = "" | ||||
|     for index, letter in enumerate(clean_text): | ||||
|         if letter == "[": | ||||
|             open_read += 1 | ||||
|         elif letter == "]": | ||||
|             open_read -= 1 | ||||
|         elif letter == "|": | ||||
|             if reading_text: | ||||
|                 final_text += letter | ||||
|             # switch from reading entity to mention in the [[entity|mention]] pattern | ||||
|             elif reading_entity: | ||||
|                 reading_text = False | ||||
|                 reading_entity = False | ||||
|                 reading_mention = True | ||||
|             else: | ||||
|                 reading_special_case = True | ||||
|         else: | ||||
|             if reading_entity: | ||||
|                 entity_buffer += letter | ||||
|             elif reading_mention: | ||||
|                 mention_buffer += letter | ||||
|             elif reading_text: | ||||
|                 final_text += letter | ||||
|             else: | ||||
|                 raise ValueError("Not sure at point", clean_text[index - 2 : index + 2]) | ||||
| 
 | ||||
|         if open_read > 2: | ||||
|             reading_special_case = True | ||||
| 
 | ||||
|         if open_read == 2 and reading_text: | ||||
|             reading_text = False | ||||
|             reading_entity = True | ||||
|             reading_mention = False | ||||
| 
 | ||||
|         # we just finished reading an entity | ||||
|         if open_read == 0 and not reading_text: | ||||
|             if "#" in entity_buffer or entity_buffer.startswith(":"): | ||||
|                 reading_special_case = True | ||||
|             # Ignore cases with nested structures like File: handles etc | ||||
|             if not reading_special_case: | ||||
|                 if not mention_buffer: | ||||
|                     mention_buffer = entity_buffer | ||||
|                 start = len(final_text) | ||||
|                 end = start + len(mention_buffer) | ||||
|                 qid = wp_to_id.get(entity_buffer, None) | ||||
|                 if qid: | ||||
|                     entities.append((mention_buffer, qid, start, end)) | ||||
|                 final_text += mention_buffer | ||||
| 
 | ||||
|             entity_buffer = "" | ||||
|             mention_buffer = "" | ||||
| 
 | ||||
|             reading_text = True | ||||
|             reading_entity = False | ||||
|             reading_mention = False | ||||
|             reading_special_case = False | ||||
|     return final_text, entities | ||||
| 
 | ||||
| 
 | ||||
| def _write_training_description(outputfile, qid, description): | ||||
|     if description is not None: | ||||
|         line = str(qid) + "|" + description + "\n" | ||||
|         outputfile.write(line) | ||||
| 
 | ||||
| 
 | ||||
| def _write_training_entities(outputfile, article_id, clean_text, entities): | ||||
|     entities_data = [ | ||||
|         {"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} | ||||
|         for ent in entities | ||||
|     ] | ||||
|     line = ( | ||||
|         json.dumps( | ||||
|             { | ||||
|                 "article_id": article_id, | ||||
|                 "clean_text": clean_text, | ||||
|                 "entities": entities_data, | ||||
|             }, | ||||
|             ensure_ascii=False, | ||||
|         ) | ||||
|         + "\n" | ||||
|     ) | ||||
|     outputfile.write(line) | ||||
| 
 | ||||
| 
 | ||||
| def read_training(nlp, entity_file_path, dev, limit, kb, labels_discard=None): | ||||
|     """ This method provides training examples that correspond to the entity annotations found by the nlp object. | ||||
|      For training, it will include both positive and negative examples by using the candidate generator from the kb. | ||||
|      For testing (kb=None), it will include all positive examples only.""" | ||||
| 
 | ||||
|     from tqdm import tqdm | ||||
| 
 | ||||
|     if not labels_discard: | ||||
|         labels_discard = [] | ||||
| 
 | ||||
|     data = [] | ||||
|     num_entities = 0 | ||||
|     get_gold_parse = partial( | ||||
|         _get_gold_parse, dev=dev, kb=kb, labels_discard=labels_discard | ||||
|     ) | ||||
| 
 | ||||
|     logger.info( | ||||
|         "Reading {} data with limit {}".format("dev" if dev else "train", limit) | ||||
|     ) | ||||
|     with entity_file_path.open("r", encoding="utf8") as file: | ||||
|         with tqdm(total=limit, leave=False) as pbar: | ||||
|             for i, line in enumerate(file): | ||||
|                 example = json.loads(line) | ||||
|                 article_id = example["article_id"] | ||||
|                 clean_text = example["clean_text"] | ||||
|                 entities = example["entities"] | ||||
| 
 | ||||
|                 if dev != is_dev(article_id) or not is_valid_article(clean_text): | ||||
|                     continue | ||||
| 
 | ||||
|                 doc = nlp(clean_text) | ||||
|                 gold = get_gold_parse(doc, entities) | ||||
|                 if gold and len(gold.links) > 0: | ||||
|                     data.append((doc, gold)) | ||||
|                     num_entities += len(gold.links) | ||||
|                     pbar.update(len(gold.links)) | ||||
|                 if limit and num_entities >= limit: | ||||
|                     break | ||||
|     logger.info("Read {} entities in {} articles".format(num_entities, len(data))) | ||||
|     return data | ||||
| 
 | ||||
| 
 | ||||
| def _get_gold_parse(doc, entities, dev, kb, labels_discard): | ||||
|     gold_entities = {} | ||||
|     tagged_ent_positions = { | ||||
|         (ent.start_char, ent.end_char): ent | ||||
|         for ent in doc.ents | ||||
|         if ent.label_ not in labels_discard | ||||
|     } | ||||
| 
 | ||||
|     for entity in entities: | ||||
|         entity_id = entity["entity"] | ||||
|         alias = entity["alias"] | ||||
|         start = entity["start"] | ||||
|         end = entity["end"] | ||||
| 
 | ||||
|         candidate_ids = [] | ||||
|         if kb and not dev: | ||||
|             candidates = kb.get_candidates(alias) | ||||
|             candidate_ids = [cand.entity_ for cand in candidates] | ||||
| 
 | ||||
|         tagged_ent = tagged_ent_positions.get((start, end), None) | ||||
|         if tagged_ent: | ||||
|             # TODO: check that alias == doc.text[start:end] | ||||
|             should_add_ent = (dev or entity_id in candidate_ids) and is_valid_sentence( | ||||
|                 tagged_ent.sent.text | ||||
|             ) | ||||
| 
 | ||||
|             if should_add_ent: | ||||
|                 value_by_id = {entity_id: 1.0} | ||||
|                 if not dev: | ||||
|                     random.shuffle(candidate_ids) | ||||
|                     value_by_id.update( | ||||
|                         {kb_id: 0.0 for kb_id in candidate_ids if kb_id != entity_id} | ||||
|                     ) | ||||
|                 gold_entities[(start, end)] = value_by_id | ||||
| 
 | ||||
|     return GoldParse(doc, links=gold_entities) | ||||
| 
 | ||||
| 
 | ||||
| def is_dev(article_id): | ||||
|     if not article_id: | ||||
|         return False | ||||
|     return article_id.endswith("3") | ||||
| 
 | ||||
| 
 | ||||
| def is_valid_article(doc_text): | ||||
|     # custom length cut-off | ||||
|     return 10 < len(doc_text) < 30000 | ||||
| 
 | ||||
| 
 | ||||
| def is_valid_sentence(sent_text): | ||||
|     if not 10 < len(sent_text) < 3000: | ||||
|         # custom length cut-off | ||||
|         return False | ||||
| 
 | ||||
|     if sent_text.strip().startswith("*") or sent_text.strip().startswith("#"): | ||||
|         # remove 'enumeration' sentences (occurs often on Wikipedia) | ||||
|         return False | ||||
| 
 | ||||
|     return True | ||||
|  |  | |||
|  | @ -7,7 +7,7 @@ dependency tree to find the noun phrase they are referring to – for example: | |||
| $9.4 million --> Net income. | ||||
| 
 | ||||
| Compatible with: spaCy v2.0.0+ | ||||
| Last tested with: v2.1.0 | ||||
| Last tested with: v2.2.1 | ||||
| """ | ||||
| from __future__ import unicode_literals, print_function | ||||
| 
 | ||||
|  | @ -38,14 +38,17 @@ def main(model="en_core_web_sm"): | |||
| 
 | ||||
| def filter_spans(spans): | ||||
|     # Filter a sequence of spans so they don't contain overlaps | ||||
|     get_sort_key = lambda span: (span.end - span.start, span.start) | ||||
|     # For spaCy 2.1.4+: this function is available as spacy.util.filter_spans() | ||||
|     get_sort_key = lambda span: (span.end - span.start, -span.start) | ||||
|     sorted_spans = sorted(spans, key=get_sort_key, reverse=True) | ||||
|     result = [] | ||||
|     seen_tokens = set() | ||||
|     for span in sorted_spans: | ||||
|         # Check for end - 1 here because boundaries are inclusive | ||||
|         if span.start not in seen_tokens and span.end - 1 not in seen_tokens: | ||||
|             result.append(span) | ||||
|             seen_tokens.update(range(span.start, span.end)) | ||||
|         seen_tokens.update(range(span.start, span.end)) | ||||
|     result = sorted(result, key=lambda span: span.start) | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -91,8 +91,8 @@ def demo(shape): | |||
|     nlp = spacy.load("en_vectors_web_lg") | ||||
|     nlp.add_pipe(KerasSimilarityShim.load(nlp.path / "similarity", nlp, shape[0])) | ||||
| 
 | ||||
|     doc1 = nlp(u"The king of France is bald.") | ||||
|     doc2 = nlp(u"France has no king.") | ||||
|     doc1 = nlp("The king of France is bald.") | ||||
|     doc2 = nlp("France has no king.") | ||||
| 
 | ||||
|     print("Sentence 1:", doc1) | ||||
|     print("Sentence 2:", doc2) | ||||
|  |  | |||
|  | @ -8,7 +8,7 @@ | |||
|             { | ||||
|               "tokens": [ | ||||
|                 { | ||||
|                   "head": 4, | ||||
|                   "head": 44, | ||||
|                   "dep": "prep", | ||||
|                   "tag": "IN", | ||||
|                   "orth": "In", | ||||
|  |  | |||
|  | @ -11,7 +11,7 @@ numpy>=1.15.0 | |||
| requests>=2.13.0,<3.0.0 | ||||
| plac<1.0.0,>=0.9.6 | ||||
| pathlib==1.0.1; python_version < "3.4" | ||||
| importlib_metadata>=0.23; python_version < "3.8" | ||||
| importlib_metadata>=0.20; python_version < "3.8" | ||||
| # Optional dependencies | ||||
| jsonschema>=2.6.0,<3.1.0 | ||||
| # Development dependencies | ||||
|  |  | |||
|  | @ -51,7 +51,7 @@ install_requires = | |||
|     wasabi>=0.2.0,<1.1.0 | ||||
|     srsly>=0.1.0,<1.1.0 | ||||
|     pathlib==1.0.1; python_version < "3.4" | ||||
|     importlib_metadata>=0.23; python_version < "3.8" | ||||
|     importlib_metadata>=0.20; python_version < "3.8" | ||||
| 
 | ||||
| [options.extras_require] | ||||
| lookups = | ||||
|  |  | |||
|  | @ -57,7 +57,8 @@ def convert( | |||
|     is written to stdout, so you can pipe them forward to a JSON file: | ||||
|     $ spacy convert some_file.conllu > some_file.json | ||||
|     """ | ||||
|     msg = Printer() | ||||
|     no_print = (output_dir == "-") | ||||
|     msg = Printer(no_print=no_print) | ||||
|     input_path = Path(input_file) | ||||
|     if file_type not in FILE_TYPES: | ||||
|         msg.fail( | ||||
|  | @ -102,6 +103,7 @@ def convert( | |||
|         use_morphology=morphology, | ||||
|         lang=lang, | ||||
|         model=model, | ||||
|         no_print=no_print, | ||||
|     ) | ||||
|     if output_dir != "-": | ||||
|         # Export data to a file | ||||
|  |  | |||
|  | @ -9,7 +9,7 @@ from ...tokens.doc import Doc | |||
| from ...util import load_model | ||||
| 
 | ||||
| 
 | ||||
| def conll_ner2json(input_data, n_sents=10, seg_sents=False, model=None, **kwargs): | ||||
| def conll_ner2json(input_data, n_sents=10, seg_sents=False, model=None, no_print=False, **kwargs): | ||||
|     """ | ||||
|     Convert files in the CoNLL-2003 NER format and similar | ||||
|     whitespace-separated columns into JSON format for use with train cli. | ||||
|  | @ -34,7 +34,7 @@ def conll_ner2json(input_data, n_sents=10, seg_sents=False, model=None, **kwargs | |||
|     . O | ||||
| 
 | ||||
|     """ | ||||
|     msg = Printer() | ||||
|     msg = Printer(no_print=no_print) | ||||
|     doc_delimiter = "-DOCSTART- -X- O O" | ||||
|     # check for existing delimiters, which should be preserved | ||||
|     if "\n\n" in input_data and seg_sents: | ||||
|  |  | |||
|  | @ -8,7 +8,7 @@ from ...util import minibatch | |||
| from .conll_ner2json import n_sents_info | ||||
| 
 | ||||
| 
 | ||||
| def iob2json(input_data, n_sents=10, *args, **kwargs): | ||||
| def iob2json(input_data, n_sents=10, no_print=False, *args, **kwargs): | ||||
|     """ | ||||
|     Convert IOB files with one sentence per line and tags separated with '|' | ||||
|     into JSON format for use with train cli. IOB and IOB2 are accepted. | ||||
|  | @ -20,7 +20,7 @@ def iob2json(input_data, n_sents=10, *args, **kwargs): | |||
|     I|PRP|O like|VBP|O London|NNP|I-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O | ||||
|     I|PRP|O like|VBP|O London|NNP|B-GPE and|CC|O New|NNP|B-GPE York|NNP|I-GPE City|NNP|I-GPE .|.|O | ||||
|     """ | ||||
|     msg = Printer() | ||||
|     msg = Printer(no_print=no_print) | ||||
|     docs = read_iob(input_data.split("\n")) | ||||
|     if n_sents > 0: | ||||
|         n_sents_info(msg, n_sents) | ||||
|  |  | |||
|  | @ -360,6 +360,16 @@ def debug_data( | |||
|             ) | ||||
|         ) | ||||
| 
 | ||||
|         # check for documents with multiple sentences | ||||
|         sents_per_doc = gold_train_data["n_sents"] / len(gold_train_data["texts"]) | ||||
|         if sents_per_doc < 1.1: | ||||
|             msg.warn( | ||||
|                 "The training data contains {:.2f} sentences per " | ||||
|                 "document. When there are very few documents containing more " | ||||
|                 "than one sentence, the parser will not learn how to segment " | ||||
|                 "longer texts into sentences.".format(sents_per_doc) | ||||
|             ) | ||||
| 
 | ||||
|         # profile labels | ||||
|         labels_train = [label for label in gold_train_data["deps"]] | ||||
|         labels_train_unpreprocessed = [ | ||||
|  |  | |||
|  | @ -246,7 +246,7 @@ def make_update(model, docs, optimizer, drop=0.0, objective="L2"): | |||
|     """Perform an update over a single batch of documents. | ||||
| 
 | ||||
|     docs (iterable): A batch of `Doc` objects. | ||||
|     drop (float): The droput rate. | ||||
|     drop (float): The dropout rate. | ||||
|     optimizer (callable): An optimizer. | ||||
|     RETURNS loss: A float for the loss. | ||||
|     """ | ||||
|  |  | |||
|  | @ -80,8 +80,8 @@ class Warnings(object): | |||
|             "the v2.x models cannot release the global interpreter lock. " | ||||
|             "Future versions may introduce a `n_process` argument for " | ||||
|             "parallel inference via multiprocessing.") | ||||
|     W017 = ("Alias '{alias}' already exists in the Knowledge base.") | ||||
|     W018 = ("Entity '{entity}' already exists in the Knowledge base.") | ||||
|     W017 = ("Alias '{alias}' already exists in the Knowledge Base.") | ||||
|     W018 = ("Entity '{entity}' already exists in the Knowledge Base.") | ||||
|     W019 = ("Changing vectors name from {old} to {new}, to avoid clash with " | ||||
|             "previously loaded vectors. See Issue #3853.") | ||||
|     W020 = ("Unnamed vectors. This won't allow multiple vectors models to be " | ||||
|  | @ -95,7 +95,10 @@ class Warnings(object): | |||
|             "you can ignore this warning by setting SPACY_WARNING_IGNORE=W022. " | ||||
|             "If this is surprising, make sure you have the spacy-lookups-data " | ||||
|             "package installed.") | ||||
|     W023 = ("Multiprocessing of Language.pipe is not supported in Python2. 'n_process' will be set to 1.") | ||||
|     W023 = ("Multiprocessing of Language.pipe is not supported in Python 2. " | ||||
|             "'n_process' will be set to 1.") | ||||
|     W024 = ("Entity '{entity}' - Alias '{alias}' combination already exists in " | ||||
|             "the Knowledge Base.") | ||||
| 
 | ||||
| 
 | ||||
| @add_codes | ||||
|  | @ -408,7 +411,7 @@ class Errors(object): | |||
|             "{probabilities_length} respectively.") | ||||
|     E133 = ("The sum of prior probabilities for alias '{alias}' should not " | ||||
|             "exceed 1, but found {sum}.") | ||||
|     E134 = ("Alias '{alias}' defined for unknown entity '{entity}'.") | ||||
|     E134 = ("Entity '{entity}' is not defined in the Knowledge Base.") | ||||
|     E135 = ("If you meant to replace a built-in component, use `create_pipe`: " | ||||
|             "`nlp.replace_pipe('{name}', nlp.create_pipe('{name}'))`") | ||||
|     E136 = ("This additional feature requires the jsonschema library to be " | ||||
|  | @ -420,7 +423,7 @@ class Errors(object): | |||
|     E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input " | ||||
|             "includes either the `text` or `tokens` key. For more info, see " | ||||
|             "the docs:\nhttps://spacy.io/api/cli#pretrain-jsonl") | ||||
|     E139 = ("Knowledge base for component '{name}' not initialized. Did you " | ||||
|     E139 = ("Knowledge Base for component '{name}' not initialized. Did you " | ||||
|             "forget to call set_kb()?") | ||||
|     E140 = ("The list of entities, prior probabilities and entity vectors " | ||||
|             "should be of equal length.") | ||||
|  | @ -498,6 +501,8 @@ class Errors(object): | |||
|             "details: https://spacy.io/api/lemmatizer#init") | ||||
|     E174 = ("Architecture '{name}' not found in registry. Available " | ||||
|             "names: {names}") | ||||
|     E175 = ("Can't remove rule for unknown match pattern ID: {key}") | ||||
|     E176 = ("Alias '{alias}' is not defined in the Knowledge Base.") | ||||
| 
 | ||||
| 
 | ||||
| @add_codes | ||||
|  |  | |||
|  | @ -743,7 +743,8 @@ def docs_to_json(docs, id=0): | |||
| 
 | ||||
|     docs (iterable / Doc): The Doc object(s) to convert. | ||||
|     id (int): Id for the JSON. | ||||
|     RETURNS (list): The data in spaCy's JSON format. | ||||
|     RETURNS (dict): The data in spaCy's JSON format  | ||||
|         - each input doc will be treated as a paragraph in the output doc | ||||
|     """ | ||||
|     if isinstance(docs, Doc): | ||||
|         docs = [docs] | ||||
|  |  | |||
							
								
								
									
										69
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							
							
						
						
									
										69
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							|  | @ -142,6 +142,7 @@ cdef class KnowledgeBase: | |||
| 
 | ||||
|         i = 0 | ||||
|         cdef KBEntryC entry | ||||
|         cdef hash_t entity_hash | ||||
|         while i < nr_entities: | ||||
|             entity_vector = vector_list[i] | ||||
|             if len(entity_vector) != self.entity_vector_length: | ||||
|  | @ -161,6 +162,14 @@ cdef class KnowledgeBase: | |||
| 
 | ||||
|             i += 1 | ||||
| 
 | ||||
|     def contains_entity(self, unicode entity): | ||||
|         cdef hash_t entity_hash = self.vocab.strings.add(entity) | ||||
|         return entity_hash in self._entry_index | ||||
| 
 | ||||
|     def contains_alias(self, unicode alias): | ||||
|         cdef hash_t alias_hash = self.vocab.strings.add(alias) | ||||
|         return alias_hash in self._alias_index | ||||
| 
 | ||||
|     def add_alias(self, unicode alias, entities, probabilities): | ||||
|         """ | ||||
|         For a given alias, add its potential entities and prior probabilies to the KB. | ||||
|  | @ -190,7 +199,7 @@ cdef class KnowledgeBase: | |||
|         for entity, prob in zip(entities, probabilities): | ||||
|             entity_hash = self.vocab.strings[entity] | ||||
|             if not entity_hash in self._entry_index: | ||||
|                 raise ValueError(Errors.E134.format(alias=alias, entity=entity)) | ||||
|                 raise ValueError(Errors.E134.format(entity=entity)) | ||||
| 
 | ||||
|             entry_index = <int64_t>self._entry_index.get(entity_hash) | ||||
|             entry_indices.push_back(int(entry_index)) | ||||
|  | @ -201,8 +210,63 @@ cdef class KnowledgeBase: | |||
| 
 | ||||
|         return alias_hash | ||||
| 
 | ||||
|     def get_candidates(self, unicode alias): | ||||
|     def append_alias(self, unicode alias, unicode entity, float prior_prob, ignore_warnings=False): | ||||
|         """ | ||||
|         For an alias already existing in the KB, extend its potential entities with one more. | ||||
|         Throw a warning if either the alias or the entity is unknown, | ||||
|         or when the combination is already previously recorded. | ||||
|         Throw an error if this entity+prior prob would exceed the sum of 1. | ||||
|         For efficiency, it's best to use the method `add_alias` as much as possible instead of this one. | ||||
|         """ | ||||
|         # Check if the alias exists in the KB | ||||
|         cdef hash_t alias_hash = self.vocab.strings[alias] | ||||
|         if not alias_hash in self._alias_index: | ||||
|             raise ValueError(Errors.E176.format(alias=alias)) | ||||
| 
 | ||||
|         # Check if the entity exists in the KB | ||||
|         cdef hash_t entity_hash = self.vocab.strings[entity] | ||||
|         if not entity_hash in self._entry_index: | ||||
|             raise ValueError(Errors.E134.format(entity=entity)) | ||||
|         entry_index = <int64_t>self._entry_index.get(entity_hash) | ||||
| 
 | ||||
|         # Throw an error if the prior probabilities (including the new one) sum up to more than 1 | ||||
|         alias_index = <int64_t>self._alias_index.get(alias_hash) | ||||
|         alias_entry = self._aliases_table[alias_index] | ||||
|         current_sum = sum([p for p in alias_entry.probs]) | ||||
|         new_sum = current_sum + prior_prob | ||||
| 
 | ||||
|         if new_sum > 1.00001: | ||||
|             raise ValueError(Errors.E133.format(alias=alias, sum=new_sum)) | ||||
| 
 | ||||
|         entry_indices = alias_entry.entry_indices | ||||
| 
 | ||||
|         is_present = False | ||||
|         for i in range(entry_indices.size()): | ||||
|             if entry_indices[i] == int(entry_index): | ||||
|                 is_present = True | ||||
| 
 | ||||
|         if is_present: | ||||
|             if not ignore_warnings: | ||||
|                 user_warning(Warnings.W024.format(entity=entity, alias=alias)) | ||||
|         else: | ||||
|             entry_indices.push_back(int(entry_index)) | ||||
|             alias_entry.entry_indices = entry_indices | ||||
| 
 | ||||
|             probs = alias_entry.probs | ||||
|             probs.push_back(float(prior_prob)) | ||||
|             alias_entry.probs = probs | ||||
|             self._aliases_table[alias_index] = alias_entry | ||||
| 
 | ||||
| 
 | ||||
|     def get_candidates(self, unicode alias): | ||||
|         """ | ||||
|         Return candidate entities for an alias. Each candidate defines the entity, the original alias, | ||||
|         and the prior probability of that alias resolving to that entity. | ||||
|         If the alias is not known in the KB, and empty list is returned. | ||||
|         """ | ||||
|         cdef hash_t alias_hash = self.vocab.strings[alias] | ||||
|         if not alias_hash in self._alias_index: | ||||
|             return [] | ||||
|         alias_index = <int64_t>self._alias_index.get(alias_hash) | ||||
|         alias_entry = self._aliases_table[alias_index] | ||||
| 
 | ||||
|  | @ -341,7 +405,6 @@ cdef class KnowledgeBase: | |||
|         assert nr_entities == self.get_size_entities() | ||||
| 
 | ||||
|         # STEP 3: load aliases | ||||
| 
 | ||||
|         cdef int64_t nr_aliases | ||||
|         reader.read_alias_length(&nr_aliases) | ||||
|         self._alias_index = PreshMap(nr_aliases+1) | ||||
|  |  | |||
							
								
								
									
										34
									
								
								spacy/lang/lb/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								spacy/lang/lb/__init__.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,34 @@ | |||
| # coding: utf8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS | ||||
| from .norm_exceptions import NORM_EXCEPTIONS | ||||
| from .lex_attrs import LEX_ATTRS | ||||
| from .tag_map import TAG_MAP | ||||
| from .stop_words import STOP_WORDS | ||||
| 
 | ||||
| from ..tokenizer_exceptions import BASE_EXCEPTIONS | ||||
| from ..norm_exceptions import BASE_NORMS | ||||
| from ...language import Language | ||||
| from ...attrs import LANG, NORM | ||||
| from ...util import update_exc, add_lookups | ||||
| 
 | ||||
| 
 | ||||
| class LuxembourgishDefaults(Language.Defaults): | ||||
|     lex_attr_getters = dict(Language.Defaults.lex_attr_getters) | ||||
|     lex_attr_getters.update(LEX_ATTRS) | ||||
|     lex_attr_getters[LANG] = lambda text: "lb" | ||||
|     lex_attr_getters[NORM] = add_lookups( | ||||
|         Language.Defaults.lex_attr_getters[NORM], NORM_EXCEPTIONS, BASE_NORMS | ||||
|     ) | ||||
|     tokenizer_exceptions = update_exc(BASE_EXCEPTIONS, TOKENIZER_EXCEPTIONS) | ||||
|     stop_words = STOP_WORDS | ||||
|     tag_map = TAG_MAP | ||||
| 
 | ||||
| 
 | ||||
| class Luxembourgish(Language): | ||||
|     lang = "lb" | ||||
|     Defaults = LuxembourgishDefaults | ||||
| 
 | ||||
| 
 | ||||
| __all__ = ["Luxembourgish"] | ||||
							
								
								
									
										18
									
								
								spacy/lang/lb/examples.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										18
									
								
								spacy/lang/lb/examples.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,18 @@ | |||
| # coding: utf8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| """ | ||||
| Example sentences to test spaCy and its language models. | ||||
| 
 | ||||
| >>> from spacy.lang.lb.examples import sentences | ||||
| >>> docs = nlp.pipe(sentences) | ||||
| """ | ||||
| 
 | ||||
| sentences = [ | ||||
|     "An der Zäit hunn sech den Nordwand an d’Sonn gestridden, wie vun hinnen zwee wuel méi staark wier, wéi e Wanderer, deen an ee waarme Mantel agepak war, iwwert de Wee koum.", | ||||
|     "Si goufen sech eens, dass deejéinege fir de Stäerkste gëlle sollt, deen de Wanderer forcéiere géif, säi Mantel auszedoen.", | ||||
|     "Den Nordwand huet mat aller Force geblosen, awer wat e méi geblosen huet, wat de Wanderer sech méi a säi Mantel agewéckelt huet.", | ||||
|     "Um Enn huet den Nordwand säi Kampf opginn.", | ||||
|     "Dunn huet d’Sonn d’Loft mat hire frëndleche Strale gewiermt, a schonn no kuerzer Zäit huet de Wanderer säi Mantel ausgedoen.", | ||||
|     "Do huet den Nordwand missen zouginn, dass d’Sonn vun hinnen zwee de Stäerkste wier.", | ||||
| ] | ||||
							
								
								
									
										44
									
								
								spacy/lang/lb/lex_attrs.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										44
									
								
								spacy/lang/lb/lex_attrs.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,44 @@ | |||
| # coding: utf8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| from ...attrs import LIKE_NUM | ||||
| 
 | ||||
| 
 | ||||
| _num_words = set( | ||||
|     """ | ||||
| null eent zwee dräi véier fënnef sechs ziwen aacht néng zéng eelef zwielef dräizéng | ||||
| véierzéng foffzéng siechzéng siwwenzéng uechtzeng uechzeng nonnzéng nongzéng zwanzeg drësseg véierzeg foffzeg sechzeg siechzeg siwenzeg achtzeg achzeg uechtzeg uechzeg nonnzeg | ||||
| honnert dausend millioun milliard billioun billiard trillioun triliard | ||||
| """.split() | ||||
| ) | ||||
| 
 | ||||
| _ordinal_words = set( | ||||
|     """ | ||||
| éischten zweeten drëtten véierten fënneften sechsten siwenten aachten néngten zéngten eeleften | ||||
| zwieleften dräizéngten véierzéngten foffzéngten siechzéngten uechtzéngen uechzéngten nonnzéngten nongzéngten zwanzegsten | ||||
| drëssegsten véierzegsten foffzegsten siechzegsten siwenzegsten uechzegsten nonnzegsten | ||||
| honnertsten dausendsten milliounsten | ||||
| milliardsten billiounsten billiardsten trilliounsten trilliardsten | ||||
| """.split() | ||||
| ) | ||||
| 
 | ||||
| 
 | ||||
| def like_num(text): | ||||
|     """ | ||||
|     check if text resembles a number | ||||
|     """ | ||||
|     text = text.replace(",", "").replace(".", "") | ||||
|     if text.isdigit(): | ||||
|         return True | ||||
|     if text.count("/") == 1: | ||||
|         num, denom = text.split("/") | ||||
|         if num.isdigit() and denom.isdigit(): | ||||
|             return True | ||||
|     if text in _num_words: | ||||
|         return True | ||||
|     if text in _ordinal_words: | ||||
|         return True | ||||
|     return False | ||||
| 
 | ||||
| 
 | ||||
| LEX_ATTRS = {LIKE_NUM: like_num} | ||||
							
								
								
									
										16
									
								
								spacy/lang/lb/norm_exceptions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								spacy/lang/lb/norm_exceptions.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,16 @@ | |||
| # coding: utf8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| # TODO | ||||
| # norm execptions: find a possibility to deal with the zillions of spelling | ||||
| # variants (vläicht = vlaicht, vleicht, viläicht, viläischt, etc. etc.) | ||||
| # here one could include the most common spelling mistakes | ||||
| 
 | ||||
| _exc = {"datt": "dass", "wgl.": "weg.", "vläicht": "viläicht"} | ||||
| 
 | ||||
| 
 | ||||
| NORM_EXCEPTIONS = {} | ||||
| 
 | ||||
| for string, norm in _exc.items(): | ||||
|     NORM_EXCEPTIONS[string] = norm | ||||
|     NORM_EXCEPTIONS[string.title()] = norm | ||||
							
								
								
									
										214
									
								
								spacy/lang/lb/stop_words.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										214
									
								
								spacy/lang/lb/stop_words.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,214 @@ | |||
| # coding: utf8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| STOP_WORDS = set( | ||||
|     """ | ||||
| a | ||||
| à | ||||
| äis | ||||
| är | ||||
| ärt | ||||
| äert | ||||
| ären | ||||
| all | ||||
| allem | ||||
| alles | ||||
| alleguer | ||||
| als | ||||
| also | ||||
| am | ||||
| an | ||||
| anerefalls | ||||
| ass | ||||
| aus | ||||
| awer | ||||
| bei | ||||
| beim | ||||
| bis | ||||
| bis | ||||
| d' | ||||
| dach | ||||
| datt | ||||
| däin | ||||
| där | ||||
| dat | ||||
| de | ||||
| dee | ||||
| den | ||||
| deel | ||||
| deem | ||||
| deen | ||||
| deene | ||||
| déi | ||||
| den | ||||
| deng | ||||
| denger | ||||
| dem | ||||
| der | ||||
| dësem | ||||
| di | ||||
| dir | ||||
| do | ||||
| da | ||||
| dann | ||||
| domat | ||||
| dozou | ||||
| drop | ||||
| du | ||||
| duerch | ||||
| duerno | ||||
| e | ||||
| ee | ||||
| em | ||||
| een | ||||
| eent | ||||
| ë | ||||
| en | ||||
| ënner | ||||
| ëm | ||||
| ech | ||||
| eis | ||||
| eise | ||||
| eisen | ||||
| eiser | ||||
| eises | ||||
| eisereen | ||||
| esou | ||||
| een | ||||
| eng | ||||
| enger | ||||
| engem | ||||
| entweder | ||||
| et | ||||
| eréischt | ||||
| falls | ||||
| fir | ||||
| géint | ||||
| géif | ||||
| gëtt | ||||
| gët | ||||
| geet | ||||
| gi | ||||
| ginn | ||||
| gouf | ||||
| gouff | ||||
| goung | ||||
| hat | ||||
| haten | ||||
| hatt | ||||
| hätt | ||||
| hei | ||||
| hu | ||||
| huet | ||||
| hun | ||||
| hunn | ||||
| hiren | ||||
| hien | ||||
| hin | ||||
| hier | ||||
| hir | ||||
| jidderen | ||||
| jiddereen | ||||
| jiddwereen | ||||
| jiddereng | ||||
| jiddwerengen | ||||
| jo | ||||
| ins | ||||
| iech | ||||
| iwwer | ||||
| kann | ||||
| kee | ||||
| keen | ||||
| kënne | ||||
| kënnt | ||||
| kéng | ||||
| kéngen | ||||
| kéngem | ||||
| koum | ||||
| kuckt | ||||
| mam | ||||
| mat | ||||
| ma | ||||
| mä | ||||
| mech | ||||
| méi | ||||
| mécht | ||||
| meng | ||||
| menger | ||||
| mer | ||||
| mir | ||||
| muss | ||||
| nach | ||||
| nämmlech | ||||
| nämmelech | ||||
| näischt | ||||
| nawell | ||||
| nëmme | ||||
| nëmmen | ||||
| net | ||||
| nees | ||||
| nee | ||||
| no | ||||
| nu | ||||
| nom | ||||
| och | ||||
| oder | ||||
| ons | ||||
| onsen | ||||
| onser | ||||
| onsereen | ||||
| onst | ||||
| om | ||||
| op | ||||
| ouni | ||||
| säi | ||||
| säin | ||||
| schonn | ||||
| schonns | ||||
| si | ||||
| sid | ||||
| sie | ||||
| se | ||||
| sech | ||||
| seng | ||||
| senge | ||||
| sengem | ||||
| senger | ||||
| selwecht | ||||
| selwer | ||||
| sinn | ||||
| sollten | ||||
| souguer | ||||
| sou | ||||
| soss | ||||
| sot | ||||
| 't | ||||
| tëscht | ||||
| u | ||||
| un | ||||
| um | ||||
| virdrun | ||||
| vu | ||||
| vum | ||||
| vun | ||||
| wann | ||||
| war | ||||
| waren | ||||
| was | ||||
| wat | ||||
| wëllt | ||||
| weider | ||||
| wéi | ||||
| wéini | ||||
| wéinst | ||||
| wi | ||||
| wollt | ||||
| wou | ||||
| wouhin | ||||
| zanter | ||||
| ze | ||||
| zu | ||||
| zum | ||||
| zwar | ||||
| """.split() | ||||
| ) | ||||
							
								
								
									
										28
									
								
								spacy/lang/lb/tag_map.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								spacy/lang/lb/tag_map.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,28 @@ | |||
| # coding: utf8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| from ...symbols import POS, PUNCT, ADJ, CONJ, NUM, DET, ADV, ADP, X, VERB | ||||
| from ...symbols import NOUN, PART, SPACE, AUX | ||||
| 
 | ||||
| # TODO: tag map is still using POS tags from an internal training set. | ||||
| # These POS tags have to be modified to match those from Universal Dependencies | ||||
| 
 | ||||
| TAG_MAP = { | ||||
|     "$": {POS: PUNCT}, | ||||
|     "ADJ": {POS: ADJ}, | ||||
|     "AV": {POS: ADV}, | ||||
|     "APPR": {POS: ADP, "AdpType": "prep"}, | ||||
|     "APPRART": {POS: ADP, "AdpType": "prep", "PronType": "art"}, | ||||
|     "D": {POS: DET, "PronType": "art"}, | ||||
|     "KO": {POS: CONJ}, | ||||
|     "N": {POS: NOUN}, | ||||
|     "P": {POS: ADV}, | ||||
|     "TRUNC": {POS: X, "Hyph": "yes"}, | ||||
|     "AUX": {POS: AUX}, | ||||
|     "V": {POS: VERB}, | ||||
|     "MV": {POS: VERB, "VerbType": "mod"}, | ||||
|     "PTK": {POS: PART}, | ||||
|     "INTER": {POS: PART}, | ||||
|     "NUM": {POS: NUM}, | ||||
|     "_SP": {POS: SPACE}, | ||||
| } | ||||
							
								
								
									
										69
									
								
								spacy/lang/lb/tokenizer_exceptions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										69
									
								
								spacy/lang/lb/tokenizer_exceptions.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,69 @@ | |||
| # coding: utf8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| from ...symbols import ORTH, LEMMA, NORM | ||||
| from ..punctuation import TOKENIZER_PREFIXES | ||||
| 
 | ||||
| # TODO | ||||
| # tokenize cliticised definite article "d'" as token of its own: d'Kanner > [d'] [Kanner] | ||||
| # treat other apostrophes within words as part of the word: [op d'mannst], [fir d'éischt] (= exceptions) | ||||
| 
 | ||||
| # how to write the tokenisation exeption for the articles d' / D' ? This one is not working. | ||||
| _prefixes = [ | ||||
|     prefix for prefix in TOKENIZER_PREFIXES if prefix not in ["d'", "D'", "d’", "D’"] | ||||
| ] | ||||
| 
 | ||||
| 
 | ||||
| _exc = { | ||||
|     "d'mannst": [ | ||||
|         {ORTH: "d'", LEMMA: "d'"}, | ||||
|         {ORTH: "mannst", LEMMA: "mann", NORM: "mann"}, | ||||
|     ], | ||||
|     "d'éischt": [ | ||||
|         {ORTH: "d'", LEMMA: "d'"}, | ||||
|         {ORTH: "éischt", LEMMA: "éischt", NORM: "éischt"}, | ||||
|     ], | ||||
| } | ||||
| 
 | ||||
| # translate / delete what is not necessary | ||||
| # what does PRON_LEMMA mean? | ||||
| for exc_data in [ | ||||
|     {ORTH: "wgl.", LEMMA: "wann ech gelift", NORM: "wann ech gelieft"}, | ||||
|     {ORTH: "M.", LEMMA: "Monsieur", NORM: "Monsieur"}, | ||||
|     {ORTH: "Mme.", LEMMA: "Madame", NORM: "Madame"}, | ||||
|     {ORTH: "Dr.", LEMMA: "Dokter", NORM: "Dokter"}, | ||||
|     {ORTH: "Tel.", LEMMA: "Telefon", NORM: "Telefon"}, | ||||
|     {ORTH: "asw.", LEMMA: "an sou weider", NORM: "an sou weider"}, | ||||
|     {ORTH: "etc.", LEMMA: "et cetera", NORM: "et cetera"}, | ||||
|     {ORTH: "bzw.", LEMMA: "bezéiungsweis", NORM: "bezéiungsweis"}, | ||||
|     {ORTH: "Jan.", LEMMA: "Januar", NORM: "Januar"}, | ||||
| ]: | ||||
|     _exc[exc_data[ORTH]] = [exc_data] | ||||
| 
 | ||||
| 
 | ||||
| # to be extended | ||||
| for orth in [ | ||||
|     "z.B.", | ||||
|     "Dipl.", | ||||
|     "Dr.", | ||||
|     "etc.", | ||||
|     "i.e.", | ||||
|     "o.k.", | ||||
|     "O.K.", | ||||
|     "p.a.", | ||||
|     "p.s.", | ||||
|     "P.S.", | ||||
|     "phil.", | ||||
|     "q.e.d.", | ||||
|     "R.I.P.", | ||||
|     "rer.", | ||||
|     "sen.", | ||||
|     "ë.a.", | ||||
|     "U.S.", | ||||
|     "U.S.A.", | ||||
| ]: | ||||
|     _exc[orth] = [{ORTH: orth}] | ||||
| 
 | ||||
| 
 | ||||
| TOKENIZER_PREFIXES = _prefixes | ||||
| TOKENIZER_EXCEPTIONS = _exc | ||||
|  | @ -1,10 +1,8 @@ | |||
| # coding: utf8 | ||||
| from __future__ import absolute_import, unicode_literals | ||||
| 
 | ||||
| import atexit | ||||
| import random | ||||
| import itertools | ||||
| from warnings import warn | ||||
| from spacy.util import minibatch | ||||
| import weakref | ||||
| import functools | ||||
|  | @ -483,7 +481,7 @@ class Language(object): | |||
| 
 | ||||
|         docs (iterable): A batch of `Doc` objects. | ||||
|         golds (iterable): A batch of `GoldParse` objects. | ||||
|         drop (float): The droput rate. | ||||
|         drop (float): The dropout rate. | ||||
|         sgd (callable): An optimizer. | ||||
|         losses (dict): Dictionary to update with the loss, keyed by component. | ||||
|         component_cfg (dict): Config parameters for specific pipeline | ||||
|  | @ -531,7 +529,7 @@ class Language(object): | |||
|         even if you're updating it with a smaller set of examples. | ||||
| 
 | ||||
|         docs (iterable): A batch of `Doc` objects. | ||||
|         drop (float): The droput rate. | ||||
|         drop (float): The dropout rate. | ||||
|         sgd (callable): An optimizer. | ||||
|         RETURNS (dict): Results from the update. | ||||
| 
 | ||||
|  | @ -753,7 +751,8 @@ class Language(object): | |||
|             use. Experimental. | ||||
|         component_cfg (dict): An optional dictionary with extra keyword | ||||
|             arguments for specific components. | ||||
|         n_process (int): Number of processors to process texts, only supported in Python3. If -1, set `multiprocessing.cpu_count()`. | ||||
|         n_process (int): Number of processors to process texts, only supported | ||||
|             in Python3. If -1, set `multiprocessing.cpu_count()`. | ||||
|         YIELDS (Doc): Documents in the order of the original text. | ||||
| 
 | ||||
|         DOCS: https://spacy.io/api/language#pipe | ||||
|  | @ -1069,9 +1068,10 @@ def _pipe(docs, proc, kwargs): | |||
| def _apply_pipes(make_doc, pipes, reciever, sender): | ||||
|     """Worker for Language.pipe | ||||
| 
 | ||||
|     Args: | ||||
|         receiver (multiprocessing.Connection): Pipe to receive text. Usually created by `multiprocessing.Pipe()` | ||||
|         sender (multiprocessing.Connection): Pipe to send doc. Usually created by `multiprocessing.Pipe()` | ||||
|     receiver (multiprocessing.Connection): Pipe to receive text. Usually | ||||
|         created by `multiprocessing.Pipe()` | ||||
|     sender (multiprocessing.Connection): Pipe to send doc. Usually created by | ||||
|         `multiprocessing.Pipe()` | ||||
|     """ | ||||
|     while True: | ||||
|         texts = reciever.get() | ||||
|  | @ -1100,7 +1100,7 @@ class _Sender: | |||
|             q.put(item) | ||||
| 
 | ||||
|     def step(self): | ||||
|         """Tell sender that comsumed one item.  | ||||
|         """Tell sender that comsumed one item. | ||||
| 
 | ||||
|         Data is sent to the workers after every chunk_size calls.""" | ||||
|         self.count += 1 | ||||
|  |  | |||
|  | @ -133,13 +133,15 @@ cdef class Matcher: | |||
| 
 | ||||
|         key (unicode): The ID of the match rule. | ||||
|         """ | ||||
|         key = self._normalize_key(key) | ||||
|         self._patterns.pop(key) | ||||
|         self._callbacks.pop(key) | ||||
|         norm_key = self._normalize_key(key) | ||||
|         if not norm_key in self._patterns: | ||||
|             raise ValueError(Errors.E175.format(key=key)) | ||||
|         self._patterns.pop(norm_key) | ||||
|         self._callbacks.pop(norm_key) | ||||
|         cdef int i = 0 | ||||
|         while i < self.patterns.size(): | ||||
|             pattern_key = get_pattern_key(self.patterns.at(i)) | ||||
|             if pattern_key == key: | ||||
|             pattern_key = get_ent_id(self.patterns.at(i)) | ||||
|             if pattern_key == norm_key: | ||||
|                 self.patterns.erase(self.patterns.begin()+i) | ||||
|             else: | ||||
|                 i += 1 | ||||
|  | @ -293,18 +295,6 @@ cdef find_matches(TokenPatternC** patterns, int n, Doc doc, extensions=None, | |||
|     return output | ||||
| 
 | ||||
| 
 | ||||
| cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil: | ||||
|     # There have been a few bugs here. | ||||
|     # The code was originally designed to always have pattern[1].attrs.value | ||||
|     # be the ent_id when we get to the end of a pattern. However, Issue #2671 | ||||
|     # showed this wasn't the case when we had a reject-and-continue before a | ||||
|     # match. | ||||
|     # The patch to #2671 was wrong though, which came up in #3839. | ||||
|     while pattern.attrs.attr != ID: | ||||
|         pattern += 1 | ||||
|     return pattern.attrs.value | ||||
| 
 | ||||
| 
 | ||||
| cdef void transition_states(vector[PatternStateC]& states, vector[MatchC]& matches, | ||||
|                             char* cached_py_predicates, | ||||
|         Token token, const attr_t* extra_attrs, py_predicates) except *: | ||||
|  | @ -533,9 +523,10 @@ cdef char get_is_match(PatternStateC state, | |||
|         if predicate_matches[state.pattern.py_predicates[i]] == -1: | ||||
|             return 0 | ||||
|     spec = state.pattern | ||||
|     for attr in spec.attrs[:spec.nr_attr]: | ||||
|         if get_token_attr(token, attr.attr) != attr.value: | ||||
|             return 0 | ||||
|     if spec.nr_attr > 0: | ||||
|         for attr in spec.attrs[:spec.nr_attr]: | ||||
|             if get_token_attr(token, attr.attr) != attr.value: | ||||
|                 return 0 | ||||
|     for i in range(spec.nr_extra_attr): | ||||
|         if spec.extra_attrs[i].value != extra_attrs[spec.extra_attrs[i].index]: | ||||
|             return 0 | ||||
|  | @ -543,7 +534,11 @@ cdef char get_is_match(PatternStateC state, | |||
| 
 | ||||
| 
 | ||||
| cdef char get_is_final(PatternStateC state) nogil: | ||||
|     if state.pattern[1].attrs[0].attr == ID and state.pattern[1].nr_attr == 0: | ||||
|     if state.pattern[1].nr_attr == 0 and state.pattern[1].attrs != NULL: | ||||
|         id_attr = state.pattern[1].attrs[0] | ||||
|         if id_attr.attr != ID: | ||||
|             with gil: | ||||
|                 raise ValueError(Errors.E074.format(attr=ID, bad_attr=id_attr.attr)) | ||||
|         return 1 | ||||
|     else: | ||||
|         return 0 | ||||
|  | @ -558,7 +553,9 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) | |||
|     cdef int i, index | ||||
|     for i, (quantifier, spec, extensions, predicates) in enumerate(token_specs): | ||||
|         pattern[i].quantifier = quantifier | ||||
|         pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC)) | ||||
|         # Ensure attrs refers to a null pointer if nr_attr == 0 | ||||
|         if len(spec) > 0: | ||||
|             pattern[i].attrs = <AttrValueC*>mem.alloc(len(spec), sizeof(AttrValueC)) | ||||
|         pattern[i].nr_attr = len(spec) | ||||
|         for j, (attr, value) in enumerate(spec): | ||||
|             pattern[i].attrs[j].attr = attr | ||||
|  | @ -574,6 +571,7 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) | |||
|         pattern[i].nr_py = len(predicates) | ||||
|         pattern[i].key = hash64(pattern[i].attrs, pattern[i].nr_attr * sizeof(AttrValueC), 0) | ||||
|     i = len(token_specs) | ||||
|     # Even though here, nr_attr == 0, we're storing the ID value in attrs[0] (bug-prone, thread carefully!) | ||||
|     pattern[i].attrs = <AttrValueC*>mem.alloc(2, sizeof(AttrValueC)) | ||||
|     pattern[i].attrs[0].attr = ID | ||||
|     pattern[i].attrs[0].value = entity_id | ||||
|  | @ -583,8 +581,26 @@ cdef TokenPatternC* init_pattern(Pool mem, attr_t entity_id, object token_specs) | |||
|     return pattern | ||||
| 
 | ||||
| 
 | ||||
| cdef attr_t get_pattern_key(const TokenPatternC* pattern) nogil: | ||||
|     while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0: | ||||
| cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil: | ||||
|     # There have been a few bugs here. We used to have two functions, | ||||
|     # get_ent_id and get_pattern_key that tried to do the same thing. These | ||||
|     # are now unified to try to solve the "ghost match" problem. | ||||
|     # Below is the previous implementation of get_ent_id and the comment on it, | ||||
|     # preserved for reference while we figure out whether the heisenbug in the | ||||
|     # matcher is resolved. | ||||
|     # | ||||
|     # | ||||
|     #     cdef attr_t get_ent_id(const TokenPatternC* pattern) nogil: | ||||
|     #         # The code was originally designed to always have pattern[1].attrs.value | ||||
|     #         # be the ent_id when we get to the end of a pattern. However, Issue #2671 | ||||
|     #         # showed this wasn't the case when we had a reject-and-continue before a | ||||
|     #         # match. | ||||
|     #         # The patch to #2671 was wrong though, which came up in #3839. | ||||
|     #         while pattern.attrs.attr != ID: | ||||
|     #             pattern += 1 | ||||
|     #         return pattern.attrs.value | ||||
|     while pattern.nr_attr != 0 or pattern.nr_extra_attr != 0 or pattern.nr_py != 0 \ | ||||
|             or pattern.quantifier != ZERO: | ||||
|         pattern += 1 | ||||
|     id_attr = pattern[0].attrs[0] | ||||
|     if id_attr.attr != ID: | ||||
|  | @ -642,7 +658,7 @@ def _get_attr_values(spec, string_store): | |||
|             value = string_store.add(value) | ||||
|         elif isinstance(value, bool): | ||||
|             value = int(value) | ||||
|         elif isinstance(value, dict): | ||||
|         elif isinstance(value, (dict, int)): | ||||
|             continue | ||||
|         else: | ||||
|             raise ValueError(Errors.E153.format(vtype=type(value).__name__)) | ||||
|  |  | |||
|  | @ -4,6 +4,7 @@ from cymem.cymem cimport Pool | |||
| from preshed.maps cimport key_t, MapStruct | ||||
| 
 | ||||
| from ..attrs cimport attr_id_t | ||||
| from ..structs cimport SpanC | ||||
| from ..tokens.doc cimport Doc | ||||
| from ..vocab cimport Vocab | ||||
| 
 | ||||
|  | @ -18,10 +19,4 @@ cdef class PhraseMatcher: | |||
|     cdef Pool mem | ||||
|     cdef key_t _terminal_hash | ||||
| 
 | ||||
|     cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil | ||||
| 
 | ||||
| 
 | ||||
| cdef struct MatchStruct: | ||||
|     key_t match_id | ||||
|     int start | ||||
|     int end | ||||
|     cdef void find_matches(self, Doc doc, vector[SpanC] *matches) nogil | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ from preshed.maps cimport map_init, map_set, map_get, map_clear, map_iter | |||
| from ..attrs cimport ORTH, POS, TAG, DEP, LEMMA | ||||
| from ..structs cimport TokenC | ||||
| from ..tokens.token cimport Token | ||||
| from ..typedefs cimport attr_t | ||||
| 
 | ||||
| from ._schemas import TOKEN_PATTERN_SCHEMA | ||||
| from ..errors import Errors, Warnings, deprecation_warning, user_warning | ||||
|  | @ -102,8 +103,10 @@ cdef class PhraseMatcher: | |||
|         cdef vector[MapStruct*] path_nodes | ||||
|         cdef vector[key_t] path_keys | ||||
|         cdef key_t key_to_remove | ||||
|         for keyword in self._docs[key]: | ||||
|         for keyword in sorted(self._docs[key], key=lambda x: len(x), reverse=True): | ||||
|             current_node = self.c_map | ||||
|             path_nodes.clear() | ||||
|             path_keys.clear() | ||||
|             for token in keyword: | ||||
|                 result = map_get(current_node, token) | ||||
|                 if result: | ||||
|  | @ -220,17 +223,17 @@ cdef class PhraseMatcher: | |||
|             # if doc is empty or None just return empty list | ||||
|             return matches | ||||
| 
 | ||||
|         cdef vector[MatchStruct] c_matches | ||||
|         cdef vector[SpanC] c_matches | ||||
|         self.find_matches(doc, &c_matches) | ||||
|         for i in range(c_matches.size()): | ||||
|             matches.append((c_matches[i].match_id, c_matches[i].start, c_matches[i].end)) | ||||
|             matches.append((c_matches[i].label, c_matches[i].start, c_matches[i].end)) | ||||
|         for i, (ent_id, start, end) in enumerate(matches): | ||||
|             on_match = self._callbacks.get(self.vocab.strings[ent_id]) | ||||
|             if on_match is not None: | ||||
|                 on_match(self, doc, i, matches) | ||||
|         return matches | ||||
| 
 | ||||
|     cdef void find_matches(self, Doc doc, vector[MatchStruct] *matches) nogil: | ||||
|     cdef void find_matches(self, Doc doc, vector[SpanC] *matches) nogil: | ||||
|         cdef MapStruct* current_node = self.c_map | ||||
|         cdef int start = 0 | ||||
|         cdef int idx = 0 | ||||
|  | @ -238,7 +241,7 @@ cdef class PhraseMatcher: | |||
|         cdef key_t key | ||||
|         cdef void* value | ||||
|         cdef int i = 0 | ||||
|         cdef MatchStruct ms | ||||
|         cdef SpanC ms | ||||
|         cdef void* result | ||||
|         while idx < doc.length: | ||||
|             start = idx | ||||
|  | @ -253,7 +256,7 @@ cdef class PhraseMatcher: | |||
|                     if result: | ||||
|                         i = 0 | ||||
|                         while map_iter(<MapStruct*>result, &i, &key, &value): | ||||
|                             ms = make_matchstruct(key, start, idy) | ||||
|                             ms = make_spanstruct(key, start, idy) | ||||
|                             matches.push_back(ms) | ||||
|                     inner_token = Token.get_struct_attr(&doc.c[idy], self.attr) | ||||
|                     result = map_get(current_node, inner_token) | ||||
|  | @ -268,7 +271,7 @@ cdef class PhraseMatcher: | |||
|                     if result: | ||||
|                         i = 0 | ||||
|                         while map_iter(<MapStruct*>result, &i, &key, &value): | ||||
|                             ms = make_matchstruct(key, start, idy) | ||||
|                             ms = make_spanstruct(key, start, idy) | ||||
|                             matches.push_back(ms) | ||||
|             current_node = self.c_map | ||||
|             idx += 1 | ||||
|  | @ -318,9 +321,9 @@ def unpickle_matcher(vocab, docs, callbacks, attr): | |||
|     return matcher | ||||
| 
 | ||||
| 
 | ||||
| cdef MatchStruct make_matchstruct(key_t match_id, int start, int end) nogil: | ||||
|     cdef MatchStruct ms | ||||
|     ms.match_id = match_id | ||||
|     ms.start = start | ||||
|     ms.end = end | ||||
|     return ms | ||||
| cdef SpanC make_spanstruct(attr_t label, int start, int end) nogil: | ||||
|     cdef SpanC spanc | ||||
|     spanc.label = label | ||||
|     spanc.start = start | ||||
|     spanc.end = end | ||||
|     return spanc | ||||
|  |  | |||
|  | @ -183,7 +183,9 @@ class EntityRuler(object): | |||
|         # disable the nlp components after this one in case they hadn't been initialized / deserialised yet | ||||
|         try: | ||||
|             current_index = self.nlp.pipe_names.index(self.name) | ||||
|             subsequent_pipes = [pipe for pipe in self.nlp.pipe_names[current_index + 1:]] | ||||
|             subsequent_pipes = [ | ||||
|                 pipe for pipe in self.nlp.pipe_names[current_index + 1 :] | ||||
|             ] | ||||
|         except ValueError: | ||||
|             subsequent_pipes = [] | ||||
|         with self.nlp.disable_pipes(*subsequent_pipes): | ||||
|  |  | |||
|  | @ -1195,23 +1195,26 @@ class EntityLinker(Pipe): | |||
|             docs = [docs] | ||||
|             golds = [golds] | ||||
| 
 | ||||
|         context_docs = [] | ||||
|         sentence_docs = [] | ||||
| 
 | ||||
|         for doc, gold in zip(docs, golds): | ||||
|             ents_by_offset = dict() | ||||
|             for ent in doc.ents: | ||||
|                 ents_by_offset["{}_{}".format(ent.start_char, ent.end_char)] = ent | ||||
|                 ents_by_offset[(ent.start_char, ent.end_char)] = ent | ||||
| 
 | ||||
|             for entity, kb_dict in gold.links.items(): | ||||
|                 start, end = entity | ||||
|                 mention = doc.text[start:end] | ||||
|                 # the gold annotations should link to proper entities - if this fails, the dataset is likely corrupt | ||||
|                 ent = ents_by_offset[(start, end)] | ||||
| 
 | ||||
|                 for kb_id, value in kb_dict.items(): | ||||
|                     # Currently only training on the positive instances | ||||
|                     if value: | ||||
|                         context_docs.append(doc) | ||||
|                         sentence_docs.append(ent.sent.as_doc()) | ||||
| 
 | ||||
|         context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop) | ||||
|         loss, d_scores = self.get_similarity_loss(scores=context_encodings, golds=golds, docs=None) | ||||
|         sentence_encodings, bp_context = self.model.begin_update(sentence_docs, drop=drop) | ||||
|         loss, d_scores = self.get_similarity_loss(scores=sentence_encodings, golds=golds, docs=None) | ||||
|         bp_context(d_scores, sgd=sgd) | ||||
| 
 | ||||
|         if losses is not None: | ||||
|  | @ -1280,50 +1283,68 @@ class EntityLinker(Pipe): | |||
|         if isinstance(docs, Doc): | ||||
|             docs = [docs] | ||||
| 
 | ||||
|         context_encodings = self.model(docs) | ||||
|         xp = get_array_module(context_encodings) | ||||
| 
 | ||||
|         for i, doc in enumerate(docs): | ||||
|             if len(doc) > 0: | ||||
|                 # currently, the context is the same for each entity in a sentence (should be refined) | ||||
|                 context_encoding = context_encodings[i] | ||||
|                 context_enc_t = context_encoding.T | ||||
|                 norm_1 = xp.linalg.norm(context_enc_t) | ||||
|                 for ent in doc.ents: | ||||
|                     entity_count += 1 | ||||
|                 # Looping through each sentence and each entity | ||||
|                 # This may go wrong if there are entities across sentences - because they might not get a KB ID | ||||
|                 for sent in doc.ents: | ||||
|                     sent_doc = sent.as_doc() | ||||
|                     # currently, the context is the same for each entity in a sentence (should be refined) | ||||
|                     sentence_encoding = self.model([sent_doc])[0] | ||||
|                     xp = get_array_module(sentence_encoding) | ||||
|                     sentence_encoding_t = sentence_encoding.T | ||||
|                     sentence_norm = xp.linalg.norm(sentence_encoding_t) | ||||
| 
 | ||||
|                     candidates = self.kb.get_candidates(ent.text) | ||||
|                     if not candidates: | ||||
|                         final_kb_ids.append(self.NIL)  # no prediction possible for this entity | ||||
|                         final_tensors.append(context_encoding) | ||||
|                     else: | ||||
|                         random.shuffle(candidates) | ||||
|                     for ent in sent_doc.ents: | ||||
|                         entity_count += 1 | ||||
| 
 | ||||
|                         # this will set all prior probabilities to 0 if they should be excluded from the model | ||||
|                         prior_probs = xp.asarray([c.prior_prob for c in candidates]) | ||||
|                         if not self.cfg.get("incl_prior", True): | ||||
|                             prior_probs = xp.asarray([0.0 for c in candidates]) | ||||
|                         scores = prior_probs | ||||
|                         if ent.label_ in self.cfg.get("labels_discard", []): | ||||
|                             # ignoring this entity - setting to NIL | ||||
|                             final_kb_ids.append(self.NIL) | ||||
|                             final_tensors.append(sentence_encoding) | ||||
| 
 | ||||
|                         # add in similarity from the context | ||||
|                         if self.cfg.get("incl_context", True): | ||||
|                             entity_encodings = xp.asarray([c.entity_vector for c in candidates]) | ||||
|                             norm_2 = xp.linalg.norm(entity_encodings, axis=1) | ||||
|                         else: | ||||
|                             candidates = self.kb.get_candidates(ent.text) | ||||
|                             if not candidates: | ||||
|                                 # no prediction possible for this entity - setting to NIL | ||||
|                                 final_kb_ids.append(self.NIL) | ||||
|                                 final_tensors.append(sentence_encoding) | ||||
| 
 | ||||
|                             if len(entity_encodings) != len(prior_probs): | ||||
|                                 raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length")) | ||||
|                             elif len(candidates) == 1: | ||||
|                                 # shortcut for efficiency reasons: take the 1 candidate | ||||
| 
 | ||||
|                              # cosine similarity | ||||
|                             sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2) | ||||
|                             if sims.shape != prior_probs.shape: | ||||
|                                 raise ValueError(Errors.E161) | ||||
|                             scores = prior_probs + sims - (prior_probs*sims) | ||||
|                                 # TODO: thresholding | ||||
|                                 final_kb_ids.append(candidates[0].entity_) | ||||
|                                 final_tensors.append(sentence_encoding) | ||||
| 
 | ||||
|                         # TODO: thresholding | ||||
|                         best_index = scores.argmax() | ||||
|                         best_candidate = candidates[best_index] | ||||
|                         final_kb_ids.append(best_candidate.entity_) | ||||
|                         final_tensors.append(context_encoding) | ||||
|                             else: | ||||
|                                 random.shuffle(candidates) | ||||
| 
 | ||||
|                                 # this will set all prior probabilities to 0 if they should be excluded from the model | ||||
|                                 prior_probs = xp.asarray([c.prior_prob for c in candidates]) | ||||
|                                 if not self.cfg.get("incl_prior", True): | ||||
|                                     prior_probs = xp.asarray([0.0 for c in candidates]) | ||||
|                                 scores = prior_probs | ||||
| 
 | ||||
|                                 # add in similarity from the context | ||||
|                                 if self.cfg.get("incl_context", True): | ||||
|                                     entity_encodings = xp.asarray([c.entity_vector for c in candidates]) | ||||
|                                     entity_norm = xp.linalg.norm(entity_encodings, axis=1) | ||||
| 
 | ||||
|                                     if len(entity_encodings) != len(prior_probs): | ||||
|                                         raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length")) | ||||
| 
 | ||||
|                                     # cosine similarity | ||||
|                                     sims = xp.dot(entity_encodings, sentence_encoding_t) / (sentence_norm * entity_norm) | ||||
|                                     if sims.shape != prior_probs.shape: | ||||
|                                         raise ValueError(Errors.E161) | ||||
|                                     scores = prior_probs + sims - (prior_probs*sims) | ||||
| 
 | ||||
|                                 # TODO: thresholding | ||||
|                                 best_index = scores.argmax() | ||||
|                                 best_candidate = candidates[best_index] | ||||
|                                 final_kb_ids.append(best_candidate.entity_) | ||||
|                                 final_tensors.append(sentence_encoding) | ||||
| 
 | ||||
|         if not (len(final_tensors) == len(final_kb_ids) == entity_count): | ||||
|             raise RuntimeError(Errors.E147.format(method="predict", msg="result variables not of equal length")) | ||||
|  |  | |||
|  | @ -219,7 +219,9 @@ class Scorer(object): | |||
|         DOCS: https://spacy.io/api/scorer#score | ||||
|         """ | ||||
|         if len(doc) != len(gold): | ||||
|             gold = GoldParse.from_annot_tuples(doc, zip(*gold.orig_annot)) | ||||
|             gold = GoldParse.from_annot_tuples( | ||||
|                 doc, tuple(zip(*gold.orig_annot)) + (gold.cats,) | ||||
|             ) | ||||
|         gold_deps = set() | ||||
|         gold_tags = set() | ||||
|         gold_ents = set(tags_to_entities([annot[-1] for annot in gold.orig_annot])) | ||||
|  |  | |||
|  | @ -47,11 +47,14 @@ cdef struct SerializedLexemeC: | |||
|     #    + sizeof(float) # l2_norm | ||||
| 
 | ||||
| 
 | ||||
| cdef struct Entity: | ||||
| cdef struct SpanC: | ||||
|     hash_t id | ||||
|     int start | ||||
|     int end | ||||
|     int start_char | ||||
|     int end_char | ||||
|     attr_t label | ||||
|     attr_t kb_id | ||||
| 
 | ||||
| 
 | ||||
| cdef struct TokenC: | ||||
|  |  | |||
|  | @ -7,7 +7,7 @@ from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno | |||
| from murmurhash.mrmr cimport hash64 | ||||
| 
 | ||||
| from ..vocab cimport EMPTY_LEXEME | ||||
| from ..structs cimport TokenC, Entity | ||||
| from ..structs cimport TokenC, SpanC | ||||
| from ..lexeme cimport Lexeme | ||||
| from ..symbols cimport punct | ||||
| from ..attrs cimport IS_SPACE | ||||
|  | @ -40,7 +40,7 @@ cdef cppclass StateC: | |||
|     int* _buffer | ||||
|     bint* shifted | ||||
|     TokenC* _sent | ||||
|     Entity* _ents | ||||
|     SpanC* _ents | ||||
|     TokenC _empty_token | ||||
|     RingBufferC _hist | ||||
|     int length | ||||
|  | @ -56,7 +56,7 @@ cdef cppclass StateC: | |||
|         this._stack = <int*>calloc(length + (PADDING * 2), sizeof(int)) | ||||
|         this.shifted = <bint*>calloc(length + (PADDING * 2), sizeof(bint)) | ||||
|         this._sent = <TokenC*>calloc(length + (PADDING * 2), sizeof(TokenC)) | ||||
|         this._ents = <Entity*>calloc(length + (PADDING * 2), sizeof(Entity)) | ||||
|         this._ents = <SpanC*>calloc(length + (PADDING * 2), sizeof(SpanC)) | ||||
|         if not (this._buffer and this._stack and this.shifted | ||||
|                 and this._sent and this._ents): | ||||
|             with gil: | ||||
|  | @ -406,7 +406,7 @@ cdef cppclass StateC: | |||
|         memcpy(this._sent, src._sent, this.length * sizeof(TokenC)) | ||||
|         memcpy(this._stack, src._stack, this.length * sizeof(int)) | ||||
|         memcpy(this._buffer, src._buffer, this.length * sizeof(int)) | ||||
|         memcpy(this._ents, src._ents, this.length * sizeof(Entity)) | ||||
|         memcpy(this._ents, src._ents, this.length * sizeof(SpanC)) | ||||
|         memcpy(this.shifted, src.shifted, this.length * sizeof(this.shifted[0])) | ||||
|         this._b_i = src._b_i | ||||
|         this._s_i = src._s_i | ||||
|  |  | |||
|  | @ -3,7 +3,7 @@ from libc.string cimport memcpy, memset | |||
| from cymem.cymem cimport Pool | ||||
| cimport cython | ||||
| 
 | ||||
| from ..structs cimport TokenC, Entity | ||||
| from ..structs cimport TokenC, SpanC | ||||
| from ..typedefs cimport attr_t | ||||
| 
 | ||||
| from ..vocab cimport EMPTY_LEXEME | ||||
|  |  | |||
|  | @ -135,6 +135,11 @@ def ko_tokenizer(): | |||
|     return get_lang_class("ko").Defaults.create_tokenizer() | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(scope="session") | ||||
| def lb_tokenizer(): | ||||
|     return get_lang_class("lb").Defaults.create_tokenizer() | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(scope="session") | ||||
| def lt_tokenizer(): | ||||
|     return get_lang_class("lt").Defaults.create_tokenizer() | ||||
|  |  | |||
|  | @ -253,3 +253,11 @@ def test_filter_spans(doc): | |||
|     assert len(filtered[1]) == 5 | ||||
|     assert filtered[0].start == 1 and filtered[0].end == 4 | ||||
|     assert filtered[1].start == 5 and filtered[1].end == 10 | ||||
|     # Test filtering overlaps with earlier preference for identical length | ||||
|     spans = [doc[1:4], doc[2:5], doc[5:10], doc[7:9], doc[1:4]] | ||||
|     filtered = filter_spans(spans) | ||||
|     assert len(filtered) == 2 | ||||
|     assert len(filtered[0]) == 3 | ||||
|     assert len(filtered[1]) == 5 | ||||
|     assert filtered[0].start == 1 and filtered[0].end == 4 | ||||
|     assert filtered[1].start == 5 and filtered[1].end == 10 | ||||
|  |  | |||
							
								
								
									
										10
									
								
								spacy/tests/lang/lb/test_exceptions.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										10
									
								
								spacy/tests/lang/lb/test_exceptions.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,10 @@ | |||
| # coding: utf-8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import pytest | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("text", ["z.B.", "Jan."]) | ||||
| def test_lb_tokenizer_handles_abbr(lb_tokenizer, text): | ||||
|     tokens = lb_tokenizer(text) | ||||
|     assert len(tokens) == 1 | ||||
							
								
								
									
										22
									
								
								spacy/tests/lang/lb/test_prefix_suffix_infix.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										22
									
								
								spacy/tests/lang/lb/test_prefix_suffix_infix.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,22 @@ | |||
| # coding: utf-8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import pytest | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("text,length", [("z.B.", 1), ("zb.", 2), ("(z.B.", 2)]) | ||||
| def test_lb_tokenizer_splits_prefix_interact(lb_tokenizer, text, length): | ||||
|     tokens = lb_tokenizer(text) | ||||
|     assert len(tokens) == length | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("text", ["z.B.)"]) | ||||
| def test_lb_tokenizer_splits_suffix_interact(lb_tokenizer, text): | ||||
|     tokens = lb_tokenizer(text) | ||||
|     assert len(tokens) == 2 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("text", ["(z.B.)"]) | ||||
| def test_lb_tokenizer_splits_even_wrap_interact(lb_tokenizer, text): | ||||
|     tokens = lb_tokenizer(text) | ||||
|     assert len(tokens) == 3 | ||||
							
								
								
									
										31
									
								
								spacy/tests/lang/lb/test_text.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										31
									
								
								spacy/tests/lang/lb/test_text.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,31 @@ | |||
| # coding: utf-8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import pytest | ||||
| 
 | ||||
| 
 | ||||
| def test_lb_tokenizer_handles_long_text(lb_tokenizer): | ||||
|     text = """Den Nordwand an d'Sonn | ||||
| 
 | ||||
| An der Zäit hunn sech den Nordwand an d’Sonn gestridden, wie vun hinnen zwee wuel méi staark wier, wéi e Wanderer, deen an ee waarme Mantel agepak war, iwwert de Wee koum. Si goufen sech eens, dass deejéinege fir de Stäerkste gëlle sollt, deen de Wanderer forcéiere géif, säi Mantel auszedoen.", | ||||
| 
 | ||||
| Den Nordwand huet mat aller Force geblosen, awer wat e méi geblosen huet, wat de Wanderer sech méi a säi Mantel agewéckelt huet. Um Enn huet den Nordwand säi Kampf opginn. | ||||
| 
 | ||||
| Dunn huet d’Sonn d’Loft mat hire frëndleche Strale gewiermt, a schonn no kuerzer Zäit huet de Wanderer säi Mantel ausgedoen. | ||||
| 
 | ||||
| Do huet den Nordwand missen zouginn, dass d’Sonn vun hinnen zwee de Stäerkste wier.""" | ||||
| 
 | ||||
|     tokens = lb_tokenizer(text) | ||||
|     assert len(tokens) == 143 | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     "text,length", | ||||
|     [ | ||||
|         ("»Wat ass mat mir geschitt?«, huet hie geduecht.", 13), | ||||
|         ("“Dëst fréi Opstoen”, denkt hien, “mécht ee ganz duercherneen. ", 15), | ||||
|     ], | ||||
| ) | ||||
| def test_lb_tokenizer_handles_examples(lb_tokenizer, text, length): | ||||
|     tokens = lb_tokenizer(text) | ||||
|     assert len(tokens) == length | ||||
|  | @ -3,6 +3,8 @@ from __future__ import unicode_literals | |||
| 
 | ||||
| import pytest | ||||
| import re | ||||
| 
 | ||||
| from spacy.lang.en import English | ||||
| from spacy.matcher import Matcher | ||||
| from spacy.tokens import Doc, Span | ||||
| 
 | ||||
|  | @ -143,3 +145,29 @@ def test_matcher_sets_return_correct_tokens(en_vocab): | |||
|     matches = matcher(doc) | ||||
|     texts = [Span(doc, s, e, label=L).text for L, s, e in matches] | ||||
|     assert texts == ["zero", "one", "two"] | ||||
| 
 | ||||
| 
 | ||||
| def test_matcher_remove(): | ||||
|     nlp = English() | ||||
|     matcher = Matcher(nlp.vocab) | ||||
|     text = "This is a test case." | ||||
| 
 | ||||
|     pattern = [{"ORTH": "test"}, {"OP": "?"}] | ||||
|     assert len(matcher) == 0 | ||||
|     matcher.add("Rule", None, pattern) | ||||
|     assert "Rule" in matcher | ||||
| 
 | ||||
|     # should give two matches | ||||
|     results1 = matcher(nlp(text)) | ||||
|     assert len(results1) == 2 | ||||
| 
 | ||||
|     # removing once should work | ||||
|     matcher.remove("Rule") | ||||
| 
 | ||||
|     # should not return any maches anymore | ||||
|     results2 = matcher(nlp(text)) | ||||
|     assert len(results2) == 0 | ||||
| 
 | ||||
|     # removing again should throw an error | ||||
|     with pytest.raises(ValueError): | ||||
|         matcher.remove("Rule") | ||||
|  |  | |||
|  | @ -12,24 +12,25 @@ from spacy.util import get_json_validator, validate_json | |||
| TEST_PATTERNS = [ | ||||
|     # Bad patterns flagged in all cases | ||||
|     ([{"XX": "foo"}], 1, 1), | ||||
|     ([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2, 1), | ||||
|     ([{"IS_ALPHA": {"==": True}}, {"LIKE_NUM": None}], 2, 1), | ||||
|     ([{"IS_PUNCT": True, "OP": "$"}], 1, 1), | ||||
|     ([{"IS_DIGIT": -1}], 1, 1), | ||||
|     ([{"ORTH": -1}], 1, 1), | ||||
|     ([{"_": "foo"}], 1, 1), | ||||
|     ('[{"TEXT": "foo"}, {"LOWER": "bar"}]', 1, 1), | ||||
|     ([1, 2, 3], 3, 1), | ||||
|     # Bad patterns flagged outside of Matcher | ||||
|     ([{"_": {"foo": "bar", "baz": {"IN": "foo"}}}], 1, 0), | ||||
|     # Bad patterns not flagged with minimal checks | ||||
|     ([{"LENGTH": "2", "TEXT": 2}, {"LOWER": "test"}], 2, 0), | ||||
|     ([{"LENGTH": {"IN": [1, 2, "3"]}}, {"POS": {"IN": "VERB"}}], 2, 0), | ||||
|     ([{"LENGTH": {"VALUE": 5}}], 1, 0), | ||||
|     ([{"TEXT": {"VALUE": "foo"}}], 1, 0), | ||||
|     ([{"IS_DIGIT": -1}], 1, 0), | ||||
|     ([{"ORTH": -1}], 1, 0), | ||||
|     # Good patterns | ||||
|     ([{"TEXT": "foo"}, {"LOWER": "bar"}], 0, 0), | ||||
|     ([{"LEMMA": {"IN": ["love", "like"]}}, {"POS": "DET", "OP": "?"}], 0, 0), | ||||
|     ([{"LIKE_NUM": True, "LENGTH": {">=": 5}}], 0, 0), | ||||
|     ([{"LENGTH": 2}], 0, 0), | ||||
|     ([{"LOWER": {"REGEX": "^X", "NOT_IN": ["XXX", "XY"]}}], 0, 0), | ||||
|     ([{"NORM": "a"}, {"POS": {"IN": ["NOUN"]}}], 0, 0), | ||||
|     ([{"_": {"foo": {"NOT_IN": ["bar", "baz"]}, "a": 5, "b": {">": 10}}}], 0, 0), | ||||
|  |  | |||
|  | @ -226,3 +226,13 @@ def test_phrase_matcher_callback(en_vocab): | |||
|     matcher.add("COMPANY", mock, pattern) | ||||
|     matches = matcher(doc) | ||||
|     mock.assert_called_once_with(matcher, doc, 0, matches) | ||||
| 
 | ||||
| 
 | ||||
| def test_phrase_matcher_remove_overlapping_patterns(en_vocab): | ||||
|     matcher = PhraseMatcher(en_vocab) | ||||
|     pattern1 = Doc(en_vocab, words=["this"]) | ||||
|     pattern2 = Doc(en_vocab, words=["this", "is"]) | ||||
|     pattern3 = Doc(en_vocab, words=["this", "is", "a"]) | ||||
|     pattern4 = Doc(en_vocab, words=["this", "is", "a", "word"]) | ||||
|     matcher.add("THIS", None, pattern1, pattern2, pattern3, pattern4) | ||||
|     matcher.remove("THIS") | ||||
|  |  | |||
|  | @ -103,7 +103,7 @@ def test_oracle_moves_missing_B(en_vocab): | |||
|             moves.add_action(move_types.index("L"), label) | ||||
|             moves.add_action(move_types.index("U"), label) | ||||
|     moves.preprocess_gold(gold) | ||||
|     seq = moves.get_oracle_sequence(doc, gold) | ||||
|     moves.get_oracle_sequence(doc, gold) | ||||
| 
 | ||||
| 
 | ||||
| def test_oracle_moves_whitespace(en_vocab): | ||||
|  |  | |||
|  | @ -131,6 +131,53 @@ def test_candidate_generation(nlp): | |||
|     assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9) | ||||
| 
 | ||||
| 
 | ||||
| def test_append_alias(nlp): | ||||
|     """Test that we can append additional alias-entity pairs""" | ||||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||
| 
 | ||||
|     # adding entities | ||||
|     mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) | ||||
|     mykb.add_entity(entity="Q2", freq=12, entity_vector=[2]) | ||||
|     mykb.add_entity(entity="Q3", freq=5, entity_vector=[3]) | ||||
| 
 | ||||
|     # adding aliases | ||||
|     mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.4, 0.1]) | ||||
|     mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) | ||||
| 
 | ||||
|     # test the size of the relevant candidates | ||||
|     assert len(mykb.get_candidates("douglas")) == 2 | ||||
| 
 | ||||
|     # append an alias | ||||
|     mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2) | ||||
| 
 | ||||
|     # test the size of the relevant candidates has been incremented | ||||
|     assert len(mykb.get_candidates("douglas")) == 3 | ||||
| 
 | ||||
|     # append the same alias-entity pair again should not work (will throw a warning) | ||||
|     mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.3) | ||||
| 
 | ||||
|     # test the size of the relevant candidates remained unchanged | ||||
|     assert len(mykb.get_candidates("douglas")) == 3 | ||||
| 
 | ||||
| 
 | ||||
| def test_append_invalid_alias(nlp): | ||||
|     """Test that append an alias will throw an error if prior probs are exceeding 1""" | ||||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||
| 
 | ||||
|     # adding entities | ||||
|     mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) | ||||
|     mykb.add_entity(entity="Q2", freq=12, entity_vector=[2]) | ||||
|     mykb.add_entity(entity="Q3", freq=5, entity_vector=[3]) | ||||
| 
 | ||||
|     # adding aliases | ||||
|     mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]) | ||||
|     mykb.add_alias(alias="adam", entities=["Q2"], probabilities=[0.9]) | ||||
| 
 | ||||
|     # append an alias - should fail because the entities and probabilities vectors are not of equal length | ||||
|     with pytest.raises(ValueError): | ||||
|         mykb.append_alias(alias="douglas", entity="Q1", prior_prob=0.2) | ||||
| 
 | ||||
| 
 | ||||
| def test_preserving_links_asdoc(nlp): | ||||
|     """Test that Span.as_doc preserves the existing entity links""" | ||||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||
|  |  | |||
|  | @ -430,7 +430,7 @@ def test_issue957(en_tokenizer): | |||
| def test_issue999(train_data): | ||||
|     """Test that adding entities and resuming training works passably OK. | ||||
|     There are two issues here: | ||||
|     1) We have to read labels. This isn't very nice. | ||||
|     1) We have to re-add labels. This isn't very nice. | ||||
|     2) There's no way to set the learning rate for the weight update, so we | ||||
|         end up out-of-scale, causing it to learn too fast. | ||||
|     """ | ||||
|  |  | |||
|  | @ -323,7 +323,7 @@ def test_issue3456(): | |||
|     nlp = English() | ||||
|     nlp.add_pipe(nlp.create_pipe("tagger")) | ||||
|     nlp.begin_training() | ||||
|     list(nlp.pipe(['hi', ''])) | ||||
|     list(nlp.pipe(["hi", ""])) | ||||
| 
 | ||||
| 
 | ||||
| def test_issue3468(): | ||||
|  |  | |||
|  | @ -76,7 +76,6 @@ def test_issue4042_bug2(): | |||
|             output_dir.mkdir() | ||||
|         ner1.to_disk(output_dir) | ||||
| 
 | ||||
|         nlp2 = English(vocab) | ||||
|         ner2 = EntityRecognizer(vocab) | ||||
|         ner2.from_disk(output_dir) | ||||
|         assert len(ner2.labels) == 2 | ||||
|  |  | |||
|  | @ -1,13 +1,8 @@ | |||
| # coding: utf8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import pytest | ||||
| 
 | ||||
| import spacy | ||||
| 
 | ||||
| from spacy.lang.en import English | ||||
| from spacy.pipeline import EntityRuler | ||||
| from spacy.tokens import Span | ||||
| 
 | ||||
| 
 | ||||
| def test_issue4267(): | ||||
|  |  | |||
|  | @ -6,6 +6,6 @@ from spacy.tokens import DocBin | |||
| 
 | ||||
| def test_issue4367(): | ||||
|     """Test that docbin init goes well""" | ||||
|     doc_bin_1 = DocBin() | ||||
|     doc_bin_2 = DocBin(attrs=["LEMMA"]) | ||||
|     doc_bin_3 = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"]) | ||||
|     DocBin() | ||||
|     DocBin(attrs=["LEMMA"]) | ||||
|     DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"]) | ||||
|  |  | |||
|  | @ -74,4 +74,4 @@ def test_serialize_doc_bin(): | |||
|     # Deserialize later, e.g. in a new process | ||||
|     nlp = spacy.blank("en") | ||||
|     doc_bin = DocBin().from_bytes(bytes_data) | ||||
|     docs = list(doc_bin.get_docs(nlp.vocab)) | ||||
|     list(doc_bin.get_docs(nlp.vocab)) | ||||
|  |  | |||
|  | @ -48,8 +48,13 @@ URLS_SHOULD_MATCH = [ | |||
|     "http://a.b--c.de/",  # this is a legit domain name see: https://gist.github.com/dperini/729294 comment on 9/9/2014 | ||||
|     "ssh://login@server.com:12345/repository.git", | ||||
|     "svn+ssh://user@ssh.yourdomain.com/path", | ||||
|     pytest.param("chrome://extensions/?id=mhjfbmdgcfjbbpaeojofohoefgiehjai", marks=pytest.mark.xfail()), | ||||
|     pytest.param("chrome-extension://mhjfbmdgcfjbbpaeojofohoefgiehjai", marks=pytest.mark.xfail()), | ||||
|     pytest.param( | ||||
|         "chrome://extensions/?id=mhjfbmdgcfjbbpaeojofohoefgiehjai", | ||||
|         marks=pytest.mark.xfail(), | ||||
|     ), | ||||
|     pytest.param( | ||||
|         "chrome-extension://mhjfbmdgcfjbbpaeojofohoefgiehjai", marks=pytest.mark.xfail() | ||||
|     ), | ||||
|     pytest.param("http://foo.com/blah_blah_(wikipedia)", marks=pytest.mark.xfail()), | ||||
|     pytest.param( | ||||
|         "http://foo.com/blah_blah_(wikipedia)_(again)", marks=pytest.mark.xfail() | ||||
|  |  | |||
|  | @ -51,6 +51,14 @@ def data(): | |||
|     return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype="f") | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def most_similar_vectors_data(): | ||||
|     return numpy.asarray( | ||||
|         [[0.0, 1.0, 2.0], [1.0, -2.0, 4.0], [1.0, 1.0, -1.0], [2.0, 3.0, 1.0]], | ||||
|         dtype="f", | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture | ||||
| def resize_data(): | ||||
|     return numpy.asarray([[0.0, 1.0], [2.0, 3.0]], dtype="f") | ||||
|  | @ -127,6 +135,12 @@ def test_set_vector(strings, data): | |||
|     assert list(v[strings[0]]) != list(orig[0]) | ||||
| 
 | ||||
| 
 | ||||
| def test_vectors_most_similar(most_similar_vectors_data): | ||||
|     v = Vectors(data=most_similar_vectors_data) | ||||
|     _, best_rows, _ = v.most_similar(v.data, batch_size=2, n=2, sort=True) | ||||
|     assert all(row[0] == i for i, row in enumerate(best_rows)) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize("text", ["apple and orange"]) | ||||
| def test_vectors_token_vector(tokenizer_v, vectors, text): | ||||
|     doc = tokenizer_v(text) | ||||
|  | @ -284,7 +298,7 @@ def test_vocab_prune_vectors(): | |||
|     vocab.set_vector("dog", data[1]) | ||||
|     vocab.set_vector("kitten", data[2]) | ||||
| 
 | ||||
|     remap = vocab.prune_vectors(2) | ||||
|     remap = vocab.prune_vectors(2, batch_size=2) | ||||
|     assert list(remap.keys()) == ["kitten"] | ||||
|     neighbour, similarity = list(remap.values())[0] | ||||
|     assert neighbour == "cat", remap | ||||
|  |  | |||
|  | @ -666,7 +666,7 @@ def filter_spans(spans): | |||
|     spans (iterable): The spans to filter. | ||||
|     RETURNS (list): The filtered spans. | ||||
|     """ | ||||
|     get_sort_key = lambda span: (span.end - span.start, span.start) | ||||
|     get_sort_key = lambda span: (span.end - span.start, -span.start) | ||||
|     sorted_spans = sorted(spans, key=get_sort_key, reverse=True) | ||||
|     result = [] | ||||
|     seen_tokens = set() | ||||
|  |  | |||
|  | @ -336,8 +336,8 @@ cdef class Vectors: | |||
|             best_rows[i:i+batch_size] = xp.argpartition(sims, -n, axis=1)[:,-n:] | ||||
|             scores[i:i+batch_size] = xp.partition(sims, -n, axis=1)[:,-n:] | ||||
| 
 | ||||
|             if sort: | ||||
|                 sorted_index = xp.arange(scores.shape[0])[:,None],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1] | ||||
|             if sort and n >= 2: | ||||
|                 sorted_index = xp.arange(scores.shape[0])[:,None][i:i+batch_size],xp.argsort(scores[i:i+batch_size], axis=1)[:,::-1] | ||||
|                 scores[i:i+batch_size] = scores[sorted_index] | ||||
|                 best_rows[i:i+batch_size] = best_rows[sorted_index] | ||||
| 
 | ||||
|  |  | |||
|  | @ -62,7 +62,7 @@ Whether the provided syntactic annotations form a projective dependency tree. | |||
| 
 | ||||
| Convert a list of Doc objects into the | ||||
| [JSON-serializable format](/api/annotation#json-input) used by the | ||||
| [`spacy train`](/api/cli#train) command. | ||||
| [`spacy train`](/api/cli#train) command. Each input doc will be treated as a 'paragraph' in the output doc. | ||||
| 
 | ||||
| > #### Example | ||||
| > | ||||
|  | @ -77,7 +77,7 @@ Convert a list of Doc objects into the | |||
| | ----------- | ---------------- | ------------------------------------------ | | ||||
| | `docs`      | iterable / `Doc` | The `Doc` object(s) to convert.            | | ||||
| | `id`        | int              | ID to assign to the JSON. Defaults to `0`. | | ||||
| | **RETURNS** | list             | The data in spaCy's JSON format.           | | ||||
| | **RETURNS** | dict             | The data in spaCy's JSON format.           | | ||||
| 
 | ||||
| ### gold.align {#align tag="function"} | ||||
| 
 | ||||
|  |  | |||
|  | @ -54,7 +54,7 @@ Lemmatize a string. | |||
| > ```python | ||||
| > from spacy.lemmatizer import Lemmatizer | ||||
| > from spacy.lookups import Lookups | ||||
| > lookups = Loookups() | ||||
| > lookups = Lookups() | ||||
| > lookups.add_table("lemma_rules", {"noun": [["s", ""]]}) | ||||
| > lemmatizer = Lemmatizer(lookups) | ||||
| > lemmas = lemmatizer("ducks", "NOUN") | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user