mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 05:01:02 +03:00 
			
		
		
		
	CLI scripts for entity linking (wikipedia & generic) (#4091)
* document token ent_kb_id * document span kb_id * update pipeline documentation * prior and context weights as bool's instead * entitylinker api documentation * drop for both models * finish entitylinker documentation * small fixes * documentation for KB * candidate documentation * links to api pages in code * small fix * frequency examples as counts for consistency * consistent documentation about tensors returned by predict * add entity linking to usage 101 * add entity linking infobox and KB section to 101 * entity-linking in linguistic features * small typo corrections * training example and docs for entity_linker * predefined nlp and kb * revert back to similarity encodings for simplicity (for now) * set prior probabilities to 0 when excluded * code clean up * bugfix: deleting kb ID from tokens when entities were removed * refactor train el example to use either model or vocab * pretrain_kb example for example kb generation * add to training docs for KB + EL example scripts * small fixes * error numbering * ensure the language of vocab and nlp stay consistent across serialization * equality with = * avoid conflict in errors file * add error 151 * final adjustements to the train scripts - consistency * update of goldparse documentation * small corrections * push commit * turn kb_creator into CLI script (wip) * proper parameters for training entity vectors * wikidata pipeline split up into two executable scripts * remove context_width * move wikidata scripts in bin directory, remove old dummy script * refine KB script with logs and preprocessing options * small edits * small improvements to logging of EL CLI script
This commit is contained in:
		
							parent
							
								
									5196dbd89d
								
							
						
					
					
						commit
						0ba1b5eebc
					
				|  | @ -1,16 +1,14 @@ | ||||||
| # coding: utf-8 | # coding: utf-8 | ||||||
| from __future__ import unicode_literals | from __future__ import unicode_literals | ||||||
| 
 | 
 | ||||||
| from .train_descriptions import EntityEncoder | from bin.wiki_entity_linking.train_descriptions import EntityEncoder | ||||||
| from . import wikidata_processor as wd, wikipedia_processor as wp | from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp | ||||||
| from spacy.kb import KnowledgeBase | from spacy.kb import KnowledgeBase | ||||||
| 
 | 
 | ||||||
| import csv | import csv | ||||||
| import datetime | import datetime | ||||||
| 
 | 
 | ||||||
| 
 | from spacy import Errors | ||||||
| INPUT_DIM = 300  # dimension of pre-trained input vectors |  | ||||||
| DESC_WIDTH = 64  # dimension of output entity vectors |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def create_kb( | def create_kb( | ||||||
|  | @ -23,17 +21,27 @@ def create_kb( | ||||||
|     count_input, |     count_input, | ||||||
|     prior_prob_input, |     prior_prob_input, | ||||||
|     wikidata_input, |     wikidata_input, | ||||||
|  |     entity_vector_length, | ||||||
|  |     limit=None, | ||||||
|  |     read_raw_data=True, | ||||||
| ): | ): | ||||||
|     # Create the knowledge base from Wikidata entries |     # Create the knowledge base from Wikidata entries | ||||||
|     kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH) |     kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length) | ||||||
|  | 
 | ||||||
|  |     # check the length of the nlp vectors | ||||||
|  |     if "vectors" in nlp.meta and nlp.vocab.vectors.size: | ||||||
|  |         input_dim = nlp.vocab.vectors_length | ||||||
|  |         print("Loaded pre-trained vectors of size %s" % input_dim) | ||||||
|  |     else: | ||||||
|  |         raise ValueError(Errors.E155) | ||||||
| 
 | 
 | ||||||
|     # disable this part of the pipeline when rerunning the KB generation from preprocessed files |     # disable this part of the pipeline when rerunning the KB generation from preprocessed files | ||||||
|     read_raw_data = True |  | ||||||
| 
 |  | ||||||
|     if read_raw_data: |     if read_raw_data: | ||||||
|         print() |         print() | ||||||
|         print(" * _read_wikidata_entities", datetime.datetime.now()) |         print(now(), " * read wikidata entities:") | ||||||
|         title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input) |         title_to_id, id_to_descr = wd.read_wikidata_entities_json( | ||||||
|  |             wikidata_input, limit=limit | ||||||
|  |         ) | ||||||
| 
 | 
 | ||||||
|         # write the title-ID and ID-description mappings to file |         # write the title-ID and ID-description mappings to file | ||||||
|         _write_entity_files( |         _write_entity_files( | ||||||
|  | @ -46,7 +54,7 @@ def create_kb( | ||||||
|         id_to_descr = get_id_to_description(entity_descr_output) |         id_to_descr = get_id_to_description(entity_descr_output) | ||||||
| 
 | 
 | ||||||
|     print() |     print() | ||||||
|     print(" * _get_entity_frequencies", datetime.datetime.now()) |     print(now(), " *  get entity frequencies:") | ||||||
|     print() |     print() | ||||||
|     entity_frequencies = wp.get_all_frequencies(count_input=count_input) |     entity_frequencies = wp.get_all_frequencies(count_input=count_input) | ||||||
| 
 | 
 | ||||||
|  | @ -65,40 +73,41 @@ def create_kb( | ||||||
|             filtered_title_to_id[title] = entity |             filtered_title_to_id[title] = entity | ||||||
| 
 | 
 | ||||||
|     print(len(title_to_id.keys()), "original titles") |     print(len(title_to_id.keys()), "original titles") | ||||||
|     print("kept", len(filtered_title_to_id.keys()), " with frequency", min_entity_freq) |     kept_nr = len(filtered_title_to_id.keys()) | ||||||
|  |     print("kept", kept_nr, "entities with min. frequency", min_entity_freq) | ||||||
| 
 | 
 | ||||||
|     print() |     print() | ||||||
|     print(" * train entity encoder", datetime.datetime.now()) |     print(now(), " * train entity encoder:") | ||||||
|     print() |     print() | ||||||
|     encoder = EntityEncoder(nlp, INPUT_DIM, DESC_WIDTH) |     encoder = EntityEncoder(nlp, input_dim, entity_vector_length) | ||||||
|     encoder.train(description_list=description_list, to_print=True) |     encoder.train(description_list=description_list, to_print=True) | ||||||
| 
 | 
 | ||||||
|     print() |     print() | ||||||
|     print(" * get entity embeddings", datetime.datetime.now()) |     print(now(), " * get entity embeddings:") | ||||||
|     print() |     print() | ||||||
|     embeddings = encoder.apply_encoder(description_list) |     embeddings = encoder.apply_encoder(description_list) | ||||||
| 
 | 
 | ||||||
|     print() |     print(now(), " * adding", len(entity_list), "entities") | ||||||
|     print(" * adding", len(entity_list), "entities", datetime.datetime.now()) |  | ||||||
|     kb.set_entities( |     kb.set_entities( | ||||||
|         entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings |         entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings | ||||||
|     ) |     ) | ||||||
| 
 | 
 | ||||||
|     print() |     alias_cnt = _add_aliases( | ||||||
|     print(" * adding aliases", datetime.datetime.now()) |  | ||||||
|     print() |  | ||||||
|     _add_aliases( |  | ||||||
|         kb, |         kb, | ||||||
|         title_to_id=filtered_title_to_id, |         title_to_id=filtered_title_to_id, | ||||||
|         max_entities_per_alias=max_entities_per_alias, |         max_entities_per_alias=max_entities_per_alias, | ||||||
|         min_occ=min_occ, |         min_occ=min_occ, | ||||||
|         prior_prob_input=prior_prob_input, |         prior_prob_input=prior_prob_input, | ||||||
|     ) |     ) | ||||||
|  |     print() | ||||||
|  |     print(now(), " * adding", alias_cnt, "aliases") | ||||||
|  |     print() | ||||||
| 
 | 
 | ||||||
|     print() |     print() | ||||||
|     print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases()) |     print("# of entities in kb:", kb.get_size_entities()) | ||||||
|  |     print("# of aliases in kb:", kb.get_size_aliases()) | ||||||
| 
 | 
 | ||||||
|     print("done with kb", datetime.datetime.now()) |     print(now(), "Done with kb") | ||||||
|     return kb |     return kb | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -140,6 +149,7 @@ def get_id_to_description(entity_descr_output): | ||||||
| 
 | 
 | ||||||
| def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input): | def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input): | ||||||
|     wp_titles = title_to_id.keys() |     wp_titles = title_to_id.keys() | ||||||
|  |     cnt = 0 | ||||||
| 
 | 
 | ||||||
|     # adding aliases with prior probabilities |     # adding aliases with prior probabilities | ||||||
|     # we can read this file sequentially, it's sorted by alias, and then by count |     # we can read this file sequentially, it's sorted by alias, and then by count | ||||||
|  | @ -176,6 +186,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in | ||||||
|                                 entities=selected_entities, |                                 entities=selected_entities, | ||||||
|                                 probabilities=prior_probs, |                                 probabilities=prior_probs, | ||||||
|                             ) |                             ) | ||||||
|  |                             cnt += 1 | ||||||
|                         except ValueError as e: |                         except ValueError as e: | ||||||
|                             print(e) |                             print(e) | ||||||
|                 total_count = 0 |                 total_count = 0 | ||||||
|  | @ -190,3 +201,8 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in | ||||||
|             previous_alias = new_alias |             previous_alias = new_alias | ||||||
| 
 | 
 | ||||||
|             line = prior_file.readline() |             line = prior_file.readline() | ||||||
|  |     return cnt | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def now(): | ||||||
|  |     return datetime.datetime.now() | ||||||
|  |  | ||||||
|  | @ -18,15 +18,19 @@ class EntityEncoder: | ||||||
|     """ |     """ | ||||||
| 
 | 
 | ||||||
|     DROP = 0 |     DROP = 0 | ||||||
|     EPOCHS = 5 |  | ||||||
|     STOP_THRESHOLD = 0.04 |  | ||||||
| 
 |  | ||||||
|     BATCH_SIZE = 1000 |     BATCH_SIZE = 1000 | ||||||
| 
 | 
 | ||||||
|     def __init__(self, nlp, input_dim, desc_width): |     # Set min. acceptable loss to avoid a 'mean of empty slice' warning by numpy | ||||||
|  |     MIN_LOSS = 0.01 | ||||||
|  | 
 | ||||||
|  |     # Reasonable default to stop training when things are not improving | ||||||
|  |     MAX_NO_IMPROVEMENT = 20 | ||||||
|  | 
 | ||||||
|  |     def __init__(self, nlp, input_dim, desc_width, epochs=5): | ||||||
|         self.nlp = nlp |         self.nlp = nlp | ||||||
|         self.input_dim = input_dim |         self.input_dim = input_dim | ||||||
|         self.desc_width = desc_width |         self.desc_width = desc_width | ||||||
|  |         self.epochs = epochs | ||||||
| 
 | 
 | ||||||
|     def apply_encoder(self, description_list): |     def apply_encoder(self, description_list): | ||||||
|         if self.encoder is None: |         if self.encoder is None: | ||||||
|  | @ -46,32 +50,41 @@ class EntityEncoder: | ||||||
| 
 | 
 | ||||||
|             start = start + batch_size |             start = start + batch_size | ||||||
|             stop = min(stop + batch_size, len(description_list)) |             stop = min(stop + batch_size, len(description_list)) | ||||||
|  |             print("encoded:", stop, "entities") | ||||||
| 
 | 
 | ||||||
|         return encodings |         return encodings | ||||||
| 
 | 
 | ||||||
|     def train(self, description_list, to_print=False): |     def train(self, description_list, to_print=False): | ||||||
|         processed, loss = self._train_model(description_list) |         processed, loss = self._train_model(description_list) | ||||||
|         if to_print: |         if to_print: | ||||||
|             print("Trained on", processed, "entities across", self.EPOCHS, "epochs") |             print( | ||||||
|  |                 "Trained entity descriptions on", | ||||||
|  |                 processed, | ||||||
|  |                 "(non-unique) entities across", | ||||||
|  |                 self.epochs, | ||||||
|  |                 "epochs", | ||||||
|  |             ) | ||||||
|             print("Final loss:", loss) |             print("Final loss:", loss) | ||||||
| 
 | 
 | ||||||
|     def _train_model(self, description_list): |     def _train_model(self, description_list): | ||||||
|         # TODO: when loss gets too low, a 'mean of empty slice' warning is thrown by numpy |         best_loss = 1.0 | ||||||
| 
 |         iter_since_best = 0 | ||||||
|         self._build_network(self.input_dim, self.desc_width) |         self._build_network(self.input_dim, self.desc_width) | ||||||
| 
 | 
 | ||||||
|         processed = 0 |         processed = 0 | ||||||
|         loss = 1 |         loss = 1 | ||||||
|         descriptions = description_list.copy()   # copy this list so that shuffling does not affect other functions |         # copy this list so that shuffling does not affect other functions | ||||||
|  |         descriptions = description_list.copy() | ||||||
|  |         to_continue = True | ||||||
| 
 | 
 | ||||||
|         for i in range(self.EPOCHS): |         for i in range(self.epochs): | ||||||
|             shuffle(descriptions) |             shuffle(descriptions) | ||||||
| 
 | 
 | ||||||
|             batch_nr = 0 |             batch_nr = 0 | ||||||
|             start = 0 |             start = 0 | ||||||
|             stop = min(self.BATCH_SIZE, len(descriptions)) |             stop = min(self.BATCH_SIZE, len(descriptions)) | ||||||
| 
 | 
 | ||||||
|             while loss > self.STOP_THRESHOLD and start < len(descriptions): |             while to_continue and start < len(descriptions): | ||||||
|                 batch = [] |                 batch = [] | ||||||
|                 for descr in descriptions[start:stop]: |                 for descr in descriptions[start:stop]: | ||||||
|                     doc = self.nlp(descr) |                     doc = self.nlp(descr) | ||||||
|  | @ -79,9 +92,24 @@ class EntityEncoder: | ||||||
|                     batch.append(doc_vector) |                     batch.append(doc_vector) | ||||||
| 
 | 
 | ||||||
|                 loss = self._update(batch) |                 loss = self._update(batch) | ||||||
|                 print(i, batch_nr, loss) |                 if batch_nr % 25 == 0: | ||||||
|  |                     print("loss:", loss) | ||||||
|                 processed += len(batch) |                 processed += len(batch) | ||||||
| 
 | 
 | ||||||
|  |                 # in general, continue training if we haven't reached our ideal min yet | ||||||
|  |                 to_continue = loss > self.MIN_LOSS | ||||||
|  | 
 | ||||||
|  |                 # store the best loss and track how long it's been | ||||||
|  |                 if loss < best_loss: | ||||||
|  |                     best_loss = loss | ||||||
|  |                     iter_since_best = 0 | ||||||
|  |                 else: | ||||||
|  |                     iter_since_best += 1 | ||||||
|  | 
 | ||||||
|  |                 # stop learning if we haven't seen improvement since the last few iterations | ||||||
|  |                 if iter_since_best > self.MAX_NO_IMPROVEMENT: | ||||||
|  |                     to_continue = False | ||||||
|  | 
 | ||||||
|                 batch_nr += 1 |                 batch_nr += 1 | ||||||
|                 start = start + self.BATCH_SIZE |                 start = start + self.BATCH_SIZE | ||||||
|                 stop = min(stop + self.BATCH_SIZE, len(descriptions)) |                 stop = min(stop + self.BATCH_SIZE, len(descriptions)) | ||||||
|  | @ -103,14 +131,16 @@ class EntityEncoder: | ||||||
|     def _build_network(self, orig_width, hidden_with): |     def _build_network(self, orig_width, hidden_with): | ||||||
|         with Model.define_operators({">>": chain}): |         with Model.define_operators({">>": chain}): | ||||||
|             # very simple encoder-decoder model |             # very simple encoder-decoder model | ||||||
|             self.encoder = ( |             self.encoder = Affine(hidden_with, orig_width) | ||||||
|                 Affine(hidden_with, orig_width) |             self.model = self.encoder >> zero_init( | ||||||
|  |                 Affine(orig_width, hidden_with, drop_factor=0.0) | ||||||
|             ) |             ) | ||||||
|             self.model = self.encoder >> zero_init(Affine(orig_width, hidden_with, drop_factor=0.0)) |  | ||||||
|         self.sgd = create_default_optimizer(self.model.ops) |         self.sgd = create_default_optimizer(self.model.ops) | ||||||
| 
 | 
 | ||||||
