mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	further code cleanup
This commit is contained in:
		
							parent
							
								
									478305cd3f
								
							
						
					
					
						commit
						a31648d28b
					
				|  | @ -1,31 +1,31 @@ | ||||||
| # coding: utf-8 | # coding: utf-8 | ||||||
| from __future__ import unicode_literals | from __future__ import unicode_literals | ||||||
| 
 | 
 | ||||||
| from bin.wiki_entity_linking.train_descriptions import EntityEncoder | from .train_descriptions import EntityEncoder | ||||||
|  | from . import wikidata_processor as wd, wikipedia_processor as wp | ||||||
| from spacy.kb import KnowledgeBase | from spacy.kb import KnowledgeBase | ||||||
| 
 | 
 | ||||||
| import csv | import csv | ||||||
| import datetime | import datetime | ||||||
| 
 | 
 | ||||||
| from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp |  | ||||||
| 
 | 
 | ||||||
| INPUT_DIM = 300  # dimension of pre-trained vectors | INPUT_DIM = 300  # dimension of pre-trained input vectors | ||||||
| DESC_WIDTH = 64 | DESC_WIDTH = 64  # dimension of output entity vectors | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | ||||||
|               entity_def_output, entity_descr_output, |               entity_def_output, entity_descr_output, | ||||||
|               count_input, prior_prob_input, to_print=False): |               count_input, prior_prob_input, wikidata_input): | ||||||
|     # Create the knowledge base from Wikidata entries |     # Create the knowledge base from Wikidata entries | ||||||
|     kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH) |     kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH) | ||||||
| 
 | 
 | ||||||
|     # disable this part of the pipeline when rerunning the KB generation from preprocessed files |     # disable this part of the pipeline when rerunning the KB generation from preprocessed files | ||||||
|     read_raw_data = False |     read_raw_data = True | ||||||
| 
 | 
 | ||||||
|     if read_raw_data: |     if read_raw_data: | ||||||
|         print() |         print() | ||||||
|         print(" * _read_wikidata_entities", datetime.datetime.now()) |         print(" * _read_wikidata_entities", datetime.datetime.now()) | ||||||
|         title_to_id, id_to_descr = wd.read_wikidata_entities_json(limit=None) |         title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input) | ||||||
| 
 | 
 | ||||||
|         # write the title-ID and ID-description mappings to file |         # write the title-ID and ID-description mappings to file | ||||||
|         _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr) |         _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr) | ||||||
|  | @ -40,7 +40,7 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | ||||||
|     print() |     print() | ||||||
|     entity_frequencies = wp.get_all_frequencies(count_input=count_input) |     entity_frequencies = wp.get_all_frequencies(count_input=count_input) | ||||||
| 
 | 
 | ||||||
|     # filter the entities for in the KB by frequency, because there's just too much data otherwise |     # filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise | ||||||
|     filtered_title_to_id = dict() |     filtered_title_to_id = dict() | ||||||
|     entity_list = list() |     entity_list = list() | ||||||
|     description_list = list() |     description_list = list() | ||||||
|  | @ -60,11 +60,10 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | ||||||
|     print() |     print() | ||||||
|     print(" * train entity encoder", datetime.datetime.now()) |     print(" * train entity encoder", datetime.datetime.now()) | ||||||
|     print() |     print() | ||||||
| 
 |  | ||||||
|     encoder = EntityEncoder(nlp, INPUT_DIM, DESC_WIDTH) |     encoder = EntityEncoder(nlp, INPUT_DIM, DESC_WIDTH) | ||||||
|     encoder.train(description_list=description_list, to_print=True) |     encoder.train(description_list=description_list, to_print=True) | ||||||
|     print() |  | ||||||
| 
 | 
 | ||||||
|  |     print() | ||||||
|     print(" * get entity embeddings", datetime.datetime.now()) |     print(" * get entity embeddings", datetime.datetime.now()) | ||||||
|     print() |     print() | ||||||
|     embeddings = encoder.apply_encoder(description_list) |     embeddings = encoder.apply_encoder(description_list) | ||||||
|  | @ -80,12 +79,10 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | ||||||
|                  max_entities_per_alias=max_entities_per_alias, min_occ=min_occ, |                  max_entities_per_alias=max_entities_per_alias, min_occ=min_occ, | ||||||
|                  prior_prob_input=prior_prob_input) |                  prior_prob_input=prior_prob_input) | ||||||
| 
 | 
 | ||||||
|     if to_print: |  | ||||||
|     print() |     print() | ||||||
|     print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases()) |     print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases()) | ||||||
| 
 | 
 | ||||||
|     print("done with kb", datetime.datetime.now()) |     print("done with kb", datetime.datetime.now()) | ||||||
| 
 |  | ||||||
|     return kb |     return kb | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -94,6 +91,7 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_ | ||||||
|         id_file.write("WP_title" + "|" + "WD_id" + "\n") |         id_file.write("WP_title" + "|" + "WD_id" + "\n") | ||||||
|         for title, qid in title_to_id.items(): |         for title, qid in title_to_id.items(): | ||||||
|             id_file.write(title + "|" + str(qid) + "\n") |             id_file.write(title + "|" + str(qid) + "\n") | ||||||
|  | 
 | ||||||
