mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-14 05:37:03 +03:00
use pathlib instead
This commit is contained in:
parent
400ff342cf
commit
a037206f0a
|
@ -13,9 +13,17 @@ INPUT_DIM = 300 # dimension of pre-trained input vectors
|
||||||
DESC_WIDTH = 64 # dimension of output entity vectors
|
DESC_WIDTH = 64 # dimension of output entity vectors
|
||||||
|
|
||||||
|
|
||||||
def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
def create_kb(
|
||||||
entity_def_output, entity_descr_output,
|
nlp,
|
||||||
count_input, prior_prob_input, wikidata_input):
|
max_entities_per_alias,
|
||||||
|
min_entity_freq,
|
||||||
|
min_occ,
|
||||||
|
entity_def_output,
|
||||||
|
entity_descr_output,
|
||||||
|
count_input,
|
||||||
|
prior_prob_input,
|
||||||
|
wikidata_input,
|
||||||
|
):
|
||||||
# Create the knowledge base from Wikidata entries
|
# Create the knowledge base from Wikidata entries
|
||||||
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
|
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
|
||||||
|
|
||||||
|
@ -28,7 +36,9 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
||||||
title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
|
title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
|
||||||
|
|
||||||
# write the title-ID and ID-description mappings to file
|
# write the title-ID and ID-description mappings to file
|
||||||
_write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr)
|
_write_entity_files(
|
||||||
|
entity_def_output, entity_descr_output, title_to_id, id_to_descr
|
||||||
|
)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# read the mappings from file
|
# read the mappings from file
|
||||||
|
@ -54,8 +64,8 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
||||||
frequency_list.append(freq)
|
frequency_list.append(freq)
|
||||||
filtered_title_to_id[title] = entity
|
filtered_title_to_id[title] = entity
|
||||||
|
|
||||||
print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()),
|
print(len(title_to_id.keys()), "original titles")
|
||||||
"titles with filter frequency", min_entity_freq)
|
print("kept", len(filtered_title_to_id.keys()), " with frequency", min_entity_freq)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(" * train entity encoder", datetime.datetime.now())
|
print(" * train entity encoder", datetime.datetime.now())
|
||||||
|
@ -70,14 +80,20 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(" * adding", len(entity_list), "entities", datetime.datetime.now())
|
print(" * adding", len(entity_list), "entities", datetime.datetime.now())
|
||||||
kb.set_entities(entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings)
|
kb.set_entities(
|
||||||
|
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
||||||
|
)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print(" * adding aliases", datetime.datetime.now())
|
print(" * adding aliases", datetime.datetime.now())
|
||||||
print()
|
print()
|
||||||
_add_aliases(kb, title_to_id=filtered_title_to_id,
|
_add_aliases(
|
||||||
max_entities_per_alias=max_entities_per_alias, min_occ=min_occ,
|
kb,
|
||||||
prior_prob_input=prior_prob_input)
|
title_to_id=filtered_title_to_id,
|
||||||
|
max_entities_per_alias=max_entities_per_alias,
|
||||||
|
min_occ=min_occ,
|
||||||
|
prior_prob_input=prior_prob_input,
|
||||||
|
)
|
||||||
|
|
||||||
print()
|
print()
|
||||||
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
|
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
|
||||||
|
@ -86,13 +102,15 @@ def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ,
|
||||||
return kb
|
return kb
|
||||||
|
|
||||||
|
|
||||||
def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr):
|
def _write_entity_files(
|
||||||
with open(entity_def_output, mode='w', encoding='utf8') as id_file:
|
entity_def_output, entity_descr_output, title_to_id, id_to_descr
|
||||||
|
):
|
||||||
|
with entity_def_output.open("w", encoding="utf8") as id_file:
|
||||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||||
for title, qid in title_to_id.items():
|
for title, qid in title_to_id.items():
|
||||||
id_file.write(title + "|" + str(qid) + "\n")
|
id_file.write(title + "|" + str(qid) + "\n")
|
||||||
|
|
||||||
with open(entity_descr_output, mode='w', encoding='utf8') as descr_file:
|
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
||||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||||
for qid, descr in id_to_descr.items():
|
for qid, descr in id_to_descr.items():
|
||||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||||
|
@ -100,8 +118,8 @@ def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_
|
||||||
|
|
||||||
def get_entity_to_id(entity_def_output):
|
def get_entity_to_id(entity_def_output):
|
||||||
entity_to_id = dict()
|
entity_to_id = dict()
|
||||||
with open(entity_def_output, 'r', encoding='utf8') as csvfile:
|
with entity_def_output.open("r", encoding="utf8") as csvfile:
|
||||||
csvreader = csv.reader(csvfile, delimiter='|')
|
csvreader = csv.reader(csvfile, delimiter="|")
|
||||||
# skip header
|
# skip header
|
||||||
next(csvreader)
|
next(csvreader)
|
||||||
for row in csvreader:
|
for row in csvreader:
|
||||||
|
@ -111,8 +129,8 @@ def get_entity_to_id(entity_def_output):
|
||||||
|
|
||||||
def get_id_to_description(entity_descr_output):
|
def get_id_to_description(entity_descr_output):
|
||||||
id_to_desc = dict()
|
id_to_desc = dict()
|
||||||
with open(entity_descr_output, 'r', encoding='utf8') as csvfile:
|
with entity_descr_output.open("r", encoding="utf8") as csvfile:
|
||||||
csvreader = csv.reader(csvfile, delimiter='|')
|
csvreader = csv.reader(csvfile, delimiter="|")
|
||||||
# skip header
|
# skip header
|
||||||
next(csvreader)
|
next(csvreader)
|
||||||
for row in csvreader:
|
for row in csvreader:
|
||||||
|
@ -125,7 +143,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
||||||
|
|
||||||
# adding aliases with prior probabilities
|
# adding aliases with prior probabilities
|
||||||
# we can read this file sequentially, it's sorted by alias, and then by count
|
# we can read this file sequentially, it's sorted by alias, and then by count
|
||||||
with open(prior_prob_input, mode='r', encoding='utf8') as prior_file:
|
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||||
# skip header
|
# skip header
|
||||||
prior_file.readline()
|
prior_file.readline()
|
||||||
line = prior_file.readline()
|
line = prior_file.readline()
|
||||||
|
@ -134,7 +152,7 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
||||||
counts = []
|
counts = []
|
||||||
entities = []
|
entities = []
|
||||||
while line:
|
while line:
|
||||||
splits = line.replace('\n', "").split(sep='|')
|
splits = line.replace("\n", "").split(sep="|")
|
||||||
new_alias = splits[0]
|
new_alias = splits[0]
|
||||||
count = int(splits[1])
|
count = int(splits[1])
|
||||||
entity = splits[2]
|
entity = splits[2]
|
||||||
|
@ -153,7 +171,11 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
||||||
|
|
||||||
if selected_entities:
|
if selected_entities:
|
||||||
try:
|
try:
|
||||||
kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs)
|
kb.add_alias(
|
||||||
|
alias=previous_alias,
|
||||||
|
entities=selected_entities,
|
||||||
|
probabilities=prior_probs,
|
||||||
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
print(e)
|
print(e)
|
||||||
total_count = 0
|
total_count = 0
|
||||||
|
@ -168,4 +190,3 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
||||||
previous_alias = new_alias
|
previous_alias = new_alias
|
||||||
|
|
||||||
line = prior_file.readline()
|
line = prior_file.readline()
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import os
|
|
||||||
import random
|
import random
|
||||||
import re
|
import re
|
||||||
import bz2
|
import bz2
|
||||||
|
@ -37,7 +36,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
||||||
|
|
||||||
read_ids = set()
|
read_ids = set()
|
||||||
entityfile_loc = training_output / ENTITY_FILE
|
entityfile_loc = training_output / ENTITY_FILE
|
||||||
with open(entityfile_loc, mode="w", encoding="utf8") as entityfile:
|
with entityfile_loc.open("w", encoding="utf8") as entityfile:
|
||||||
# write entity training header file
|
# write entity training header file
|
||||||
_write_training_entity(
|
_write_training_entity(
|
||||||
outputfile=entityfile,
|
outputfile=entityfile,
|
||||||
|
@ -301,8 +300,8 @@ def _get_clean_wp_text(article_text):
|
||||||
|
|
||||||
|
|
||||||
def _write_training_article(article_id, clean_text, training_output):
|
def _write_training_article(article_id, clean_text, training_output):
|
||||||
file_loc = training_output / str(article_id) + ".txt"
|
file_loc = training_output / "{}.txt".format(article_id)
|
||||||
with open(file_loc, mode="w", encoding="utf8") as outputfile:
|
with file_loc.open("w", encoding="utf8") as outputfile:
|
||||||
outputfile.write(clean_text)
|
outputfile.write(clean_text)
|
||||||
|
|
||||||
|
|
||||||
|
@ -330,7 +329,7 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||||
skip_articles = set()
|
skip_articles = set()
|
||||||
total_entities = 0
|
total_entities = 0
|
||||||
|
|
||||||
with open(entityfile_loc, mode="r", encoding="utf8") as file:
|
with entityfile_loc.open("r", encoding="utf8") as file:
|
||||||
for line in file:
|
for line in file:
|
||||||
if not limit or len(data) < limit:
|
if not limit or len(data) < limit:
|
||||||
fields = line.replace("\n", "").split(sep="|")
|
fields = line.replace("\n", "").split(sep="|")
|
||||||
|
@ -349,11 +348,8 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||||
# parse the new article text
|
# parse the new article text
|
||||||
file_name = article_id + ".txt"
|
file_name = article_id + ".txt"
|
||||||
try:
|
try:
|
||||||
with open(
|
training_file = training_dir / file_name
|
||||||
os.path.join(training_dir, file_name),
|
with training_file.open("r", encoding="utf8") as f:
|
||||||
mode="r",
|
|
||||||
encoding="utf8",
|
|
||||||
) as f:
|
|
||||||
text = f.read()
|
text = f.read()
|
||||||
# threshold for convenience / speed of processing
|
# threshold for convenience / speed of processing
|
||||||
if len(text) < 30000:
|
if len(text) < 30000:
|
||||||
|
@ -364,7 +360,9 @@ def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||||
sent_length = len(ent.sent)
|
sent_length = len(ent.sent)
|
||||||
# custom filtering to avoid too long or too short sentences
|
# custom filtering to avoid too long or too short sentences
|
||||||
if 5 < sent_length < 100:
|
if 5 < sent_length < 100:
|
||||||
offset = "{}_{}".format(ent.start_char, ent.end_char)
|
offset = "{}_{}".format(
|
||||||
|
ent.start_char, ent.end_char
|
||||||
|
)
|
||||||
ents_by_offset[offset] = ent
|
ents_by_offset[offset] = ent
|
||||||
else:
|
else:
|
||||||
skip_articles.add(article_id)
|
skip_articles.add(article_id)
|
||||||
|
|
|
@ -143,7 +143,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output):
|
||||||
cnt += 1
|
cnt += 1
|
||||||
|
|
||||||
# write all aliases and their entities and count occurrences to file
|
# write all aliases and their entities and count occurrences to file
|
||||||
with open(prior_prob_output, mode="w", encoding="utf8") as outputfile:
|
with prior_prob_output.open("w", encoding="utf8") as outputfile:
|
||||||
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
|
outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n")
|
||||||
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
|
for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]):
|
||||||
s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
|
s_dict = sorted(alias_dict.items(), key=lambda x: x[1], reverse=True)
|
||||||
|
@ -220,7 +220,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||||
entity_to_count = dict()
|
entity_to_count = dict()
|
||||||
total_count = 0
|
total_count = 0
|
||||||
|
|
||||||
with open(prior_prob_input, mode="r", encoding="utf8") as prior_file:
|
with prior_prob_input.open("r", encoding="utf8") as prior_file:
|
||||||
# skip header
|
# skip header
|
||||||
prior_file.readline()
|
prior_file.readline()
|
||||||
line = prior_file.readline()
|
line = prior_file.readline()
|
||||||
|
@ -238,7 +238,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||||
|
|
||||||
line = prior_file.readline()
|
line = prior_file.readline()
|
||||||
|
|
||||||
with open(count_output, mode="w", encoding="utf8") as entity_file:
|
with count_output.open("w", encoding="utf8") as entity_file:
|
||||||
entity_file.write("entity" + "|" + "count" + "\n")
|
entity_file.write("entity" + "|" + "count" + "\n")
|
||||||
for entity, count in entity_to_count.items():
|
for entity, count in entity_to_count.items():
|
||||||
entity_file.write(entity + "|" + str(count) + "\n")
|
entity_file.write(entity + "|" + str(count) + "\n")
|
||||||
|
@ -251,7 +251,7 @@ def write_entity_counts(prior_prob_input, count_output, to_print=False):
|
||||||
|
|
||||||
def get_all_frequencies(count_input):
|
def get_all_frequencies(count_input):
|
||||||
entity_to_count = dict()
|
entity_to_count = dict()
|
||||||
with open(count_input, "r", encoding="utf8") as csvfile:
|
with count_input.open("r", encoding="utf8") as csvfile:
|
||||||
csvreader = csv.reader(csvfile, delimiter="|")
|
csvreader = csv.reader(csvfile, delimiter="|")
|
||||||
# skip header
|
# skip header
|
||||||
next(csvreader)
|
next(csvreader)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user