|     def _update(self, vectors): |     def _update(self, vectors): | ||||||
|         predictions, bp_model = self.model.begin_update(np.asarray(vectors), drop=self.DROP) |         predictions, bp_model = self.model.begin_update( | ||||||
|  |             np.asarray(vectors), drop=self.DROP | ||||||
|  |         ) | ||||||
|         loss, d_scores = self._get_loss(scores=predictions, golds=np.asarray(vectors)) |         loss, d_scores = self._get_loss(scores=predictions, golds=np.asarray(vectors)) | ||||||
|         bp_model(d_scores, sgd=self.sgd) |         bp_model(d_scores, sgd=self.sgd) | ||||||
|         return loss / len(vectors) |         return loss / len(vectors) | ||||||
|  |  | ||||||
|  | @ -21,9 +21,9 @@ def now(): | ||||||
|     return datetime.datetime.now() |     return datetime.datetime.now() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def create_training(wikipedia_input, entity_def_input, training_output): | def create_training(wikipedia_input, entity_def_input, training_output, limit=None): | ||||||
|     wp_to_id = kb_creator.get_entity_to_id(entity_def_input) |     wp_to_id = kb_creator.get_entity_to_id(entity_def_input) | ||||||
|     _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None) |     _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=limit) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None): | def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None): | ||||||
|  | @ -128,6 +128,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N | ||||||
| 
 | 
 | ||||||
|                 line = file.readline() |                 line = file.readline() | ||||||
|                 cnt += 1 |                 cnt += 1 | ||||||
|  |             print(now(), "processed", cnt, "lines of Wikipedia dump") | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)") | text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)") | ||||||
|  |  | ||||||
							
								
								
									
										139
									
								
								bin/wiki_entity_linking/wikidata_pretrain_kb.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								bin/wiki_entity_linking/wikidata_pretrain_kb.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,139 @@ | ||||||
|  | # coding: utf-8 | ||||||
|  | """Script to process Wikipedia and Wikidata dumps and create a knowledge base (KB) | ||||||
|  | with specific parameters. Intermediate files are written to disk. | ||||||
|  | 
 | ||||||
|  | Running the full pipeline on a standard laptop, may take up to 13 hours of processing. | ||||||
|  | Use the -p, -d and -s options to speed up processing using the intermediate files | ||||||
|  | from a previous run. | ||||||
|  | 
 | ||||||
|  | For the Wikidata dump: get the latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ | ||||||
|  | For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2 | ||||||
|  | from https://dumps.wikimedia.org/enwiki/latest/ | ||||||
|  | 
 | ||||||
|  | """ | ||||||
|  | from __future__ import unicode_literals | ||||||
|  | 
 | ||||||
|  | import datetime | ||||||
|  | from pathlib import Path | ||||||
|  | import plac | ||||||
|  | 
 | ||||||
|  | from bin.wiki_entity_linking import wikipedia_processor as wp | ||||||
|  | from bin.wiki_entity_linking import kb_creator | ||||||
|  | 
 | ||||||
|  | import spacy | ||||||
|  | 
 | ||||||
|  | from spacy import Errors | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def now(): | ||||||
|  |     return datetime.datetime.now() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @plac.annotations( | ||||||
|  |     wd_json=("Path to the downloaded WikiData JSON dump.", "positional", None, Path), | ||||||
|  |     wp_xml=("Path to the downloaded Wikipedia XML dump.", "positional", None, Path), | ||||||
|  |     output_dir=("Output directory", "positional", None, Path), | ||||||
|  |     model=("Model name, should include pretrained vectors.", "positional", None, str), | ||||||
|  |     max_per_alias=("Max. # entities per alias (default 10)", "option", "a", int), | ||||||
|  |     min_freq=("Min. count of an entity in the corpus (default 20)", "option", "f", int), | ||||||
|  |     min_pair=("Min. count of entity-alias pairs (default 5)", "option", "c", int), | ||||||
|  |     entity_vector_length=("Length of entity vectors (default 64)", "option", "v", int), | ||||||
|  |     loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path), | ||||||
|  |     loc_entity_defs=("Location to file with entity definitions", "option", "d", Path), | ||||||
|  |     loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path), | ||||||
|  |     limit=("Optional threshold to limit lines read from dumps", "option", "l", int), | ||||||
|  | ) | ||||||
|  | def main( | ||||||
|  |     wd_json, | ||||||
|  |     wp_xml, | ||||||
|  |     output_dir, | ||||||
|  |     model, | ||||||
|  |     max_per_alias=10, | ||||||
|  |     min_freq=20, | ||||||
|  |     min_pair=5, | ||||||
|  |     entity_vector_length=64, | ||||||
|  |     loc_prior_prob=None, | ||||||
|  |     loc_entity_defs=None, | ||||||
|  |     loc_entity_desc=None, | ||||||
|  |     limit=None, | ||||||
|  | ): | ||||||
|  |     print(now(), "Creating KB with Wikipedia and WikiData") | ||||||
|  |     print() | ||||||
|  | 
 | ||||||
|  |     if limit is not None: | ||||||
|  |         print("Warning: reading only", limit, "lines of Wikipedia/Wikidata dumps.") | ||||||
|  | 
 | ||||||
|  |     # STEP 0: set up IO | ||||||
|  |     if not output_dir.exists(): | ||||||
|  |         output_dir.mkdir() | ||||||
|  | 
 | ||||||
|  |     # STEP 1: create the NLP object | ||||||
|  |     print(now(), "STEP 1: loaded model", model) | ||||||
|  |     nlp = spacy.load(model) | ||||||
|  | 
 | ||||||
|  |     # check the length of the nlp vectors | ||||||
|  |     if "vectors" not in nlp.meta or not nlp.vocab.vectors.size: | ||||||
|  |         raise ValueError(Errors.E155) | ||||||
|  | 
 | ||||||
|  |     # STEP 2: create prior probabilities from WP | ||||||
|  |     print() | ||||||
|  |     if loc_prior_prob: | ||||||
|  |         print(now(), "STEP 2: reading prior probabilities from", loc_prior_prob) | ||||||
|  |     else: | ||||||
|  |         # It takes about 2h to process 1000M lines of Wikipedia XML dump | ||||||
|  |         loc_prior_prob = output_dir / "prior_prob.csv" | ||||||
|  |         print(now(), "STEP 2: writing prior probabilities at", loc_prior_prob) | ||||||
|  |         wp.read_prior_probs(wp_xml, loc_prior_prob, limit=limit) | ||||||
|  | 
 | ||||||
|  |     # STEP 3: deduce entity frequencies from WP (takes only a few minutes) | ||||||
|  |     print() | ||||||
|  |     print(now(), "STEP 3: calculating entity frequencies") | ||||||
|  |     loc_entity_freq = output_dir / "entity_freq.csv" | ||||||
|  |     wp.write_entity_counts(loc_prior_prob, loc_entity_freq, to_print=False) | ||||||
|  | 
 | ||||||
|  |     loc_kb = output_dir / "kb" | ||||||
|  | 
 | ||||||
|  |     # STEP 4: reading entity descriptions and definitions from WikiData or from file | ||||||
|  |     print() | ||||||
|  |     if loc_entity_defs and loc_entity_desc: | ||||||
|  |         read_raw = False | ||||||
|  |         print(now(), "STEP 4a: reading entity definitions from", loc_entity_defs) | ||||||
|  |         print(now(), "STEP 4b: reading entity descriptions from", loc_entity_desc) | ||||||
|  |     else: | ||||||
|  |         # It takes about 10h to process 55M lines of Wikidata JSON dump | ||||||
|  |         read_raw = True | ||||||
|  |         loc_entity_defs = output_dir / "entity_defs.csv" | ||||||
|  |         loc_entity_desc = output_dir / "entity_descriptions.csv" | ||||||
|  |         print(now(), "STEP 4: parsing wikidata for entity definitions and descriptions") | ||||||
|  | 
 | ||||||
|  |     # STEP 5: creating the actual KB | ||||||
|  |     # It takes ca. 30 minutes to pretrain the entity embeddings | ||||||
|  |     print() | ||||||
|  |     print(now(), "STEP 5: creating the KB at", loc_kb) | ||||||
|  |     kb = kb_creator.create_kb( | ||||||
|  |         nlp=nlp, | ||||||
|  |         max_entities_per_alias=max_per_alias, | ||||||
|  |         min_entity_freq=min_freq, | ||||||
|  |         min_occ=min_pair, | ||||||
|  |         entity_def_output=loc_entity_defs, | ||||||
|  |         entity_descr_output=loc_entity_desc, | ||||||
|  |         count_input=loc_entity_freq, | ||||||
|  |         prior_prob_input=loc_prior_prob, | ||||||
|  |         wikidata_input=wd_json, | ||||||
|  |         entity_vector_length=entity_vector_length, | ||||||
|  |         limit=limit, | ||||||
|  |         read_raw_data=read_raw, | ||||||
|  |     ) | ||||||
|  |     if read_raw: | ||||||
|  |         print(" - wrote entity definitions to", loc_entity_defs) | ||||||
|  |         print(" - wrote writing entity descriptions to", loc_entity_desc) | ||||||
|  | 
 | ||||||
|  |     kb.dump(loc_kb) | ||||||
|  |     nlp.to_disk(output_dir / "nlp") | ||||||
|  | 
 | ||||||
|  |     print() | ||||||
|  |     print(now(), "Done!") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     plac.call(main) | ||||||
|  | @ -10,8 +10,8 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): | ||||||
|     # Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. |     # Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. | ||||||
|     # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ |     # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ | ||||||
| 
 | 
 | ||||||
|     lang = 'en' |     lang = "en" | ||||||
|     site_filter = 'enwiki' |     site_filter = "enwiki" | ||||||
| 
 | 
 | ||||||
|     # properties filter (currently disabled to get ALL data) |     # properties filter (currently disabled to get ALL data) | ||||||
|     prop_filter = dict() |     prop_filter = dict() | ||||||
|  | @ -28,12 +28,14 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): | ||||||
|     parse_aliases = False |     parse_aliases = False | ||||||
|     parse_claims = False |     parse_claims = False | ||||||
| 
 | 
 | ||||||
|     with bz2.open(wikidata_file, mode='rb') as file: |     with bz2.open(wikidata_file, mode="rb") as file: | ||||||
|         line = file.readline() |         line = file.readline() | ||||||
|         cnt = 0 |         cnt = 0 | ||||||
|         while line and (not limit or cnt < limit): |         while line and (not limit or cnt < limit): | ||||||
|             if cnt % 500000 == 0: |             if cnt % 1000000 == 0: | ||||||
|                 print(datetime.datetime.now(), "processed", cnt, "lines of WikiData dump") |                 print( | ||||||
|  |                     datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump" | ||||||
|  |                 ) | ||||||
|             clean_line = line.strip() |             clean_line = line.strip() | ||||||
|             if clean_line.endswith(b","): |             if clean_line.endswith(b","): | ||||||
|                 clean_line = clean_line[:-1] |                 clean_line = clean_line[:-1] | ||||||
|  | @ -52,8 +54,13 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): | ||||||
|                             claim_property = claims.get(prop, None) |                             claim_property = claims.get(prop, None) | ||||||
|                             if claim_property: |                             if claim_property: | ||||||
|                                 for cp in claim_property: |                                 for cp in claim_property: | ||||||
|                                     cp_id = cp['mainsnak'].get('datavalue', {}).get('value', {}).get('id') |                                     cp_id = ( | ||||||
|                                     cp_rank = cp['rank'] |                                         cp["mainsnak"] | ||||||
|  |                                         .get("datavalue", {}) | ||||||
|  |                                         .get("value", {}) | ||||||
|  |                                         .get("id") | ||||||
|  |                                     ) | ||||||
|  |                                     cp_rank = cp["rank"] | ||||||
|                                     if cp_rank != "deprecated" and cp_id in value_set: |                                     if cp_rank != "deprecated" and cp_id in value_set: | ||||||
|                                         keep = True |                                         keep = True | ||||||
| 
 | 
 | ||||||
|  | @ -67,10 +74,17 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): | ||||||
|                         # parsing all properties that refer to other entities |                         # parsing all properties that refer to other entities | ||||||
|                         if parse_properties: |                         if parse_properties: | ||||||
|                             for prop, claim_property in claims.items(): |                             for prop, claim_property in claims.items(): | ||||||
|                                 cp_dicts = [cp['mainsnak']['datavalue'].get('value') for cp in claim_property |                                 cp_dicts = [ | ||||||
|                                             if cp['mainsnak'].get('datavalue')] |                                     cp["mainsnak"]["datavalue"].get("value") | ||||||
|                                 cp_values = [cp_dict.get('id') for cp_dict in cp_dicts if isinstance(cp_dict, dict) |                                     for cp in claim_property | ||||||
|                                              if cp_dict.get('id') is not None] |                                     if cp["mainsnak"].get("datavalue") | ||||||
|  |                                 ] | ||||||
|  |                                 cp_values = [ | ||||||
|  |                                     cp_dict.get("id") | ||||||
|  |                                     for cp_dict in cp_dicts | ||||||
|  |                                     if isinstance(cp_dict, dict) | ||||||
|  |                                     if cp_dict.get("id") is not None | ||||||
|  |                                 ] | ||||||
|                                 if cp_values: |                                 if cp_values: | ||||||
|                                     if to_print: |                                     if to_print: | ||||||
|                                         print("prop:", prop, cp_values) |                                         print("prop:", prop, cp_values) | ||||||
|  | @ -79,7 +93,7 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): | ||||||
|                         if parse_sitelinks: |                         if parse_sitelinks: | ||||||
|                             site_value = obj["sitelinks"].get(site_filter, None) |                             site_value = obj["sitelinks"].get(site_filter, None) | ||||||
|                             if site_value: |                             if site_value: | ||||||
|                                 site = site_value['title'] |                                 site = site_value["title"] | ||||||
|                                 if to_print: |                                 if to_print: | ||||||
|                                     print(site_filter, ":", site) |                                     print(site_filter, ":", site) | ||||||
|                                 title_to_id[site] = unique_id |                                 title_to_id[site] = unique_id | ||||||
|  | @ -91,7 +105,9 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): | ||||||
|                                 lang_label = labels.get(lang, None) |                                 lang_label = labels.get(lang, None) | ||||||
|                                 if lang_label: |                                 if lang_label: | ||||||
|                                     if to_print: |                                     if to_print: | ||||||
|                                         print("label (" + lang + "):", lang_label["value"]) |                                         print( | ||||||
|  |                                             "label (" + lang + "):", lang_label["value"] | ||||||
|  |                                         ) | ||||||
| 
 | 
 | ||||||
|                         if found_link and parse_descriptions: |                         if found_link and parse_descriptions: | ||||||
|                             descriptions = obj["descriptions"] |                             descriptions = obj["descriptions"] | ||||||
|  | @ -99,7 +115,10 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): | ||||||
|                                 lang_descr = descriptions.get(lang, None) |                                 lang_descr = descriptions.get(lang, None) | ||||||
|                                 if lang_descr: |                                 if lang_descr: | ||||||
|                                     if to_print: |                                     if to_print: | ||||||
|                                         print("description (" + lang + "):", lang_descr["value"]) |                                         print( | ||||||
|  |                                             "description (" + lang + "):", | ||||||
|  |                                             lang_descr["value"], | ||||||
|  |                                         ) | ||||||
|                                     id_to_descr[unique_id] = lang_descr["value"] |                                     id_to_descr[unique_id] = lang_descr["value"] | ||||||
| 
 | 
 | ||||||
|                         if parse_aliases: |                         if parse_aliases: | ||||||
|  | @ -109,11 +128,14 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): | ||||||
|                                 if lang_aliases: |                                 if lang_aliases: | ||||||
|                                     for item in lang_aliases: |                                     for item in lang_aliases: | ||||||
|                                         if to_print: |                                         if to_print: | ||||||
|                                             print("alias (" + lang + "):", item["value"]) |                                             print( | ||||||
|  |                                                 "alias (" + lang + "):", item["value"] | ||||||
|  |                                             ) | ||||||
| 
 | 
 | ||||||
