mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
use pathlib instead
This commit is contained in:
parent
400ff342cf
commit
a037206f0a
|
@ -13,9 +13,17 @@ INPUT_DIM = 300 # dimension of pre-trained input vectors
|
|||
DESC_WIDTH = 64 # dimension of output entity vectors
|
||||
|
||||
|
||||
def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
||||
entity_def_output, entity_descr_output,
|
||||
count_input, prior_prob_input, wikidata_input):
|
||||
def create_kb(
|
||||
nlp,
|
||||
max_entities_per_alias,
|
||||
min_entity_freq,
|
||||
min_occ,
|
||||
entity_def_output,
|
||||
entity_descr_output,
|
||||
count_input,
|
||||
prior_prob_input,
|
||||
wikidata_input,
|
||||
):
|
||||
# Create the knowledge base from Wikidata entries
|
||||
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
|
||||
|
||||
|
@ -28,7 +36,9 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
|||
title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
|
||||
|
||||
# write the title-ID and ID-description mappings to file
|
||||
_write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr)
|
||||
_write_entity_files(
|
||||
entity_def_output, entity_descr_output, title_to_id, id_to_descr
|
||||
)
|
||||
|
||||
else:
|
||||
# read the mappings from file
|
||||
|
@ -54,8 +64,8 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
|||
frequency_list.append(freq)
|
||||
filtered_title_to_id[title] = entity
|
||||
|
||||
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()),
|
||||
"titles with filter frequency", min_entity_freq)
|
||||
print(len(title_to_id.keys()), "original titles")
|
||||
print("kept", len(filtered_title_to_id.keys()), " with frequency", min_entity_freq)
|
||||
|
||||
print()
|
||||
print(" * train entity encoder", datetime.datetime.now())
|
||||
|
@ -70,14 +80,20 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
|||
|
||||
print()
|
||||
print(" * adding", len(entity_list), "entities", datetime.datetime.now())
|
||||
kb.set_entities(entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings)
|
||||
kb.set_entities(
|
||||
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
||||
)
|
||||
|
||||
print()
|
||||
print(" * adding aliases", datetime.datetime.now())
|
||||
print()
|
||||
_add_aliases(kb, title_to_id=filtered_title_to_id,
|
||||
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
|
||||
prior_prob_input=prior_prob_input)
|
||||
_add_aliases(
|
||||
kb,
|
||||
title_to_id=filtered_title_to_id,
|
||||
max_entities_per_alias=max_entities_per_alias,
|
||||
min_occ=min_occ,
|
||||
prior_prob_input=prior_prob_input,
|
||||
)
|
||||
|
||||
print()
|
||||
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
|
||||
|
@ -86,13 +102,15 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
|||
return kb
|
||||
|
||||
|
||||
def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr):
|
||||
with open(entity_def_output, mode='w', encoding='utf8') as id_file:
|
||||
def _write_entity_files(
|
||||
entity_def_output, entity_descr_output, title_to_id, id_to_descr
|
||||
):
|
||||
with entity_def_output.open("w", encoding="utf8") as id_file:
|
||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||
for title, qid in title_to_id.items():
|
||||
id_file.write(title + "|" + str(qid) + "\n")
|
||||
|
||||
with open(entity_descr_output, mode='w', encoding='utf8') as descr_file:
|
||||
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||
for qid, descr in id_to_descr.items():
|
||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||
|
@ -100,8 +118,8 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_
|
|||
|
||||
def get_entity_to_id(entity_def_output):
|
||||
entity_to_id = dict()
|
||||
with open(entity_def_output, 'r', encoding='utf8') as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter='|')
|
||||
with entity_def_output.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
|
@ -111,8 +129,8 @@ def get_entity_to_id(entity_def_output):
|
|||
|
||||
def get_id_to_description(entity_descr_output):
|
||||
id_to_desc = dict()
|
||||
with open(entity_descr_output, 'r', encoding='utf8') as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter='|')
|
||||
with entity_descr_output.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
for row in csvreader:
|
||||
|
@ -125,7 +143,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
|||
|
||||
# adding aliases with prior probabilities
|
||||
# we can read this file sequentially, it's sorted by alias, and then by count
|
||||
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
|
||||
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||
# skip header
|
||||
prior_file.readline()
|
||||
line = prior_file.readline()
|
||||
|
@ -134,7 +152,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
|||
counts = []
|
||||
entities = []
|
||||
while line:
|
||||
splits = line.replace('\n', "").split(sep='|')
|
||||
splits = line.replace("\n", "").split(sep="|")
|
||||
new_alias = splits[0]
|
||||
count = int(splits[1])
|
||||
entity = splits[2]
|
||||
|
@ -153,7 +171,11 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
|||
|
||||
if selected_entities:
|
||||
try:
|
||||
kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs)
|
||||
kb.add_alias(
|
||||
alias=previous_alias,
|
||||
entities=selected_entities,
|
||||
probabilities=prior_probs,
|
||||
)
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
total_count = 0
|
||||
|
@ -168,4 +190,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
|||
previous_alias = new_alias
|
||||
|
||||
line = prior_file.readline()
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import bz2
|
||||
|
@ -37,7 +36,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
|
||||
read_ids = set()
|
||||
entityfile_loc = training_output / ENTITY_FILE
|
||||
with open(entityfile_loc, mode="w", encoding="utf8") as entityfile:
|
||||
with entityfile_loc.open("w", encoding="utf8") as entityfile:
|
||||
# write entity training header file
|
||||
_write_training_entity(
|
||||
outputfile=entityfile,
|
||||
|
@ -301,8 +300,8 @@ def _get_clean_wp_text(article_text):
|
|||
|
||||
|
||||
def _write_training_article(article_id, clean_text, training_output):
|
||||
file_loc = training_output / str(article_id) + ".txt"
|
||||
with open(file_loc, mode="w", encoding="utf8") as outputfile:
|
||||
file_loc = training_output / "{}.txt".format(article_id)
|
||||
with file_loc.open("w", encoding="utf8") as outputfile:
|
||||
outputfile.write(clean_text)
|
||||
|
||||
|
||||
|
@ -330,7 +329,7 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
|||
skip_articles = set()
|
||||
total_entities = 0
|
||||
|
||||
with open(entityfile_loc, mode="r", encoding="utf8") as file:
|
||||
with entityfile_loc.open("r", encoding="utf8") as file:
|
||||
for line in file:
|
||||
if not limit or len(data) < limit:
|
||||
fields = line.replace("\n", "").split(sep="|")
|
||||
|
@ -349,11 +348,8 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
|||
# parse the new article text
|
||||
file_name = article_id + ".txt"
|
||||
try:
|
||||
with open(
|
||||
os.path.join(training_dir, file_name),
|
||||
mode="r",
|
||||
encoding="utf8",
|
||||
) as f:
|
||||
training_file = training_dir / file_name
|
||||
with training_file.open("r", encoding="utf8") as f:
|
||||
text = f.read()
|
||||
# threshold for convenience / speed of processing
|
||||
if len(text) < 30000:
|
||||
|
@ -364,7 +360,9 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
|||
sent_length = len(ent.sent)
|
||||
# custom filtering to avoid too long or too short sentences
|
||||
if 5 < sent_length < 100:
|
||||
offset = "{}_{}".format(ent.start_char, ent.end_char)
|
||||
offset = "{}_{}".format(
|
||||
ent.start_char, ent.end_char
|
||||
)
|
||||
ents_by_offset[offset] = ent
|
||||
else:
|
||||
skip_articles.add(article_id)
|
||||
|
|
|
@ -143,7 +143,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output):
|
|||
cnt += 1
|
||||
|
||||
# write all aliases and their entities and count occurrences to file
|
||||
with open(prior_prob_output, mode="w", encoding="utf8") as outputfile:
|
||||
with prior_prob_output.open("w", encoding="utf8") as outputfile:
|
||||
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
|
||||
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
|
||||
s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
|
||||
|
@ -220,7 +220,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
|||
entity_to_count = dict()
|
||||
total_count = 0
|
||||
|
||||
with open(prior_prob_input, mode="r", encoding="utf8") as prior_file:
|
||||
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||
# skip header
|
||||
prior_file.readline()
|
||||
line = prior_file.readline()
|
||||
|
@ -238,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
|||
|
||||
line = prior_file.readline()
|
||||
|
||||
with open(count_output, mode="w", encoding="utf8") as entity_file:
|
||||
with count_output.open("w", encoding="utf8") as entity_file:
|
||||
entity_file.write("entity" + "|" + "count" + "\n")
|
||||
for entity, count in entity_to_count.items():
|
||||
entity_file.write(entity + "|" + str(count) + "\n")
|
||||
|
@ -251,7 +251,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
|||
|
||||
def get_all_frequencies(count_input):
|
||||
entity_to_count = dict()
|
||||
with open(count_input, "r", encoding="utf8") as csvfile:
|
||||
with count_input.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
|
|
Loading…
Reference in New Issue
Block a user