|     with open(entity_descr_output, mode='w', encoding='utf8') as descr_file: |     with open(entity_descr_output, mode='w', encoding='utf8') as descr_file: | ||||||
|         descr_file.write("WD_id" + "|" + "description" + "\n") |         descr_file.write("WD_id" + "|" + "description" + "\n") | ||||||
|         for qid, descr in id_to_descr.items(): |         for qid, descr in id_to_descr.items(): | ||||||
|  | @ -108,7 +106,6 @@ def get_entity_to_id(entity_def_output): | ||||||
|         next(csvreader) |         next(csvreader) | ||||||
|         for row in csvreader: |         for row in csvreader: | ||||||
|             entity_to_id[row[0]] = row[1] |             entity_to_id[row[0]] = row[1] | ||||||
| 
 |  | ||||||
|     return entity_to_id |     return entity_to_id | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | @ -120,16 +117,12 @@ def _get_id_to_description(entity_descr_output): | ||||||
|         next(csvreader) |         next(csvreader) | ||||||
|         for row in csvreader: |         for row in csvreader: | ||||||
|             id_to_desc[row[0]] = row[1] |             id_to_desc[row[0]] = row[1] | ||||||
| 
 |  | ||||||
|     return id_to_desc |     return id_to_desc | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input, to_print=False): | def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input): | ||||||
|     wp_titles = title_to_id.keys() |     wp_titles = title_to_id.keys() | ||||||
| 
 | 
 | ||||||
|     if to_print: |  | ||||||
|         print("wp titles:", wp_titles) |  | ||||||
| 
 |  | ||||||
|     # adding aliases with prior probabilities |     # adding aliases with prior probabilities | ||||||
|     # we can read this file sequentially, it's sorted by alias, and then by count |     # we can read this file sequentially, it's sorted by alias, and then by count | ||||||
|     with open(prior_prob_input, mode='r', encoding='utf8') as prior_file: |     with open(prior_prob_input, mode='r', encoding='utf8') as prior_file: | ||||||
|  | @ -176,6 +169,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in | ||||||
| 
 | 
 | ||||||
|             line = prior_file.readline() |             line = prior_file.readline() | ||||||
| 
 | 
 | ||||||
|     if to_print: |  | ||||||
|         print("added", kb.get_size_aliases(), "aliases:", kb.get_alias_strings()) |  | ||||||
| 
 |  | ||||||
|  |  | ||||||
|  | @ -32,8 +32,6 @@ class EntityEncoder: | ||||||
|         if self.encoder is None: |         if self.encoder is None: | ||||||
|             raise ValueError("Can not apply encoder before training it") |             raise ValueError("Can not apply encoder before training it") | ||||||
| 
 | 
 | ||||||
|         print("Encoding", len(description_list), "entities") |  | ||||||
| 
 |  | ||||||
|         batch_size = 100000 |         batch_size = 100000 | ||||||
| 
 | 
 | ||||||
|         start = 0 |         start = 0 | ||||||
|  | @ -48,13 +46,11 @@ class EntityEncoder: | ||||||
| 
 | 
 | ||||||
|             start = start + batch_size |             start = start + batch_size | ||||||
|             stop = min(stop + batch_size, len(description_list)) |             stop = min(stop + batch_size, len(description_list)) | ||||||
|             print("encoded :", len(encodings)) |  | ||||||
| 
 | 
 | ||||||
|         return encodings |         return encodings | ||||||
| 
 | 
 | ||||||
|     def train(self, description_list, to_print=False): |     def train(self, description_list, to_print=False): | ||||||
|         processed, loss = self._train_model(description_list) |         processed, loss = self._train_model(description_list) | ||||||
| 
 |  | ||||||
|         if to_print: |         if to_print: | ||||||
|             print("Trained on", processed, "entities across", self.EPOCHS, "epochs") |             print("Trained on", processed, "entities across", self.EPOCHS, "epochs") | ||||||
|             print("Final loss:", loss) |             print("Final loss:", loss) | ||||||
|  | @ -111,15 +107,12 @@ class EntityEncoder: | ||||||
|                 Affine(hidden_with, orig_width) |                 Affine(hidden_with, orig_width) | ||||||
|             ) |             ) | ||||||
|             self.model = self.encoder >> zero_init(Affine(orig_width, hidden_with, drop_factor=0.0)) |             self.model = self.encoder >> zero_init(Affine(orig_width, hidden_with, drop_factor=0.0)) | ||||||
| 
 |  | ||||||
|         self.sgd = create_default_optimizer(self.model.ops) |         self.sgd = create_default_optimizer(self.model.ops) | ||||||
| 
 | 
 | ||||||
|     def _update(self, vectors): |     def _update(self, vectors): | ||||||
|         predictions, bp_model = self.model.begin_update(np.asarray(vectors), drop=self.DROP) |         predictions, bp_model = self.model.begin_update(np.asarray(vectors), drop=self.DROP) | ||||||
| 
 |  | ||||||
|         loss, d_scores = self._get_loss(scores=predictions, golds=np.asarray(vectors)) |         loss, d_scores = self._get_loss(scores=predictions, golds=np.asarray(vectors)) | ||||||
|         bp_model(d_scores, sgd=self.sgd) |         bp_model(d_scores, sgd=self.sgd) | ||||||
| 
 |  | ||||||
|         return loss / len(vectors) |         return loss / len(vectors) | ||||||
| 
 | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|  |  | ||||||