|                         if to_print: |                         if to_print: | ||||||
|                             print() |                             print() | ||||||
|             line = file.readline() |             line = file.readline() | ||||||
|             cnt += 1 |             cnt += 1 | ||||||
|  |         print(datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump") | ||||||
| 
 | 
 | ||||||
|     return title_to_id, id_to_descr |     return title_to_id, id_to_descr | ||||||
|  |  | ||||||
							
								
								
									
										430
									
								
								bin/wiki_entity_linking/wikidata_train_entity_linker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										430
									
								
								bin/wiki_entity_linking/wikidata_train_entity_linker.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,430 @@ | ||||||
|  | # coding: utf-8 | ||||||
|  | """Script to take a previously created Knowledge Base and train an entity linking | ||||||
|  | pipeline. The provided KB directory should hold the kb, the original nlp object and | ||||||
|  | its vocab used to create the KB, and a few auxiliary files such as the entity definitions, | ||||||
|  | as created by the script `wikidata_create_kb`. | ||||||
|  | 
 | ||||||
|  | For the Wikipedia dump: get enwiki-latest-pages-articles-multistream.xml.bz2 | ||||||
|  | from https://dumps.wikimedia.org/enwiki/latest/ | ||||||
|  | 
 | ||||||
|  | """ | ||||||
|  | from __future__ import unicode_literals | ||||||
|  | 
 | ||||||
|  | import random | ||||||
|  | import datetime | ||||||
|  | from pathlib import Path | ||||||
|  | import plac | ||||||
|  | 
 | ||||||
|  | from bin.wiki_entity_linking import training_set_creator | ||||||
|  | 
 | ||||||
|  | import spacy | ||||||
|  | from spacy.kb import KnowledgeBase | ||||||
|  | 
 | ||||||
|  | from spacy import Errors | ||||||
|  | from spacy.util import minibatch, compounding | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def now(): | ||||||
|  |     return datetime.datetime.now() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @plac.annotations( | ||||||
|  |     dir_kb=("Directory with KB, NLP and related files", "positional", None, Path), | ||||||
|  |     output_dir=("Output directory", "option", "o", Path), | ||||||
|  |     loc_training=("Location to training data", "option", "k", Path), | ||||||
|  |     wp_xml=("Path to the downloaded Wikipedia XML dump.", "option", "w", Path), | ||||||
|  |     epochs=("Number of training iterations (default 10)", "option", "e", int), | ||||||
|  |     dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float), | ||||||
|  |     lr=("Learning rate (default 0.005)", "option", "n", float), | ||||||
|  |     l2=("L2 regularization", "option", "r", float), | ||||||
|  |     train_inst=("# training instances (default 90% of all)", "option", "t", int), | ||||||
|  |     dev_inst=("# test instances (default 10% of all)", "option", "d", int), | ||||||
|  |     limit=("Optional threshold to limit lines read from WP dump", "option", "l", int), | ||||||
|  | ) | ||||||
|  | def main( | ||||||
|  |     dir_kb, | ||||||
|  |     output_dir=None, | ||||||
|  |     loc_training=None, | ||||||
|  |     wp_xml=None, | ||||||
|  |     epochs=10, | ||||||
|  |     dropout=0.5, | ||||||
|  |     lr=0.005, | ||||||
|  |     l2=1e-6, | ||||||
|  |     train_inst=None, | ||||||
|  |     dev_inst=None, | ||||||
|  |     limit=None, | ||||||
|  | ): | ||||||
|  |     print(now(), "Creating Entity Linker with Wikipedia and WikiData") | ||||||
|  |     print() | ||||||
|  | 
 | ||||||
|  |     # STEP 0: set up IO | ||||||
|  |     if output_dir and not output_dir.exists(): | ||||||
|  |         output_dir.mkdir() | ||||||
|  | 
 | ||||||
|  |     # STEP 1 : load the NLP object | ||||||
|  |     nlp_dir = dir_kb / "nlp" | ||||||
|  |     print(now(), "STEP 1: loading model from", nlp_dir) | ||||||
|  |     nlp = spacy.load(nlp_dir) | ||||||
|  | 
 | ||||||
|  |     # check that there is a NER component in the pipeline | ||||||
|  |     if "ner" not in nlp.pipe_names: | ||||||
|  |         raise ValueError(Errors.E152) | ||||||
|  | 
 | ||||||
|  |     # STEP 2 : read the KB | ||||||
|  |     print() | ||||||
|  |     print(now(), "STEP 2: reading the KB from", dir_kb / "kb") | ||||||
|  |     kb = KnowledgeBase(vocab=nlp.vocab) | ||||||
|  |     kb.load_bulk(dir_kb / "kb") | ||||||
|  | 
 | ||||||
|  |     # STEP 3: create a training dataset from WP | ||||||
|  |     print() | ||||||
|  |     if loc_training: | ||||||
|  |         print(now(), "STEP 3: reading training dataset from", loc_training) | ||||||
|  |     else: | ||||||
|  |         if not wp_xml: | ||||||
|  |             raise ValueError(Errors.E153) | ||||||
|  | 
 | ||||||
|  |         if output_dir: | ||||||
|  |             loc_training = output_dir / "training_data" | ||||||
|  |         else: | ||||||
|  |             loc_training = dir_kb / "training_data" | ||||||
|  |         if not loc_training.exists(): | ||||||
|  |             loc_training.mkdir() | ||||||
|  |         print(now(), "STEP 3: creating training dataset at", loc_training) | ||||||
|  | 
 | ||||||
|  |         if limit is not None: | ||||||
|  |             print("Warning: reading only", limit, "lines of Wikipedia dump.") | ||||||
|  | 
 | ||||||
|  |         loc_entity_defs = dir_kb / "entity_defs.csv" | ||||||
|  |         training_set_creator.create_training( | ||||||
|  |             wikipedia_input=wp_xml, | ||||||
|  |             entity_def_input=loc_entity_defs, | ||||||
|  |             training_output=loc_training, | ||||||
|  |             limit=limit, | ||||||
|  |         ) | ||||||
|  | 
 | ||||||
|  |     # STEP 4: parse the training data | ||||||
|  |     print() | ||||||
|  |     print(now(), "STEP 4: parse the training & evaluation data") | ||||||
|  | 
 | ||||||
|  |     # for training, get pos & neg instances that correspond to entries in the kb | ||||||
|  |     print("Parsing training data, limit =", train_inst) | ||||||
|  |     train_data = training_set_creator.read_training( | ||||||
|  |         nlp=nlp, training_dir=loc_training, dev=False, limit=train_inst, kb=kb | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     print("Training on", len(train_data), "articles") | ||||||
|  |     print() | ||||||
|  | 
 | ||||||
|  |     print("Parsing dev testing data, limit =", dev_inst) | ||||||
|  |     # for testing, get all pos instances, whether or not they are in the kb | ||||||
|  |     dev_data = training_set_creator.read_training( | ||||||
|  |         nlp=nlp, training_dir=loc_training, dev=True, limit=dev_inst, kb=None | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     print("Dev testing on", len(dev_data), "articles") | ||||||
|  |     print() | ||||||
|  | 
 | ||||||
|  |     # STEP 5: create and train the entity linking pipe | ||||||
|  |     print() | ||||||
|  |     print(now(), "STEP 5: training Entity Linking pipe") | ||||||
|  | 
 | ||||||
|  |     el_pipe = nlp.create_pipe( | ||||||
|  |         name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name} | ||||||
|  |     ) | ||||||
|  |     el_pipe.set_kb(kb) | ||||||
|  |     nlp.add_pipe(el_pipe, last=True) | ||||||
|  | 
 | ||||||
|  |     other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"] | ||||||
|  |     with nlp.disable_pipes(*other_pipes):  # only train Entity Linking | ||||||
|  |         optimizer = nlp.begin_training() | ||||||
|  |         optimizer.learn_rate = lr | ||||||
|  |         optimizer.L2 = l2 | ||||||
|  | 
 | ||||||
|  |     if not train_data: | ||||||
|  |         print("Did not find any training data") | ||||||
|  |     else: | ||||||
|  |         for itn in range(epochs): | ||||||
|  |             random.shuffle(train_data) | ||||||
|  |             losses = {} | ||||||
|  |             batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001)) | ||||||
|  |             batchnr = 0 | ||||||
|  | 
 | ||||||
|  |             with nlp.disable_pipes(*other_pipes): | ||||||
|  |                 for batch in batches: | ||||||
|  |                     try: | ||||||
|  |                         docs, golds = zip(*batch) | ||||||
|  |                         nlp.update( | ||||||
|  |                             docs=docs, | ||||||
|  |                             golds=golds, | ||||||
|  |                             sgd=optimizer, | ||||||
|  |                             drop=dropout, | ||||||
|  |                             losses=losses, | ||||||
|  |                         ) | ||||||
|  |                         batchnr += 1 | ||||||
|  |                     except Exception as e: | ||||||
|  |                         print("Error updating batch:", e) | ||||||
|  | 
 | ||||||
|  |                 if batchnr > 0: | ||||||
|  |                     el_pipe.cfg["incl_context"] = True | ||||||
|  |                     el_pipe.cfg["incl_prior"] = True | ||||||
|  |                     dev_acc_context, _ = _measure_acc(dev_data, el_pipe) | ||||||
|  |                     losses["entity_linker"] = losses["entity_linker"] / batchnr | ||||||
|  |                     print( | ||||||
|  |                         "Epoch, train loss", | ||||||
|  |                         itn, | ||||||
|  |                         round(losses["entity_linker"], 2), | ||||||
|  |                         " / dev accuracy avg", | ||||||
|  |                         round(dev_acc_context, 3), | ||||||
|  |                     ) | ||||||
|  | 
 | ||||||
|  |     # STEP 6: measure the performance of our trained pipe on an independent dev set | ||||||
|  |     print() | ||||||
|  |     if len(dev_data): | ||||||
|  |         print() | ||||||
|  |         print(now(), "STEP 6: performance measurement of Entity Linking pipe") | ||||||
|  |         print() | ||||||
|  | 
 | ||||||
|  |         counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines( | ||||||
|  |             dev_data, kb | ||||||
|  |         ) | ||||||
|  |         print("dev counts:", sorted(counts.items(), key=lambda x: x[0])) | ||||||
|  | 
 | ||||||
|  |         oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()] | ||||||
|  |         print("dev accuracy oracle:", round(acc_o, 3), oracle_by_label) | ||||||
|  | 
 | ||||||
|  |         random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()] | ||||||
|  |         print("dev accuracy random:", round(acc_r, 3), random_by_label) | ||||||
|  | 
 | ||||||
|  |         prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()] | ||||||
|  |         print("dev accuracy prior:", round(acc_p, 3), prior_by_label) | ||||||
|  | 
 | ||||||
|  |         # using only context | ||||||
|  |         el_pipe.cfg["incl_context"] = True | ||||||
|  |         el_pipe.cfg["incl_prior"] = False | ||||||
|  |         dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe) | ||||||
|  |         context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()] | ||||||
|  |         print("dev accuracy context:", round(dev_acc_context, 3), context_by_label) | ||||||
|  | 
 | ||||||
|  |         # measuring combined accuracy (prior + context) | ||||||
|  |         el_pipe.cfg["incl_context"] = True | ||||||
|  |         el_pipe.cfg["incl_prior"] = True | ||||||
|  |         dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe) | ||||||
|  |         combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()] | ||||||
|  |         print("dev accuracy prior+context:", round(dev_acc_combo, 3), combo_by_label) | ||||||
|  | 
 | ||||||
|  |     # STEP 7: apply the EL pipe on a toy example | ||||||
|  |     print() | ||||||
|  |     print(now(), "STEP 7: applying Entity Linking to toy example") | ||||||
|  |     print() | ||||||
|  |     run_el_toy_example(nlp=nlp) | ||||||
|  | 
 | ||||||
|  |     # STEP 8: write the NLP pipeline (including entity linker) to file | ||||||
|  |     if output_dir: | ||||||
|  |         print() | ||||||
|  |         nlp_loc = output_dir / "nlp" | ||||||
|  |         print(now(), "STEP 8: Writing trained NLP to", nlp_loc) | ||||||
|  |         nlp.to_disk(nlp_loc) | ||||||
|  |         print() | ||||||
|  | 
 | ||||||
|  |     print() | ||||||
|  |     print(now(), "Done!") | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _measure_acc(data, el_pipe=None, error_analysis=False): | ||||||
|  |     # If the docs in the data require further processing with an entity linker, set el_pipe | ||||||
|  |     correct_by_label = dict() | ||||||
|  |     incorrect_by_label = dict() | ||||||
|  | 
 | ||||||
|  |     docs = [d for d, g in data if len(d) > 0] | ||||||
|  |     if el_pipe is not None: | ||||||
|  |         docs = list(el_pipe.pipe(docs)) | ||||||
|  |     golds = [g for d, g in data if len(d) > 0] | ||||||
|  | 
 | ||||||
|  |     for doc, gold in zip(docs, golds): | ||||||
|  |         try: | ||||||
|  |             correct_entries_per_article = dict() | ||||||
|  |             for entity, kb_dict in gold.links.items(): | ||||||
|  |                 start, end = entity | ||||||
|  |                 # only evaluating on positive examples | ||||||
|  |                 for gold_kb, value in kb_dict.items(): | ||||||
|  |                     if value: | ||||||
|  |                         offset = _offset(start, end) | ||||||
|  |                         correct_entries_per_article[offset] = gold_kb | ||||||
|  | 
 | ||||||
|  |             for ent in doc.ents: | ||||||
|  |                 ent_label = ent.label_ | ||||||
|  |                 pred_entity = ent.kb_id_ | ||||||
|  |                 start = ent.start_char | ||||||
|  |                 end = ent.end_char | ||||||
|  |                 offset = _offset(start, end) | ||||||
|  |                 gold_entity = correct_entries_per_article.get(offset, None) | ||||||
|  |                 # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' | ||||||
|  |                 if gold_entity is not None: | ||||||
|  |                     if gold_entity == pred_entity: | ||||||
|  |                         correct = correct_by_label.get(ent_label, 0) | ||||||
|  |                         correct_by_label[ent_label] = correct + 1 | ||||||
|  |                     else: | ||||||
|  |                         incorrect = incorrect_by_label.get(ent_label, 0) | ||||||
|  |                         incorrect_by_label[ent_label] = incorrect + 1 | ||||||
|  |                         if error_analysis: | ||||||
|  |                             print(ent.text, "in", doc) | ||||||
|  |                             print( | ||||||
|  |                                 "Predicted", | ||||||
|  |                                 pred_entity, | ||||||
|  |                                 "should have been", | ||||||
|  |                                 gold_entity, | ||||||
|  |                             ) | ||||||
|  |                             print() | ||||||
|  | 
 | ||||||
|  |         except Exception as e: | ||||||
|  |             print("Error assessing accuracy", e) | ||||||
|  | 
 | ||||||
|  |     acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label) | ||||||
|  |     return acc, acc_by_label | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _measure_baselines(data, kb): | ||||||
|  |     # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound | ||||||
|  |     counts_d = dict() | ||||||
|  | 
 | ||||||
|  |     random_correct_d = dict() | ||||||
|  |     random_incorrect_d = dict() | ||||||
|  | 
 | ||||||
|  |     oracle_correct_d = dict() | ||||||
|  |     oracle_incorrect_d = dict() | ||||||
|  | 
 | ||||||
|  |     prior_correct_d = dict() | ||||||
|  |     prior_incorrect_d = dict() | ||||||
|  | 
 | ||||||
|  |     docs = [d for d, g in data if len(d) > 0] | ||||||
|  |     golds = [g for d, g in data if len(d) > 0] | ||||||
|  | 
 | ||||||
|  |     for doc, gold in zip(docs, golds): | ||||||
|  |         try: | ||||||
|  |             correct_entries_per_article = dict() | ||||||
|  |             for entity, kb_dict in gold.links.items(): | ||||||
|  |                 start, end = entity | ||||||
|  |                 for gold_kb, value in kb_dict.items(): | ||||||
|  |                     # only evaluating on positive examples | ||||||
|  |                     if value: | ||||||
|  |                         offset = _offset(start, end) | ||||||
|  |                         correct_entries_per_article[offset] = gold_kb | ||||||
|  | 
 | ||||||
|  |             for ent in doc.ents: | ||||||
|  |                 label = ent.label_ | ||||||
|  |                 start = ent.start_char | ||||||
|  |                 end = ent.end_char | ||||||
|  |                 offset = _offset(start, end) | ||||||
|  |                 gold_entity = correct_entries_per_article.get(offset, None) | ||||||
|  | 
 | ||||||
|  |                 # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' | ||||||
|  |                 if gold_entity is not None: | ||||||
|  |                     counts_d[label] = counts_d.get(label, 0) + 1 | ||||||
|  |                     candidates = kb.get_candidates(ent.text) | ||||||
|  |                     oracle_candidate = "" | ||||||
|  |                     best_candidate = "" | ||||||
|  |                     random_candidate = "" | ||||||
|  |                     if candidates: | ||||||
|  |                         scores = [] | ||||||
|  | 
 | ||||||
|  |                         for c in candidates: | ||||||
|  |                             scores.append(c.prior_prob) | ||||||
|  |                             if c.entity_ == gold_entity: | ||||||
|  |                                 oracle_candidate = c.entity_ | ||||||
|  | 
 | ||||||
|  |                         best_index = scores.index(max(scores)) | ||||||
|  |                         best_candidate = candidates[best_index].entity_ | ||||||
|  |                         random_candidate = random.choice(candidates).entity_ | ||||||
|  | 
 | ||||||
|  |                     if gold_entity == best_candidate: | ||||||
|  |                         prior_correct_d[label] = prior_correct_d.get(label, 0) + 1 | ||||||
|  |                     else: | ||||||
|  |                         prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1 | ||||||
|  | 
 | ||||||
