use pathlib instead

This commit is contained in:
svlandeg 2019-07-23 12:17:19 +02:00
parent 400ff342cf
commit a037206f0a
3 changed files with 55 additions and 36 deletions

View File

@ -13,9 +13,17 @@ INPUT_DIM = 300 # dimension of pre-trained input vectors
DESC_WIDTH = 64 # dimension of output entity vectors
def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
entity_def_output, entity_descr_output,
count_input, prior_prob_input, wikidata_input):
def create_kb(
nlp,
max_entities_per_alias,
min_entity_freq,
min_occ,
entity_def_output,
entity_descr_output,
count_input,
prior_prob_input,
wikidata_input,
):
# Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
@ -28,7 +36,9 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
# write the title-ID and ID-description mappings to file
_write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr)
_write_entity_files(
entity_def_output, entity_descr_output, title_to_id, id_to_descr
)
else:
# read the mappings from file
@ -54,8 +64,8 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
frequency_list.append(freq)
filtered_title_to_id[title] = entity
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()),
"titles with filter frequency", min_entity_freq)
print(len(title_to_id.keys()), "original titles")
print("kept", len(filtered_title_to_id.keys()), " with frequency", min_entity_freq)
print()
print(" * train entity encoder", datetime.datetime.now())
@ -70,14 +80,20 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
print()
print(" * adding", len(entity_list), "entities", datetime.datetime.now())
kb.set_entities(entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings)
kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
)
print()
print(" * adding aliases", datetime.datetime.now())
print()
_add_aliases(kb, title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
prior_prob_input=prior_prob_input)
_add_aliases(
kb,
title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ,
prior_prob_input=prior_prob_input,
)
print()
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
@ -86,13 +102,15 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
return kb
def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr):
with open(entity_def_output, mode='w', encoding='utf8') as id_file:
def _write_entity_files(
entity_def_output, entity_descr_output, title_to_id, id_to_descr
):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
with open(entity_descr_output, mode='w', encoding='utf8') as descr_file:
with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")
@ -100,8 +118,8 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_
def get_entity_to_id(entity_def_output):
entity_to_id = dict()
with open(entity_def_output, 'r', encoding='utf8') as csvfile:
csvreader = csv.reader(csvfile, delimiter='|')
with entity_def_output.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
@ -111,8 +129,8 @@ def get_entity_to_id(entity_def_output):
def get_id_to_description(entity_descr_output):
id_to_desc = dict()
with open(entity_descr_output, 'r', encoding='utf8') as csvfile:
csvreader = csv.reader(csvfile, delimiter='|')
with entity_descr_output.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
@ -125,7 +143,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
# adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
with prior_prob_input.open("r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
@ -134,7 +152,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
counts = []
entities = []
while line:
splits = line.replace('\n', "").split(sep='|')
splits = line.replace("\n", "").split(sep="|")
new_alias = splits[0]
count = int(splits[1])
entity = splits[2]
@ -153,7 +171,11 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
if selected_entities:
try:
kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs)
kb.add_alias(
alias=previous_alias,
entities=selected_entities,
probabilities=prior_probs,
)
except ValueError as e:
print(e)
total_count = 0
@ -168,4 +190,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
previous_alias = new_alias
line = prior_file.readline()

View File

@ -1,7 +1,6 @@
# coding: utf-8
from __future__ import unicode_literals
import os
import random
import re
import bz2
@ -37,7 +36,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
read_ids = set()
entityfile_loc = training_output / ENTITY_FILE
with open(entityfile_loc, mode="w", encoding="utf8") as entityfile:
with entityfile_loc.open("w", encoding="utf8") as entityfile:
# write entity training header file
_write_training_entity(
outputfile=entityfile,
@ -301,8 +300,8 @@ def _get_clean_wp_text(article_text):
def _write_training_article(article_id, clean_text, training_output):
file_loc = training_output / str(article_id) + ".txt"
with open(file_loc, mode="w", encoding="utf8") as outputfile:
file_loc = training_output / "{}.txt".format(article_id)
with file_loc.open("w", encoding="utf8") as outputfile:
outputfile.write(clean_text)
@ -330,7 +329,7 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
skip_articles = set()
total_entities = 0
with open(entityfile_loc, mode="r", encoding="utf8") as file:
with entityfile_loc.open("r", encoding="utf8") as file:
for line in file:
if not limit or len(data) < limit:
fields = line.replace("\n", "").split(sep="|")
@ -349,11 +348,8 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
# parse the new article text
file_name = article_id + ".txt"
try:
with open(
os.path.join(training_dir, file_name),
mode="r",
encoding="utf8",
) as f:
training_file = training_dir / file_name
with training_file.open("r", encoding="utf8") as f:
text = f.read()
# threshold for convenience / speed of processing
if len(text) < 30000:
@ -364,7 +360,9 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
sent_length = len(ent.sent)
# custom filtering to avoid too long or too short sentences
if 5 < sent_length < 100:
offset = "{}_{}".format(ent.start_char, ent.end_char)
offset = "{}_{}".format(
ent.start_char, ent.end_char
)
ents_by_offset[offset] = ent
else:
skip_articles.add(article_id)

View File

@ -143,7 +143,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output):
cnt += 1
# write all aliases and their entities and count occurrences to file
with open(prior_prob_output, mode="w", encoding="utf8") as outputfile:
with prior_prob_output.open("w", encoding="utf8") as outputfile:
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
@ -220,7 +220,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
entity_to_count = dict()
total_count = 0
with open(prior_prob_input, mode="r", encoding="utf8") as prior_file:
with prior_prob_input.open("r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
@ -238,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
line = prior_file.readline()
with open(count_output, mode="w", encoding="utf8") as entity_file:
with count_output.open("w", encoding="utf8") as entity_file:
entity_file.write("entity" + "|" + "count" + "\n")
for entity, count in entity_to_count.items():
entity_file.write(entity + "|" + str(count) + "\n")
@ -251,7 +251,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
def get_all_frequencies(count_input):
entity_to_count = dict()
with open(count_input, "r", encoding="utf8") as csvfile:
with count_input.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)