|  | @ -18,23 +18,21 @@ Gold-standard entities are stored in one file in standoff format (by character o | ||||||
| ENTITY_FILE = "gold_entities_1000000.csv"   # use this file for faster processing | ENTITY_FILE = "gold_entities_1000000.csv"   # use this file for faster processing | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def create_training(entity_def_input, training_output): | def create_training(wikipedia_input, entity_def_input, training_output): | ||||||
|     wp_to_id = kb_creator.get_entity_to_id(entity_def_input) |     wp_to_id = kb_creator.get_entity_to_id(entity_def_input) | ||||||
|     _process_wikipedia_texts(wp_to_id, training_output, limit=None) |     _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _process_wikipedia_texts(wp_to_id, training_output, limit=None): | def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None): | ||||||
|     """ |     """ | ||||||
|     Read the XML wikipedia data to parse out training data: |     Read the XML wikipedia data to parse out training data: | ||||||
|     raw text data + positive instances |     raw text data + positive instances | ||||||
|     """ |     """ | ||||||
| 
 |  | ||||||
|     title_regex = re.compile(r'(?<=<title>).*(?=</title>)') |     title_regex = re.compile(r'(?<=<title>).*(?=</title>)') | ||||||
|     id_regex = re.compile(r'(?<=<id>)\d*(?=</id>)') |     id_regex = re.compile(r'(?<=<id>)\d*(?=</id>)') | ||||||
| 
 | 
 | ||||||
|     read_ids = set() |     read_ids = set() | ||||||
| 
 |     entityfile_loc = training_output / ENTITY_FILE | ||||||
|     entityfile_loc = training_output + "/" + ENTITY_FILE |  | ||||||
|     with open(entityfile_loc, mode="w", encoding='utf8') as entityfile: |     with open(entityfile_loc, mode="w", encoding='utf8') as entityfile: | ||||||
|         # write entity training header file |         # write entity training header file | ||||||
|         _write_training_entity(outputfile=entityfile, |         _write_training_entity(outputfile=entityfile, | ||||||
|  | @ -44,7 +42,7 @@ def _process_wikipedia_texts(wp_to_id, training_output, limit=None): | ||||||
|                                start="start", |                                start="start", | ||||||
|                                end="end") |                                end="end") | ||||||
| 
 | 
 | ||||||
|         with bz2.open(wp.ENWIKI_DUMP, mode='rb') as file: |         with bz2.open(wikipedia_input, mode='rb') as file: | ||||||
|             line = file.readline() |             line = file.readline() | ||||||
|             cnt = 0 |             cnt = 0 | ||||||
|             article_text = "" |             article_text = "" | ||||||
|  | @ -134,7 +132,7 @@ def _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_te | ||||||
|     # get the raw text without markup etc, keeping only interwiki links |     # get the raw text without markup etc, keeping only interwiki links | ||||||
|     clean_text = _get_clean_wp_text(text) |     clean_text = _get_clean_wp_text(text) | ||||||
| 
 | 
 | ||||||
|     # read the text char by char to get the right offsets of the interwiki links |     # read the text char by char to get the right offsets for the interwiki links | ||||||
|     final_text = "" |     final_text = "" | ||||||
|     open_read = 0 |     open_read = 0 | ||||||
|     reading_text = True |     reading_text = True | ||||||
|  | @ -274,7 +272,7 @@ def _get_clean_wp_text(article_text): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def _write_training_article(article_id, clean_text, training_output): | def _write_training_article(article_id, clean_text, training_output): | ||||||
|     file_loc = training_output + "/" + str(article_id) + ".txt" |     file_loc = training_output / str(article_id) + ".txt" | ||||||
|     with open(file_loc, mode='w', encoding='utf8') as outputfile: |     with open(file_loc, mode='w', encoding='utf8') as outputfile: | ||||||
|         outputfile.write(clean_text) |         outputfile.write(clean_text) | ||||||
| 
 | 
 | ||||||
|  | @ -289,11 +287,10 @@ def is_dev(article_id): | ||||||
| 
 | 
 | ||||||
| def read_training(nlp, training_dir, dev, limit): | def read_training(nlp, training_dir, dev, limit): | ||||||
|     # This method provides training examples that correspond to the entity annotations found by the nlp object |     # This method provides training examples that correspond to the entity annotations found by the nlp object | ||||||
| 
 |     entityfile_loc = training_dir / ENTITY_FILE | ||||||
|     entityfile_loc = training_dir + "/" + ENTITY_FILE |  | ||||||
|     data = [] |     data = [] | ||||||
| 
 | 
 | ||||||
|     # we assume the data is written sequentially |     # assume the data is written sequentially, so we can reuse the article docs | ||||||
|     current_article_id = None |     current_article_id = None | ||||||
|     current_doc = None |     current_doc = None | ||||||
|     ents_by_offset = dict() |     ents_by_offset = dict() | ||||||
|  | @ -347,10 +344,10 @@ def read_training(nlp, training_dir, dev, limit): | ||||||
|                                 gold_end = int(end) - found_ent.sent.start_char |                                 gold_end = int(end) - found_ent.sent.start_char | ||||||
|                                 gold_entities = list() |                                 gold_entities = list() | ||||||
|                                 gold_entities.append((gold_start, gold_end, wp_title)) |                                 gold_entities.append((gold_start, gold_end, wp_title)) | ||||||
|                                 gold = GoldParse(doc=current_doc, links=gold_entities) |                                 gold = GoldParse(doc=sent, links=gold_entities) | ||||||
|                                 data.append((sent, gold)) |                                 data.append((sent, gold)) | ||||||
|                                 total_entities += 1 |                                 total_entities += 1 | ||||||
|                                 if len(data) % 500 == 0: |                                 if len(data) % 2500 == 0: | ||||||
|                                     print(" -read", total_entities, "entities") |                                     print(" -read", total_entities, "entities") | ||||||
| 
 | 
 | ||||||
|     print(" -read", total_entities, "entities") |     print(" -read", total_entities, "entities") | ||||||
|  |  | ||||||
|  | @ -5,17 +5,15 @@ import bz2 | ||||||
| import json | import json | ||||||
| import datetime | import datetime | ||||||
| 
 | 
 | ||||||
