mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-01 00:17:44 +03:00 
			
		
		
		
	use pathlib instead
This commit is contained in:
		
							parent
							
								
									400ff342cf
								
							
						
					
					
						commit
						a037206f0a
					
				|  | @ -13,9 +13,17 @@ INPUT_DIM = 300  # dimension of pre-trained input vectors | |||
| DESC_WIDTH = 64  # dimension of output entity vectors | ||||
| 
 | ||||
| 
 | ||||
| def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | ||||
|               entity_def_output, entity_descr_output, | ||||
|               count_input, prior_prob_input, wikidata_input): | ||||
| def create_kb( | ||||
|     nlp, | ||||
|     max_entities_per_alias, | ||||
|     min_entity_freq, | ||||
|     min_occ, | ||||
|     entity_def_output, | ||||
|     entity_descr_output, | ||||
|     count_input, | ||||
|     prior_prob_input, | ||||
|     wikidata_input, | ||||
| ): | ||||
|     # Create the knowledge base from Wikidata entries | ||||
|     kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH) | ||||
| 
 | ||||
|  | @ -28,7 +36,9 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | |||
|         title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input) | ||||
| 
 | ||||
|         # write the title-ID and ID-description mappings to file | ||||
|         _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr) | ||||
|         _write_entity_files( | ||||
|             entity_def_output, entity_descr_output, title_to_id, id_to_descr | ||||
|         ) | ||||
| 
 | ||||
|     else: | ||||
|         # read the mappings from file | ||||
|  | @ -54,8 +64,8 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | |||
|             frequency_list.append(freq) | ||||
|             filtered_title_to_id[title] = entity | ||||
| 
 | ||||
|     print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), | ||||
|           "titles with filter frequency", min_entity_freq) | ||||
|     print(len(title_to_id.keys()), "original titles") | ||||
|     print("kept", len(filtered_title_to_id.keys()), " with frequency", min_entity_freq) | ||||
| 
 | ||||
|     print() | ||||
|     print(" * train entity encoder", datetime.datetime.now()) | ||||
|  | @ -70,14 +80,20 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | |||
| 
 | ||||
|     print() | ||||
|     print(" * adding", len(entity_list), "entities", datetime.datetime.now()) | ||||
|     kb.set_entities(entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings) | ||||
|     kb.set_entities( | ||||
|         entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings | ||||
|     ) | ||||
| 
 | ||||
|     print() | ||||
|     print(" * adding aliases", datetime.datetime.now()) | ||||
|     print() | ||||
|     _add_aliases(kb, title_to_id=filtered_title_to_id, | ||||
|                  max_entities_per_alias=max_entities_per_alias, min_occ=min_occ, | ||||
|                  prior_prob_input=prior_prob_input) | ||||
|     _add_aliases( | ||||
|         kb, | ||||
|         title_to_id=filtered_title_to_id, | ||||
|         max_entities_per_alias=max_entities_per_alias, | ||||
|         min_occ=min_occ, | ||||
|         prior_prob_input=prior_prob_input, | ||||
|     ) | ||||
| 
 | ||||
|     print() | ||||
|     print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases()) | ||||
|  | @ -86,13 +102,15 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, | |||
|     return kb | ||||
| 
 | ||||
| 
 | ||||
| def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr): | ||||
|     with open(entity_def_output, mode='w', encoding='utf8') as id_file: | ||||
| def _write_entity_files( | ||||
|     entity_def_output, entity_descr_output, title_to_id, id_to_descr | ||||
| ): | ||||
|     with entity_def_output.open("w", encoding="utf8") as id_file: | ||||
|         id_file.write("WP_title" + "|" + "WD_id" + "\n") | ||||
|         for title, qid in title_to_id.items(): | ||||
|             id_file.write(title + "|" + str(qid) + "\n") | ||||
| 
 | ||||
|     with open(entity_descr_output, mode='w', encoding='utf8') as descr_file: | ||||
|     with entity_descr_output.open("w", encoding="utf8") as descr_file: | ||||
|         descr_file.write("WD_id" + "|" + "description" + "\n") | ||||
|         for qid, descr in id_to_descr.items(): | ||||
|             descr_file.write(str(qid) + "|" + descr + "\n") | ||||
|  | @ -100,8 +118,8 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_ | |||
| 
 | ||||
| def get_entity_to_id(entity_def_output): | ||||
|     entity_to_id = dict() | ||||
|     with open(entity_def_output, 'r', encoding='utf8') as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter='|') | ||||
|     with entity_def_output.open("r", encoding="utf8") as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|         for row in csvreader: | ||||
|  | @ -111,8 +129,8 @@ def get_entity_to_id(entity_def_output): | |||
| 
 | ||||
| def get_id_to_description(entity_descr_output): | ||||
|     id_to_desc = dict() | ||||
|     with open(entity_descr_output, 'r', encoding='utf8') as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter='|') | ||||
|     with entity_descr_output.open("r", encoding="utf8") as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|         for row in csvreader: | ||||
|  | @ -125,7 +143,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in | |||
| 
 | ||||
|     # adding aliases with prior probabilities | ||||
|     # we can read this file sequentially, it's sorted by alias, and then by count | ||||
|     with open(prior_prob_input, mode='r', encoding='utf8') as prior_file: | ||||
|     with prior_prob_input.open("r", encoding="utf8") as prior_file: | ||||
|         # skip header | ||||
|         prior_file.readline() | ||||
|         line = prior_file.readline() | ||||
|  | @ -134,7 +152,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in | |||
|         counts = [] | ||||
|         entities = [] | ||||
|         while line: | ||||
|             splits = line.replace('\n', "").split(sep='|') | ||||
|             splits = line.replace("\n", "").split(sep="|") | ||||
|             new_alias = splits[0] | ||||
|             count = int(splits[1]) | ||||
|             entity = splits[2] | ||||
|  | @ -153,7 +171,11 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in | |||
| 
 | ||||
