spaCy/bin/wiki_entity_linking/kb_creator.py

193 lines
6.4 KiB
Python
Raw Normal View History

# coding: utf-8
from __future__ import unicode_literals
2019-06-19 10:15:43 +03:00
from .train_descriptions import EntityEncoder
from . import wikidata_processor as wd, wikipedia_processor as wp
from spacy.kb import KnowledgeBase
import csv
import datetime
2019-06-19 10:15:43 +03:00
INPUT_DIM = 300 # dimension of pre-trained input vectors
DESC_WIDTH = 64 # dimension of output entity vectors
2019-06-13 23:32:56 +03:00
2019-07-23 13:17:19 +03:00
def create_kb(
nlp,
max_entities_per_alias,
min_entity_freq,
min_occ,
entity_def_output,
entity_descr_output,
count_input,
prior_prob_input,
wikidata_input,
):
2019-06-14 20:55:46 +03:00
# Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH)
# disable this part of the pipeline when rerunning the KB generation from preprocessed files
2019-06-19 10:15:43 +03:00
read_raw_data = True
if read_raw_data:
print()
2019-06-13 23:32:56 +03:00
print(" * _read_wikidata_entities", datetime.datetime.now())
2019-06-19 10:15:43 +03:00
title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input)
# write the title-ID and ID-description mappings to file
2019-07-23 13:17:19 +03:00
_write_entity_files(
entity_def_output, entity_descr_output, title_to_id, id_to_descr
)
else:
# read the mappings from file
title_to_id = get_entity_to_id(entity_def_output)
id_to_descr = get_id_to_description(entity_descr_output)
print()
2019-06-13 23:32:56 +03:00
print(" * _get_entity_frequencies", datetime.datetime.now())
print()
2019-06-14 20:55:46 +03:00
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
2019-06-19 10:15:43 +03:00
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
2019-06-14 20:55:46 +03:00
filtered_title_to_id = dict()
2019-06-24 11:55:04 +03:00
entity_list = []
description_list = []
frequency_list = []
2019-06-14 20:55:46 +03:00
for title, entity in title_to_id.items():
freq = entity_frequencies.get(title, 0)
desc = id_to_descr.get(entity, None)
if desc and freq > min_entity_freq:
entity_list.append(entity)
description_list.append(desc)
frequency_list.append(freq)
filtered_title_to_id[title] = entity
2019-07-23 13:17:19 +03:00
print(len(title_to_id.keys()), "original titles")
print("kept", len(filtered_title_to_id.keys()), " with frequency", min_entity_freq)
print()
2019-06-13 23:32:56 +03:00
print(" * train entity encoder", datetime.datetime.now())
print()
encoder = EntityEncoder(nlp, INPUT_DIM, DESC_WIDTH)
encoder.train(description_list=description_list, to_print=True)
2019-06-19 10:15:43 +03:00
print()
2019-06-13 23:32:56 +03:00
print(" * get entity embeddings", datetime.datetime.now())
print()
embeddings = encoder.apply_encoder(description_list)
print()
2019-06-13 23:32:56 +03:00
print(" * adding", len(entity_list), "entities", datetime.datetime.now())
2019-07-23 13:17:19 +03:00
kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
)
print()
2019-06-13 23:32:56 +03:00
print(" * adding aliases", datetime.datetime.now())
print()
2019-07-23 13:17:19 +03:00
_add_aliases(
kb,
title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ,
prior_prob_input=prior_prob_input,
)
2019-06-19 10:15:43 +03:00
print()
print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases())
print("done with kb", datetime.datetime.now())
return kb
2019-07-23 13:17:19 +03:00
def _write_entity_files(
entity_def_output, entity_descr_output, title_to_id, id_to_descr
):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
2019-06-19 10:15:43 +03:00
2019-07-23 13:17:19 +03:00
with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")
2019-06-06 21:22:14 +03:00
def get_entity_to_id(entity_def_output):
entity_to_id = dict()
2019-07-23 13:17:19 +03:00
with entity_def_output.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
entity_to_id[row[0]] = row[1]
return entity_to_id
def get_id_to_description(entity_descr_output):
id_to_desc = dict()
2019-07-23 13:17:19 +03:00
with entity_descr_output.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
for row in csvreader:
id_to_desc[row[0]] = row[1]
return id_to_desc
2019-06-19 10:15:43 +03:00
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
wp_titles = title_to_id.keys()
# adding aliases with prior probabilities
2019-06-06 21:22:14 +03:00
# we can read this file sequentially, it's sorted by alias, and then by count
2019-07-23 13:17:19 +03:00
with prior_prob_input.open("r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
previous_alias = None
total_count = 0
2019-06-24 11:55:04 +03:00
counts = []
entities = []
while line:
2019-07-23 13:17:19 +03:00
splits = line.replace("\n", "").split(sep="|")
new_alias = splits[0]
count = int(splits[1])
entity = splits[2]
if new_alias != previous_alias and previous_alias:
# done reading the previous alias --> output
if len(entities) > 0:
2019-06-24 11:55:04 +03:00
selected_entities = []
prior_probs = []
for ent_count, ent_string in zip(counts, entities):
if ent_string in wp_titles:
wd_id = title_to_id[ent_string]
p_entity_givenalias = ent_count / total_count
selected_entities.append(wd_id)
prior_probs.append(p_entity_givenalias)
if selected_entities:
try:
2019-07-23 13:17:19 +03:00
kb.add_alias(
alias=previous_alias,
entities=selected_entities,
probabilities=prior_probs,
)
except ValueError as e:
print(e)
total_count = 0
2019-06-24 11:55:04 +03:00
counts = []
entities = []
total_count += count
if len(entities) < max_entities_per_alias and count >= min_occ:
counts.append(count)
entities.append(entity)
previous_alias = new_alias
line = prior_file.readline()