| # TODO: remove hardcoded paths |  | ||||||
| WIKIDATA_JSON = 'C:/Users/Sofie/Documents/data/wikidata/wikidata-20190304-all.json.bz2' |  | ||||||
| 
 | 
 | ||||||
| 
 | def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): | ||||||
| def read_wikidata_entities_json(limit=None, to_print=False): |  | ||||||
|     # Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. |     # Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. | ||||||
|  |     # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ | ||||||
| 
 | 
 | ||||||
|     lang = 'en' |     lang = 'en' | ||||||
|     site_filter = 'enwiki' |     site_filter = 'enwiki' | ||||||
| 
 | 
 | ||||||
|     # filter currently disabled to get ALL data |     # properties filter (currently disabled to get ALL data) | ||||||
|     prop_filter = dict() |     prop_filter = dict() | ||||||
|     # prop_filter = {'P31': {'Q5', 'Q15632617'}}     # currently defined as OR: one property suffices to be selected |     # prop_filter = {'P31': {'Q5', 'Q15632617'}}     # currently defined as OR: one property suffices to be selected | ||||||
| 
 | 
 | ||||||
|  | @ -30,7 +28,7 @@ def read_wikidata_entities_json(limit=None, to_print=False): | ||||||
|     parse_aliases = False |     parse_aliases = False | ||||||
|     parse_claims = False |     parse_claims = False | ||||||
| 
 | 
 | ||||||
|     with bz2.open(WIKIDATA_JSON, mode='rb') as file: |     with bz2.open(wikidata_file, mode='rb') as file: | ||||||
|         line = file.readline() |         line = file.readline() | ||||||
|         cnt = 0 |         cnt = 0 | ||||||
|         while line and (not limit or cnt < limit): |         while line and (not limit or cnt < limit): | ||||||
|  |  | ||||||
|  | @ -11,11 +11,6 @@ Process a Wikipedia dump to calculate entity frequencies and prior probabilities | ||||||
| Write these results to file for downstream KB and training data generation. | Write these results to file for downstream KB and training data generation. | ||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| # TODO: remove hardcoded paths |  | ||||||
| ENWIKI_DUMP = 'C:/Users/Sofie/Documents/data/wikipedia/enwiki-20190320-pages-articles-multistream.xml.bz2' |  | ||||||
| ENWIKI_INDEX = 'C:/Users/Sofie/Documents/data/wikipedia/enwiki-20190320-pages-articles-multistream-index.txt.bz2' |  | ||||||
| 
 |  | ||||||
| map_alias_to_link = dict() | map_alias_to_link = dict() | ||||||
| 
 | 
 | ||||||
| # these will/should be matched ignoring case | # these will/should be matched ignoring case | ||||||
|  | @ -46,15 +41,13 @@ for ns in wiki_namespaces: | ||||||
| ns_regex = re.compile(ns_regex, re.IGNORECASE) | ns_regex = re.compile(ns_regex, re.IGNORECASE) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def read_wikipedia_prior_probs(prior_prob_output): | def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output): | ||||||
|     """ |     """ | ||||||
|     Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities |     Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities. | ||||||
|     The full file takes about 2h to parse 1100M lines (update printed every 5M lines). |     The full file takes about 2h to parse 1100M lines. | ||||||
|     It works relatively fast because we don't care about which article we parsed the interwiki from, |     It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from. | ||||||
|     we just process line by line. |  | ||||||
|     """ |     """ | ||||||
| 
 |     with bz2.open(wikipedia_input, mode='rb') as file: | ||||||
|     with bz2.open(ENWIKI_DUMP, mode='rb') as file: |  | ||||||
|         line = file.readline() |         line = file.readline() | ||||||
|         cnt = 0 |         cnt = 0 | ||||||
|         while line: |         while line: | ||||||
|  | @ -70,7 +63,7 @@ def read_wikipedia_prior_probs(prior_prob_output): | ||||||
|             line = file.readline() |             line = file.readline() | ||||||
|             cnt += 1 |             cnt += 1 | ||||||
| 
 | 
 | ||||||
|     # write all aliases and their entities and occurrences to file |     # write all aliases and their entities and count occurrences to file | ||||||
|     with open(prior_prob_output, mode='w', encoding='utf8') as outputfile: |     with open(prior_prob_output, mode='w', encoding='utf8') as outputfile: | ||||||
|         outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n") |         outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n") | ||||||
|         for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]): |         for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]): | ||||||
|  | @ -108,7 +101,7 @@ def get_wp_links(text): | ||||||
|         if ns_regex.match(match): |         if ns_regex.match(match): | ||||||
|             pass  # ignore namespaces at the beginning of the string |             pass  # ignore namespaces at the beginning of the string | ||||||
| 
 | 
 | ||||||
|         # this is a simple link, with the alias the same as the mention |         # this is a simple [[link]], with the alias the same as the mention | ||||||
|         elif "|" not in match: |         elif "|" not in match: | ||||||
|             aliases.append(match) |             aliases.append(match) | ||||||
|             entities.append(match) |             entities.append(match) | ||||||
|  |  | ||||||
|  | @ -2,35 +2,45 @@ | ||||||
| from __future__ import unicode_literals | from __future__ import unicode_literals | ||||||
| 
 | 
 | ||||||
| import random | import random | ||||||
| 
 | import datetime | ||||||
| from spacy.util import minibatch, compounding | from pathlib import Path | ||||||
| 
 | 
 | ||||||