|  |                     if gold_entity == random_candidate: | ||||||
|  |                         random_correct_d[label] = random_correct_d.get(label, 0) + 1 | ||||||
|  |                     else: | ||||||
|  |                         random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1 | ||||||
|  | 
 | ||||||
|  |                     if gold_entity == oracle_candidate: | ||||||
|  |                         oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1 | ||||||
|  |                     else: | ||||||
|  |                         oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1 | ||||||
|  | 
 | ||||||
|  |         except Exception as e: | ||||||
|  |             print("Error assessing accuracy", e) | ||||||
|  | 
 | ||||||
|  |     acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d) | ||||||
|  |     acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d) | ||||||
|  |     acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d) | ||||||
|  | 
 | ||||||
|  |     return ( | ||||||
|  |         counts_d, | ||||||
|  |         acc_rand, | ||||||
|  |         acc_rand_d, | ||||||
|  |         acc_prior, | ||||||
|  |         acc_prior_d, | ||||||
|  |         acc_oracle, | ||||||
|  |         acc_oracle_d, | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _offset(start, end): | ||||||
|  |     return "{}_{}".format(start, end) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def calculate_acc(correct_by_label, incorrect_by_label): | ||||||
|  |     acc_by_label = dict() | ||||||
|  |     total_correct = 0 | ||||||
|  |     total_incorrect = 0 | ||||||
|  |     all_keys = set() | ||||||
|  |     all_keys.update(correct_by_label.keys()) | ||||||
|  |     all_keys.update(incorrect_by_label.keys()) | ||||||
|  |     for label in sorted(all_keys): | ||||||
|  |         correct = correct_by_label.get(label, 0) | ||||||
|  |         incorrect = incorrect_by_label.get(label, 0) | ||||||
|  |         total_correct += correct | ||||||
|  |         total_incorrect += incorrect | ||||||
|  |         if correct == incorrect == 0: | ||||||
|  |             acc_by_label[label] = 0 | ||||||
|  |         else: | ||||||
|  |             acc_by_label[label] = correct / (correct + incorrect) | ||||||
|  |     acc = 0 | ||||||
|  |     if not (total_correct == total_incorrect == 0): | ||||||
|  |         acc = total_correct / (total_correct + total_incorrect) | ||||||
|  |     return acc, acc_by_label | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def check_kb(kb): | ||||||
|  |     for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"): | ||||||
|  |         candidates = kb.get_candidates(mention) | ||||||
|  | 
 | ||||||
|  |         print("generating candidates for " + mention + " :") | ||||||
|  |         for c in candidates: | ||||||
|  |             print( | ||||||
|  |                 " ", | ||||||
|  |                 c.prior_prob, | ||||||
|  |                 c.alias_, | ||||||
|  |                 "-->", | ||||||
|  |                 c.entity_ + " (freq=" + str(c.entity_freq) + ")", | ||||||
|  |             ) | ||||||
|  |         print() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def run_el_toy_example(nlp): | ||||||
|  |     text = ( | ||||||
|  |         "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " | ||||||
|  |         "Douglas reminds us to always bring our towel, even in China or Brazil. " | ||||||
|  |         "The main character in Doug's novel is the man Arthur Dent, " | ||||||
|  |         "but Dougledydoug doesn't write about George Washington or Homer Simpson." | ||||||
|  |     ) | ||||||
|  |     doc = nlp(text) | ||||||
|  |     print(text) | ||||||
|  |     for ent in doc.ents: | ||||||
|  |         print(" ent", ent.text, ent.label_, ent.kb_id_) | ||||||
|  |     print() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     plac.call(main) | ||||||
|  | @ -120,7 +120,7 @@ def now(): | ||||||
|     return datetime.datetime.now() |     return datetime.datetime.now() | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def read_prior_probs(wikipedia_input, prior_prob_output): | def read_prior_probs(wikipedia_input, prior_prob_output, limit=None): | ||||||
|     """ |     """ | ||||||
|     Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities. |     Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities. | ||||||
|     The full file takes about 2h to parse 1100M lines. |     The full file takes about 2h to parse 1100M lines. | ||||||
|  | @ -129,9 +129,9 @@ def read_prior_probs(wikipedia_input, prior_prob_output): | ||||||
|     with bz2.open(wikipedia_input, mode="rb") as file: |     with bz2.open(wikipedia_input, mode="rb") as file: | ||||||
|         line = file.readline() |         line = file.readline() | ||||||
|         cnt = 0 |         cnt = 0 | ||||||
|         while line: |         while line and (not limit or cnt < limit): | ||||||
|             if cnt % 5000000 == 0: |             if cnt % 25000000 == 0: | ||||||
|                 print(now(), "processed", cnt, "lines of Wikipedia dump") |                 print(now(), "processed", cnt, "lines of Wikipedia XML dump") | ||||||
|             clean_line = line.strip().decode("utf-8") |             clean_line = line.strip().decode("utf-8") | ||||||
| 
 | 
 | ||||||
|             aliases, entities, normalizations = get_wp_links(clean_line) |             aliases, entities, normalizations = get_wp_links(clean_line) | ||||||
|  | @ -141,6 +141,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output): | ||||||
| 
 | 
 | ||||||
|             line = file.readline() |             line = file.readline() | ||||||
|             cnt += 1 |             cnt += 1 | ||||||
|  |         print(now(), "processed", cnt, "lines of Wikipedia XML dump") | ||||||
| 
 | 
 | ||||||
|     # write all aliases and their entities and count occurrences to file |     # write all aliases and their entities and count occurrences to file | ||||||
|     with prior_prob_output.open("w", encoding="utf8") as outputfile: |     with prior_prob_output.open("w", encoding="utf8") as outputfile: | ||||||
|  |  | ||||||
|  | @ -1,75 +0,0 @@ | ||||||
| # coding: utf-8 |  | ||||||
| from __future__ import unicode_literals |  | ||||||
| 
 |  | ||||||
| """Demonstrate how to build a simple knowledge base and run an Entity Linking algorithm. |  | ||||||
| Currently still a bit of a dummy algorithm: taking simply the entity with highest probability for a given alias |  | ||||||
| """ |  | ||||||
| import spacy |  | ||||||
| from spacy.kb import KnowledgeBase |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def create_kb(vocab): |  | ||||||
|     kb = KnowledgeBase(vocab=vocab, entity_vector_length=1) |  | ||||||
| 
 |  | ||||||
|     # adding entities |  | ||||||
|     entity_0 = "Q1004791_Douglas" |  | ||||||
|     print("adding entity", entity_0) |  | ||||||
|     kb.add_entity(entity=entity_0, freq=0.5, entity_vector=[0]) |  | ||||||
| 
 |  | ||||||
|     entity_1 = "Q42_Douglas_Adams" |  | ||||||
|     print("adding entity", entity_1) |  | ||||||
|     kb.add_entity(entity=entity_1, freq=0.5, entity_vector=[1]) |  | ||||||
| 
 |  | ||||||
|     entity_2 = "Q5301561_Douglas_Haig" |  | ||||||
|     print("adding entity", entity_2) |  | ||||||
|     kb.add_entity(entity=entity_2, freq=0.5, entity_vector=[2]) |  | ||||||
| 
 |  | ||||||
|     # adding aliases |  | ||||||
|     print() |  | ||||||
|     alias_0 = "Douglas" |  | ||||||
|     print("adding alias", alias_0) |  | ||||||
|     kb.add_alias(alias=alias_0, entities=[entity_0, entity_1, entity_2], probabilities=[0.6, 0.1, 0.2]) |  | ||||||
| 
 |  | ||||||
|     alias_1 = "Douglas Adams" |  | ||||||
|     print("adding alias", alias_1) |  | ||||||
|     kb.add_alias(alias=alias_1, entities=[entity_1], probabilities=[0.9]) |  | ||||||
| 
 |  | ||||||
|     print() |  | ||||||
|     print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases()) |  | ||||||
| 
 |  | ||||||
|     return kb |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def add_el(kb, nlp): |  | ||||||
|     el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64}) |  | ||||||
|     el_pipe.set_kb(kb) |  | ||||||
|     nlp.add_pipe(el_pipe, last=True) |  | ||||||
|     nlp.begin_training() |  | ||||||
|     el_pipe.context_weight = 0 |  | ||||||
|     el_pipe.prior_weight = 1 |  | ||||||
| 
 |  | ||||||
|     for alias in ["Douglas Adams", "Douglas"]: |  | ||||||
|         candidates = nlp.linker.kb.get_candidates(alias) |  | ||||||
|         print() |  | ||||||
|         print(len(candidates), "candidate(s) for", alias, ":") |  | ||||||
|         for c in candidates: |  | ||||||
|             print(" ", c.entity_, c.prior_prob) |  | ||||||
| 
 |  | ||||||
|     text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \ |  | ||||||
|            "Douglas reminds us to always bring our towel. " \ |  | ||||||
|            "The main character in Doug's novel is called Arthur Dent." |  | ||||||
|     doc = nlp(text) |  | ||||||
| 
 |  | ||||||
|     print() |  | ||||||
|     for token in doc: |  | ||||||
|         print("token", token.text, token.ent_type_, token.ent_kb_id_) |  | ||||||
| 
 |  | ||||||
|     print() |  | ||||||
|     for ent in doc.ents: |  | ||||||
|         print("ent", ent.text, ent.label_, ent.kb_id_) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     my_nlp = spacy.load('en_core_web_sm') |  | ||||||
|     my_kb = create_kb(my_nlp.vocab) |  | ||||||
|     add_el(my_kb, my_nlp) |  | ||||||
|  | @ -1,514 +0,0 @@ | ||||||
| # coding: utf-8 |  | ||||||
| from __future__ import unicode_literals |  | ||||||
| 
 |  | ||||||
| import os |  | ||||||
| from os import path |  | ||||||
| import random |  | ||||||
| import datetime |  | ||||||
| from pathlib import Path |  | ||||||
| 
 |  | ||||||
| from bin.wiki_entity_linking import wikipedia_processor as wp |  | ||||||
| from bin.wiki_entity_linking import training_set_creator, kb_creator |  | ||||||
| from bin.wiki_entity_linking.kb_creator import DESC_WIDTH |  | ||||||
| 
 |  | ||||||
| import spacy |  | ||||||
| from spacy.kb import KnowledgeBase |  | ||||||
| from spacy.util import minibatch, compounding |  | ||||||
| 
 |  | ||||||
| """ |  | ||||||
| Demonstrate how to build a knowledge base from WikiData and run an Entity Linking algorithm. |  | ||||||
| """ |  | ||||||
| 
 |  | ||||||
| ROOT_DIR = Path("C:/Users/Sofie/Documents/data/") |  | ||||||
| OUTPUT_DIR = ROOT_DIR / "wikipedia" |  | ||||||
| TRAINING_DIR = OUTPUT_DIR / "training_data_nel" |  | ||||||
| 
 |  | ||||||
| PRIOR_PROB = OUTPUT_DIR / "prior_prob.csv" |  | ||||||
| ENTITY_COUNTS = OUTPUT_DIR / "entity_freq.csv" |  | ||||||
| ENTITY_DEFS = OUTPUT_DIR / "entity_defs.csv" |  | ||||||
| ENTITY_DESCR = OUTPUT_DIR / "entity_descriptions.csv" |  | ||||||
| 
 |  | ||||||
| KB_DIR = OUTPUT_DIR / "kb_1" |  | ||||||
| KB_FILE = "kb" |  | ||||||
| NLP_1_DIR = OUTPUT_DIR / "nlp_1" |  | ||||||
| NLP_2_DIR = OUTPUT_DIR / "nlp_2" |  | ||||||
| 
 |  | ||||||
| # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ |  | ||||||
| WIKIDATA_JSON = ROOT_DIR / "wikidata" / "wikidata-20190304-all.json.bz2" |  | ||||||
| 
 |  | ||||||
| # get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/ |  | ||||||
| ENWIKI_DUMP = ( |  | ||||||
|     ROOT_DIR / "wikipedia" / "enwiki-20190320-pages-articles-multistream.xml.bz2" |  | ||||||
| ) |  | ||||||
| 
 |  | ||||||
| # KB construction parameters |  | ||||||
| MAX_CANDIDATES = 10 |  | ||||||
| MIN_ENTITY_FREQ = 20 |  | ||||||
| MIN_PAIR_OCC = 5 |  | ||||||
| 
 |  | ||||||
| # model training parameters |  | ||||||
| EPOCHS = 10 |  | ||||||
| DROPOUT = 0.5 |  | ||||||
| LEARN_RATE = 0.005 |  | ||||||
| L2 = 1e-6 |  | ||||||
| CONTEXT_WIDTH = 128 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def now(): |  | ||||||
|     return datetime.datetime.now() |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def run_pipeline(): |  | ||||||
|     # set the appropriate booleans to define which parts of the pipeline should be re(run) |  | ||||||
|     print("START", now()) |  | ||||||
|     print() |  | ||||||
|     nlp_1 = spacy.load("en_core_web_lg") |  | ||||||
|     nlp_2 = None |  | ||||||
|     kb_2 = None |  | ||||||
| 
 |  | ||||||
|     # one-time methods to create KB and write to file |  | ||||||
|     to_create_prior_probs = False |  | ||||||
|     to_create_entity_counts = False |  | ||||||
|     to_create_kb = False |  | ||||||
| 
 |  | ||||||
|     # read KB back in from file |  | ||||||
|     to_read_kb = True |  | ||||||
|     to_test_kb = False |  | ||||||
| 
 |  | ||||||
|     # create training dataset |  | ||||||
|     create_wp_training = False |  | ||||||
| 
 |  | ||||||
|     # train the EL pipe |  | ||||||
|     train_pipe = True |  | ||||||
|     measure_performance = True |  | ||||||
| 
 |  | ||||||
|     # test the EL pipe on a simple example |  | ||||||
|     to_test_pipeline = True |  | ||||||
| 
 |  | ||||||
|     # write the NLP object, read back in and test again |  | ||||||
|     to_write_nlp = True |  | ||||||
|     to_read_nlp = True |  | ||||||
|     test_from_file = False |  | ||||||
| 
 |  | ||||||
|     # STEP 1 : create prior probabilities from WP (run only once) |  | ||||||
|     if to_create_prior_probs: |  | ||||||
|         print("STEP 1: to_create_prior_probs", now()) |  | ||||||
|         wp.read_prior_probs(ENWIKI_DUMP, PRIOR_PROB) |  | ||||||
|         print() |  | ||||||
| 
 |  | ||||||
|     # STEP 2 : deduce entity frequencies from WP (run only once) |  | ||||||
|     if to_create_entity_counts: |  | ||||||
|         print("STEP 2: to_create_entity_counts", now()) |  | ||||||
|         wp.write_entity_counts(PRIOR_PROB, ENTITY_COUNTS, to_print=False) |  | ||||||
|         print() |  | ||||||
| 
 |  | ||||||
|     # STEP 3 : create KB and write to file (run only once) |  | ||||||
|     if to_create_kb: |  | ||||||
|         print("STEP 3a: to_create_kb", now()) |  | ||||||
|         kb_1 = kb_creator.create_kb( |  | ||||||
|             nlp=nlp_1, |  | ||||||
|             max_entities_per_alias=MAX_CANDIDATES, |  | ||||||
|             min_entity_freq=MIN_ENTITY_FREQ, |  | ||||||
|             min_occ=MIN_PAIR_OCC, |  | ||||||
|             entity_def_output=ENTITY_DEFS, |  | ||||||
|             entity_descr_output=ENTITY_DESCR, |  | ||||||
|             count_input=ENTITY_COUNTS, |  | ||||||
|             prior_prob_input=PRIOR_PROB, |  | ||||||
|             wikidata_input=WIKIDATA_JSON, |  | ||||||
|         ) |  | ||||||
|         print("kb entities:", kb_1.get_size_entities()) |  | ||||||
|         print("kb aliases:", kb_1.get_size_aliases()) |  | ||||||
|         print() |  | ||||||
| 
 |  | ||||||
|         print("STEP 3b: write KB and NLP", now()) |  | ||||||
| 
 |  | ||||||
|         if not path.exists(KB_DIR): |  | ||||||
|             os.makedirs(KB_DIR) |  | ||||||
|         kb_1.dump(KB_DIR / KB_FILE) |  | ||||||
|         nlp_1.to_disk(NLP_1_DIR) |  | ||||||
|         print() |  | ||||||
| 
 |  | ||||||
|     # STEP 4 : read KB back in from file |  | ||||||
|     if to_read_kb: |  | ||||||
|         print("STEP 4: to_read_kb", now()) |  | ||||||
|         nlp_2 = spacy.load(NLP_1_DIR) |  | ||||||
|         kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH) |  | ||||||
|         kb_2.load_bulk(KB_DIR / KB_FILE) |  | ||||||
|         print("kb entities:", kb_2.get_size_entities()) |  | ||||||
|         print("kb aliases:", kb_2.get_size_aliases()) |  | ||||||
|         print() |  | ||||||
| 
 |  | ||||||
