diff --git a/bin/wiki_entity_linking/kb_creator.py b/bin/wiki_entity_linking/kb_creator.py index d8cdf6dd7..5b25475b2 100644 --- a/bin/wiki_entity_linking/kb_creator.py +++ b/bin/wiki_entity_linking/kb_creator.py @@ -13,9 +13,17 @@ INPUT_DIM = 300 # dimension of pre-trained input vectors DESC_WIDTH = 64 # dimension of output entity vectors -def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, - entity_def_output, entity_descr_output, - count_input, prior_prob_input, wikidata_input): +def create_kb( + nlp, + max_entities_per_alias, + min_entity_freq, + min_occ, + entity_def_output, + entity_descr_output, + count_input, + prior_prob_input, + wikidata_input, +): # Create the knowledge base from Wikidata entries kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH) @@ -28,7 +36,9 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input) # write the title-ID and ID-description mappings to file - _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr) + _write_entity_files( + entity_def_output, entity_descr_output, title_to_id, id_to_descr + ) else: # read the mappings from file @@ -54,8 +64,8 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, frequency_list.append(freq) filtered_title_to_id[title] = entity - print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), - "titles with filter frequency", min_entity_freq) + print(len(title_to_id.keys()), "original titles") + print("kept", len(filtered_title_to_id.keys()), " with frequency", min_entity_freq) print() print(" * train entity encoder", datetime.datetime.now()) @@ -70,14 +80,20 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, print() print(" * adding", len(entity_list), "entities", datetime.datetime.now()) - kb.set_entities(entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings) + kb.set_entities( + entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings + ) print() print(" * adding aliases", datetime.datetime.now()) print() - _add_aliases(kb, title_to_id=filtered_title_to_id, - max_entities_per_alias=max_entities_per_alias, min_occ=min_occ, - prior_prob_input=prior_prob_input) + _add_aliases( + kb, + title_to_id=filtered_title_to_id, + max_entities_per_alias=max_entities_per_alias, + min_occ=min_occ, + prior_prob_input=prior_prob_input, + ) print() print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases()) @@ -86,13 +102,15 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, return kb -def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr): - with open(entity_def_output, mode='w', encoding='utf8') as id_file: +def _write_entity_files( + entity_def_output, entity_descr_output, title_to_id, id_to_descr +): + with entity_def_output.open("w", encoding="utf8") as id_file: id_file.write("WP_title" + "|" + "WD_id" + "\n") for title, qid in title_to_id.items(): id_file.write(title + "|" + str(qid) + "\n") - with open(entity_descr_output, mode='w', encoding='utf8') as descr_file: + with entity_descr_output.open("w", encoding="utf8") as descr_file: descr_file.write("WD_id" + "|" + "description" + "\n") for qid, descr in id_to_descr.items(): descr_file.write(str(qid) + "|" + descr + "\n") @@ -100,8 +118,8 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_ def get_entity_to_id(entity_def_output): entity_to_id = dict() - with open(entity_def_output, 'r', encoding='utf8') as csvfile: - csvreader = csv.reader(csvfile, delimiter='|') + with entity_def_output.open("r", encoding="utf8") as csvfile: + csvreader = csv.reader(csvfile, delimiter="|") # skip header next(csvreader) for row in csvreader: @@ -111,8 +129,8 @@ def get_entity_to_id(entity_def_output): def get_id_to_description(entity_descr_output): id_to_desc = dict() - with open(entity_descr_output, 'r', encoding='utf8') as csvfile: - csvreader = csv.reader(csvfile, delimiter='|') + with entity_descr_output.open("r", encoding="utf8") as csvfile: + csvreader = csv.reader(csvfile, delimiter="|") # skip header next(csvreader) for row in csvreader: @@ -125,7 +143,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in # adding aliases with prior probabilities # we can read this file sequentially, it's sorted by alias, and then by count - with open(prior_prob_input, mode='r', encoding='utf8') as prior_file: + with prior_prob_input.open("r", encoding="utf8") as prior_file: # skip header prior_file.readline() line = prior_file.readline() @@ -134,7 +152,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in counts = [] entities = [] while line: - splits = line.replace('\n', "").split(sep='|') + splits = line.replace("\n", "").split(sep="|") new_alias = splits[0] count = int(splits[1]) entity = splits[2] @@ -153,7 +171,11 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in if selected_entities: try: - kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs) + kb.add_alias( + alias=previous_alias, + entities=selected_entities, + probabilities=prior_probs, + ) except ValueError as e: print(e) total_count = 0 @@ -168,4 +190,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in previous_alias = new_alias line = prior_file.readline() - diff --git a/bin/wiki_entity_linking/training_set_creator.py b/bin/wiki_entity_linking/training_set_creator.py index 74bdbe9fb..b090d7659 100644 --- a/bin/wiki_entity_linking/training_set_creator.py +++ b/bin/wiki_entity_linking/training_set_creator.py @@ -1,7 +1,6 @@ # coding: utf-8 from __future__ import unicode_literals -import os import random import re import bz2 @@ -37,7 +36,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N read_ids = set() entityfile_loc = training_output / ENTITY_FILE - with open(entityfile_loc, mode="w", encoding="utf8") as entityfile: + with entityfile_loc.open("w", encoding="utf8") as entityfile: # write entity training header file _write_training_entity( outputfile=entityfile, @@ -301,8 +300,8 @@ def _get_clean_wp_text(article_text): def _write_training_article(article_id, clean_text, training_output): - file_loc = training_output / str(article_id) + ".txt" - with open(file_loc, mode="w", encoding="utf8") as outputfile: + file_loc = training_output / "{}.txt".format(article_id) + with file_loc.open("w", encoding="utf8") as outputfile: outputfile.write(clean_text) @@ -330,7 +329,7 @@ def read_training(nlp, training_dir, dev, limit, kb=None): skip_articles = set() total_entities = 0 - with open(entityfile_loc, mode="r", encoding="utf8") as file: + with entityfile_loc.open("r", encoding="utf8") as file: for line in file: if not limit or len(data) < limit: fields = line.replace("\n", "").split(sep="|") @@ -349,11 +348,8 @@ def read_training(nlp, training_dir, dev, limit, kb=None): # parse the new article text file_name = article_id + ".txt" try: - with open( - os.path.join(training_dir, file_name), - mode="r", - encoding="utf8", - ) as f: + training_file = training_dir / file_name + with training_file.open("r", encoding="utf8") as f: text = f.read() # threshold for convenience / speed of processing if len(text) < 30000: @@ -364,7 +360,9 @@ def read_training(nlp, training_dir, dev, limit, kb=None): sent_length = len(ent.sent) # custom filtering to avoid too long or too short sentences if 5 < sent_length < 100: - offset = "{}_{}".format(ent.start_char, ent.end_char) + offset = "{}_{}".format( + ent.start_char, ent.end_char + ) ents_by_offset[offset] = ent else: skip_articles.add(article_id) diff --git a/bin/wiki_entity_linking/wikipedia_processor.py b/bin/wiki_entity_linking/wikipedia_processor.py index 4d11aee61..80d75b013 100644 --- a/bin/wiki_entity_linking/wikipedia_processor.py +++ b/bin/wiki_entity_linking/wikipedia_processor.py @@ -143,7 +143,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output): cnt += 1 # write all aliases and their entities and count occurrences to file - with open(prior_prob_output, mode="w", encoding="utf8") as outputfile: + with prior_prob_output.open("w", encoding="utf8") as outputfile: outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n") for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]): s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True) @@ -220,7 +220,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False): entity_to_count = dict() total_count = 0 - with open(prior_prob_input, mode="r", encoding="utf8") as prior_file: + with prior_prob_input.open("r", encoding="utf8") as prior_file: # skip header prior_file.readline() line = prior_file.readline() @@ -238,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False): line = prior_file.readline() - with open(count_output, mode="w", encoding="utf8") as entity_file: + with count_output.open("w", encoding="utf8") as entity_file: entity_file.write("entity" + "|" + "count" + "\n") for entity, count in entity_to_count.items(): entity_file.write(entity + "|" + str(count) + "\n") @@ -251,7 +251,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False): def get_all_frequencies(count_input): entity_to_count = dict() - with open(count_input, "r", encoding="utf8") as csvfile: + with count_input.open("r", encoding="utf8") as csvfile: csvreader = csv.reader(csvfile, delimiter="|") # skip header next(csvreader)