| from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp | from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp | ||||||
| from bin.wiki_entity_linking.kb_creator import DESC_WIDTH | from bin.wiki_entity_linking.kb_creator import DESC_WIDTH | ||||||
| 
 | 
 | ||||||
| import spacy | import spacy | ||||||
| from spacy.kb import KnowledgeBase | from spacy.kb import KnowledgeBase | ||||||
| import datetime | from spacy.util import minibatch, compounding | ||||||
| 
 | 
 | ||||||
| """ | """ | ||||||
| Demonstrate how to build a knowledge base from WikiData and run an Entity Linking algorithm. | Demonstrate how to build a knowledge base from WikiData and run an Entity Linking algorithm. | ||||||
| """ | """ | ||||||
| 
 | 
 | ||||||
| PRIOR_PROB = 'C:/Users/Sofie/Documents/data/wikipedia/prior_prob.csv' | ROOT_DIR = Path("C:/Users/Sofie/Documents/data/") | ||||||
| ENTITY_COUNTS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_freq.csv' | OUTPUT_DIR = ROOT_DIR / 'wikipedia' | ||||||
| ENTITY_DEFS = 'C:/Users/Sofie/Documents/data/wikipedia/entity_defs.csv' | TRAINING_DIR = OUTPUT_DIR / 'training_data_nel' | ||||||
| ENTITY_DESCR = 'C:/Users/Sofie/Documents/data/wikipedia/entity_descriptions.csv' |  | ||||||
| 
 | 
 | ||||||
| KB_FILE = 'C:/Users/Sofie/Documents/data/wikipedia/kb_1/kb' | PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv' | ||||||
| NLP_1_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_1' | ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv' | ||||||
| NLP_2_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/nlp_2' | ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv' | ||||||
|  | ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv' | ||||||
| 
 | 
 | ||||||
| TRAINING_DIR = 'C:/Users/Sofie/Documents/data/wikipedia/training_data_nel/' | KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb' | ||||||
|  | NLP_1_DIR = OUTPUT_DIR / 'nlp_1' | ||||||
|  | NLP_2_DIR = OUTPUT_DIR / 'nlp_2' | ||||||
| 
 | 
 | ||||||
|  | # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ | ||||||
|  | WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2' | ||||||
|  | 
 | ||||||
|  | # get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/ | ||||||
|  | ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2' | ||||||
|  | 
 | ||||||
|  | # KB construction parameters | ||||||
| MAX_CANDIDATES = 10 | MAX_CANDIDATES = 10 | ||||||
| MIN_ENTITY_FREQ = 20 | MIN_ENTITY_FREQ = 20 | ||||||
| MIN_PAIR_OCC = 5 | MIN_PAIR_OCC = 5 | ||||||
| 
 | 
 | ||||||
|  | # model training parameters | ||||||
| EPOCHS = 10 | EPOCHS = 10 | ||||||
| DROPOUT = 0.1 | DROPOUT = 0.1 | ||||||
| LEARN_RATE = 0.005 | LEARN_RATE = 0.005 | ||||||
|  | @ -38,6 +48,7 @@ L2 = 1e-6 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def run_pipeline(): | def run_pipeline(): | ||||||
|  |     # set the appropriate booleans to define which parts of the pipeline should be re(run) | ||||||
|     print("START", datetime.datetime.now()) |     print("START", datetime.datetime.now()) | ||||||
|     print() |     print() | ||||||
|     nlp_1 = spacy.load('en_core_web_lg') |     nlp_1 = spacy.load('en_core_web_lg') | ||||||
|  | @ -67,22 +78,19 @@ def run_pipeline(): | ||||||
|     to_write_nlp = False |     to_write_nlp = False | ||||||
|     to_read_nlp = False |     to_read_nlp = False | ||||||
| 
 | 
 | ||||||
|     # STEP 1 : create prior probabilities from WP |     # STEP 1 : create prior probabilities from WP (run only once) | ||||||
|     # run only once ! |  | ||||||
|     if to_create_prior_probs: |     if to_create_prior_probs: | ||||||
|         print("STEP 1: to_create_prior_probs", datetime.datetime.now()) |         print("STEP 1: to_create_prior_probs", datetime.datetime.now()) | ||||||
|         wp.read_wikipedia_prior_probs(prior_prob_output=PRIOR_PROB) |         wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB) | ||||||
|         print() |         print() | ||||||
| 
 | 
 | ||||||
|     # STEP 2 : deduce entity frequencies from WP |     # STEP 2 : deduce entity frequencies from WP (run only once) | ||||||
|     # run only once ! |  | ||||||
|     if to_create_entity_counts: |     if to_create_entity_counts: | ||||||
|         print("STEP 2: to_create_entity_counts", datetime.datetime.now()) |         print("STEP 2: to_create_entity_counts", datetime.datetime.now()) | ||||||
|         wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False) |         wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False) | ||||||
|         print() |         print() | ||||||
| 
 | 
 | ||||||