|         # test KB |  | ||||||
|         if to_test_kb: |  | ||||||
|             check_kb(kb_2) |  | ||||||
|             print() |  | ||||||
| 
 |  | ||||||
|     # STEP 5: create a training dataset from WP |  | ||||||
|     if create_wp_training: |  | ||||||
|         print("STEP 5: create training dataset", now()) |  | ||||||
|         training_set_creator.create_training( |  | ||||||
|             wikipedia_input=ENWIKI_DUMP, |  | ||||||
|             entity_def_input=ENTITY_DEFS, |  | ||||||
|             training_output=TRAINING_DIR, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|     # STEP 6: create and train the entity linking pipe |  | ||||||
|     if train_pipe: |  | ||||||
|         print("STEP 6: training Entity Linking pipe", now()) |  | ||||||
|         type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)} |  | ||||||
|         print(" -analysing", len(type_to_int), "different entity types") |  | ||||||
|         el_pipe = nlp_2.create_pipe( |  | ||||||
|             name="entity_linker", |  | ||||||
|             config={ |  | ||||||
|                 "context_width": CONTEXT_WIDTH, |  | ||||||
|                 "pretrained_vectors": nlp_2.vocab.vectors.name, |  | ||||||
|                 "type_to_int": type_to_int, |  | ||||||
|             }, |  | ||||||
|         ) |  | ||||||
|         el_pipe.set_kb(kb_2) |  | ||||||
|         nlp_2.add_pipe(el_pipe, last=True) |  | ||||||
| 
 |  | ||||||
|         other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"] |  | ||||||
|         with nlp_2.disable_pipes(*other_pipes):  # only train Entity Linking |  | ||||||
|             optimizer = nlp_2.begin_training() |  | ||||||
|             optimizer.learn_rate = LEARN_RATE |  | ||||||
|             optimizer.L2 = L2 |  | ||||||
| 
 |  | ||||||
|         # define the size (nr of entities) of training and dev set |  | ||||||
|         train_limit = 5000 |  | ||||||
|         dev_limit = 5000 |  | ||||||
| 
 |  | ||||||
|         # for training, get pos & neg instances that correspond to entries in the kb |  | ||||||
|         train_data = training_set_creator.read_training( |  | ||||||
|             nlp=nlp_2, |  | ||||||
|             training_dir=TRAINING_DIR, |  | ||||||
|             dev=False, |  | ||||||
|             limit=train_limit, |  | ||||||
|             kb=el_pipe.kb, |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         print("Training on", len(train_data), "articles") |  | ||||||
|         print() |  | ||||||
| 
 |  | ||||||
|         # for testing, get all pos instances, whether or not they are in the kb |  | ||||||
|         dev_data = training_set_creator.read_training( |  | ||||||
|             nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         print("Dev testing on", len(dev_data), "articles") |  | ||||||
|         print() |  | ||||||
| 
 |  | ||||||
|         if not train_data: |  | ||||||
|             print("Did not find any training data") |  | ||||||
|         else: |  | ||||||
|             for itn in range(EPOCHS): |  | ||||||
|                 random.shuffle(train_data) |  | ||||||
|                 losses = {} |  | ||||||
|                 batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001)) |  | ||||||
|                 batchnr = 0 |  | ||||||
| 
 |  | ||||||
|                 with nlp_2.disable_pipes(*other_pipes): |  | ||||||
|                     for batch in batches: |  | ||||||
|                         try: |  | ||||||
|                             docs, golds = zip(*batch) |  | ||||||
|                             nlp_2.update( |  | ||||||
|                                 docs=docs, |  | ||||||
|                                 golds=golds, |  | ||||||
|                                 sgd=optimizer, |  | ||||||
|                                 drop=DROPOUT, |  | ||||||
|                                 losses=losses, |  | ||||||
|                             ) |  | ||||||
|                             batchnr += 1 |  | ||||||
|                         except Exception as e: |  | ||||||
|                             print("Error updating batch:", e) |  | ||||||
| 
 |  | ||||||
|                 if batchnr > 0: |  | ||||||
|                     el_pipe.cfg["context_weight"] = 1 |  | ||||||
|                     el_pipe.cfg["prior_weight"] = 1 |  | ||||||
|                     dev_acc_context, _ = _measure_acc(dev_data, el_pipe) |  | ||||||
|                     losses["entity_linker"] = losses["entity_linker"] / batchnr |  | ||||||
|                     print( |  | ||||||
|                         "Epoch, train loss", |  | ||||||
|                         itn, |  | ||||||
|                         round(losses["entity_linker"], 2), |  | ||||||
|                         " / dev acc avg", |  | ||||||
|                         round(dev_acc_context, 3), |  | ||||||
|                     ) |  | ||||||
| 
 |  | ||||||
|         # STEP 7: measure the performance of our trained pipe on an independent dev set |  | ||||||
|         if len(dev_data) and measure_performance: |  | ||||||
|             print() |  | ||||||
|             print("STEP 7: performance measurement of Entity Linking pipe", now()) |  | ||||||
|             print() |  | ||||||
| 
 |  | ||||||
|             counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines( |  | ||||||
|                 dev_data, kb_2 |  | ||||||
|             ) |  | ||||||
|             print("dev counts:", sorted(counts.items(), key=lambda x: x[0])) |  | ||||||
| 
 |  | ||||||
|             oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()] |  | ||||||
|             print("dev acc oracle:", round(acc_o, 3), oracle_by_label) |  | ||||||
| 
 |  | ||||||
|             random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()] |  | ||||||
|             print("dev acc random:", round(acc_r, 3), random_by_label) |  | ||||||
| 
 |  | ||||||
|             prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()] |  | ||||||
|             print("dev acc prior:", round(acc_p, 3), prior_by_label) |  | ||||||
| 
 |  | ||||||
|             # using only context |  | ||||||
|             el_pipe.cfg["context_weight"] = 1 |  | ||||||
|             el_pipe.cfg["prior_weight"] = 0 |  | ||||||
|             dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe) |  | ||||||
|             context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()] |  | ||||||
|             print("dev acc context avg:", round(dev_acc_context, 3), context_by_label) |  | ||||||
| 
 |  | ||||||
|             # measuring combined accuracy (prior + context) |  | ||||||
|             el_pipe.cfg["context_weight"] = 1 |  | ||||||
|             el_pipe.cfg["prior_weight"] = 1 |  | ||||||
|             dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe) |  | ||||||
|             combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()] |  | ||||||
|             print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label) |  | ||||||
| 
 |  | ||||||
|         # STEP 8: apply the EL pipe on a toy example |  | ||||||
|         if to_test_pipeline: |  | ||||||
|             print() |  | ||||||
|             print("STEP 8: applying Entity Linking to toy example", now()) |  | ||||||
|             print() |  | ||||||
|             run_el_toy_example(nlp=nlp_2) |  | ||||||
| 
 |  | ||||||
|         # STEP 9: write the NLP pipeline (including entity linker) to file |  | ||||||
|         if to_write_nlp: |  | ||||||
|             print() |  | ||||||
|             print("STEP 9: testing NLP IO", now()) |  | ||||||
|             print() |  | ||||||
|             print("writing to", NLP_2_DIR) |  | ||||||
|             nlp_2.to_disk(NLP_2_DIR) |  | ||||||
|             print() |  | ||||||
| 
 |  | ||||||
|     # verify that the IO has gone correctly |  | ||||||
|     if to_read_nlp: |  | ||||||
|         print("reading from", NLP_2_DIR) |  | ||||||
|         nlp_3 = spacy.load(NLP_2_DIR) |  | ||||||
| 
 |  | ||||||
|         print("running toy example with NLP 3") |  | ||||||
|         run_el_toy_example(nlp=nlp_3) |  | ||||||
| 
 |  | ||||||
|     # testing performance with an NLP model from file |  | ||||||
|     if test_from_file: |  | ||||||
|         nlp_2 = spacy.load(NLP_1_DIR) |  | ||||||
|         nlp_3 = spacy.load(NLP_2_DIR) |  | ||||||
|         el_pipe = nlp_3.get_pipe("entity_linker") |  | ||||||
| 
 |  | ||||||
|         dev_limit = 5000 |  | ||||||
|         dev_data = training_set_creator.read_training( |  | ||||||
|             nlp=nlp_2, training_dir=TRAINING_DIR, dev=True, limit=dev_limit, kb=None |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         print("Dev testing from file on", len(dev_data), "articles") |  | ||||||
|         print() |  | ||||||
| 
 |  | ||||||
|         dev_acc_combo, dev_acc_combo_dict = _measure_acc(dev_data, el_pipe) |  | ||||||
|         combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()] |  | ||||||
|         print("dev acc combo avg:", round(dev_acc_combo, 3), combo_by_label) |  | ||||||
| 
 |  | ||||||
|     print() |  | ||||||
|     print("STOP", now()) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def _measure_acc(data, el_pipe=None, error_analysis=False): |  | ||||||
|     # If the docs in the data require further processing with an entity linker, set el_pipe |  | ||||||
|     correct_by_label = dict() |  | ||||||
|     incorrect_by_label = dict() |  | ||||||
| 
 |  | ||||||
|     docs = [d for d, g in data if len(d) > 0] |  | ||||||
|     if el_pipe is not None: |  | ||||||
|         docs = list(el_pipe.pipe(docs)) |  | ||||||
|     golds = [g for d, g in data if len(d) > 0] |  | ||||||
| 
 |  | ||||||
|     for doc, gold in zip(docs, golds): |  | ||||||
|         try: |  | ||||||
|             correct_entries_per_article = dict() |  | ||||||
|             for entity, kb_dict in gold.links.items(): |  | ||||||
|                 start, end = entity |  | ||||||
|                 # only evaluating on positive examples |  | ||||||
|                 for gold_kb, value in kb_dict.items(): |  | ||||||
|                     if value: |  | ||||||
|                         offset = _offset(start, end) |  | ||||||
|                         correct_entries_per_article[offset] = gold_kb |  | ||||||
| 
 |  | ||||||
|             for ent in doc.ents: |  | ||||||
|                 ent_label = ent.label_ |  | ||||||
|                 pred_entity = ent.kb_id_ |  | ||||||
|                 start = ent.start_char |  | ||||||
|                 end = ent.end_char |  | ||||||
|                 offset = _offset(start, end) |  | ||||||
|                 gold_entity = correct_entries_per_article.get(offset, None) |  | ||||||
|                 # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' |  | ||||||
|                 if gold_entity is not None: |  | ||||||
|                     if gold_entity == pred_entity: |  | ||||||
|                         correct = correct_by_label.get(ent_label, 0) |  | ||||||
|                         correct_by_label[ent_label] = correct + 1 |  | ||||||
|                     else: |  | ||||||
|                         incorrect = incorrect_by_label.get(ent_label, 0) |  | ||||||
|                         incorrect_by_label[ent_label] = incorrect + 1 |  | ||||||
|                         if error_analysis: |  | ||||||
|                             print(ent.text, "in", doc) |  | ||||||
|                             print( |  | ||||||
|                                 "Predicted", |  | ||||||
|                                 pred_entity, |  | ||||||
|                                 "should have been", |  | ||||||
|                                 gold_entity, |  | ||||||
|                             ) |  | ||||||
|                             print() |  | ||||||
| 
 |  | ||||||
|         except Exception as e: |  | ||||||
|             print("Error assessing accuracy", e) |  | ||||||
| 
 |  | ||||||
|     acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label) |  | ||||||
|     return acc, acc_by_label |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def _measure_baselines(data, kb): |  | ||||||
|     # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound |  | ||||||
|     counts_d = dict() |  | ||||||
| 
 |  | ||||||
|     random_correct_d = dict() |  | ||||||
|     random_incorrect_d = dict() |  | ||||||
| 
 |  | ||||||
|     oracle_correct_d = dict() |  | ||||||
|     oracle_incorrect_d = dict() |  | ||||||
| 
 |  | ||||||
|     prior_correct_d = dict() |  | ||||||
|     prior_incorrect_d = dict() |  | ||||||
| 
 |  | ||||||
|     docs = [d for d, g in data if len(d) > 0] |  | ||||||
|     golds = [g for d, g in data if len(d) > 0] |  | ||||||
| 
 |  | ||||||
|     for doc, gold in zip(docs, golds): |  | ||||||
|         try: |  | ||||||
|             correct_entries_per_article = dict() |  | ||||||
|             for entity, kb_dict in gold.links.items(): |  | ||||||
|                 start, end = entity |  | ||||||
|                 for gold_kb, value in kb_dict.items(): |  | ||||||
|                     # only evaluating on positive examples |  | ||||||
|                     if value: |  | ||||||
|                         offset = _offset(start, end) |  | ||||||
|                         correct_entries_per_article[offset] = gold_kb |  | ||||||
| 
 |  | ||||||
|             for ent in doc.ents: |  | ||||||
|                 label = ent.label_ |  | ||||||
|                 start = ent.start_char |  | ||||||
|                 end = ent.end_char |  | ||||||
|                 offset = _offset(start, end) |  | ||||||
|                 gold_entity = correct_entries_per_article.get(offset, None) |  | ||||||
| 
 |  | ||||||
|                 # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' |  | ||||||
|                 if gold_entity is not None: |  | ||||||
|                     counts_d[label] = counts_d.get(label, 0) + 1 |  | ||||||
|                     candidates = kb.get_candidates(ent.text) |  | ||||||
|                     oracle_candidate = "" |  | ||||||
|                     best_candidate = "" |  | ||||||
|                     random_candidate = "" |  | ||||||
|                     if candidates: |  | ||||||
|                         scores = [] |  | ||||||
| 
 |  | ||||||
|                         for c in candidates: |  | ||||||
|                             scores.append(c.prior_prob) |  | ||||||
|                             if c.entity_ == gold_entity: |  | ||||||
|                                 oracle_candidate = c.entity_ |  | ||||||
| 
 |  | ||||||
|                         best_index = scores.index(max(scores)) |  | ||||||
|                         best_candidate = candidates[best_index].entity_ |  | ||||||
|                         random_candidate = random.choice(candidates).entity_ |  | ||||||
| 
 |  | ||||||
|                     if gold_entity == best_candidate: |  | ||||||
|                         prior_correct_d[label] = prior_correct_d.get(label, 0) + 1 |  | ||||||
|                     else: |  | ||||||
|                         prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1 |  | ||||||
| 
 |  | ||||||
|                     if gold_entity == random_candidate: |  | ||||||
|                         random_correct_d[label] = random_correct_d.get(label, 0) + 1 |  | ||||||
|                     else: |  | ||||||
|                         random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1 |  | ||||||
| 
 |  | ||||||
|                     if gold_entity == oracle_candidate: |  | ||||||
|                         oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1 |  | ||||||
|                     else: |  | ||||||
|                         oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1 |  | ||||||
| 
 |  | ||||||
|         except Exception as e: |  | ||||||
|             print("Error assessing accuracy", e) |  | ||||||
| 
 |  | ||||||
|     acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d) |  | ||||||
|     acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d) |  | ||||||
|     acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d) |  | ||||||
| 
 |  | ||||||
|     return ( |  | ||||||
|         counts_d, |  | ||||||
|         acc_rand, |  | ||||||
|         acc_rand_d, |  | ||||||
|         acc_prior, |  | ||||||
|         acc_prior_d, |  | ||||||
|         acc_oracle, |  | ||||||
|         acc_oracle_d, |  | ||||||
|     ) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def _offset(start, end): |  | ||||||
|     return "{}_{}".format(start, end) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def calculate_acc(correct_by_label, incorrect_by_label): |  | ||||||
|     acc_by_label = dict() |  | ||||||
|     total_correct = 0 |  | ||||||
|     total_incorrect = 0 |  | ||||||
|     all_keys = set() |  | ||||||
|     all_keys.update(correct_by_label.keys()) |  | ||||||
|     all_keys.update(incorrect_by_label.keys()) |  | ||||||
|     for label in sorted(all_keys): |  | ||||||
|         correct = correct_by_label.get(label, 0) |  | ||||||
|         incorrect = incorrect_by_label.get(label, 0) |  | ||||||
|         total_correct += correct |  | ||||||
|         total_incorrect += incorrect |  | ||||||
|         if correct == incorrect == 0: |  | ||||||
|             acc_by_label[label] = 0 |  | ||||||
|         else: |  | ||||||
|             acc_by_label[label] = correct / (correct + incorrect) |  | ||||||
|     acc = 0 |  | ||||||
|     if not (total_correct == total_incorrect == 0): |  | ||||||
|         acc = total_correct / (total_correct + total_incorrect) |  | ||||||
|     return acc, acc_by_label |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def check_kb(kb): |  | ||||||
|     for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"): |  | ||||||
|         candidates = kb.get_candidates(mention) |  | ||||||
| 
 |  | ||||||