|                     if selected_entities: | ||||
|                         try: | ||||
|                             kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs) | ||||
|                             kb.add_alias( | ||||
|                                 alias=previous_alias, | ||||
|                                 entities=selected_entities, | ||||
|                                 probabilities=prior_probs, | ||||
|                             ) | ||||
|                         except ValueError as e: | ||||
|                             print(e) | ||||
|                 total_count = 0 | ||||
|  | @ -168,4 +190,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in | |||
|             previous_alias = new_alias | ||||
| 
 | ||||
|             line = prior_file.readline() | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,7 +1,6 @@ | |||
| # coding: utf-8 | ||||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| import os | ||||
| import random | ||||
| import re | ||||
| import bz2 | ||||
|  | @ -37,7 +36,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N | |||
| 
 | ||||
|     read_ids = set() | ||||
|     entityfile_loc = training_output / ENTITY_FILE | ||||
|     with open(entityfile_loc, mode="w", encoding="utf8") as entityfile: | ||||
|     with entityfile_loc.open("w", encoding="utf8") as entityfile: | ||||
|         # write entity training header file | ||||
|         _write_training_entity( | ||||
|             outputfile=entityfile, | ||||
|  | @ -301,8 +300,8 @@ def _get_clean_wp_text(article_text): | |||
| 
 | ||||
| 
 | ||||
| def _write_training_article(article_id, clean_text, training_output): | ||||
|     file_loc = training_output / str(article_id) + ".txt" | ||||
|     with open(file_loc, mode="w", encoding="utf8") as outputfile: | ||||
|     file_loc = training_output / "{}.txt".format(article_id) | ||||
|     with file_loc.open("w", encoding="utf8") as outputfile: | ||||
|         outputfile.write(clean_text) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -330,7 +329,7 @@ def read_training(nlp, training_dir, dev, limit, kb=None): | |||
|     skip_articles = set() | ||||
|     total_entities = 0 | ||||
| 
 | ||||
|     with open(entityfile_loc, mode="r", encoding="utf8") as file: | ||||
|     with entityfile_loc.open("r", encoding="utf8") as file: | ||||
|         for line in file: | ||||
|             if not limit or len(data) < limit: | ||||
|                 fields = line.replace("\n", "").split(sep="|") | ||||
|  | @ -349,11 +348,8 @@ def read_training(nlp, training_dir, dev, limit, kb=None): | |||
|                         # parse the new article text | ||||
|                         file_name = article_id + ".txt" | ||||
|                         try: | ||||
|                             with open( | ||||
|                                 os.path.join(training_dir, file_name), | ||||
|                                 mode="r", | ||||
|                                 encoding="utf8", | ||||
|                             ) as f: | ||||
|                             training_file = training_dir / file_name | ||||
|                             with training_file.open("r", encoding="utf8") as f: | ||||
|                                 text = f.read() | ||||
|                                 # threshold for convenience / speed of processing | ||||
|                                 if len(text) < 30000: | ||||
|  | @ -364,7 +360,9 @@ def read_training(nlp, training_dir, dev, limit, kb=None): | |||
|                                         sent_length = len(ent.sent) | ||||
|                                         # custom filtering to avoid too long or too short sentences | ||||
|                                         if 5 < sent_length < 100: | ||||
|                                             offset = "{}_{}".format(ent.start_char, ent.end_char) | ||||
|                                             offset = "{}_{}".format( | ||||
|                                                 ent.start_char, ent.end_char | ||||
|                                             ) | ||||
|                                             ents_by_offset[offset] = ent | ||||
|                                 else: | ||||
|                                     skip_articles.add(article_id) | ||||
|  |  | |||
|  | @ -143,7 +143,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output): | |||
|             cnt += 1 | ||||
| 
 | ||||
|     # write all aliases and their entities and count occurrences to file | ||||
|     with open(prior_prob_output, mode="w", encoding="utf8") as outputfile: | ||||
|     with prior_prob_output.open("w", encoding="utf8") as outputfile: | ||||
|         outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n") | ||||
|         for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]): | ||||
|             s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True) | ||||
|  | @ -220,7 +220,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False): | |||
|     entity_to_count = dict() | ||||
|     total_count = 0 | ||||
| 
 | ||||
|     with open(prior_prob_input, mode="r", encoding="utf8") as prior_file: | ||||
|     with prior_prob_input.open("r", encoding="utf8") as prior_file: | ||||
|         # skip header | ||||
|         prior_file.readline() | ||||
|         line = prior_file.readline() | ||||
|  | @ -238,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False): | |||
| 
 | ||||
|             line = prior_file.readline() | ||||
| 
 | ||||
|     with open(count_output, mode="w", encoding="utf8") as entity_file: | ||||
|     with count_output.open("w", encoding="utf8") as entity_file: | ||||
|         entity_file.write("entity" + "|" + "count" + "\n") | ||||
|         for entity, count in entity_to_count.items(): | ||||
|             entity_file.write(entity + "|" + str(count) + "\n") | ||||
|  | @ -251,7 +251,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False): | |||
| 
 | ||||
| def get_all_frequencies(count_input): | ||||
|     entity_to_count = dict() | ||||
|     with open(count_input, "r", encoding="utf8") as csvfile: | ||||
|     with count_input.open("r", encoding="utf8") as csvfile: | ||||
|         csvreader = csv.reader(csvfile, delimiter="|") | ||||
|         # skip header | ||||
|         next(csvreader) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user