|     # STEP 3 : create KB and write to file |     # STEP 3 : create KB and write to file (run only once) | ||||||
|     # run only once ! |  | ||||||
|     if to_create_kb: |     if to_create_kb: | ||||||
|         print("STEP 3a: to_create_kb", datetime.datetime.now()) |         print("STEP 3a: to_create_kb", datetime.datetime.now()) | ||||||
|         kb_1 = kb_creator.create_kb(nlp_1, |         kb_1 = kb_creator.create_kb(nlp_1, | ||||||
|  | @ -93,7 +101,7 @@ def run_pipeline(): | ||||||
|                                     entity_descr_output=ENTITY_DESCR, |                                     entity_descr_output=ENTITY_DESCR, | ||||||
|                                     count_input=ENTITY_COUNTS, |                                     count_input=ENTITY_COUNTS, | ||||||
|                                     prior_prob_input=PRIOR_PROB, |                                     prior_prob_input=PRIOR_PROB, | ||||||
|                                     to_print=False) |                                     wikidata_input=WIKIDATA_JSON) | ||||||
|         print("kb entities:", kb_1.get_size_entities()) |         print("kb entities:", kb_1.get_size_entities()) | ||||||
|         print("kb aliases:", kb_1.get_size_aliases()) |         print("kb aliases:", kb_1.get_size_aliases()) | ||||||
|         print() |         print() | ||||||
|  | @ -121,7 +129,9 @@ def run_pipeline(): | ||||||
|     # STEP 5: create a training dataset from WP |     # STEP 5: create a training dataset from WP | ||||||
|     if create_wp_training: |     if create_wp_training: | ||||||
|         print("STEP 5: create training dataset", datetime.datetime.now()) |         print("STEP 5: create training dataset", datetime.datetime.now()) | ||||||
|         training_set_creator.create_training(entity_def_input=ENTITY_DEFS, training_output=TRAINING_DIR) |         training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP, | ||||||
|  |                                              entity_def_input=ENTITY_DEFS, | ||||||
|  |                                              training_output=TRAINING_DIR) | ||||||
| 
 | 
 | ||||||
|     # STEP 6: create and train the entity linking pipe |     # STEP 6: create and train the entity linking pipe | ||||||
|     el_pipe = nlp_2.create_pipe(name='entity_linker', config={}) |     el_pipe = nlp_2.create_pipe(name='entity_linker', config={}) | ||||||
|  | @ -136,7 +146,8 @@ def run_pipeline(): | ||||||
| 
 | 
 | ||||||
|     if train_pipe: |     if train_pipe: | ||||||
|         print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) |         print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) | ||||||
|         train_limit = 25000 |         # define the size (nr of entities) of training and dev set | ||||||
|  |         train_limit = 10000 | ||||||
|         dev_limit = 5000 |         dev_limit = 5000 | ||||||
| 
 | 
 | ||||||
|         train_data = training_set_creator.read_training(nlp=nlp_2, |         train_data = training_set_creator.read_training(nlp=nlp_2, | ||||||
|  | @ -157,7 +168,6 @@ def run_pipeline(): | ||||||
| 
 | 
 | ||||||
|         if not train_data: |         if not train_data: | ||||||
|             print("Did not find any training data") |             print("Did not find any training data") | ||||||
| 
 |  | ||||||
|         else: |         else: | ||||||
|             for itn in range(EPOCHS): |             for itn in range(EPOCHS): | ||||||
|                 random.shuffle(train_data) |                 random.shuffle(train_data) | ||||||
|  | @ -196,7 +206,7 @@ def run_pipeline(): | ||||||
|             print() |             print() | ||||||
| 
 | 
 | ||||||
|             counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2) |             counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2) | ||||||
|             print("dev counts:", sorted(counts)) |             print("dev counts:", sorted(counts.items(), key=lambda x: x[0])) | ||||||
|             print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()]) |             print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()]) | ||||||
|             print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()]) |             print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()]) | ||||||
|             print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()]) |             print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()]) | ||||||
|  | @ -215,7 +225,6 @@ def run_pipeline(): | ||||||
|                 dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe) |                 dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe) | ||||||
|                 print("dev acc context avg:", round(dev_acc_context, 3), |                 print("dev acc context avg:", round(dev_acc_context, 3), | ||||||
|                       [(x, round(y, 3)) for x, y in dev_acc_context_dict.items()]) |                       [(x, round(y, 3)) for x, y in dev_acc_context_dict.items()]) | ||||||
|                 print() |  | ||||||
| 
 | 
 | ||||||
|             # reset for follow-up tests |             # reset for follow-up tests | ||||||
|             el_pipe.context_weight = 1 |             el_pipe.context_weight = 1 | ||||||
|  | @ -227,7 +236,6 @@ def run_pipeline(): | ||||||
|         print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now()) |         print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now()) | ||||||
|         print() |         print() | ||||||
|         run_el_toy_example(nlp=nlp_2) |         run_el_toy_example(nlp=nlp_2) | ||||||
|         print() |  | ||||||
| 
 | 
 | ||||||
|     # STEP 9: write the NLP pipeline (including entity linker) to file |     # STEP 9: write the NLP pipeline (including entity linker) to file | ||||||
|     if to_write_nlp: |     if to_write_nlp: | ||||||
|  | @ -400,26 +408,9 @@ def run_el_toy_example(nlp): | ||||||
|     doc = nlp(text) |     doc = nlp(text) | ||||||
|     print(text) |     print(text) | ||||||
|     for ent in doc.ents: |     for ent in doc.ents: | ||||||
|         print("ent", ent.text, ent.label_, ent.kb_id_) |         print(" ent", ent.text, ent.label_, ent.kb_id_) | ||||||
|     print() |     print() | ||||||
| 
 | 
 | ||||||
|     # Q4426480 is her husband |  | ||||||
|     text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\ |  | ||||||
|            "She loved her husband William King dearly. " |  | ||||||
|     doc = nlp(text) |  | ||||||
|     print(text) |  | ||||||
|     for ent in doc.ents: |  | ||||||
|         print("ent", ent.text, ent.label_, ent.kb_id_) |  | ||||||
|     print() |  | ||||||
| 
 |  | ||||||