|         print("generating candidates for " + mention + " :") |  | ||||||
|         for c in candidates: |  | ||||||
|             print( |  | ||||||
|                 " ", |  | ||||||
|                 c.prior_prob, |  | ||||||
|                 c.alias_, |  | ||||||
|                 "-->", |  | ||||||
|                 c.entity_ + " (freq=" + str(c.entity_freq) + ")", |  | ||||||
|             ) |  | ||||||
|         print() |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| def run_el_toy_example(nlp): |  | ||||||
|     text = ( |  | ||||||
|         "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " |  | ||||||
|         "Douglas reminds us to always bring our towel, even in China or Brazil. " |  | ||||||
|         "The main character in Doug's novel is the man Arthur Dent, " |  | ||||||
|         "but Dougledydoug doesn't write about George Washington or Homer Simpson." |  | ||||||
|     ) |  | ||||||
|     doc = nlp(text) |  | ||||||
|     print(text) |  | ||||||
|     for ent in doc.ents: |  | ||||||
|         print(" ent", ent.text, ent.label_, ent.kb_id_) |  | ||||||
|     print() |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| if __name__ == "__main__": |  | ||||||
|     run_pipeline() |  | ||||||
							
								
								
									
										139
									
								
								examples/training/pretrain_kb.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										139
									
								
								examples/training/pretrain_kb.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,139 @@ | ||||||
|  | #!/usr/bin/env python | ||||||
|  | # coding: utf8 | ||||||
|  | 
 | ||||||
|  | """Example of defining and (pre)training spaCy's knowledge base, | ||||||
|  | which is needed to implement entity linking functionality. | ||||||
|  | 
 | ||||||
|  | For more details, see the documentation: | ||||||
|  | * Knowledge base: https://spacy.io/api/kb | ||||||
|  | * Entity Linking: https://spacy.io/usage/linguistic-features#entity-linking | ||||||
|  | 
 | ||||||
|  | Compatible with: spaCy vX.X | ||||||
|  | Last tested with: vX.X | ||||||
|  | """ | ||||||
|  | from __future__ import unicode_literals, print_function | ||||||
|  | 
 | ||||||
|  | import plac | ||||||
|  | from pathlib import Path | ||||||
|  | 
 | ||||||
|  | from spacy.vocab import Vocab | ||||||
|  | 
 | ||||||
|  | import spacy | ||||||
|  | from spacy.kb import KnowledgeBase | ||||||
|  | 
 | ||||||
|  | from bin.wiki_entity_linking.train_descriptions import EntityEncoder | ||||||
|  | from spacy import Errors | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # Q2146908 (Russ Cochran): American golfer | ||||||
|  | # Q7381115 (Russ Cochran): publisher | ||||||
|  | ENTITIES = {"Q2146908": ("American golfer", 342), "Q7381115": ("publisher", 17)} | ||||||
|  | 
 | ||||||
|  | INPUT_DIM = 300  # dimension of pre-trained input vectors | ||||||
|  | DESC_WIDTH = 64  # dimension of output entity vectors | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @plac.annotations( | ||||||
|  |     vocab_path=("Path to the vocab for the kb", "option", "v", Path), | ||||||
|  |     model=("Model name, should have pretrained word embeddings", "option", "m", str), | ||||||
|  |     output_dir=("Optional output directory", "option", "o", Path), | ||||||
|  |     n_iter=("Number of training iterations", "option", "n", int), | ||||||
|  | ) | ||||||
|  | def main(vocab_path=None, model=None, output_dir=None, n_iter=50): | ||||||
|  |     """Load the model, create the KB and pretrain the entity encodings. | ||||||
|  |     Either an nlp model or a vocab is needed to provide access to pre-trained word embeddings. | ||||||
|  |     If an output_dir is provided, the KB will be stored there in a file 'kb'. | ||||||
|  |     When providing an nlp model, the updated vocab will also be written to a directory in the output_dir.""" | ||||||
|  |     if model is None and vocab_path is None: | ||||||
|  |         raise ValueError(Errors.E154) | ||||||
|  | 
 | ||||||
|  |     if model is not None: | ||||||
|  |         nlp = spacy.load(model)  # load existing spaCy model | ||||||
|  |         print("Loaded model '%s'" % model) | ||||||
|  |     else: | ||||||
|  |         vocab = Vocab().from_disk(vocab_path) | ||||||
|  |         # create blank Language class with specified vocab | ||||||
|  |         nlp = spacy.blank("en", vocab=vocab) | ||||||
|  |         print("Created blank 'en' model with vocab from '%s'" % vocab_path) | ||||||
|  | 
 | ||||||
|  |     kb = KnowledgeBase(vocab=nlp.vocab) | ||||||
|  | 
 | ||||||
|  |     # set up the data | ||||||
|  |     entity_ids = [] | ||||||
|  |     descriptions = [] | ||||||
|  |     freqs = [] | ||||||
|  |     for key, value in ENTITIES.items(): | ||||||
|  |         desc, freq = value | ||||||
|  |         entity_ids.append(key) | ||||||
|  |         descriptions.append(desc) | ||||||
|  |         freqs.append(freq) | ||||||
|  | 
 | ||||||
|  |     # training entity description encodings | ||||||
|  |     # this part can easily be replaced with a custom entity encoder | ||||||
|  |     encoder = EntityEncoder( | ||||||
|  |         nlp=nlp, | ||||||
|  |         input_dim=INPUT_DIM, | ||||||
|  |         desc_width=DESC_WIDTH, | ||||||
|  |         epochs=n_iter, | ||||||
|  |         threshold=0.001, | ||||||
|  |     ) | ||||||
|  |     encoder.train(description_list=descriptions, to_print=True) | ||||||
|  | 
 | ||||||
|  |     # get the pretrained entity vectors | ||||||
|  |     embeddings = encoder.apply_encoder(descriptions) | ||||||
|  | 
 | ||||||
|  |     # set the entities, can also be done by calling `kb.add_entity` for each entity | ||||||
|  |     kb.set_entities(entity_list=entity_ids, freq_list=freqs, vector_list=embeddings) | ||||||
|  | 
 | ||||||
|  |     # adding aliases, the entities need to be defined in the KB beforehand | ||||||
|  |     kb.add_alias( | ||||||
|  |         alias="Russ Cochran", | ||||||
|  |         entities=["Q2146908", "Q7381115"], | ||||||
|  |         probabilities=[0.24, 0.7],  # the sum of these probabilities should not exceed 1 | ||||||
|  |     ) | ||||||
|  | 
 | ||||||
|  |     # test the trained model | ||||||
|  |     print() | ||||||
|  |     _print_kb(kb) | ||||||
|  | 
 | ||||||
|  |     # save model to output directory | ||||||
|  |     if output_dir is not None: | ||||||
|  |         output_dir = Path(output_dir) | ||||||
|  |         if not output_dir.exists(): | ||||||
|  |             output_dir.mkdir() | ||||||
|  |         kb_path = str(output_dir / "kb") | ||||||
|  |         kb.dump(kb_path) | ||||||
|  |         print() | ||||||
|  |         print("Saved KB to", kb_path) | ||||||
|  | 
 | ||||||
|  |         # only storing the vocab if we weren't already reading it from file | ||||||
|  |         if not vocab_path: | ||||||
|  |             vocab_path = output_dir / "vocab" | ||||||
|  |             kb.vocab.to_disk(vocab_path) | ||||||
|  |             print("Saved vocab to", vocab_path) | ||||||
|  | 
 | ||||||
|  |         print() | ||||||
|  | 
 | ||||||
|  |         # test the saved model | ||||||
|  |         # always reload a knowledge base with the same vocab instance! | ||||||
|  |         print("Loading vocab from", vocab_path) | ||||||
|  |         print("Loading KB from", kb_path) | ||||||
|  |         vocab2 = Vocab().from_disk(vocab_path) | ||||||
|  |         kb2 = KnowledgeBase(vocab=vocab2) | ||||||
|  |         kb2.load_bulk(kb_path) | ||||||
|  |         _print_kb(kb2) | ||||||
|  |         print() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _print_kb(kb): | ||||||
|  |     print(kb.get_size_entities(), "kb entities:", kb.get_entity_strings()) | ||||||
|  |     print(kb.get_size_aliases(), "kb aliases:", kb.get_alias_strings()) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     plac.call(main) | ||||||
|  | 
 | ||||||
|  |     # Expected output: | ||||||
|  | 
 | ||||||
|  |     # 2 kb entities: ['Q2146908', 'Q7381115'] | ||||||
|  |     # 1 kb aliases: ['Russ Cochran'] | ||||||
							
								
								
									
										173
									
								
								examples/training/train_entity_linker.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										173
									
								
								examples/training/train_entity_linker.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,173 @@ | ||||||
|  | #!/usr/bin/env python | ||||||
|  | # coding: utf8 | ||||||
|  | 
 | ||||||
|  | """Example of training spaCy's entity linker, starting off with an | ||||||
|  | existing model and a pre-defined knowledge base. | ||||||
|  | 
 | ||||||
|  | For more details, see the documentation: | ||||||
|  | * Training: https://spacy.io/usage/training | ||||||
|  | * Entity Linking: https://spacy.io/usage/linguistic-features#entity-linking | ||||||
|  | 
 | ||||||
|  | Compatible with: spaCy vX.X | ||||||
|  | Last tested with: vX.X | ||||||
|  | """ | ||||||
|  | from __future__ import unicode_literals, print_function | ||||||
|  | 
 | ||||||
|  | import plac | ||||||
|  | import random | ||||||
|  | from pathlib import Path | ||||||
|  | 
 | ||||||
|  | from spacy.symbols import PERSON | ||||||
|  | from spacy.vocab import Vocab | ||||||
|  | 
 | ||||||
|  | import spacy | ||||||
|  | from spacy.kb import KnowledgeBase | ||||||
|  | 
 | ||||||
|  | from spacy import Errors | ||||||
|  | from spacy.tokens import Span | ||||||
|  | from spacy.util import minibatch, compounding | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def sample_train_data(): | ||||||
|  |     train_data = [] | ||||||
|  | 
 | ||||||
|  |     # Q2146908 (Russ Cochran): American golfer | ||||||
|  |     # Q7381115 (Russ Cochran): publisher | ||||||
|  | 
 | ||||||
|  |     text_1 = "Russ Cochran his reprints include EC Comics." | ||||||
|  |     dict_1 = {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}} | ||||||
|  |     train_data.append((text_1, {"links": dict_1})) | ||||||
|  | 
 | ||||||
|  |     text_2 = "Russ Cochran has been publishing comic art." | ||||||
|  |     dict_2 = {(0, 12): {"Q7381115": 1.0, "Q2146908": 0.0}} | ||||||
|  |     train_data.append((text_2, {"links": dict_2})) | ||||||
|  | 
 | ||||||
|  |     text_3 = "Russ Cochran captured his first major title with his son as caddie." | ||||||
|  |     dict_3 = {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}} | ||||||
|  |     train_data.append((text_3, {"links": dict_3})) | ||||||
|  | 
 | ||||||
|  |     text_4 = "Russ Cochran was a member of University of Kentucky's golf team." | ||||||
|  |     dict_4 = {(0, 12): {"Q7381115": 0.0, "Q2146908": 1.0}} | ||||||
|  |     train_data.append((text_4, {"links": dict_4})) | ||||||
|  | 
 | ||||||
|  |     return train_data | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | # training data | ||||||
|  | TRAIN_DATA = sample_train_data() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | @plac.annotations( | ||||||
|  |     kb_path=("Path to the knowledge base", "positional", None, Path), | ||||||
|  |     vocab_path=("Path to the vocab for the kb", "positional", None, Path), | ||||||
|  |     output_dir=("Optional output directory", "option", "o", Path), | ||||||
|  |     n_iter=("Number of training iterations", "option", "n", int), | ||||||
|  | ) | ||||||
|  | def main(kb_path, vocab_path=None, output_dir=None, n_iter=50): | ||||||
|  |     """Create a blank model with the specified vocab, set up the pipeline and train the entity linker. | ||||||
|  |     The `vocab` should be the one used during creation of the KB.""" | ||||||
|  |     vocab = Vocab().from_disk(vocab_path) | ||||||
|  |     # create blank Language class with correct vocab | ||||||
|  |     nlp = spacy.blank("en", vocab=vocab) | ||||||
|  |     nlp.vocab.vectors.name = "spacy_pretrained_vectors" | ||||||
|  |     print("Created blank 'en' model with vocab from '%s'" % vocab_path) | ||||||
|  | 
 | ||||||
|  |     # create the built-in pipeline components and add them to the pipeline | ||||||
|  |     # nlp.create_pipe works for built-ins that are registered with spaCy | ||||||
|  |     if "entity_linker" not in nlp.pipe_names: | ||||||
|  |         entity_linker = nlp.create_pipe("entity_linker") | ||||||
|  |         kb = KnowledgeBase(vocab=nlp.vocab) | ||||||
|  |         kb.load_bulk(kb_path) | ||||||
|  |         print("Loaded Knowledge Base from '%s'" % kb_path) | ||||||
|  |         entity_linker.set_kb(kb) | ||||||
|  |         nlp.add_pipe(entity_linker, last=True) | ||||||
|  |     else: | ||||||
|  |         entity_linker = nlp.get_pipe("entity_linker") | ||||||
|  |         kb = entity_linker.kb | ||||||
|  | 
 | ||||||
|  |     # make sure the annotated examples correspond to known identifiers in the knowlege base | ||||||
|  |     kb_ids = kb.get_entity_strings() | ||||||
|  |     for text, annotation in TRAIN_DATA: | ||||||
|  |         for offset, kb_id_dict in annotation["links"].items(): | ||||||
|  |             new_dict = {} | ||||||
|  |             for kb_id, value in kb_id_dict.items(): | ||||||
|  |                 if kb_id in kb_ids: | ||||||
|  |                     new_dict[kb_id] = value | ||||||
|  |                 else: | ||||||
|  |                     print( | ||||||
|  |                         "Removed", kb_id, "from training because it is not in the KB." | ||||||
|  |                     ) | ||||||
|  |             annotation["links"][offset] = new_dict | ||||||
|  | 
 | ||||||
|  |     # get names of other pipes to disable them during training | ||||||
|  |     other_pipes = [pipe for pipe in nlp.pipe_names if pipe != "entity_linker"] | ||||||
|  |     with nlp.disable_pipes(*other_pipes):  # only train entity linker | ||||||
|  |         # reset and initialize the weights randomly | ||||||
|  |         optimizer = nlp.begin_training() | ||||||
|  |         for itn in range(n_iter): | ||||||
|  |             random.shuffle(TRAIN_DATA) | ||||||
|  |             losses = {} | ||||||
|  |             # batch up the examples using spaCy's minibatch | ||||||
|  |             batches = minibatch(TRAIN_DATA, size=compounding(4.0, 32.0, 1.001)) | ||||||
|  |             for batch in batches: | ||||||
|  |                 texts, annotations = zip(*batch) | ||||||
|  |                 nlp.update( | ||||||
|  |                     texts,  # batch of texts | ||||||
|  |                     annotations,  # batch of annotations | ||||||
|  |                     drop=0.2,  # dropout - make it harder to memorise data | ||||||
|  |                     losses=losses, | ||||||
|  |                     sgd=optimizer, | ||||||
|  |                 ) | ||||||
|  |             print(itn, "Losses", losses) | ||||||
|  | 
 | ||||||
|  |     # test the trained model | ||||||
|  |     _apply_model(nlp) | ||||||
|  | 
 | ||||||
|  |     # save model to output directory | ||||||
|  |     if output_dir is not None: | ||||||
|  |         output_dir = Path(output_dir) | ||||||
|  |         if not output_dir.exists(): | ||||||
|  |             output_dir.mkdir() | ||||||
|  |         nlp.to_disk(output_dir) | ||||||
|  |         print() | ||||||
|  |         print("Saved model to", output_dir) | ||||||
|  | 
 | ||||||
|  |         # test the saved model | ||||||
|  |         print("Loading from", output_dir) | ||||||
|  |         nlp2 = spacy.load(output_dir) | ||||||
|  |         _apply_model(nlp2) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | def _apply_model(nlp): | ||||||
|  |     for text, annotation in TRAIN_DATA: | ||||||
|  |         doc = nlp.tokenizer(text) | ||||||
|  | 
 | ||||||
