spaCy/bin/wiki_entity_linking/kb_creator.py
Sofie Van Landeghem e48a09df4e Example class for training data (#4543)
* OrigAnnot class instead of gold.orig_annot list of zipped tuples

* from_orig to replace from_annot_tuples

* rename to RawAnnot

* some unit tests for GoldParse creation and internal format

* removing orig_annot and switching to lists instead of tuple

* rewriting tuples to use RawAnnot (+ debug statements, WIP)

* fix pop() changing the data

* small fixes

* pop-append fixes

* return RawAnnot for existing GoldParse to have uniform interface

* clean up imports

* fix merge_sents

* add unit test for 4402 with new structure (not working yet)

* introduce DocAnnot

* typo fixes

* add unit test for merge_sents

* rename from_orig to from_raw

* fixing unit tests

* fix nn parser

* read_annots to produce text, doc_annot pairs

* _make_golds fix

* rename golds_to_gold_annots

* small fixes

* fix encoding

* have golds_to_gold_annots use DocAnnot

* missed a spot

* merge_sents as function in DocAnnot

* allow specifying only part of the token-level annotations

* refactor with Example class + underlying dicts

* pipeline components to work with Example objects (wip)

* input checking

* fix yielding

* fix calls to update

* small fixes

* fix scorer unit test with new format

* fix kwargs order

* fixes for ud and conllu scripts

* fix reading data for conllu script

* add in proper errors (not fixed numbering yet to avoid merge conflicts)

* fixing few more small bugs

* fix EL script
2019-11-11 17:35:27 +01:00

162 lines
5.9 KiB
Python

# coding: utf-8
from __future__ import unicode_literals
import logging
from spacy.kb import KnowledgeBase
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
from bin.wiki_entity_linking import wiki_io as io
logger = logging.getLogger(__name__)
def create_kb(
nlp,
max_entities_per_alias,
min_entity_freq,
min_occ,
entity_def_path,
entity_descr_path,
entity_alias_path,
entity_freq_path,
prior_prob_path,
entity_vector_length,
):
# Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
entity_list, filtered_title_to_id = _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length)
_define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path)
return kb
def _define_entities(nlp, kb, entity_def_path, entity_descr_path, min_entity_freq, entity_freq_path, entity_vector_length):
# read the mappings from file
title_to_id = io.read_title_to_id(entity_def_path)
id_to_descr = io.read_id_to_descr(entity_descr_path)
# check the length of the nlp vectors
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
input_dim = nlp.vocab.vectors_length
logger.info("Loaded pretrained vectors of size %s" % input_dim)
else:
raise ValueError(
"The `nlp` object should have access to pretrained word vectors, "
" cf. https://spacy.io/usage/models#languages."
)
logger.info("Filtering entities with fewer than {} mentions or no description".format(min_entity_freq))
entity_frequencies = io.read_entity_to_count(entity_freq_path)
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
title_to_id,
id_to_descr,
entity_frequencies,
min_entity_freq
)
logger.info("Kept {} entities from the set of {}".format(len(description_list), len(title_to_id.keys())))
logger.info("Training entity encoder")
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
encoder.train(description_list=description_list, to_print=True)
logger.info("Getting entity embeddings")
embeddings = encoder.apply_encoder(description_list)
logger.info("Adding {} entities".format(len(entity_list)))
kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
)
return entity_list, filtered_title_to_id
def _define_aliases(kb, entity_alias_path, entity_list, filtered_title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
logger.info("Adding aliases from Wikipedia and Wikidata")
_add_aliases(
kb,
entity_list=entity_list,
title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ,
prior_prob_path=prior_prob_path,
)
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
min_entity_freq: int = 10):
filtered_title_to_id = dict()
entity_list = []
description_list = []
frequency_list = []
for title, entity in title_to_id.items():
freq = entity_frequencies.get(title, 0)
desc = id_to_descr.get(entity, None)
if desc and freq > min_entity_freq:
entity_list.append(entity)
description_list.append(desc)
frequency_list.append(freq)
filtered_title_to_id[title] = entity
return filtered_title_to_id, entity_list, description_list, frequency_list
def _add_aliases(kb, entity_list, title_to_id, max_entities_per_alias, min_occ, prior_prob_path):
wp_titles = title_to_id.keys()
# adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count
logger.info("Adding WP aliases")
with prior_prob_path.open("r", encoding="utf8") as prior_file:
# skip header
prior_file.readline()
line = prior_file.readline()
previous_alias = None
total_count = 0
counts = []
entities = []
while line:
splits = line.replace("\n", "").split(sep="|")
new_alias = splits[0]
count = int(splits[1])
entity = splits[2]
if new_alias != previous_alias and previous_alias:
# done reading the previous alias --> output
if len(entities) > 0:
selected_entities = []
prior_probs = []
for ent_count, ent_string in zip(counts, entities):
if ent_string in wp_titles:
wd_id = title_to_id[ent_string]
p_entity_givenalias = ent_count / total_count
selected_entities.append(wd_id)
prior_probs.append(p_entity_givenalias)
if selected_entities:
try:
kb.add_alias(
alias=previous_alias,
entities=selected_entities,
probabilities=prior_probs,
)
except ValueError as e:
logger.error(e)
total_count = 0
counts = []
entities = []
total_count += count
if len(entities) < max_entities_per_alias and count >= min_occ:
counts.append(count)
entities.append(entity)
previous_alias = new_alias
line = prior_file.readline()
def read_kb(nlp, kb_file):
kb = KnowledgeBase(vocab=nlp.vocab)
kb.load_bulk(kb_file)
return kb