|     # Q3568763 is her tutor |  | ||||||
|     text = "Ada Lovelace was the countess of Lovelace. She's known for her programming work on the analytical engine. "\ |  | ||||||
|            "She was tutored by her favorite physics tutor William King." |  | ||||||
|     doc = nlp(text) |  | ||||||
|     print(text) |  | ||||||
|     for ent in doc.ents: |  | ||||||
|         print("ent", ent.text, ent.label_, ent.kb_id_) |  | ||||||
| 
 |  | ||||||
| 
 | 
 | ||||||
| if __name__ == "__main__": | if __name__ == "__main__": | ||||||
|     run_pipeline() |     run_pipeline() | ||||||
|  |  | ||||||
|  | @ -18,7 +18,6 @@ ctypedef vector[float_vec] float_matrix | ||||||
| 
 | 
 | ||||||
| # Object used by the Entity Linker that summarizes one entity-alias candidate combination. | # Object used by the Entity Linker that summarizes one entity-alias candidate combination. | ||||||
| cdef class Candidate: | cdef class Candidate: | ||||||
| 
 |  | ||||||
|     cdef readonly KnowledgeBase kb |     cdef readonly KnowledgeBase kb | ||||||
|     cdef hash_t entity_hash |     cdef hash_t entity_hash | ||||||
|     cdef float entity_freq |     cdef float entity_freq | ||||||
|  | @ -143,7 +142,6 @@ cdef class KnowledgeBase: | ||||||
| 
 | 
 | ||||||
|     cpdef load_bulk(self, loc) |     cpdef load_bulk(self, loc) | ||||||
|     cpdef set_entities(self, entity_list, prob_list, vector_list) |     cpdef set_entities(self, entity_list, prob_list, vector_list) | ||||||
|     cpdef set_aliases(self, alias_list, entities_list, probabilities_list) |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef class Writer: | cdef class Writer: | ||||||
|  |  | ||||||
							
								
								
									
										50
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							
							
						
						
									
										50
									
								
								spacy/kb.pyx
									
									
									
									
									
								
							|  | @ -1,23 +1,16 @@ | ||||||
| # cython: infer_types=True | # cython: infer_types=True | ||||||
| # cython: profile=True | # cython: profile=True | ||||||
| # coding: utf8 | # coding: utf8 | ||||||
| from collections import OrderedDict |  | ||||||
| from pathlib import Path, WindowsPath |  | ||||||
| 
 |  | ||||||
| from cpython.exc cimport PyErr_CheckSignals |  | ||||||
| 
 |  | ||||||
| from spacy import util |  | ||||||
| from spacy.errors import Errors, Warnings, user_warning | from spacy.errors import Errors, Warnings, user_warning | ||||||
| 
 | 
 | ||||||
|  | from pathlib import Path | ||||||
| from cymem.cymem cimport Pool | from cymem.cymem cimport Pool | ||||||
| from preshed.maps cimport PreshMap | from preshed.maps cimport PreshMap | ||||||
| 
 | 
 | ||||||
| from cpython.mem cimport PyMem_Malloc |  | ||||||
| from cpython.exc cimport PyErr_SetFromErrno | from cpython.exc cimport PyErr_SetFromErrno | ||||||
| 
 | 
 | ||||||
| from libc.stdio cimport FILE, fopen, fclose, fread, fwrite, feof, fseek | from libc.stdio cimport fopen, fclose, fread, fwrite, feof, fseek | ||||||
| from libc.stdint cimport int32_t, int64_t | from libc.stdint cimport int32_t, int64_t | ||||||
| from libc.stdlib cimport qsort |  | ||||||
| 
 | 
 | ||||||
| from .typedefs cimport hash_t | from .typedefs cimport hash_t | ||||||
| 
 | 
 | ||||||
|  | @ -25,7 +18,6 @@ from os import path | ||||||
| from libcpp.vector cimport vector | from libcpp.vector cimport vector | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| cdef class Candidate: | cdef class Candidate: | ||||||
| 
 | 
 | ||||||
|     def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob): |     def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob): | ||||||
|  | @ -79,8 +71,6 @@ cdef class KnowledgeBase: | ||||||
|         self._entry_index = PreshMap() |         self._entry_index = PreshMap() | ||||||
|         self._alias_index = PreshMap() |         self._alias_index = PreshMap() | ||||||
| 
 | 
 | ||||||
|         # Should we initialize self._entries and self._aliases_table to specific starting size ? |  | ||||||
| 
 |  | ||||||
|         self.vocab.strings.add("") |         self.vocab.strings.add("") | ||||||
|         self._create_empty_vectors(dummy_hash=self.vocab.strings[""]) |         self._create_empty_vectors(dummy_hash=self.vocab.strings[""]) | ||||||
| 
 | 
 | ||||||
|  | @ -165,47 +155,11 @@ cdef class KnowledgeBase: | ||||||
| 
 | 
 | ||||||
|             i += 1 |             i += 1 | ||||||
| 
 | 
 | ||||||
|     # TODO: this method is untested |  | ||||||
|     cpdef set_aliases(self, alias_list, entities_list, probabilities_list): |  | ||||||
|         nr_aliases = len(alias_list) |  | ||||||
|         self._alias_index = PreshMap(nr_aliases+1) |  | ||||||
|         self._aliases_table = alias_vec(nr_aliases+1) |  | ||||||
| 
 |  | ||||||
|         i = 0 |  | ||||||
|         cdef AliasC alias |  | ||||||
|         cdef int32_t dummy_value = 342 |  | ||||||
|         while i <= nr_aliases: |  | ||||||
|             alias_hash = self.vocab.strings.add(alias_list[i]) |  | ||||||
|             entities = entities_list[i] |  | ||||||
|             probabilities = probabilities_list[i] |  | ||||||
| 
 |  | ||||||