|  |         # set entities so the evaluation is independent of the NER step | ||||||
|  |         # all the examples contain 'Russ Cochran' as the first two tokens in the sentence | ||||||
|  |         rc_ent = Span(doc, 0, 2, label=PERSON) | ||||||
|  |         doc.ents = [rc_ent] | ||||||
|  | 
 | ||||||
|  |         # apply the entity linker which will now make predictions for the 'Russ Cochran' entities | ||||||
|  |         doc = nlp.get_pipe("entity_linker")(doc) | ||||||
|  | 
 | ||||||
|  |         print() | ||||||
|  |         print("Entities", [(ent.text, ent.label_, ent.kb_id_) for ent in doc.ents]) | ||||||
|  |         print("Tokens", [(t.text, t.ent_type_, t.ent_kb_id_) for t in doc]) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | if __name__ == "__main__": | ||||||
|  |     plac.call(main) | ||||||
|  | 
 | ||||||
|  |     # Expected output (can be shuffled): | ||||||
|  | 
 | ||||||
|  |     # Entities[('Russ Cochran', 'PERSON', 'Q7381115')] | ||||||
|  |     # Tokens[('Russ', 'PERSON', 'Q7381115'), ('Cochran', 'PERSON', 'Q7381115'), ("his", '', ''), ('reprints', '', ''), ('include', '', ''), ('The', '', ''), ('Complete', '', ''), ('EC', '', ''), ('Library', '', ''), ('.', '', '')] | ||||||
|  | 
 | ||||||
|  |     # Entities[('Russ Cochran', 'PERSON', 'Q7381115')] | ||||||
|  |     # Tokens[('Russ', 'PERSON', 'Q7381115'), ('Cochran', 'PERSON', 'Q7381115'), ('has', '', ''), ('been', '', ''), ('publishing', '', ''), ('comic', '', ''), ('art', '', ''), ('.', '', '')] | ||||||
|  | 
 | ||||||
|  |     # Entities[('Russ Cochran', 'PERSON', 'Q2146908')] | ||||||
|  |     # Tokens[('Russ', 'PERSON', 'Q2146908'), ('Cochran', 'PERSON', 'Q2146908'), ('captured', '', ''), ('his', '', ''), ('first', '', ''), ('major', '', ''), ('title', '', ''), ('with', '', ''), ('his', '', ''), ('son', '', ''), ('as', '', ''), ('caddie', '', ''), ('.', '', '')] | ||||||
|  | 
 | ||||||
|  |     # Entities[('Russ Cochran', 'PERSON', 'Q2146908')] | ||||||
|  |     # Tokens[('Russ', 'PERSON', 'Q2146908'), ('Cochran', 'PERSON', 'Q2146908'), ('was', '', ''), ('a', '', ''), ('member', '', ''), ('of', '', ''), ('University', '', ''), ('of', '', ''), ('Kentucky', '', ''), ("'s", '', ''), ('golf', '', ''), ('team', '', ''), ('.', '', '')] | ||||||
							
								
								
									
										24
									
								
								spacy/_ml.py
									
									
									
									
									
								
							
							
						
						
									
										24
									
								
								spacy/_ml.py
									
									
									
									
									
								
							|  | @ -665,25 +665,15 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, | ||||||
| def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg): | def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg): | ||||||
|     if "entity_width" not in cfg: |     if "entity_width" not in cfg: | ||||||
|         raise ValueError(Errors.E144.format(param="entity_width")) |         raise ValueError(Errors.E144.format(param="entity_width")) | ||||||
|     if "context_width" not in cfg: |  | ||||||
|         raise ValueError(Errors.E144.format(param="context_width")) |  | ||||||
| 
 | 
 | ||||||
|     conv_depth = cfg.get("conv_depth", 2) |     conv_depth = cfg.get("conv_depth", 2) | ||||||
|     cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3) |     cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3) | ||||||
|     pretrained_vectors = cfg.get("pretrained_vectors", None) |     pretrained_vectors = cfg.get("pretrained_vectors", None) | ||||||
|     context_width = cfg.get("context_width") |     context_width = cfg.get("entity_width") | ||||||
|     entity_width = cfg.get("entity_width") |  | ||||||
| 
 | 
 | ||||||
|     with Model.define_operators({">>": chain, "**": clone}): |     with Model.define_operators({">>": chain, "**": clone}): | ||||||
|         model = ( |  | ||||||
|             Affine(entity_width, entity_width + context_width + 1 + ner_types) |  | ||||||
|             >> Affine(1, entity_width, drop_factor=0.0) |  | ||||||
|             >> logistic |  | ||||||
|         ) |  | ||||||
| 
 |  | ||||||
|         # context encoder |         # context encoder | ||||||
|         tok2vec = ( |         tok2vec = Tok2Vec( | ||||||
|             Tok2Vec( |  | ||||||
|                 width=hidden_width, |                 width=hidden_width, | ||||||
|                 embed_size=embed_width, |                 embed_size=embed_width, | ||||||
|                 pretrained_vectors=pretrained_vectors, |                 pretrained_vectors=pretrained_vectors, | ||||||
|  | @ -692,17 +682,17 @@ def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg): | ||||||
|                 conv_depth=conv_depth, |                 conv_depth=conv_depth, | ||||||
|                 bilstm_depth=0, |                 bilstm_depth=0, | ||||||
|             ) |             ) | ||||||
|  | 
 | ||||||
|  |         model = ( | ||||||
|  |             tok2vec | ||||||
|             >> flatten_add_lengths |             >> flatten_add_lengths | ||||||
|             >> Pooling(mean_pool) |             >> Pooling(mean_pool) | ||||||
|             >> Residual(zero_init(Maxout(hidden_width, hidden_width))) |             >> Residual(zero_init(Maxout(hidden_width, hidden_width))) | ||||||
|             >> zero_init(Affine(context_width, hidden_width)) |             >> zero_init(Affine(context_width, hidden_width, drop_factor=0.0)) | ||||||
|         ) |         ) | ||||||
| 
 | 
 | ||||||
|         model.tok2vec = tok2vec |         model.tok2vec = tok2vec | ||||||
| 
 |         model.nO = context_width | ||||||
|     model.tok2vec = tok2vec |  | ||||||
|     model.tok2vec.nO = context_width |  | ||||||
|     model.nO = 1 |  | ||||||
|     return model |     return model | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -124,7 +124,8 @@ class Errors(object): | ||||||
|     E016 = ("MultitaskObjective target should be function or one of: dep, " |     E016 = ("MultitaskObjective target should be function or one of: dep, " | ||||||
|             "tag, ent, dep_tag_offset, ent_tag.") |             "tag, ent, dep_tag_offset, ent_tag.") | ||||||
|     E017 = ("Can only add unicode or bytes. Got type: {value_type}") |     E017 = ("Can only add unicode or bytes. Got type: {value_type}") | ||||||
|     E018 = ("Can't retrieve string for hash '{hash_value}'.") |     E018 = ("Can't retrieve string for hash '{hash_value}'. This usually refers " | ||||||
|  |             "to an issue with the `Vocab` or `StringStore`.") | ||||||
|     E019 = ("Can't create transition with unknown action ID: {action}. Action " |     E019 = ("Can't create transition with unknown action ID: {action}. Action " | ||||||
|             "IDs are enumerated in spacy/syntax/{src}.pyx.") |             "IDs are enumerated in spacy/syntax/{src}.pyx.") | ||||||
|     E020 = ("Could not find a gold-standard action to supervise the " |     E020 = ("Could not find a gold-standard action to supervise the " | ||||||
|  | @ -420,7 +421,12 @@ class Errors(object): | ||||||
|     E151 = ("Trying to call nlp.update without required annotation types. " |     E151 = ("Trying to call nlp.update without required annotation types. " | ||||||
|             "Expected top-level keys: {expected_keys}." |             "Expected top-level keys: {expected_keys}." | ||||||
|             " Got: {unexpected_keys}.") |             " Got: {unexpected_keys}.") | ||||||
| 
 |     E152 = ("The `nlp` object should have a pre-trained `ner` component.") | ||||||
|  |     E153 = ("Either provide a path to a preprocessed training directory, " | ||||||
|  |             "or to the original Wikipedia XML dump.") | ||||||
|  |     E154 = ("Either the `nlp` model or the `vocab` should be specified.") | ||||||
|  |     E155 = ("The `nlp` object should have access to pre-trained word vectors, cf. " | ||||||
|  |             "https://spacy.io/usage/models#languages.") | ||||||
| 
 | 
 | ||||||
| @add_codes | @add_codes | ||||||
| class TempErrors(object): | class TempErrors(object): | ||||||
|  |  | ||||||
							
								
								
									
										14
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							
							
						
						
									
										14
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							|  | @ -19,6 +19,13 @@ from libcpp.vector cimport vector | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef class Candidate: | cdef class Candidate: | ||||||
|  |     """A `Candidate` object refers to a textual mention (`alias`) that may or may not be resolved | ||||||
|  |     to a specific `entity` from a Knowledge Base. This will be used as input for the entity linking | ||||||
|  |     algorithm which will disambiguate the various candidates to the correct one. | ||||||
|  |     Each candidate (alias, entity) pair is assigned to a certain prior probability. | ||||||
|  | 
 | ||||||
|  |     DOCS: https://spacy.io/api/candidate | ||||||
|  |     """ | ||||||
| 
 | 
 | ||||||
|     def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob): |     def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob): | ||||||
|         self.kb = kb |         self.kb = kb | ||||||
|  | @ -62,8 +69,13 @@ cdef class Candidate: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef class KnowledgeBase: | cdef class KnowledgeBase: | ||||||
|  |     """A `KnowledgeBase` instance stores unique identifiers for entities and their textual aliases, | ||||||
|  |     to support entity linking of named entities to real-world concepts. | ||||||
| 
 | 
 | ||||||
|     def __init__(self, Vocab vocab, entity_vector_length): |     DOCS: https://spacy.io/api/kb | ||||||
|  |     """ | ||||||
|  | 
 | ||||||
|  |     def __init__(self, Vocab vocab, entity_vector_length=64): | ||||||
|         self.vocab = vocab |         self.vocab = vocab | ||||||
|         self.mem = Pool() |         self.mem = Pool() | ||||||
|         self.entity_vector_length = entity_vector_length |         self.entity_vector_length = entity_vector_length | ||||||
|  |  | ||||||
|  | @ -14,6 +14,8 @@ from thinc.neural.util import to_categorical | ||||||
| from thinc.neural.util import get_array_module | from thinc.neural.util import get_array_module | ||||||
| 
 | 
 | ||||||
| from spacy.kb import KnowledgeBase | from spacy.kb import KnowledgeBase | ||||||
|  | 
 | ||||||
|  | from spacy.cli.pretrain import get_cossim_loss | ||||||
| from .functions import merge_subtokens | from .functions import merge_subtokens | ||||||
| from ..tokens.doc cimport Doc | from ..tokens.doc cimport Doc | ||||||
| from ..syntax.nn_parser cimport Parser | from ..syntax.nn_parser cimport Parser | ||||||
|  | @ -1102,7 +1104,7 @@ cdef class EntityRecognizer(Parser): | ||||||
| class EntityLinker(Pipe): | class EntityLinker(Pipe): | ||||||
|     """Pipeline component for named entity linking. |     """Pipeline component for named entity linking. | ||||||
| 
 | 
 | ||||||
|     DOCS: TODO |     DOCS: https://spacy.io/api/entitylinker | ||||||
|     """ |     """ | ||||||
|     name = 'entity_linker' |     name = 'entity_linker' | ||||||
|     NIL = "NIL"  # string used to refer to a non-existing link |     NIL = "NIL"  # string used to refer to a non-existing link | ||||||
|  | @ -1121,9 +1123,6 @@ class EntityLinker(Pipe): | ||||||
|         self.model = True |         self.model = True | ||||||
|         self.kb = None |         self.kb = None | ||||||
|         self.cfg = dict(cfg) |         self.cfg = dict(cfg) | ||||||
|         self.sgd_context = None |  | ||||||
|         if not self.cfg.get("context_width"): |  | ||||||
|             self.cfg["context_width"] = 128 |  | ||||||
| 
 | 
 | ||||||
|     def set_kb(self, kb): |     def set_kb(self, kb): | ||||||
|         self.kb = kb |         self.kb = kb | ||||||
|  | @ -1144,7 +1143,6 @@ class EntityLinker(Pipe): | ||||||
| 
 | 
 | ||||||
|         if self.model is True: |         if self.model is True: | ||||||
|             self.model = self.Model(**self.cfg) |             self.model = self.Model(**self.cfg) | ||||||
|             self.sgd_context = self.create_optimizer() |  | ||||||
| 
 | 
 | ||||||
|         if sgd is None: |         if sgd is None: | ||||||
|             sgd = self.create_optimizer() |             sgd = self.create_optimizer() | ||||||
|  | @ -1170,12 +1168,6 @@ class EntityLinker(Pipe): | ||||||
|             golds = [golds] |             golds = [golds] | ||||||
| 
 | 
 | ||||||
|         context_docs = [] |         context_docs = [] | ||||||
|         entity_encodings = [] |  | ||||||
| 
 |  | ||||||
|         priors = [] |  | ||||||
|         type_vectors = [] |  | ||||||
| 
 |  | ||||||
|         type_to_int = self.cfg.get("type_to_int", dict()) |  | ||||||
| 
 | 
 | ||||||
|         for doc, gold in zip(docs, golds): |         for doc, gold in zip(docs, golds): | ||||||
|             ents_by_offset = dict() |             ents_by_offset = dict() | ||||||
|  | @ -1184,49 +1176,38 @@ class EntityLinker(Pipe): | ||||||
|             for entity, kb_dict in gold.links.items(): |             for entity, kb_dict in gold.links.items(): | ||||||
|                 start, end = entity |                 start, end = entity | ||||||
|                 mention = doc.text[start:end] |                 mention = doc.text[start:end] | ||||||
|  | 
 | ||||||
|                 for kb_id, value in kb_dict.items(): |                 for kb_id, value in kb_dict.items(): | ||||||
|                     entity_encoding = self.kb.get_vector(kb_id) |                     # Currently only training on the positive instances | ||||||
|                     prior_prob = self.kb.get_prior_prob(kb_id, mention) |                     if value: | ||||||
| 
 |  | ||||||
|                     gold_ent = ents_by_offset["{}_{}".format(start, end)] |  | ||||||
|                     if gold_ent is None: |  | ||||||
|                         raise RuntimeError(Errors.E147.format(method="update", msg="gold entity not found")) |  | ||||||
| 
 |  | ||||||
|                     type_vector = [0 for i in range(len(type_to_int))] |  | ||||||
|                     if len(type_to_int) > 0: |  | ||||||
|                         type_vector[type_to_int[gold_ent.label_]] = 1 |  | ||||||
| 
 |  | ||||||
|                     # store data |  | ||||||
|                     entity_encodings.append(entity_encoding) |  | ||||||
|                         context_docs.append(doc) |                         context_docs.append(doc) | ||||||
|                     type_vectors.append(type_vector) |  | ||||||
| 
 | 
 | ||||||
|                     if self.cfg.get("prior_weight", 1) > 0: |         context_encodings, bp_context = self.model.begin_update(context_docs, drop=drop) | ||||||
|                         priors.append([prior_prob]) |         loss, d_scores = self.get_similarity_loss(scores=context_encodings, golds=golds, docs=None) | ||||||
|                     else: |         bp_context(d_scores, sgd=sgd) | ||||||
|                         priors.append([0]) |  | ||||||
| 
 |  | ||||||
|         if len(entity_encodings) > 0: |  | ||||||
|             if not (len(priors) == len(entity_encodings) == len(context_docs) == len(type_vectors)): |  | ||||||
|                 raise RuntimeError(Errors.E147.format(method="update", msg="vector lengths not equal")) |  | ||||||
| 
 |  | ||||||
|             entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") |  | ||||||
| 
 |  | ||||||
|             context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop) |  | ||||||
|             mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i] |  | ||||||
|                                  for i in range(len(entity_encodings))] |  | ||||||
|             pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop) |  | ||||||
| 
 |  | ||||||
|             loss, d_scores = self.get_loss(scores=pred, golds=golds, docs=docs) |  | ||||||
|             mention_gradient = bp_mention(d_scores, sgd=sgd) |  | ||||||
| 
 |  | ||||||
|             context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient] |  | ||||||
|             bp_context(self.model.ops.asarray(context_gradients, dtype="float32"), sgd=self.sgd_context) |  | ||||||
| 
 | 
 | ||||||
|         if losses is not None: |         if losses is not None: | ||||||
|             losses[self.name] += loss |             losses[self.name] += loss | ||||||
|         return loss |         return loss | ||||||
|         return 0 | 
 | ||||||
