mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 12:18:04 +03:00
2d249a9502
* fix overflow error on windows * more documentation & logging fixes * md fix * 3 different limit parameters to play with execution time * bug fixes directory locations * small fixes * exclude dev test articles from prior probabilities stats * small fixes * filtering wikidata entities, removing numeric and meta items * adding aliases from wikidata also to the KB * fix adding WD aliases * adding also new aliases to previously added entities * fixing comma's * small doc fixes * adding subclassof filtering * append alias functionality in KB * prevent appending the same entity-alias pair * fix for appending WD aliases * remove date filter * remove unnecessary import * small corrections and reformatting * remove WD aliases for now (too slow) * removing numeric entities from training and evaluation * small fixes * shortcut during prediction if there is only one candidate * add counts and fscore logging, remove FP NER from evaluation * fix entity_linker.predict to take docs instead of single sentences * remove enumeration sentences from the WP dataset * entity_linker.update to process full doc instead of single sentence * spelling corrections and dump locations in readme * NLP IO fix * reading KB is unnecessary at the end of the pipeline * small logging fix * remove empty files
162 lines
5.9 KiB
Python
162 lines
5.9 KiB
Python
# coding: utf-8
|
|
from __future__ import unicode_literals
|
|
|
|
import logging
|
|
|
|
from spacy.kb import KnowledgeBase
|
|
|
|
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
|
|
from bin.wiki_entity_linking import wiki_io as io
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def create_kb(
|
|
nlp,
|
|
max_entities_per_alias,
|
|
min_entity_freq,
|
|
min_occ,
|
|
entity_def_path,
|
|
entity_descr_path,
|
|
entity_alias_path,
|
|
entity_freq_path,
|
|
prior_prob_path,
|
|
entity_vector_length,
|
|
):
|
|
# Create the knowledge base from Wikidata entries
|
|
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
|
|
entity_list, filtered_title_to_id = _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length)
|
|
_define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path)
|
|
return kb
|
|
|
|
|
|
def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length):
|
|
# read the mappings from file
|
|
title_to_id = io.read_title_to_id(entity_def_path)
|
|
id_to_descr = io.read_id_to_descr(entity_descr_path)
|
|
|
|
# check the length of the nlp vectors
|
|
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
|
|
input_dim = nlp.vocab.vectors_length
|
|
logger.info("Loaded pretrained vectors of size %s" % input_dim)
|
|
else:
|
|
raise ValueError(
|
|
"The `nlp` object should have access to pretrained word vectors, "
|
|
" cf. https://spacy.io/usage/models#languages."
|
|
)
|
|
|
|
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
|
|
entity_frequencies = io.read_entity_to_count(entity_freq_path)
|
|
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
|
|
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
|
|
title_to_id,
|
|
id_to_descr,
|
|
entity_frequencies,
|
|
min_entity_freq
|
|
)
|
|
logger.info("Kept {} entities from the set of {}".format(len(description_list), len(title_to_id.keys())))
|
|
|
|
logger.info("Training entity encoder")
|
|
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
|
|
encoder.train(description_list=description_list, to_print=True)
|
|
|
|
logger.info("Getting entity embeddings")
|
|
embeddings = encoder.apply_encoder(description_list)
|
|
|
|
logger.info("Adding {} entities".format(len(entity_list)))
|
|
kb.set_entities(
|
|
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
|
)
|
|
return entity_list, filtered_title_to_id
|
|
|
|
|
|
def _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
|
|
logger.info("Adding aliases from Wikipedia and Wikidata")
|
|
_add_aliases(
|
|
kb,
|
|
entity_list=entity_list,
|
|
title_to_id=filtered_title_to_id,
|
|
max_entities_per_alias=max_entities_per_alias,
|
|
min_occ=min_occ,
|
|
prior_prob_path=prior_prob_path,
|
|
)
|
|
|
|
|
|
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
|
|
min_entity_freq: int = 10):
|
|
filtered_title_to_id = dict()
|
|
entity_list = []
|
|
description_list = []
|
|
frequency_list = []
|
|
for title, entity in title_to_id.items():
|
|
freq = entity_frequencies.get(title, 0)
|
|
desc = id_to_descr.get(entity, None)
|
|
if desc and freq > min_entity_freq:
|
|
entity_list.append(entity)
|
|
description_list.append(desc)
|
|
frequency_list.append(freq)
|
|
filtered_title_to_id[title] = entity
|
|
return filtered_title_to_id, entity_list, description_list, frequency_list
|
|
|
|
|
|
def _add_aliases(kb, entity_list, title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
|
|
wp_titles = title_to_id.keys()
|
|
|
|
# adding aliases with prior probabilities
|
|
# we can read this file sequentially, it's sorted by alias, and then by count
|
|
logger.info("Adding WP aliases")
|
|
with prior_prob_path.open("r", encoding="utf8") as prior_file:
|
|
# skip header
|
|
prior_file.readline()
|
|
line = prior_file.readline()
|
|
previous_alias = None
|
|
total_count = 0
|
|
counts = []
|
|
entities = []
|
|
while line:
|
|
splits = line.replace("\n", "").split(sep="|")
|
|
new_alias = splits[0]
|
|
count = int(splits[1])
|
|
entity = splits[2]
|
|
|
|
if new_alias != previous_alias and previous_alias:
|
|
# done reading the previous alias --> output
|
|
if len(entities) > 0:
|
|
selected_entities = []
|
|
prior_probs = []
|
|
for ent_count, ent_string in zip(counts, entities):
|
|
if ent_string in wp_titles:
|
|
wd_id = title_to_id[ent_string]
|
|
p_entity_givenalias = ent_count / total_count
|
|
selected_entities.append(wd_id)
|
|
prior_probs.append(p_entity_givenalias)
|
|
|
|
if selected_entities:
|
|
try:
|
|
kb.add_alias(
|
|
alias=previous_alias,
|
|
entities=selected_entities,
|
|
probabilities=prior_probs,
|
|
)
|
|
except ValueError as e:
|
|
logger.error(e)
|
|
total_count = 0
|
|
counts = []
|
|
entities = []
|
|
|
|
total_count += count
|
|
|
|
if len(entities) < max_entities_per_alias and count >= min_occ:
|
|
counts.append(count)
|
|
entities.append(entity)
|
|
previous_alias = new_alias
|
|
|
|
line = prior_file.readline()
|
|
|
|
|
|
def read_kb(nlp, kb_file):
|
|
kb = KnowledgeBase(vocab=nlp.vocab)
|
|
kb.load_bulk(kb_file)
|
|
return kb
|