|             nr_candidates = len(entities) |  | ||||||
|             entry_indices = vector[int64_t](nr_candidates) |  | ||||||
|             probs = vector[float](nr_candidates) |  | ||||||
| 
 |  | ||||||
|             for j in range(0, nr_candidates): |  | ||||||
|                 entity = entities[j] |  | ||||||
|                 entity_hash = self.vocab.strings[entity] |  | ||||||
|                 if not entity_hash in self._entry_index: |  | ||||||
|                     raise ValueError(Errors.E134.format(alias=alias, entity=entity)) |  | ||||||
| 
 |  | ||||||
|                 entry_index = <int64_t>self._entry_index.get(entity_hash) |  | ||||||
|                 entry_indices[j] = entry_index |  | ||||||
| 
 |  | ||||||
|             alias.entry_indices = entry_indices |  | ||||||
|             alias.probs = probs |  | ||||||
| 
 |  | ||||||
|             self._aliases_table[i] = alias |  | ||||||
|             self._alias_index[alias_hash] = i |  | ||||||
| 
 |  | ||||||
|             i += 1 |  | ||||||
| 
 |  | ||||||
|     def add_alias(self, unicode alias, entities, probabilities): |     def add_alias(self, unicode alias, entities, probabilities): | ||||||
|         """ |         """ | ||||||
|         For a given alias, add its potential entities and prior probabilies to the KB. |         For a given alias, add its potential entities and prior probabilies to the KB. | ||||||
|         Return the alias_hash at the end |         Return the alias_hash at the end | ||||||
|         """ |         """ | ||||||
| 
 |  | ||||||
|         # Throw an error if the length of entities and probabilities are not the same |         # Throw an error if the length of entities and probabilities are not the same | ||||||
|         if not len(entities) == len(probabilities): |         if not len(entities) == len(probabilities): | ||||||
|             raise ValueError(Errors.E132.format(alias=alias, |             raise ValueError(Errors.E132.format(alias=alias, | ||||||
|  |  | ||||||
|  | @ -1068,8 +1068,6 @@ class EntityLinker(Pipe): | ||||||
|     DOCS: TODO |     DOCS: TODO | ||||||
|     """ |     """ | ||||||
|     name = 'entity_linker' |     name = 'entity_linker' | ||||||
|     context_weight = 1 |  | ||||||
|     prior_weight = 1 |  | ||||||
| 
 | 
 | ||||||
|     @classmethod |     @classmethod | ||||||
|     def Model(cls, **cfg): |     def Model(cls, **cfg): | ||||||
|  | @ -1078,18 +1076,17 @@ class EntityLinker(Pipe): | ||||||
| 
 | 
 | ||||||
|         embed_width = cfg.get("embed_width", 300) |         embed_width = cfg.get("embed_width", 300) | ||||||
|         hidden_width = cfg.get("hidden_width", 128) |         hidden_width = cfg.get("hidden_width", 128) | ||||||
| 
 |         entity_width = cfg.get("entity_width")  # this needs to correspond with the KB entity length | ||||||
|         # no default because this needs to correspond with the KB entity length |  | ||||||
|         entity_width = cfg.get("entity_width") |  | ||||||
| 
 | 
 | ||||||
|         model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=entity_width, **cfg) |         model = build_nel_encoder(in_width=embed_width, hidden_width=hidden_width, end_width=entity_width, **cfg) | ||||||
| 
 |  | ||||||
|         return model |         return model | ||||||
| 
 | 
 | ||||||
|     def __init__(self, **cfg): |     def __init__(self, **cfg): | ||||||
|         self.model = True |         self.model = True | ||||||
|         self.kb = None |         self.kb = None | ||||||
|         self.cfg = dict(cfg) |         self.cfg = dict(cfg) | ||||||
|  |         self.context_weight = cfg.get("context_weight", 1) | ||||||
|  |         self.prior_weight = cfg.get("prior_weight", 1) | ||||||
| 
 | 
 | ||||||
|     def set_kb(self, kb): |     def set_kb(self, kb): | ||||||
|         self.kb = kb |         self.kb = kb | ||||||
|  | @ -1162,7 +1159,6 @@ class EntityLinker(Pipe): | ||||||
|             if losses is not None: |             if losses is not None: | ||||||
|                 losses[self.name] += loss |                 losses[self.name] += loss | ||||||
|             return loss |             return loss | ||||||
| 
 |  | ||||||
|         return 0 |         return 0 | ||||||
| 
 | 
 | ||||||
|     def get_loss(self, docs, golds, scores): |     def get_loss(self, docs, golds, scores): | ||||||
|  | @ -1224,7 +1220,7 @@ class EntityLinker(Pipe): | ||||||
|                             kb_id = c.entity_ |                             kb_id = c.entity_ | ||||||
|                             entity_encoding = c.entity_vector |                             entity_encoding = c.entity_vector | ||||||
|                             sim = float(cosine(np.asarray([entity_encoding]), context_enc_t)) * self.context_weight |                             sim = float(cosine(np.asarray([entity_encoding]), context_enc_t)) * self.context_weight | ||||||
|                             score = prior_prob + sim - (prior_prob*sim)  # put weights on the different factors ? |                             score = prior_prob + sim - (prior_prob*sim) | ||||||
|                             scores.append(score) |                             scores.append(score) | ||||||
| 
 | 
 | ||||||
|                         # TODO: thresholding |                         # TODO: thresholding | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user