|  |     def get_similarity_loss(self, docs, golds, scores): | ||||||
|  |         entity_encodings = [] | ||||||
|  |         for gold in golds: | ||||||
|  |             for entity, kb_dict in gold.links.items(): | ||||||
|  |                 for kb_id, value in kb_dict.items(): | ||||||
|  |                     # this loss function assumes we're only using positive examples | ||||||
|  |                     if value: | ||||||
|  |                         entity_encoding = self.kb.get_vector(kb_id) | ||||||
|  |                         entity_encodings.append(entity_encoding) | ||||||
|  | 
 | ||||||
|  |         entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") | ||||||
|  | 
 | ||||||
|  |         if scores.shape != entity_encodings.shape: | ||||||
|  |             raise RuntimeError(Errors.E147.format(method="get_loss", msg="gold entities do not match up")) | ||||||
|  | 
 | ||||||
|  |         loss, gradients = get_cossim_loss(yh=scores, y=entity_encodings) | ||||||
|  |         loss = loss / len(entity_encodings) | ||||||
|  |         return loss, gradients | ||||||
| 
 | 
 | ||||||
|     def get_loss(self, docs, golds, scores): |     def get_loss(self, docs, golds, scores): | ||||||
|         cats = [] |         cats = [] | ||||||
|  | @ -1271,20 +1252,17 @@ class EntityLinker(Pipe): | ||||||
|         if isinstance(docs, Doc): |         if isinstance(docs, Doc): | ||||||
|             docs = [docs] |             docs = [docs] | ||||||
| 
 | 
 | ||||||
|         context_encodings = self.model.tok2vec(docs) |         context_encodings = self.model(docs) | ||||||
|         xp = get_array_module(context_encodings) |         xp = get_array_module(context_encodings) | ||||||
| 
 | 
 | ||||||
|         type_to_int = self.cfg.get("type_to_int", dict()) |  | ||||||
| 
 |  | ||||||
|         for i, doc in enumerate(docs): |         for i, doc in enumerate(docs): | ||||||
|             if len(doc) > 0: |             if len(doc) > 0: | ||||||
|                 # currently, the context is the same for each entity in a sentence (should be refined) |                 # currently, the context is the same for each entity in a sentence (should be refined) | ||||||
|                 context_encoding = context_encodings[i] |                 context_encoding = context_encodings[i] | ||||||
|  |                 context_enc_t = context_encoding.T | ||||||
|  |                 norm_1 = xp.linalg.norm(context_enc_t) | ||||||
|                 for ent in doc.ents: |                 for ent in doc.ents: | ||||||
|                     entity_count += 1 |                     entity_count += 1 | ||||||
|                     type_vector = [0 for i in range(len(type_to_int))] |  | ||||||
|                     if len(type_to_int) > 0: |  | ||||||
|                         type_vector[type_to_int[ent.label_]] = 1 |  | ||||||
| 
 | 
 | ||||||
|                     candidates = self.kb.get_candidates(ent.text) |                     candidates = self.kb.get_candidates(ent.text) | ||||||
|                     if not candidates: |                     if not candidates: | ||||||
|  | @ -1293,20 +1271,23 @@ class EntityLinker(Pipe): | ||||||
|                     else: |                     else: | ||||||
|                         random.shuffle(candidates) |                         random.shuffle(candidates) | ||||||
| 
 | 
 | ||||||
|                         # this will set the prior probabilities to 0 (just like in training) if their weight is 0 |                         # this will set all prior probabilities to 0 if they should be excluded from the model | ||||||
|                         prior_probs = xp.asarray([[c.prior_prob] for c in candidates]) |                         prior_probs = xp.asarray([c.prior_prob for c in candidates]) | ||||||
|                         prior_probs *= self.cfg.get("prior_weight", 1) |                         if not self.cfg.get("incl_prior", True): | ||||||
|  |                             prior_probs = xp.asarray([[0.0] for c in candidates]) | ||||||
|                         scores = prior_probs |                         scores = prior_probs | ||||||
| 
 | 
 | ||||||
|                         if self.cfg.get("context_weight", 1) > 0: |                         # add in similarity from the context | ||||||
|  |                         if self.cfg.get("incl_context", True): | ||||||
|                             entity_encodings = xp.asarray([c.entity_vector for c in candidates]) |                             entity_encodings = xp.asarray([c.entity_vector for c in candidates]) | ||||||
|  |                             norm_2 = xp.linalg.norm(entity_encodings, axis=1) | ||||||
|  | 
 | ||||||
|                             if len(entity_encodings) != len(prior_probs): |                             if len(entity_encodings) != len(prior_probs): | ||||||
|                                 raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length")) |                                 raise RuntimeError(Errors.E147.format(method="predict", msg="vectors not of equal length")) | ||||||
| 
 | 
 | ||||||
|                             mention_encodings = [list(context_encoding) + list(entity_encodings[i]) |                              # cosine similarity | ||||||
|                                                  + list(prior_probs[i]) + type_vector |                             sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2) | ||||||
|                                                  for i in range(len(entity_encodings))] |                             scores = prior_probs + sims - (prior_probs*sims) | ||||||
|                             scores = self.model(self.model.ops.asarray(mention_encodings, dtype="float32")) |  | ||||||
| 
 | 
 | ||||||
|                         # TODO: thresholding |                         # TODO: thresholding | ||||||
|                         best_index = scores.argmax() |                         best_index = scores.argmax() | ||||||
|  |  | ||||||
|  | @ -23,9 +23,9 @@ def test_kb_valid_entities(nlp): | ||||||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) |     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) | ||||||
| 
 | 
 | ||||||
|     # adding entities |     # adding entities | ||||||
|     mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[8, 4, 3]) |     mykb.add_entity(entity="Q1", freq=19, entity_vector=[8, 4, 3]) | ||||||
|     mykb.add_entity(entity="Q2", freq=0.5, entity_vector=[2, 1, 0]) |     mykb.add_entity(entity="Q2", freq=5, entity_vector=[2, 1, 0]) | ||||||
|     mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[-1, -6, 5]) |     mykb.add_entity(entity="Q3", freq=25, entity_vector=[-1, -6, 5]) | ||||||
| 
 | 
 | ||||||
|     # adding aliases |     # adding aliases | ||||||
|     mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2]) |     mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.2]) | ||||||
|  | @ -52,9 +52,9 @@ def test_kb_invalid_entities(nlp): | ||||||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) |     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||||
| 
 | 
 | ||||||
|     # adding entities |     # adding entities | ||||||
|     mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1]) |     mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) | ||||||
|     mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2]) |     mykb.add_entity(entity="Q2", freq=5, entity_vector=[2]) | ||||||
|     mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3]) |     mykb.add_entity(entity="Q3", freq=25, entity_vector=[3]) | ||||||
| 
 | 
 | ||||||
|     # adding aliases - should fail because one of the given IDs is not valid |     # adding aliases - should fail because one of the given IDs is not valid | ||||||
|     with pytest.raises(ValueError): |     with pytest.raises(ValueError): | ||||||
|  | @ -68,9 +68,9 @@ def test_kb_invalid_probabilities(nlp): | ||||||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) |     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||||
| 
 | 
 | ||||||
|     # adding entities |     # adding entities | ||||||
|     mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1]) |     mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) | ||||||
|     mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2]) |     mykb.add_entity(entity="Q2", freq=5, entity_vector=[2]) | ||||||
|     mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3]) |     mykb.add_entity(entity="Q3", freq=25, entity_vector=[3]) | ||||||
| 
 | 
 | ||||||
|     # adding aliases - should fail because the sum of the probabilities exceeds 1 |     # adding aliases - should fail because the sum of the probabilities exceeds 1 | ||||||
|     with pytest.raises(ValueError): |     with pytest.raises(ValueError): | ||||||
|  | @ -82,9 +82,9 @@ def test_kb_invalid_combination(nlp): | ||||||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) |     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||||
| 
 | 
 | ||||||
|     # adding entities |     # adding entities | ||||||
|     mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1]) |     mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) | ||||||
|     mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2]) |     mykb.add_entity(entity="Q2", freq=5, entity_vector=[2]) | ||||||
|     mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3]) |     mykb.add_entity(entity="Q3", freq=25, entity_vector=[3]) | ||||||
| 
 | 
 | ||||||
|     # adding aliases - should fail because the entities and probabilities vectors are not of equal length |     # adding aliases - should fail because the entities and probabilities vectors are not of equal length | ||||||
|     with pytest.raises(ValueError): |     with pytest.raises(ValueError): | ||||||
|  | @ -98,11 +98,11 @@ def test_kb_invalid_entity_vector(nlp): | ||||||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) |     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) | ||||||
| 
 | 
 | ||||||
|     # adding entities |     # adding entities | ||||||
|     mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1, 2, 3]) |     mykb.add_entity(entity="Q1", freq=19, entity_vector=[1, 2, 3]) | ||||||
| 
 | 
 | ||||||
|     # this should fail because the kb's expected entity vector length is 3 |     # this should fail because the kb's expected entity vector length is 3 | ||||||
|     with pytest.raises(ValueError): |     with pytest.raises(ValueError): | ||||||
|         mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2]) |         mykb.add_entity(entity="Q2", freq=5, entity_vector=[2]) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def test_candidate_generation(nlp): | def test_candidate_generation(nlp): | ||||||
|  | @ -110,9 +110,9 @@ def test_candidate_generation(nlp): | ||||||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) |     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||||
| 
 | 
 | ||||||
|     # adding entities |     # adding entities | ||||||
|     mykb.add_entity(entity="Q1", freq=0.7, entity_vector=[1]) |     mykb.add_entity(entity="Q1", freq=27, entity_vector=[1]) | ||||||
|     mykb.add_entity(entity="Q2", freq=0.2, entity_vector=[2]) |     mykb.add_entity(entity="Q2", freq=12, entity_vector=[2]) | ||||||
|     mykb.add_entity(entity="Q3", freq=0.5, entity_vector=[3]) |     mykb.add_entity(entity="Q3", freq=5, entity_vector=[3]) | ||||||
| 
 | 
 | ||||||
|     # adding aliases |     # adding aliases | ||||||
|     mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]) |     mykb.add_alias(alias="douglas", entities=["Q2", "Q3"], probabilities=[0.8, 0.1]) | ||||||
|  | @ -126,7 +126,7 @@ def test_candidate_generation(nlp): | ||||||
|     # test the content of the candidates |     # test the content of the candidates | ||||||
|     assert mykb.get_candidates("adam")[0].entity_ == "Q2" |     assert mykb.get_candidates("adam")[0].entity_ == "Q2" | ||||||
|     assert mykb.get_candidates("adam")[0].alias_ == "adam" |     assert mykb.get_candidates("adam")[0].alias_ == "adam" | ||||||
|     assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 0.2) |     assert_almost_equal(mykb.get_candidates("adam")[0].entity_freq, 12) | ||||||
|     assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9) |     assert_almost_equal(mykb.get_candidates("adam")[0].prior_prob, 0.9) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -135,8 +135,8 @@ def test_preserving_links_asdoc(nlp): | ||||||
|     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) |     mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) | ||||||
| 
 | 
 | ||||||
|     # adding entities |     # adding entities | ||||||
|     mykb.add_entity(entity="Q1", freq=0.9, entity_vector=[1]) |     mykb.add_entity(entity="Q1", freq=19, entity_vector=[1]) | ||||||
|     mykb.add_entity(entity="Q2", freq=0.8, entity_vector=[1]) |     mykb.add_entity(entity="Q2", freq=8, entity_vector=[1]) | ||||||
| 
 | 
 | ||||||
|     # adding aliases |     # adding aliases | ||||||
|     mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7]) |     mykb.add_alias(alias="Boston", entities=["Q1"], probabilities=[0.7]) | ||||||
|  | @ -154,11 +154,11 @@ def test_preserving_links_asdoc(nlp): | ||||||
|     ruler.add_patterns(patterns) |     ruler.add_patterns(patterns) | ||||||
|     nlp.add_pipe(ruler) |     nlp.add_pipe(ruler) | ||||||
| 
 | 
 | ||||||
|     el_pipe = nlp.create_pipe(name="entity_linker", config={"context_width": 64}) |     el_pipe = nlp.create_pipe(name="entity_linker") | ||||||
|     el_pipe.set_kb(mykb) |     el_pipe.set_kb(mykb) | ||||||
|     el_pipe.begin_training() |     el_pipe.begin_training() | ||||||
|     el_pipe.context_weight = 0 |     el_pipe.incl_context = False | ||||||
|     el_pipe.prior_weight = 1 |     el_pipe.incl_prior = True | ||||||
|     nlp.add_pipe(el_pipe, last=True) |     nlp.add_pipe(el_pipe, last=True) | ||||||
| 
 | 
 | ||||||
|     # test whether the entity links are preserved by the `as_doc()` function |     # test whether the entity links are preserved by the `as_doc()` function | ||||||
|  |  | ||||||
|  | @ -30,10 +30,10 @@ def test_serialize_kb_disk(en_vocab): | ||||||
| def _get_dummy_kb(vocab): | def _get_dummy_kb(vocab): | ||||||
|     kb = KnowledgeBase(vocab=vocab, entity_vector_length=3) |     kb = KnowledgeBase(vocab=vocab, entity_vector_length=3) | ||||||
| 
 | 
 | ||||||
|     kb.add_entity(entity='Q53', freq=0.33, entity_vector=[0, 5, 3]) |     kb.add_entity(entity='Q53', freq=33, entity_vector=[0, 5, 3]) | ||||||
|     kb.add_entity(entity='Q17', freq=0.2, entity_vector=[7, 1, 0]) |     kb.add_entity(entity='Q17', freq=2, entity_vector=[7, 1, 0]) | ||||||
|     kb.add_entity(entity='Q007', freq=0.7, entity_vector=[0, 0, 7]) |     kb.add_entity(entity='Q007', freq=7, entity_vector=[0, 0, 7]) | ||||||
|     kb.add_entity(entity='Q44', freq=0.4, entity_vector=[4, 4, 4]) |     kb.add_entity(entity='Q44', freq=342, entity_vector=[4, 4, 4]) | ||||||
| 
 | 
 | ||||||
|     kb.add_alias(alias='double07', entities=['Q17', 'Q007'], probabilities=[0.1, 0.9]) |     kb.add_alias(alias='double07', entities=['Q17', 'Q007'], probabilities=[0.1, 0.9]) | ||||||
|     kb.add_alias(alias='guy', entities=['Q53', 'Q007', 'Q17', 'Q44'], probabilities=[0.3, 0.3, 0.2, 0.1]) |     kb.add_alias(alias='guy', entities=['Q53', 'Q007', 'Q17', 'Q44'], probabilities=[0.3, 0.3, 0.2, 0.1]) | ||||||
|  | @ -62,13 +62,13 @@ def _check_kb(kb): | ||||||
|     assert len(candidates) == 2 |     assert len(candidates) == 2 | ||||||
| 
 | 
 | ||||||
|     assert candidates[0].entity_ == 'Q007' |     assert candidates[0].entity_ == 'Q007' | ||||||
|     assert 0.6999 < candidates[0].entity_freq < 0.701 |     assert 6.999 < candidates[0].entity_freq < 7.01 | ||||||
|     assert candidates[0].entity_vector == [0, 0, 7] |     assert candidates[0].entity_vector == [0, 0, 7] | ||||||
|     assert candidates[0].alias_ == 'double07' |     assert candidates[0].alias_ == 'double07' | ||||||
|     assert 0.899 < candidates[0].prior_prob < 0.901 |     assert 0.899 < candidates[0].prior_prob < 0.901 | ||||||
| 
 | 
 | ||||||
|     assert candidates[1].entity_ == 'Q17' |     assert candidates[1].entity_ == 'Q17' | ||||||
|     assert 0.199 < candidates[1].entity_freq < 0.201 |     assert 1.99 < candidates[1].entity_freq < 2.01 | ||||||
|     assert candidates[1].entity_vector == [7, 1, 0] |     assert candidates[1].entity_vector == [7, 1, 0] | ||||||
|     assert candidates[1].alias_ == 'double07' |     assert candidates[1].alias_ == 'double07' | ||||||
|     assert 0.099 < candidates[1].prior_prob < 0.101 |     assert 0.099 < candidates[1].prior_prob < 0.101 | ||||||
|  |  | ||||||
|  | @ -546,6 +546,7 @@ cdef class Doc: | ||||||
|             cdef int i |             cdef int i | ||||||
|             for i in range(self.length): |             for i in range(self.length): | ||||||
|                 self.c[i].ent_type = 0 |                 self.c[i].ent_type = 0 | ||||||
|  |                 self.c[i].ent_kb_id = 0 | ||||||
|                 self.c[i].ent_iob = 0  # Means missing. |                 self.c[i].ent_iob = 0  # Means missing. | ||||||
|             cdef attr_t ent_type |             cdef attr_t ent_type | ||||||
|             cdef int start, end |             cdef int start, end | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user