diff --git a/bin/__init__.py b/bin/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bin/ud/conll17_ud_eval.py b/bin/ud/conll17_ud_eval.py index 78a976a6d..88acfabac 100644 --- a/bin/ud/conll17_ud_eval.py +++ b/bin/ud/conll17_ud_eval.py @@ -292,8 +292,8 @@ def evaluate(gold_ud, system_ud, deprel_weights=None, check_parse=True): def spans_score(gold_spans, system_spans): correct, gi, si = 0, 0, 0 - undersegmented = list() - oversegmented = list() + undersegmented = [] + oversegmented = [] combo = 0 previous_end_si_earlier = False previous_end_gi_earlier = False diff --git a/bin/wiki_entity_linking/__init__.py b/bin/wiki_entity_linking/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/bin/wiki_entity_linking/kb_creator.py b/bin/wiki_entity_linking/kb_creator.py new file mode 100644 index 000000000..e8e081cef --- /dev/null +++ b/bin/wiki_entity_linking/kb_creator.py @@ -0,0 +1,171 @@ +# coding: utf-8 +from __future__ import unicode_literals + +from .train_descriptions import EntityEncoder +from . import wikidata_processor as wd, wikipedia_processor as wp +from spacy.kb import KnowledgeBase + +import csv +import datetime + + +INPUT_DIM = 300 # dimension of pre-trained input vectors +DESC_WIDTH = 64 # dimension of output entity vectors + + +def create_kb(nlp, max_entities_per_alias, min_entity_freq, min_occ, + entity_def_output, entity_descr_output, + count_input, prior_prob_input, wikidata_input): + # Create the knowledge base from Wikidata entries + kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=DESC_WIDTH) + + # disable this part of the pipeline when rerunning the KB generation from preprocessed files + read_raw_data = True + + if read_raw_data: + print() + print(" * _read_wikidata_entities", datetime.datetime.now()) + title_to_id, id_to_descr = wd.read_wikidata_entities_json(wikidata_input) + + # write the title-ID and ID-description mappings to file + _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr) + + else: + # read the mappings from file + title_to_id = get_entity_to_id(entity_def_output) + id_to_descr = get_id_to_description(entity_descr_output) + + print() + print(" * _get_entity_frequencies", datetime.datetime.now()) + print() + entity_frequencies = wp.get_all_frequencies(count_input=count_input) + + # filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise + filtered_title_to_id = dict() + entity_list = [] + description_list = [] + frequency_list = [] + for title, entity in title_to_id.items(): + freq = entity_frequencies.get(title, 0) + desc = id_to_descr.get(entity, None) + if desc and freq > min_entity_freq: + entity_list.append(entity) + description_list.append(desc) + frequency_list.append(freq) + filtered_title_to_id[title] = entity + + print("Kept", len(filtered_title_to_id.keys()), "out of", len(title_to_id.keys()), + "titles with filter frequency", min_entity_freq) + + print() + print(" * train entity encoder", datetime.datetime.now()) + print() + encoder = EntityEncoder(nlp, INPUT_DIM, DESC_WIDTH) + encoder.train(description_list=description_list, to_print=True) + + print() + print(" * get entity embeddings", datetime.datetime.now()) + print() + embeddings = encoder.apply_encoder(description_list) + + print() + print(" * adding", len(entity_list), "entities", datetime.datetime.now()) + kb.set_entities(entity_list=entity_list, prob_list=frequency_list, vector_list=embeddings) + + print() + print(" * adding aliases", datetime.datetime.now()) + print() + _add_aliases(kb, title_to_id=filtered_title_to_id, + max_entities_per_alias=max_entities_per_alias, min_occ=min_occ, + prior_prob_input=prior_prob_input) + + print() + print("kb size:", len(kb), kb.get_size_entities(), kb.get_size_aliases()) + + print("done with kb", datetime.datetime.now()) + return kb + + +def _write_entity_files(entity_def_output, entity_descr_output, title_to_id, id_to_descr): + with open(entity_def_output, mode='w', encoding='utf8') as id_file: + id_file.write("WP_title" + "|" + "WD_id" + "\n") + for title, qid in title_to_id.items(): + id_file.write(title + "|" + str(qid) + "\n") + + with open(entity_descr_output, mode='w', encoding='utf8') as descr_file: + descr_file.write("WD_id" + "|" + "description" + "\n") + for qid, descr in id_to_descr.items(): + descr_file.write(str(qid) + "|" + descr + "\n") + + +def get_entity_to_id(entity_def_output): + entity_to_id = dict() + with open(entity_def_output, 'r', encoding='utf8') as csvfile: + csvreader = csv.reader(csvfile, delimiter='|') + # skip header + next(csvreader) + for row in csvreader: + entity_to_id[row[0]] = row[1] + return entity_to_id + + +def get_id_to_description(entity_descr_output): + id_to_desc = dict() + with open(entity_descr_output, 'r', encoding='utf8') as csvfile: + csvreader = csv.reader(csvfile, delimiter='|') + # skip header + next(csvreader) + for row in csvreader: + id_to_desc[row[0]] = row[1] + return id_to_desc + + +def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input): + wp_titles = title_to_id.keys() + + # adding aliases with prior probabilities + # we can read this file sequentially, it's sorted by alias, and then by count + with open(prior_prob_input, mode='r', encoding='utf8') as prior_file: + # skip header + prior_file.readline() + line = prior_file.readline() + previous_alias = None + total_count = 0 + counts = [] + entities = [] + while line: + splits = line.replace('\n', "").split(sep='|') + new_alias = splits[0] + count = int(splits[1]) + entity = splits[2] + + if new_alias != previous_alias and previous_alias: + # done reading the previous alias --> output + if len(entities) > 0: + selected_entities = [] + prior_probs = [] + for ent_count, ent_string in zip(counts, entities): + if ent_string in wp_titles: + wd_id = title_to_id[ent_string] + p_entity_givenalias = ent_count / total_count + selected_entities.append(wd_id) + prior_probs.append(p_entity_givenalias) + + if selected_entities: + try: + kb.add_alias(alias=previous_alias, entities=selected_entities, probabilities=prior_probs) + except ValueError as e: + print(e) + total_count = 0 + counts = [] + entities = [] + + total_count += count + + if len(entities) < max_entities_per_alias and count >= min_occ: + counts.append(count) + entities.append(entity) + previous_alias = new_alias + + line = prior_file.readline() + diff --git a/bin/wiki_entity_linking/train_descriptions.py b/bin/wiki_entity_linking/train_descriptions.py new file mode 100644 index 000000000..6a4d046e5 --- /dev/null +++ b/bin/wiki_entity_linking/train_descriptions.py @@ -0,0 +1,121 @@ +# coding: utf-8 +from random import shuffle + +import numpy as np + +from spacy._ml import zero_init, create_default_optimizer +from spacy.cli.pretrain import get_cossim_loss + +from thinc.v2v import Model +from thinc.api import chain +from thinc.neural._classes.affine import Affine + + +class EntityEncoder: + """ + Train the embeddings of entity descriptions to fit a fixed-size entity vector (e.g. 64D). + This entity vector will be stored in the KB, for further downstream use in the entity model. + """ + + DROP = 0 + EPOCHS = 5 + STOP_THRESHOLD = 0.04 + + BATCH_SIZE = 1000 + + def __init__(self, nlp, input_dim, desc_width): + self.nlp = nlp + self.input_dim = input_dim + self.desc_width = desc_width + + def apply_encoder(self, description_list): + if self.encoder is None: + raise ValueError("Can not apply encoder before training it") + + batch_size = 100000 + + start = 0 + stop = min(batch_size, len(description_list)) + encodings = [] + + while start < len(description_list): + docs = list(self.nlp.pipe(description_list[start:stop])) + doc_embeddings = [self._get_doc_embedding(doc) for doc in docs] + enc = self.encoder(np.asarray(doc_embeddings)) + encodings.extend(enc.tolist()) + + start = start + batch_size + stop = min(stop + batch_size, len(description_list)) + + return encodings + + def train(self, description_list, to_print=False): + processed, loss = self._train_model(description_list) + if to_print: + print("Trained on", processed, "entities across", self.EPOCHS, "epochs") + print("Final loss:", loss) + + def _train_model(self, description_list): + # TODO: when loss gets too low, a 'mean of empty slice' warning is thrown by numpy + + self._build_network(self.input_dim, self.desc_width) + + processed = 0 + loss = 1 + descriptions = description_list.copy() # copy this list so that shuffling does not affect other functions + + for i in range(self.EPOCHS): + shuffle(descriptions) + + batch_nr = 0 + start = 0 + stop = min(self.BATCH_SIZE, len(descriptions)) + + while loss > self.STOP_THRESHOLD and start < len(descriptions): + batch = [] + for descr in descriptions[start:stop]: + doc = self.nlp(descr) + doc_vector = self._get_doc_embedding(doc) + batch.append(doc_vector) + + loss = self._update(batch) + print(i, batch_nr, loss) + processed += len(batch) + + batch_nr += 1 + start = start + self.BATCH_SIZE + stop = min(stop + self.BATCH_SIZE, len(descriptions)) + + return processed, loss + + @staticmethod + def _get_doc_embedding(doc): + indices = np.zeros((len(doc),), dtype="i") + for i, word in enumerate(doc): + if word.orth in doc.vocab.vectors.key2row: + indices[i] = doc.vocab.vectors.key2row[word.orth] + else: + indices[i] = 0 + word_vectors = doc.vocab.vectors.data[indices] + doc_vector = np.mean(word_vectors, axis=0) + return doc_vector + + def _build_network(self, orig_width, hidden_with): + with Model.define_operators({">>": chain}): + # very simple encoder-decoder model + self.encoder = ( + Affine(hidden_with, orig_width) + ) + self.model = self.encoder >> zero_init(Affine(orig_width, hidden_with, drop_factor=0.0)) + self.sgd = create_default_optimizer(self.model.ops) + + def _update(self, vectors): + predictions, bp_model = self.model.begin_update(np.asarray(vectors), drop=self.DROP) + loss, d_scores = self._get_loss(scores=predictions, golds=np.asarray(vectors)) + bp_model(d_scores, sgd=self.sgd) + return loss / len(vectors) + + @staticmethod + def _get_loss(golds, scores): + loss, gradients = get_cossim_loss(scores, golds) + return loss, gradients diff --git a/bin/wiki_entity_linking/training_set_creator.py b/bin/wiki_entity_linking/training_set_creator.py new file mode 100644 index 000000000..5d401bb3f --- /dev/null +++ b/bin/wiki_entity_linking/training_set_creator.py @@ -0,0 +1,353 @@ +# coding: utf-8 +from __future__ import unicode_literals + +import os +import re +import bz2 +import datetime + +from spacy.gold import GoldParse +from bin.wiki_entity_linking import kb_creator + +""" +Process Wikipedia interlinks to generate a training dataset for the EL algorithm. +Gold-standard entities are stored in one file in standoff format (by character offset). +""" + +ENTITY_FILE = "gold_entities.csv" + + +def create_training(wikipedia_input, entity_def_input, training_output): + wp_to_id = kb_creator.get_entity_to_id(entity_def_input) + _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None) + + +def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None): + """ + Read the XML wikipedia data to parse out training data: + raw text data + positive instances + """ + title_regex = re.compile(r'(?<=).*(?=)') + id_regex = re.compile(r'(?<=)\d*(?=)') + + read_ids = set() + entityfile_loc = training_output / ENTITY_FILE + with open(entityfile_loc, mode="w", encoding='utf8') as entityfile: + # write entity training header file + _write_training_entity(outputfile=entityfile, + article_id="article_id", + alias="alias", + entity="WD_id", + start="start", + end="end") + + with bz2.open(wikipedia_input, mode='rb') as file: + line = file.readline() + cnt = 0 + article_text = "" + article_title = None + article_id = None + reading_text = False + reading_revision = False + while line and (not limit or cnt < limit): + if cnt % 1000000 == 0: + print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump") + clean_line = line.strip().decode("utf-8") + + if clean_line == "": + reading_revision = True + elif clean_line == "": + reading_revision = False + + # Start reading new page + if clean_line == "": + article_text = "" + article_title = None + article_id = None + + # finished reading this page + elif clean_line == "": + if article_id: + try: + _process_wp_text(wp_to_id, entityfile, article_id, article_title, article_text.strip(), + training_output) + except Exception as e: + print("Error processing article", article_id, article_title, e) + else: + print("Done processing a page, but couldn't find an article_id ?", article_title) + article_text = "" + article_title = None + article_id = None + reading_text = False + reading_revision = False + + # start reading text within a page + if ").*(?= 2: + reading_special_case = True + + if open_read == 2 and reading_text: + reading_text = False + reading_entity = True + reading_mention = False + + # we just finished reading an entity + if open_read == 0 and not reading_text: + if '#' in entity_buffer or entity_buffer.startswith(':'): + reading_special_case = True + # Ignore cases with nested structures like File: handles etc + if not reading_special_case: + if not mention_buffer: + mention_buffer = entity_buffer + start = len(final_text) + end = start + len(mention_buffer) + qid = wp_to_id.get(entity_buffer, None) + if qid: + _write_training_entity(outputfile=entityfile, + article_id=article_id, + alias=mention_buffer, + entity=qid, + start=start, + end=end) + found_entities = True + final_text += mention_buffer + + entity_buffer = "" + mention_buffer = "" + + reading_text = True + reading_entity = False + reading_mention = False + reading_special_case = False + + if found_entities: + _write_training_article(article_id=article_id, clean_text=final_text, training_output=training_output) + + +info_regex = re.compile(r'{[^{]*?}') +htlm_regex = re.compile(r'<!--[^-]*-->') +category_regex = re.compile(r'\[\[Category:[^\[]*]]') +file_regex = re.compile(r'\[\[File:[^[\]]+]]') +ref_regex = re.compile(r'<ref.*?>') # non-greedy +ref_2_regex = re.compile(r'</ref.*?>') # non-greedy + + +def _get_clean_wp_text(article_text): + clean_text = article_text.strip() + + # remove bolding & italic markup + clean_text = clean_text.replace('\'\'\'', '') + clean_text = clean_text.replace('\'\'', '') + + # remove nested {{info}} statements by removing the inner/smallest ones first and iterating + try_again = True + previous_length = len(clean_text) + while try_again: + clean_text = info_regex.sub('', clean_text) # non-greedy match excluding a nested { + if len(clean_text) < previous_length: + try_again = True + else: + try_again = False + previous_length = len(clean_text) + + # remove HTML comments + clean_text = htlm_regex.sub('', clean_text) + + # remove Category and File statements + clean_text = category_regex.sub('', clean_text) + clean_text = file_regex.sub('', clean_text) + + # remove multiple = + while '==' in clean_text: + clean_text = clean_text.replace("==", "=") + + clean_text = clean_text.replace(". =", ".") + clean_text = clean_text.replace(" = ", ". ") + clean_text = clean_text.replace("= ", ".") + clean_text = clean_text.replace(" =", "") + + # remove refs (non-greedy match) + clean_text = ref_regex.sub('', clean_text) + clean_text = ref_2_regex.sub('', clean_text) + + # remove additional wikiformatting + clean_text = re.sub(r'<blockquote>', '', clean_text) + clean_text = re.sub(r'</blockquote>', '', clean_text) + + # change special characters back to normal ones + clean_text = clean_text.replace(r'<', '<') + clean_text = clean_text.replace(r'>', '>') + clean_text = clean_text.replace(r'"', '"') + clean_text = clean_text.replace(r'&nbsp;', ' ') + clean_text = clean_text.replace(r'&', '&') + + # remove multiple spaces + while ' ' in clean_text: + clean_text = clean_text.replace(' ', ' ') + + return clean_text.strip() + + +def _write_training_article(article_id, clean_text, training_output): + file_loc = training_output / str(article_id) + ".txt" + with open(file_loc, mode='w', encoding='utf8') as outputfile: + outputfile.write(clean_text) + + +def _write_training_entity(outputfile, article_id, alias, entity, start, end): + outputfile.write(article_id + "|" + alias + "|" + entity + "|" + str(start) + "|" + str(end) + "\n") + + +def is_dev(article_id): + return article_id.endswith("3") + + +def read_training(nlp, training_dir, dev, limit): + # This method provides training examples that correspond to the entity annotations found by the nlp object + entityfile_loc = training_dir / ENTITY_FILE + data = [] + + # assume the data is written sequentially, so we can reuse the article docs + current_article_id = None + current_doc = None + ents_by_offset = dict() + skip_articles = set() + total_entities = 0 + + with open(entityfile_loc, mode='r', encoding='utf8') as file: + for line in file: + if not limit or len(data) < limit: + fields = line.replace('\n', "").split(sep='|') + article_id = fields[0] + alias = fields[1] + wp_title = fields[2] + start = fields[3] + end = fields[4] + + if dev == is_dev(article_id) and article_id != "article_id" and article_id not in skip_articles: + if not current_doc or (current_article_id != article_id): + # parse the new article text + file_name = article_id + ".txt" + try: + with open(os.path.join(training_dir, file_name), mode="r", encoding='utf8') as f: + text = f.read() + if len(text) < 30000: # threshold for convenience / speed of processing + current_doc = nlp(text) + current_article_id = article_id + ents_by_offset = dict() + for ent in current_doc.ents: + sent_length = len(ent.sent) + # custom filtering to avoid too long or too short sentences + if 5 < sent_length < 100: + ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent + else: + skip_articles.add(article_id) + current_doc = None + except Exception as e: + print("Problem parsing article", article_id, e) + skip_articles.add(article_id) + raise e + + # repeat checking this condition in case an exception was thrown + if current_doc and (current_article_id == article_id): + found_ent = ents_by_offset.get(start + "_" + end, None) + if found_ent: + if found_ent.text != alias: + skip_articles.add(article_id) + current_doc = None + else: + sent = found_ent.sent.as_doc() + # currently feeding the gold data one entity per sentence at a time + gold_start = int(start) - found_ent.sent.start_char + gold_end = int(end) - found_ent.sent.start_char + gold_entities = [(gold_start, gold_end, wp_title)] + gold = GoldParse(doc=sent, links=gold_entities) + data.append((sent, gold)) + total_entities += 1 + if len(data) % 2500 == 0: + print(" -read", total_entities, "entities") + + print(" -read", total_entities, "entities") + return data diff --git a/bin/wiki_entity_linking/wikidata_processor.py b/bin/wiki_entity_linking/wikidata_processor.py new file mode 100644 index 000000000..a32a0769a --- /dev/null +++ b/bin/wiki_entity_linking/wikidata_processor.py @@ -0,0 +1,119 @@ +# coding: utf-8 +from __future__ import unicode_literals + +import bz2 +import json +import datetime + + +def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): + # Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. + # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ + + lang = 'en' + site_filter = 'enwiki' + + # properties filter (currently disabled to get ALL data) + prop_filter = dict() + # prop_filter = {'P31': {'Q5', 'Q15632617'}} # currently defined as OR: one property suffices to be selected + + title_to_id = dict() + id_to_descr = dict() + + # parse appropriate fields - depending on what we need in the KB + parse_properties = False + parse_sitelinks = True + parse_labels = False + parse_descriptions = True + parse_aliases = False + parse_claims = False + + with bz2.open(wikidata_file, mode='rb') as file: + line = file.readline() + cnt = 0 + while line and (not limit or cnt < limit): + if cnt % 500000 == 0: + print(datetime.datetime.now(), "processed", cnt, "lines of WikiData dump") + clean_line = line.strip() + if clean_line.endswith(b","): + clean_line = clean_line[:-1] + if len(clean_line) > 1: + obj = json.loads(clean_line) + entry_type = obj["type"] + + if entry_type == "item": + # filtering records on their properties (currently disabled to get ALL data) + # keep = False + keep = True + + claims = obj["claims"] + if parse_claims: + for prop, value_set in prop_filter.items(): + claim_property = claims.get(prop, None) + if claim_property: + for cp in claim_property: + cp_id = cp['mainsnak'].get('datavalue', {}).get('value', {}).get('id') + cp_rank = cp['rank'] + if cp_rank != "deprecated" and cp_id in value_set: + keep = True + + if keep: + unique_id = obj["id"] + + if to_print: + print("ID:", unique_id) + print("type:", entry_type) + + # parsing all properties that refer to other entities + if parse_properties: + for prop, claim_property in claims.items(): + cp_dicts = [cp['mainsnak']['datavalue'].get('value') for cp in claim_property + if cp['mainsnak'].get('datavalue')] + cp_values = [cp_dict.get('id') for cp_dict in cp_dicts if isinstance(cp_dict, dict) + if cp_dict.get('id') is not None] + if cp_values: + if to_print: + print("prop:", prop, cp_values) + + found_link = False + if parse_sitelinks: + site_value = obj["sitelinks"].get(site_filter, None) + if site_value: + site = site_value['title'] + if to_print: + print(site_filter, ":", site) + title_to_id[site] = unique_id + found_link = True + + if parse_labels: + labels = obj["labels"] + if labels: + lang_label = labels.get(lang, None) + if lang_label: + if to_print: + print("label (" + lang + "):", lang_label["value"]) + + if found_link and parse_descriptions: + descriptions = obj["descriptions"] + if descriptions: + lang_descr = descriptions.get(lang, None) + if lang_descr: + if to_print: + print("description (" + lang + "):", lang_descr["value"]) + id_to_descr[unique_id] = lang_descr["value"] + + if parse_aliases: + aliases = obj["aliases"] + if aliases: + lang_aliases = aliases.get(lang, None) + if lang_aliases: + for item in lang_aliases: + if to_print: + print("alias (" + lang + "):", item["value"]) + + if to_print: + print() + line = file.readline() + cnt += 1 + + return title_to_id, id_to_descr diff --git a/bin/wiki_entity_linking/wikipedia_processor.py b/bin/wiki_entity_linking/wikipedia_processor.py new file mode 100644 index 000000000..c02e472bc --- /dev/null +++ b/bin/wiki_entity_linking/wikipedia_processor.py @@ -0,0 +1,182 @@ +# coding: utf-8 +from __future__ import unicode_literals + +import re +import bz2 +import csv +import datetime + +""" +Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions. +Write these results to file for downstream KB and training data generation. +""" + +map_alias_to_link = dict() + +# these will/should be matched ignoring case +wiki_namespaces = ["b", "betawikiversity", "Book", "c", "Category", "Commons", + "d", "dbdump", "download", "Draft", "Education", "Foundation", + "Gadget", "Gadget definition", "gerrit", "File", "Help", "Image", "Incubator", + "m", "mail", "mailarchive", "media", "MediaWiki", "MediaWiki talk", "Mediawikiwiki", + "MediaZilla", "Meta", "Metawikipedia", "Module", + "mw", "n", "nost", "oldwikisource", "outreach", "outreachwiki", "otrs", "OTRSwiki", + "Portal", "phab", "Phabricator", "Project", "q", "quality", "rev", + "s", "spcom", "Special", "species", "Strategy", "sulutil", "svn", + "Talk", "Template", "Template talk", "Testwiki", "ticket", "TimedText", "Toollabs", "tools", + "tswiki", "User", "User talk", "v", "voy", + "w", "Wikibooks", "Wikidata", "wikiHow", "Wikinvest", "wikilivres", "Wikimedia", "Wikinews", + "Wikipedia", "Wikipedia talk", "Wikiquote", "Wikisource", "Wikispecies", "Wikitech", + "Wikiversity", "Wikivoyage", "wikt", "wiktionary", "wmf", "wmania", "WP"] + +# find the links +link_regex = re.compile(r'\[\[[^\[\]]*\]\]') + +# match on interwiki links, e.g. `en:` or `:fr:` +ns_regex = r":?" + "[a-z][a-z]" + ":" + +# match on Namespace: optionally preceded by a : +for ns in wiki_namespaces: + ns_regex += "|" + ":?" + ns + ":" + +ns_regex = re.compile(ns_regex, re.IGNORECASE) + + +def read_wikipedia_prior_probs(wikipedia_input, prior_prob_output): + """ + Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities. + The full file takes about 2h to parse 1100M lines. + It works relatively fast because it runs line by line, irrelevant of which article the intrawiki is from. + """ + with bz2.open(wikipedia_input, mode='rb') as file: + line = file.readline() + cnt = 0 + while line: + if cnt % 5000000 == 0: + print(datetime.datetime.now(), "processed", cnt, "lines of Wikipedia dump") + clean_line = line.strip().decode("utf-8") + + aliases, entities, normalizations = get_wp_links(clean_line) + for alias, entity, norm in zip(aliases, entities, normalizations): + _store_alias(alias, entity, normalize_alias=norm, normalize_entity=True) + _store_alias(alias, entity, normalize_alias=norm, normalize_entity=True) + + line = file.readline() + cnt += 1 + + # write all aliases and their entities and count occurrences to file + with open(prior_prob_output, mode='w', encoding='utf8') as outputfile: + outputfile.write("alias" + "|" + "count" + "|" + "entity" + "\n") + for alias, alias_dict in sorted(map_alias_to_link.items(), key=lambda x: x[0]): + for entity, count in sorted(alias_dict.items(), key=lambda x: x[1], reverse=True): + outputfile.write(alias + "|" + str(count) + "|" + entity + "\n") + + +def _store_alias(alias, entity, normalize_alias=False, normalize_entity=True): + alias = alias.strip() + entity = entity.strip() + + # remove everything after # as this is not part of the title but refers to a specific paragraph + if normalize_entity: + # wikipedia titles are always capitalized + entity = _capitalize_first(entity.split("#")[0]) + if normalize_alias: + alias = alias.split("#")[0] + + if alias and entity: + alias_dict = map_alias_to_link.get(alias, dict()) + entity_count = alias_dict.get(entity, 0) + alias_dict[entity] = entity_count + 1 + map_alias_to_link[alias] = alias_dict + + +def get_wp_links(text): + aliases = [] + entities = [] + normalizations = [] + + matches = link_regex.findall(text) + for match in matches: + match = match[2:][:-2].replace("_", " ").strip() + + if ns_regex.match(match): + pass # ignore namespaces at the beginning of the string + + # this is a simple [[link]], with the alias the same as the mention + elif "|" not in match: + aliases.append(match) + entities.append(match) + normalizations.append(True) + + # in wiki format, the link is written as [[entity|alias]] + else: + splits = match.split("|") + entity = splits[0].strip() + alias = splits[1].strip() + # specific wiki format [[alias (specification)|]] + if len(alias) == 0 and "(" in entity: + alias = entity.split("(")[0] + aliases.append(alias) + entities.append(entity) + normalizations.append(False) + else: + aliases.append(alias) + entities.append(entity) + normalizations.append(False) + + return aliases, entities, normalizations + + +def _capitalize_first(text): + if not text: + return None + result = text[0].capitalize() + if len(result) > 0: + result += text[1:] + return result + + +def write_entity_counts(prior_prob_input, count_output, to_print=False): + # Write entity counts for quick access later + entity_to_count = dict() + total_count = 0 + + with open(prior_prob_input, mode='r', encoding='utf8') as prior_file: + # skip header + prior_file.readline() + line = prior_file.readline() + + while line: + splits = line.replace('\n', "").split(sep='|') + # alias = splits[0] + count = int(splits[1]) + entity = splits[2] + + current_count = entity_to_count.get(entity, 0) + entity_to_count[entity] = current_count + count + + total_count += count + + line = prior_file.readline() + + with open(count_output, mode='w', encoding='utf8') as entity_file: + entity_file.write("entity" + "|" + "count" + "\n") + for entity, count in entity_to_count.items(): + entity_file.write(entity + "|" + str(count) + "\n") + + if to_print: + for entity, count in entity_to_count.items(): + print("Entity count:", entity, count) + print("Total count:", total_count) + + +def get_all_frequencies(count_input): + entity_to_count = dict() + with open(count_input, 'r', encoding='utf8') as csvfile: + csvreader = csv.reader(csvfile, delimiter='|') + # skip header + next(csvreader) + for row in csvreader: + entity_to_count[row[0]] = int(row[1]) + + return entity_to_count + diff --git a/examples/pipeline/dummy_entity_linking.py b/examples/pipeline/dummy_entity_linking.py index 88415d040..0e59db304 100644 --- a/examples/pipeline/dummy_entity_linking.py +++ b/examples/pipeline/dummy_entity_linking.py @@ -9,26 +9,26 @@ from spacy.kb import KnowledgeBase def create_kb(vocab): - kb = KnowledgeBase(vocab=vocab) + kb = KnowledgeBase(vocab=vocab, entity_vector_length=1) # adding entities entity_0 = "Q1004791_Douglas" print("adding entity", entity_0) - kb.add_entity(entity=entity_0, prob=0.5) + kb.add_entity(entity=entity_0, prob=0.5, entity_vector=[0]) entity_1 = "Q42_Douglas_Adams" print("adding entity", entity_1) - kb.add_entity(entity=entity_1, prob=0.5) + kb.add_entity(entity=entity_1, prob=0.5, entity_vector=[1]) entity_2 = "Q5301561_Douglas_Haig" print("adding entity", entity_2) - kb.add_entity(entity=entity_2, prob=0.5) + kb.add_entity(entity=entity_2, prob=0.5, entity_vector=[2]) # adding aliases print() alias_0 = "Douglas" print("adding alias", alias_0) - kb.add_alias(alias=alias_0, entities=[entity_0, entity_1, entity_2], probabilities=[0.1, 0.6, 0.2]) + kb.add_alias(alias=alias_0, entities=[entity_0, entity_1, entity_2], probabilities=[0.6, 0.1, 0.2]) alias_1 = "Douglas Adams" print("adding alias", alias_1) @@ -41,8 +41,12 @@ def create_kb(vocab): def add_el(kb, nlp): - el_pipe = nlp.create_pipe(name='entity_linker', config={"kb": kb}) + el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64}) + el_pipe.set_kb(kb) nlp.add_pipe(el_pipe, last=True) + nlp.begin_training() + el_pipe.context_weight = 0 + el_pipe.prior_weight = 1 for alias in ["Douglas Adams", "Douglas"]: candidates = nlp.linker.kb.get_candidates(alias) @@ -66,6 +70,6 @@ def add_el(kb, nlp): if __name__ == "__main__": - nlp = spacy.load('en_core_web_sm') - my_kb = create_kb(nlp.vocab) - add_el(my_kb, nlp) + my_nlp = spacy.load('en_core_web_sm') + my_kb = create_kb(my_nlp.vocab) + add_el(my_kb, my_nlp) diff --git a/examples/pipeline/wikidata_entity_linking.py b/examples/pipeline/wikidata_entity_linking.py new file mode 100644 index 000000000..17c2976dd --- /dev/null +++ b/examples/pipeline/wikidata_entity_linking.py @@ -0,0 +1,442 @@ +# coding: utf-8 +from __future__ import unicode_literals + +import random +import datetime +from pathlib import Path + +from bin.wiki_entity_linking import training_set_creator, kb_creator, wikipedia_processor as wp +from bin.wiki_entity_linking.kb_creator import DESC_WIDTH + +import spacy +from spacy.kb import KnowledgeBase +from spacy.util import minibatch, compounding + +""" +Demonstrate how to build a knowledge base from WikiData and run an Entity Linking algorithm. +""" + +ROOT_DIR = Path("C:/Users/Sofie/Documents/data/") +OUTPUT_DIR = ROOT_DIR / 'wikipedia' +TRAINING_DIR = OUTPUT_DIR / 'training_data_nel' + +PRIOR_PROB = OUTPUT_DIR / 'prior_prob.csv' +ENTITY_COUNTS = OUTPUT_DIR / 'entity_freq.csv' +ENTITY_DEFS = OUTPUT_DIR / 'entity_defs.csv' +ENTITY_DESCR = OUTPUT_DIR / 'entity_descriptions.csv' + +KB_FILE = OUTPUT_DIR / 'kb_1' / 'kb' +NLP_1_DIR = OUTPUT_DIR / 'nlp_1' +NLP_2_DIR = OUTPUT_DIR / 'nlp_2' + +# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ +WIKIDATA_JSON = ROOT_DIR / 'wikidata' / 'wikidata-20190304-all.json.bz2' + +# get enwiki-latest-pages-articles-multistream.xml.bz2 from https://dumps.wikimedia.org/enwiki/latest/ +ENWIKI_DUMP = ROOT_DIR / 'wikipedia' / 'enwiki-20190320-pages-articles-multistream.xml.bz2' + +# KB construction parameters +MAX_CANDIDATES = 10 +MIN_ENTITY_FREQ = 20 +MIN_PAIR_OCC = 5 + +# model training parameters +EPOCHS = 10 +DROPOUT = 0.5 +LEARN_RATE = 0.005 +L2 = 1e-6 +CONTEXT_WIDTH = 128 + + +def run_pipeline(): + # set the appropriate booleans to define which parts of the pipeline should be re(run) + print("START", datetime.datetime.now()) + print() + nlp_1 = spacy.load('en_core_web_lg') + nlp_2 = None + kb_2 = None + + # one-time methods to create KB and write to file + to_create_prior_probs = False + to_create_entity_counts = False + to_create_kb = False + + # read KB back in from file + to_read_kb = True + to_test_kb = False + + # create training dataset + create_wp_training = False + + # train the EL pipe + train_pipe = True + measure_performance = True + + # test the EL pipe on a simple example + to_test_pipeline = True + + # write the NLP object, read back in and test again + to_write_nlp = True + to_read_nlp = True + test_from_file = False + + # STEP 1 : create prior probabilities from WP (run only once) + if to_create_prior_probs: + print("STEP 1: to_create_prior_probs", datetime.datetime.now()) + wp.read_wikipedia_prior_probs(wikipedia_input=ENWIKI_DUMP, prior_prob_output=PRIOR_PROB) + print() + + # STEP 2 : deduce entity frequencies from WP (run only once) + if to_create_entity_counts: + print("STEP 2: to_create_entity_counts", datetime.datetime.now()) + wp.write_entity_counts(prior_prob_input=PRIOR_PROB, count_output=ENTITY_COUNTS, to_print=False) + print() + + # STEP 3 : create KB and write to file (run only once) + if to_create_kb: + print("STEP 3a: to_create_kb", datetime.datetime.now()) + kb_1 = kb_creator.create_kb(nlp_1, + max_entities_per_alias=MAX_CANDIDATES, + min_entity_freq=MIN_ENTITY_FREQ, + min_occ=MIN_PAIR_OCC, + entity_def_output=ENTITY_DEFS, + entity_descr_output=ENTITY_DESCR, + count_input=ENTITY_COUNTS, + prior_prob_input=PRIOR_PROB, + wikidata_input=WIKIDATA_JSON) + print("kb entities:", kb_1.get_size_entities()) + print("kb aliases:", kb_1.get_size_aliases()) + print() + + print("STEP 3b: write KB and NLP", datetime.datetime.now()) + kb_1.dump(KB_FILE) + nlp_1.to_disk(NLP_1_DIR) + print() + + # STEP 4 : read KB back in from file + if to_read_kb: + print("STEP 4: to_read_kb", datetime.datetime.now()) + nlp_2 = spacy.load(NLP_1_DIR) + kb_2 = KnowledgeBase(vocab=nlp_2.vocab, entity_vector_length=DESC_WIDTH) + kb_2.load_bulk(KB_FILE) + print("kb entities:", kb_2.get_size_entities()) + print("kb aliases:", kb_2.get_size_aliases()) + print() + + # test KB + if to_test_kb: + check_kb(kb_2) + print() + + # STEP 5: create a training dataset from WP + if create_wp_training: + print("STEP 5: create training dataset", datetime.datetime.now()) + training_set_creator.create_training(wikipedia_input=ENWIKI_DUMP, + entity_def_input=ENTITY_DEFS, + training_output=TRAINING_DIR) + + # STEP 6: create and train the entity linking pipe + if train_pipe: + print("STEP 6: training Entity Linking pipe", datetime.datetime.now()) + type_to_int = {label: i for i, label in enumerate(nlp_2.entity.labels)} + print(" -analysing", len(type_to_int), "different entity types") + el_pipe = nlp_2.create_pipe(name='entity_linker', + config={"context_width": CONTEXT_WIDTH, + "pretrained_vectors": nlp_2.vocab.vectors.name, + "type_to_int": type_to_int}) + el_pipe.set_kb(kb_2) + nlp_2.add_pipe(el_pipe, last=True) + + other_pipes = [pipe for pipe in nlp_2.pipe_names if pipe != "entity_linker"] + with nlp_2.disable_pipes(*other_pipes): # only train Entity Linking + optimizer = nlp_2.begin_training() + optimizer.learn_rate = LEARN_RATE + optimizer.L2 = L2 + + # define the size (nr of entities) of training and dev set + train_limit = 5000 + dev_limit = 5000 + + train_data = training_set_creator.read_training(nlp=nlp_2, + training_dir=TRAINING_DIR, + dev=False, + limit=train_limit) + + print("Training on", len(train_data), "articles") + print() + + dev_data = training_set_creator.read_training(nlp=nlp_2, + training_dir=TRAINING_DIR, + dev=True, + limit=dev_limit) + + print("Dev testing on", len(dev_data), "articles") + print() + + if not train_data: + print("Did not find any training data") + else: + for itn in range(EPOCHS): + random.shuffle(train_data) + losses = {} + batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001)) + batchnr = 0 + + with nlp_2.disable_pipes(*other_pipes): + for batch in batches: + try: + docs, golds = zip(*batch) + nlp_2.update( + docs, + golds, + sgd=optimizer, + drop=DROPOUT, + losses=losses, + ) + batchnr += 1 + except Exception as e: + print("Error updating batch:", e) + + if batchnr > 0: + el_pipe.cfg["context_weight"] = 1 + el_pipe.cfg["prior_weight"] = 1 + dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe) + losses['entity_linker'] = losses['entity_linker'] / batchnr + print("Epoch, train loss", itn, round(losses['entity_linker'], 2), + " / dev acc avg", round(dev_acc_context, 3)) + + # STEP 7: measure the performance of our trained pipe on an independent dev set + if len(dev_data) and measure_performance: + print() + print("STEP 7: performance measurement of Entity Linking pipe", datetime.datetime.now()) + print() + + counts, acc_r, acc_r_label, acc_p, acc_p_label, acc_o, acc_o_label = _measure_baselines(dev_data, kb_2) + print("dev counts:", sorted(counts.items(), key=lambda x: x[0])) + print("dev acc oracle:", round(acc_o, 3), [(x, round(y, 3)) for x, y in acc_o_label.items()]) + print("dev acc random:", round(acc_r, 3), [(x, round(y, 3)) for x, y in acc_r_label.items()]) + print("dev acc prior:", round(acc_p, 3), [(x, round(y, 3)) for x, y in acc_p_label.items()]) + + # using only context + el_pipe.cfg["context_weight"] = 1 + el_pipe.cfg["prior_weight"] = 0 + dev_acc_context, dev_acc_context_dict = _measure_accuracy(dev_data, el_pipe) + print("dev acc context avg:", round(dev_acc_context, 3), + [(x, round(y, 3)) for x, y in dev_acc_context_dict.items()]) + + # measuring combined accuracy (prior + context) + el_pipe.cfg["context_weight"] = 1 + el_pipe.cfg["prior_weight"] = 1 + dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe, error_analysis=False) + print("dev acc combo avg:", round(dev_acc_combo, 3), + [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]) + + # STEP 8: apply the EL pipe on a toy example + if to_test_pipeline: + print() + print("STEP 8: applying Entity Linking to toy example", datetime.datetime.now()) + print() + run_el_toy_example(nlp=nlp_2) + + # STEP 9: write the NLP pipeline (including entity linker) to file + if to_write_nlp: + print() + print("STEP 9: testing NLP IO", datetime.datetime.now()) + print() + print("writing to", NLP_2_DIR) + nlp_2.to_disk(NLP_2_DIR) + print() + + # verify that the IO has gone correctly + if to_read_nlp: + print("reading from", NLP_2_DIR) + nlp_3 = spacy.load(NLP_2_DIR) + + print("running toy example with NLP 3") + run_el_toy_example(nlp=nlp_3) + + # testing performance with an NLP model from file + if test_from_file: + nlp_2 = spacy.load(NLP_1_DIR) + nlp_3 = spacy.load(NLP_2_DIR) + el_pipe = nlp_3.get_pipe("entity_linker") + + dev_limit = 5000 + dev_data = training_set_creator.read_training(nlp=nlp_2, + training_dir=TRAINING_DIR, + dev=True, + limit=dev_limit) + + print("Dev testing from file on", len(dev_data), "articles") + print() + + dev_acc_combo, dev_acc_combo_dict = _measure_accuracy(dev_data, el_pipe=el_pipe, error_analysis=False) + print("dev acc combo avg:", round(dev_acc_combo, 3), + [(x, round(y, 3)) for x, y in dev_acc_combo_dict.items()]) + + print() + print("STOP", datetime.datetime.now()) + + +def _measure_accuracy(data, el_pipe=None, error_analysis=False): + # If the docs in the data require further processing with an entity linker, set el_pipe + correct_by_label = dict() + incorrect_by_label = dict() + + docs = [d for d, g in data if len(d) > 0] + if el_pipe is not None: + docs = list(el_pipe.pipe(docs)) + golds = [g for d, g in data if len(d) > 0] + + for doc, gold in zip(docs, golds): + try: + correct_entries_per_article = dict() + for entity in gold.links: + start, end, gold_kb = entity + correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb + + for ent in doc.ents: + ent_label = ent.label_ + pred_entity = ent.kb_id_ + start = ent.start_char + end = ent.end_char + gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) + # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' + if gold_entity is not None: + if gold_entity == pred_entity: + correct = correct_by_label.get(ent_label, 0) + correct_by_label[ent_label] = correct + 1 + else: + incorrect = incorrect_by_label.get(ent_label, 0) + incorrect_by_label[ent_label] = incorrect + 1 + if error_analysis: + print(ent.text, "in", doc) + print("Predicted", pred_entity, "should have been", gold_entity) + print() + + except Exception as e: + print("Error assessing accuracy", e) + + acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label) + return acc, acc_by_label + + +def _measure_baselines(data, kb): + # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound + counts_by_label = dict() + + random_correct_by_label = dict() + random_incorrect_by_label = dict() + + oracle_correct_by_label = dict() + oracle_incorrect_by_label = dict() + + prior_correct_by_label = dict() + prior_incorrect_by_label = dict() + + docs = [d for d, g in data if len(d) > 0] + golds = [g for d, g in data if len(d) > 0] + + for doc, gold in zip(docs, golds): + try: + correct_entries_per_article = dict() + for entity in gold.links: + start, end, gold_kb = entity + correct_entries_per_article[str(start) + "-" + str(end)] = gold_kb + + for ent in doc.ents: + ent_label = ent.label_ + start = ent.start_char + end = ent.end_char + gold_entity = correct_entries_per_article.get(str(start) + "-" + str(end), None) + + # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' + if gold_entity is not None: + counts_by_label[ent_label] = counts_by_label.get(ent_label, 0) + 1 + candidates = kb.get_candidates(ent.text) + oracle_candidate = "" + best_candidate = "" + random_candidate = "" + if candidates: + scores = [] + + for c in candidates: + scores.append(c.prior_prob) + if c.entity_ == gold_entity: + oracle_candidate = c.entity_ + + best_index = scores.index(max(scores)) + best_candidate = candidates[best_index].entity_ + random_candidate = random.choice(candidates).entity_ + + if gold_entity == best_candidate: + prior_correct_by_label[ent_label] = prior_correct_by_label.get(ent_label, 0) + 1 + else: + prior_incorrect_by_label[ent_label] = prior_incorrect_by_label.get(ent_label, 0) + 1 + + if gold_entity == random_candidate: + random_correct_by_label[ent_label] = random_correct_by_label.get(ent_label, 0) + 1 + else: + random_incorrect_by_label[ent_label] = random_incorrect_by_label.get(ent_label, 0) + 1 + + if gold_entity == oracle_candidate: + oracle_correct_by_label[ent_label] = oracle_correct_by_label.get(ent_label, 0) + 1 + else: + oracle_incorrect_by_label[ent_label] = oracle_incorrect_by_label.get(ent_label, 0) + 1 + + except Exception as e: + print("Error assessing accuracy", e) + + acc_prior, acc_prior_by_label = calculate_acc(prior_correct_by_label, prior_incorrect_by_label) + acc_rand, acc_rand_by_label = calculate_acc(random_correct_by_label, random_incorrect_by_label) + acc_oracle, acc_oracle_by_label = calculate_acc(oracle_correct_by_label, oracle_incorrect_by_label) + + return counts_by_label, acc_rand, acc_rand_by_label, acc_prior, acc_prior_by_label, acc_oracle, acc_oracle_by_label + + +def calculate_acc(correct_by_label, incorrect_by_label): + acc_by_label = dict() + total_correct = 0 + total_incorrect = 0 + all_keys = set() + all_keys.update(correct_by_label.keys()) + all_keys.update(incorrect_by_label.keys()) + for label in sorted(all_keys): + correct = correct_by_label.get(label, 0) + incorrect = incorrect_by_label.get(label, 0) + total_correct += correct + total_incorrect += incorrect + if correct == incorrect == 0: + acc_by_label[label] = 0 + else: + acc_by_label[label] = correct / (correct + incorrect) + acc = 0 + if not (total_correct == total_incorrect == 0): + acc = total_correct / (total_correct + total_incorrect) + return acc, acc_by_label + + +def check_kb(kb): + for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"): + candidates = kb.get_candidates(mention) + + print("generating candidates for " + mention + " :") + for c in candidates: + print(" ", c.prior_prob, c.alias_, "-->", c.entity_ + " (freq=" + str(c.entity_freq) + ")") + print() + + +def run_el_toy_example(nlp): + text = "In The Hitchhiker's Guide to the Galaxy, written by Douglas Adams, " \ + "Douglas reminds us to always bring our towel, even in China or Brazil. " \ + "The main character in Doug's novel is the man Arthur Dent, " \ + "but Douglas doesn't write about George Washington or Homer Simpson." + doc = nlp(text) + print(text) + for ent in doc.ents: + print(" ent", ent.text, ent.label_, ent.kb_id_) + print() + + +if __name__ == "__main__": + run_pipeline() diff --git a/spacy/_ml.py b/spacy/_ml.py index 349b88df9..cca324b45 100644 --- a/spacy/_ml.py +++ b/spacy/_ml.py @@ -652,6 +652,38 @@ def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes=False, return model +def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg): + # TODO proper error + if "entity_width" not in cfg: + raise ValueError("entity_width not found") + if "context_width" not in cfg: + raise ValueError("context_width not found") + + conv_depth = cfg.get("conv_depth", 2) + cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3) + pretrained_vectors = cfg.get("pretrained_vectors") # self.nlp.vocab.vectors.name + context_width = cfg.get("context_width") + entity_width = cfg.get("entity_width") + + with Model.define_operators({">>": chain, "**": clone}): + model = Affine(entity_width, entity_width+context_width+1+ner_types)\ + >> Affine(1, entity_width, drop_factor=0.0)\ + >> logistic + + # context encoder + tok2vec = Tok2Vec(width=hidden_width, embed_size=embed_width, pretrained_vectors=pretrained_vectors, + cnn_maxout_pieces=cnn_maxout_pieces, subword_features=True, conv_depth=conv_depth, + bilstm_depth=0) >> flatten_add_lengths >> Pooling(mean_pool)\ + >> Residual(zero_init(Maxout(hidden_width, hidden_width))) \ + >> zero_init(Affine(context_width, hidden_width)) + + model.tok2vec = tok2vec + + model.tok2vec = tok2vec + model.tok2vec.nO = context_width + model.nO = 1 + return model + @layerize def flatten(seqs, drop=0.0): ops = Model.ops diff --git a/spacy/attrs.pxd b/spacy/attrs.pxd index 79a177ba9..c5ba8d765 100644 --- a/spacy/attrs.pxd +++ b/spacy/attrs.pxd @@ -82,6 +82,7 @@ cdef enum attr_id_t: DEP ENT_IOB ENT_TYPE + ENT_KB_ID HEAD SENT_START SPACY diff --git a/spacy/attrs.pyx b/spacy/attrs.pyx index ed1f39a3f..8eeea363f 100644 --- a/spacy/attrs.pyx +++ b/spacy/attrs.pyx @@ -84,6 +84,7 @@ IDS = { "DEP": DEP, "ENT_IOB": ENT_IOB, "ENT_TYPE": ENT_TYPE, + "ENT_KB_ID": ENT_KB_ID, "HEAD": HEAD, "SENT_START": SENT_START, "SPACY": SPACY, diff --git a/spacy/cli/pretrain.py b/spacy/cli/pretrain.py index 678f12be1..57c26fcbd 100644 --- a/spacy/cli/pretrain.py +++ b/spacy/cli/pretrain.py @@ -301,7 +301,7 @@ def get_vectors_loss(ops, docs, prediction, objective="L2"): elif objective == "cosine": loss, d_target = get_cossim_loss(prediction, target) else: - raise ValueError(Errors.E139.format(loss_func=objective)) + raise ValueError(Errors.E142.format(loss_func=objective)) return loss, d_target diff --git a/spacy/errors.py b/spacy/errors.py index 176003e79..8f2eab3a1 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -399,7 +399,10 @@ class Errors(object): E138 = ("Invalid JSONL format for raw text '{text}'. Make sure the input includes either the " "`text` or `tokens` key. For more info, see the docs:\n" "https://spacy.io/api/cli#pretrain-jsonl") - E139 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'") + E139 = ("Knowledge base for component '{name}' not initialized. Did you forget to call set_kb()?") + E140 = ("The list of entities, prior probabilities and entity vectors should be of equal length.") + E141 = ("Entity vectors should be of length {required} instead of the provided {found}.") + E142 = ("Unsupported loss_function '{loss_func}'. Use either 'L2' or 'cosine'") @add_codes diff --git a/spacy/gold.pxd b/spacy/gold.pxd index a1550b1ef..8943a155a 100644 --- a/spacy/gold.pxd +++ b/spacy/gold.pxd @@ -31,6 +31,7 @@ cdef class GoldParse: cdef public list ents cdef public dict brackets cdef public object cats + cdef public list links cdef readonly list cand_to_gold cdef readonly list gold_to_cand diff --git a/spacy/gold.pyx b/spacy/gold.pyx index 569979a5f..4fb22f3f0 100644 --- a/spacy/gold.pyx +++ b/spacy/gold.pyx @@ -427,7 +427,7 @@ cdef class GoldParse: def __init__(self, doc, annot_tuples=None, words=None, tags=None, heads=None, deps=None, entities=None, make_projective=False, - cats=None, **_): + cats=None, links=None, **_): """Create a GoldParse. doc (Doc): The document the annotations refer to. @@ -450,6 +450,8 @@ cdef class GoldParse: examples of a label to have the value 0.0. Labels not in the dictionary are treated as missing - the gradient for those labels will be zero. + links (iterable): A sequence of `(start_char, end_char, kb_id)` tuples, + representing the external ID of an entity in a knowledge base. RETURNS (GoldParse): The newly constructed object. """ if words is None: @@ -485,6 +487,7 @@ cdef class GoldParse: self.c.ner = self.mem.alloc(len(doc), sizeof(Transition)) self.cats = {} if cats is None else dict(cats) + self.links = links self.words = [None] * len(doc) self.tags = [None] * len(doc) self.heads = [None] * len(doc) diff --git a/spacy/kb.pxd b/spacy/kb.pxd index e34a0a9ba..40b22b275 100644 --- a/spacy/kb.pxd +++ b/spacy/kb.pxd @@ -1,53 +1,27 @@ """Knowledge-base for entity or concept linking.""" from cymem.cymem cimport Pool from preshed.maps cimport PreshMap + from libcpp.vector cimport vector from libc.stdint cimport int32_t, int64_t +from libc.stdio cimport FILE from spacy.vocab cimport Vocab from .typedefs cimport hash_t - -# Internal struct, for storage and disambiguation. This isn't what we return -# to the user as the answer to "here's your entity". It's the minimum number -# of bits we need to keep track of the answers. -cdef struct _EntryC: - - # The hash of this entry's unique ID and name in the kB - hash_t entity_hash - - # Allows retrieval of one or more vectors. - # Each element of vector_rows should be an index into a vectors table. - # Every entry should have the same number of vectors, so we can avoid storing - # the number of vectors in each knowledge-base struct - int32_t* vector_rows - - # Allows retrieval of a struct of non-vector features. We could make this a - # pointer, but we have 32 bits left over in the struct after prob, so we'd - # like this to only be 32 bits. We can also set this to -1, for the common - # case where there are no features. - int32_t feats_row - - # log probability of entity, based on corpus frequency - float prob - - -# Each alias struct stores a list of Entry pointers with their prior probabilities -# for this specific mention/alias. -cdef struct _AliasC: - - # All entry candidates for this alias - vector[int64_t] entry_indices - - # Prior probability P(entity|alias) - should sum up to (at most) 1. - vector[float] probs +from .structs cimport KBEntryC, AliasC +ctypedef vector[KBEntryC] entry_vec +ctypedef vector[AliasC] alias_vec +ctypedef vector[float] float_vec +ctypedef vector[float_vec] float_matrix # Object used by the Entity Linker that summarizes one entity-alias candidate combination. cdef class Candidate: - cdef readonly KnowledgeBase kb cdef hash_t entity_hash + cdef float entity_freq + cdef vector[float] entity_vector cdef hash_t alias_hash cdef float prior_prob @@ -55,9 +29,10 @@ cdef class Candidate: cdef class KnowledgeBase: cdef Pool mem cpdef readonly Vocab vocab + cdef int64_t entity_vector_length # This maps 64bit keys (hash of unique entity string) - # to 64bit values (position of the _EntryC struct in the _entries vector). + # to 64bit values (position of the _KBEntryC struct in the _entries vector). # The PreshMap is pretty space efficient, as it uses open addressing. So # the only overhead is the vacancy rate, which is approximately 30%. cdef PreshMap _entry_index @@ -66,7 +41,7 @@ cdef class KnowledgeBase: # over allocation. # In total we end up with (N*128*1.3)+(N*128*1.3) bits for N entries. # Storing 1m entries would take 41.6mb under this scheme. - cdef vector[_EntryC] _entries + cdef entry_vec _entries # This maps 64bit keys (hash of unique alias string) # to 64bit values (position of the _AliasC struct in the _aliases_table vector). @@ -76,7 +51,7 @@ cdef class KnowledgeBase: # should be P(entity | mention), which is pretty important to know. # We can pack both pieces of information into a 64-bit value, to keep things # efficient. - cdef vector[_AliasC] _aliases_table + cdef alias_vec _aliases_table # This is the part which might take more space: storing various # categorical features for the entries, and storing vectors for disambiguation @@ -87,7 +62,7 @@ cdef class KnowledgeBase: # model, that embeds different features of the entities into vectors. We'll # still want some per-entity features, like the Wikipedia text or entity # co-occurrence. Hopefully those vectors can be narrow, e.g. 64 dimensions. - cdef object _vectors_table + cdef float_matrix _vectors_table # It's very useful to track categorical features, at least for output, even # if they're not useful in the model itself. For instance, we should be @@ -96,53 +71,102 @@ cdef class KnowledgeBase: # optional data, we can let users configure a DB as the backend for this. cdef object _features_table + + cdef inline int64_t c_add_vector(self, vector[float] entity_vector) nogil: + """Add an entity vector to the vectors table.""" + cdef int64_t new_index = self._vectors_table.size() + self._vectors_table.push_back(entity_vector) + return new_index + + cdef inline int64_t c_add_entity(self, hash_t entity_hash, float prob, - int32_t* vector_rows, int feats_row): - """Add an entry to the knowledge base.""" - # This is what we'll map the hash key to. It's where the entry will sit + int32_t vector_index, int feats_row) nogil: + """Add an entry to the vector of entries. + After calling this method, make sure to update also the _entry_index using the return value""" + # This is what we'll map the entity hash key to. It's where the entry will sit # in the vector of entries, so we can get it later. cdef int64_t new_index = self._entries.size() - self._entries.push_back( - _EntryC( - entity_hash=entity_hash, - vector_rows=vector_rows, - feats_row=feats_row, - prob=prob - )) - self._entry_index[entity_hash] = new_index + + # Avoid struct initializer to enable nogil, cf https://github.com/cython/cython/issues/1642 + cdef KBEntryC entry + entry.entity_hash = entity_hash + entry.vector_index = vector_index + entry.feats_row = feats_row + entry.prob = prob + + self._entries.push_back(entry) return new_index - cdef inline int64_t c_add_aliases(self, hash_t alias_hash, vector[int64_t] entry_indices, vector[float] probs): - """Connect a mention to a list of potential entities with their prior probabilities .""" + cdef inline int64_t c_add_aliases(self, hash_t alias_hash, vector[int64_t] entry_indices, vector[float] probs) nogil: + """Connect a mention to a list of potential entities with their prior probabilities . + After calling this method, make sure to update also the _alias_index using the return value""" + # This is what we'll map the alias hash key to. It's where the alias will be defined + # in the vector of aliases. cdef int64_t new_index = self._aliases_table.size() - self._aliases_table.push_back( - _AliasC( - entry_indices=entry_indices, - probs=probs - )) - self._alias_index[alias_hash] = new_index + # Avoid struct initializer to enable nogil + cdef AliasC alias + alias.entry_indices = entry_indices + alias.probs = probs + + self._aliases_table.push_back(alias) return new_index - cdef inline _create_empty_vectors(self): + cdef inline void _create_empty_vectors(self, hash_t dummy_hash) nogil: """ - Making sure the first element of each vector is a dummy, + Initializing the vectors and making sure the first element of each vector is a dummy, because the PreshMap maps pointing to indices in these vectors can not contain 0 as value cf. https://github.com/explosion/preshed/issues/17 """ cdef int32_t dummy_value = 0 - self.vocab.strings.add("") - self._entries.push_back( - _EntryC( - entity_hash=self.vocab.strings[""], - vector_rows=&dummy_value, - feats_row=dummy_value, - prob=dummy_value - )) - self._aliases_table.push_back( - _AliasC( - entry_indices=[dummy_value], - probs=[dummy_value] - )) + + # Avoid struct initializer to enable nogil + cdef KBEntryC entry + entry.entity_hash = dummy_hash + entry.vector_index = dummy_value + entry.feats_row = dummy_value + entry.prob = dummy_value + + # Avoid struct initializer to enable nogil + cdef vector[int64_t] dummy_entry_indices + dummy_entry_indices.push_back(0) + cdef vector[float] dummy_probs + dummy_probs.push_back(0) + + cdef AliasC alias + alias.entry_indices = dummy_entry_indices + alias.probs = dummy_probs + + self._entries.push_back(entry) + self._aliases_table.push_back(alias) + + cpdef load_bulk(self, loc) + cpdef set_entities(self, entity_list, prob_list, vector_list) +cdef class Writer: + cdef FILE* _fp + + cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1 + cdef int write_vector_element(self, float element) except -1 + cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1 + + cdef int write_alias_length(self, int64_t alias_length) except -1 + cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1 + cdef int write_alias(self, int64_t entry_index, float prob) except -1 + + cdef int _write(self, void* value, size_t size) except -1 + +cdef class Reader: + cdef FILE* _fp + + cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1 + cdef int read_vector_element(self, float* element) except -1 + cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1 + + cdef int read_alias_length(self, int64_t* alias_length) except -1 + cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1 + cdef int read_alias(self, int64_t* entry_index, float* prob) except -1 + + cdef int _read(self, void* value, size_t size) except -1 + diff --git a/spacy/kb.pyx b/spacy/kb.pyx index 3a0a8b918..7c2daa659 100644 --- a/spacy/kb.pyx +++ b/spacy/kb.pyx @@ -1,13 +1,30 @@ +# cython: infer_types=True # cython: profile=True # coding: utf8 from spacy.errors import Errors, Warnings, user_warning +from pathlib import Path +from cymem.cymem cimport Pool +from preshed.maps cimport PreshMap + +from cpython.exc cimport PyErr_SetFromErrno + +from libc.stdio cimport fopen, fclose, fread, fwrite, feof, fseek +from libc.stdint cimport int32_t, int64_t + +from .typedefs cimport hash_t + +from os import path +from libcpp.vector cimport vector + cdef class Candidate: - def __init__(self, KnowledgeBase kb, entity_hash, alias_hash, prior_prob): + def __init__(self, KnowledgeBase kb, entity_hash, entity_freq, entity_vector, alias_hash, prior_prob): self.kb = kb self.entity_hash = entity_hash + self.entity_freq = entity_freq + self.entity_vector = entity_vector self.alias_hash = alias_hash self.prior_prob = prior_prob @@ -19,7 +36,7 @@ cdef class Candidate: @property def entity_(self): """RETURNS (unicode): ID/name of this entity in the KB""" - return self.kb.vocab.strings[self.entity] + return self.kb.vocab.strings[self.entity_hash] @property def alias(self): @@ -29,7 +46,15 @@ cdef class Candidate: @property def alias_(self): """RETURNS (unicode): ID of the original alias""" - return self.kb.vocab.strings[self.alias] + return self.kb.vocab.strings[self.alias_hash] + + @property + def entity_freq(self): + return self.entity_freq + + @property + def entity_vector(self): + return self.entity_vector @property def prior_prob(self): @@ -38,26 +63,41 @@ cdef class Candidate: cdef class KnowledgeBase: - def __init__(self, Vocab vocab): + def __init__(self, Vocab vocab, entity_vector_length): self.vocab = vocab + self.mem = Pool() + self.entity_vector_length = entity_vector_length + self._entry_index = PreshMap() self._alias_index = PreshMap() - self.mem = Pool() - self._create_empty_vectors() + + self.vocab.strings.add("") + self._create_empty_vectors(dummy_hash=self.vocab.strings[""]) + + @property + def entity_vector_length(self): + """RETURNS (uint64): length of the entity vectors""" + return self.entity_vector_length def __len__(self): return self.get_size_entities() def get_size_entities(self): - return self._entries.size() - 1 # not counting dummy element on index 0 + return len(self._entry_index) + + def get_entity_strings(self): + return [self.vocab.strings[x] for x in self._entry_index] def get_size_aliases(self): - return self._aliases_table.size() - 1 # not counting dummy element on index 0 + return len(self._alias_index) - def add_entity(self, unicode entity, float prob=0.5, vectors=None, features=None): + def get_alias_strings(self): + return [self.vocab.strings[x] for x in self._alias_index] + + def add_entity(self, unicode entity, float prob, vector[float] entity_vector): """ - Add an entity to the KB. - Return the hash of the entity ID at the end + Add an entity to the KB, optionally specifying its log probability based on corpus frequency + Return the hash of the entity ID/name at the end. """ cdef hash_t entity_hash = self.vocab.strings.add(entity) @@ -66,40 +106,72 @@ cdef class KnowledgeBase: user_warning(Warnings.W018.format(entity=entity)) return - cdef int32_t dummy_value = 342 - self.c_add_entity(entity_hash=entity_hash, prob=prob, - vector_rows=&dummy_value, feats_row=dummy_value) - # TODO self._vectors_table.get_pointer(vectors), - # self._features_table.get(features)) + # Raise an error if the provided entity vector is not of the correct length + if len(entity_vector) != self.entity_vector_length: + raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length)) + + vector_index = self.c_add_vector(entity_vector=entity_vector) + + new_index = self.c_add_entity(entity_hash=entity_hash, + prob=prob, + vector_index=vector_index, + feats_row=-1) # Features table currently not implemented + self._entry_index[entity_hash] = new_index return entity_hash + cpdef set_entities(self, entity_list, prob_list, vector_list): + if len(entity_list) != len(prob_list) or len(entity_list) != len(vector_list): + raise ValueError(Errors.E140) + + nr_entities = len(entity_list) + self._entry_index = PreshMap(nr_entities+1) + self._entries = entry_vec(nr_entities+1) + + i = 0 + cdef KBEntryC entry + while i < nr_entities: + entity_vector = vector_list[i] + if len(entity_vector) != self.entity_vector_length: + raise ValueError(Errors.E141.format(found=len(entity_vector), required=self.entity_vector_length)) + + entity_hash = self.vocab.strings.add(entity_list[i]) + entry.entity_hash = entity_hash + entry.prob = prob_list[i] + + vector_index = self.c_add_vector(entity_vector=vector_list[i]) + entry.vector_index = vector_index + + entry.feats_row = -1 # Features table currently not implemented + + self._entries[i+1] = entry + self._entry_index[entity_hash] = i+1 + + i += 1 + def add_alias(self, unicode alias, entities, probabilities): """ For a given alias, add its potential entities and prior probabilies to the KB. Return the alias_hash at the end """ - # Throw an error if the length of entities and probabilities are not the same if not len(entities) == len(probabilities): raise ValueError(Errors.E132.format(alias=alias, entities_length=len(entities), probabilities_length=len(probabilities))) - # Throw an error if the probabilities sum up to more than 1 + # Throw an error if the probabilities sum up to more than 1 (allow for some rounding errors) prob_sum = sum(probabilities) - if prob_sum > 1: + if prob_sum > 1.00001: raise ValueError(Errors.E133.format(alias=alias, sum=prob_sum)) cdef hash_t alias_hash = self.vocab.strings.add(alias) - # Return if this alias was added before + # Check whether this alias was added before if alias_hash in self._alias_index: user_warning(Warnings.W017.format(alias=alias)) return - cdef hash_t entity_hash - cdef vector[int64_t] entry_indices cdef vector[float] probs @@ -112,20 +184,295 @@ cdef class KnowledgeBase: entry_indices.push_back(int(entry_index)) probs.push_back(float(prob)) - self.c_add_aliases(alias_hash=alias_hash, entry_indices=entry_indices, probs=probs) + new_index = self.c_add_aliases(alias_hash=alias_hash, entry_indices=entry_indices, probs=probs) + self._alias_index[alias_hash] = new_index return alias_hash - def get_candidates(self, unicode alias): - """ TODO: where to put this functionality ?""" cdef hash_t alias_hash = self.vocab.strings[alias] alias_index = self._alias_index.get(alias_hash) alias_entry = self._aliases_table[alias_index] return [Candidate(kb=self, entity_hash=self._entries[entry_index].entity_hash, + entity_freq=self._entries[entry_index].prob, + entity_vector=self._vectors_table[self._entries[entry_index].vector_index], alias_hash=alias_hash, prior_prob=prob) for (entry_index, prob) in zip(alias_entry.entry_indices, alias_entry.probs) if entry_index != 0] + + + def dump(self, loc): + cdef Writer writer = Writer(loc) + writer.write_header(self.get_size_entities(), self.entity_vector_length) + + # dumping the entity vectors in their original order + i = 0 + for entity_vector in self._vectors_table: + for element in entity_vector: + writer.write_vector_element(element) + i = i+1 + + # dumping the entry records in the order in which they are in the _entries vector. + # index 0 is a dummy object not stored in the _entry_index and can be ignored. + i = 1 + for entry_hash, entry_index in sorted(self._entry_index.items(), key=lambda x: x[1]): + entry = self._entries[entry_index] + assert entry.entity_hash == entry_hash + assert entry_index == i + writer.write_entry(entry.entity_hash, entry.prob, entry.vector_index) + i = i+1 + + writer.write_alias_length(self.get_size_aliases()) + + # dumping the aliases in the order in which they are in the _alias_index vector. + # index 0 is a dummy object not stored in the _aliases_table and can be ignored. + i = 1 + for alias_hash, alias_index in sorted(self._alias_index.items(), key=lambda x: x[1]): + alias = self._aliases_table[alias_index] + assert alias_index == i + + candidate_length = len(alias.entry_indices) + writer.write_alias_header(alias_hash, candidate_length) + + for j in range(0, candidate_length): + writer.write_alias(alias.entry_indices[j], alias.probs[j]) + + i = i+1 + + writer.close() + + cpdef load_bulk(self, loc): + cdef hash_t entity_hash + cdef hash_t alias_hash + cdef int64_t entry_index + cdef float prob + cdef int32_t vector_index + cdef KBEntryC entry + cdef AliasC alias + cdef float vector_element + + cdef Reader reader = Reader(loc) + + # STEP 0: load header and initialize KB + cdef int64_t nr_entities + cdef int64_t entity_vector_length + reader.read_header(&nr_entities, &entity_vector_length) + + self.entity_vector_length = entity_vector_length + self._entry_index = PreshMap(nr_entities+1) + self._entries = entry_vec(nr_entities+1) + self._vectors_table = float_matrix(nr_entities+1) + + # STEP 1: load entity vectors + cdef int i = 0 + cdef int j = 0 + while i < nr_entities: + entity_vector = float_vec(entity_vector_length) + j = 0 + while j < entity_vector_length: + reader.read_vector_element(&vector_element) + entity_vector[j] = vector_element + j = j+1 + self._vectors_table[i] = entity_vector + i = i+1 + + # STEP 2: load entities + # we assume that the entity data was written in sequence + # index 0 is a dummy object not stored in the _entry_index and can be ignored. + i = 1 + while i <= nr_entities: + reader.read_entry(&entity_hash, &prob, &vector_index) + + entry.entity_hash = entity_hash + entry.prob = prob + entry.vector_index = vector_index + entry.feats_row = -1 # Features table currently not implemented + + self._entries[i] = entry + self._entry_index[entity_hash] = i + + i += 1 + + # check that all entities were read in properly + assert nr_entities == self.get_size_entities() + + # STEP 3: load aliases + + cdef int64_t nr_aliases + reader.read_alias_length(&nr_aliases) + self._alias_index = PreshMap(nr_aliases+1) + self._aliases_table = alias_vec(nr_aliases+1) + + cdef int64_t nr_candidates + cdef vector[int64_t] entry_indices + cdef vector[float] probs + + i = 1 + # we assume the alias data was written in sequence + # index 0 is a dummy object not stored in the _entry_index and can be ignored. + while i <= nr_aliases: + reader.read_alias_header(&alias_hash, &nr_candidates) + entry_indices = vector[int64_t](nr_candidates) + probs = vector[float](nr_candidates) + + for j in range(0, nr_candidates): + reader.read_alias(&entry_index, &prob) + entry_indices[j] = entry_index + probs[j] = prob + + alias.entry_indices = entry_indices + alias.probs = probs + + self._aliases_table[i] = alias + self._alias_index[alias_hash] = i + + i += 1 + + # check that all aliases were read in properly + assert nr_aliases == self.get_size_aliases() + + +cdef class Writer: + def __init__(self, object loc): + if path.exists(loc): + assert not path.isdir(loc), "%s is directory." % loc + if isinstance(loc, Path): + loc = bytes(loc) + cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc + self._fp = fopen(bytes_loc, 'wb') + assert self._fp != NULL + fseek(self._fp, 0, 0) + + def close(self): + cdef size_t status = fclose(self._fp) + assert status == 0 + + cdef int write_header(self, int64_t nr_entries, int64_t entity_vector_length) except -1: + self._write(&nr_entries, sizeof(nr_entries)) + self._write(&entity_vector_length, sizeof(entity_vector_length)) + + cdef int write_vector_element(self, float element) except -1: + self._write(&element, sizeof(element)) + + cdef int write_entry(self, hash_t entry_hash, float entry_prob, int32_t vector_index) except -1: + self._write(&entry_hash, sizeof(entry_hash)) + self._write(&entry_prob, sizeof(entry_prob)) + self._write(&vector_index, sizeof(vector_index)) + # Features table currently not implemented and not written to file + + cdef int write_alias_length(self, int64_t alias_length) except -1: + self._write(&alias_length, sizeof(alias_length)) + + cdef int write_alias_header(self, hash_t alias_hash, int64_t candidate_length) except -1: + self._write(&alias_hash, sizeof(alias_hash)) + self._write(&candidate_length, sizeof(candidate_length)) + + cdef int write_alias(self, int64_t entry_index, float prob) except -1: + self._write(&entry_index, sizeof(entry_index)) + self._write(&prob, sizeof(prob)) + + cdef int _write(self, void* value, size_t size) except -1: + status = fwrite(value, size, 1, self._fp) + assert status == 1, status + + +cdef class Reader: + def __init__(self, object loc): + assert path.exists(loc) + assert not path.isdir(loc) + if isinstance(loc, Path): + loc = bytes(loc) + cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc + self._fp = fopen(bytes_loc, 'rb') + if not self._fp: + PyErr_SetFromErrno(IOError) + status = fseek(self._fp, 0, 0) # this can be 0 if there is no header + + def __dealloc__(self): + fclose(self._fp) + + cdef int read_header(self, int64_t* nr_entries, int64_t* entity_vector_length) except -1: + status = self._read(nr_entries, sizeof(int64_t)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading header from input file") + + status = self._read(entity_vector_length, sizeof(int64_t)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading header from input file") + + cdef int read_vector_element(self, float* element) except -1: + status = self._read(element, sizeof(float)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading entity vector from input file") + + cdef int read_entry(self, hash_t* entity_hash, float* prob, int32_t* vector_index) except -1: + status = self._read(entity_hash, sizeof(hash_t)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading entity hash from input file") + + status = self._read(prob, sizeof(float)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading entity prob from input file") + + status = self._read(vector_index, sizeof(int32_t)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading entity vector from input file") + + if feof(self._fp): + return 0 + else: + return 1 + + cdef int read_alias_length(self, int64_t* alias_length) except -1: + status = self._read(alias_length, sizeof(int64_t)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading alias length from input file") + + cdef int read_alias_header(self, hash_t* alias_hash, int64_t* candidate_length) except -1: + status = self._read(alias_hash, sizeof(hash_t)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading alias hash from input file") + + status = self._read(candidate_length, sizeof(int64_t)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading candidate length from input file") + + cdef int read_alias(self, int64_t* entry_index, float* prob) except -1: + status = self._read(entry_index, sizeof(int64_t)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading entry index for alias from input file") + + status = self._read(prob, sizeof(float)) + if status < 1: + if feof(self._fp): + return 0 # end of file + raise IOError("error reading prob for entity/alias from input file") + + cdef int _read(self, void* value, size_t size) except -1: + status = fread(value, size, 1, self._fp) + return status + + diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 1f4dd4253..d99a1f73e 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -3,16 +3,18 @@ # coding: utf8 from __future__ import unicode_literals -cimport numpy as np - import numpy import srsly +import random from collections import OrderedDict from thinc.api import chain from thinc.v2v import Affine, Maxout, Softmax from thinc.misc import LayerNorm -from thinc.neural.util import to_categorical, copy_array +from thinc.neural.util import to_categorical +from thinc.neural.util import get_array_module +from spacy.kb import KnowledgeBase +from ..cli.pretrain import get_cossim_loss from .functions import merge_subtokens from ..tokens.doc cimport Doc from ..syntax.nn_parser cimport Parser @@ -24,9 +26,9 @@ from ..vocab cimport Vocab from ..syntax import nonproj from ..attrs import POS, ID from ..parts_of_speech import X -from .._ml import Tok2Vec, build_tagger_model +from .._ml import Tok2Vec, build_tagger_model, cosine from .._ml import build_text_classifier, build_simple_cnn_text_classifier -from .._ml import build_bow_text_classifier +from .._ml import build_bow_text_classifier, build_nel_encoder from .._ml import link_vectors_to_models, zero_init, flatten from .._ml import masked_language_model, create_default_optimizer from ..errors import Errors, TempErrors @@ -229,7 +231,7 @@ class Tensorizer(Pipe): vocab (Vocab): A `Vocab` instance. The model must share the same `Vocab` instance with the `Doc` objects it will process. - model (Model): A `Model` instance or `True` allocate one later. + model (Model): A `Model` instance or `True` to allocate one later. **cfg: Config parameters. EXAMPLE: @@ -294,7 +296,7 @@ class Tensorizer(Pipe): docs (iterable): A batch of `Doc` objects. golds (iterable): A batch of `GoldParse` objects. - drop (float): The droput rate. + drop (float): The dropout rate. sgd (callable): An optimizer. RETURNS (dict): Results from the update. """ @@ -386,7 +388,7 @@ class Tagger(Pipe): def predict(self, docs): self.require_model() if not any(len(doc) for doc in docs): - # Handle case where there are no tokens in any docs. + # Handle cases where there are no tokens in any docs. n_labels = len(self.labels) guesses = [self.model.ops.allocate((0, n_labels)) for doc in docs] tokvecs = self.model.ops.allocate((0, self.model.tok2vec.nO)) @@ -1063,52 +1065,252 @@ cdef class EntityRecognizer(Parser): class EntityLinker(Pipe): + """Pipeline component for named entity linking. + + DOCS: TODO + """ name = 'entity_linker' @classmethod - def Model(cls, nr_class=1, **cfg): - # TODO: non-dummy EL implementation - return None + def Model(cls, **cfg): + embed_width = cfg.get("embed_width", 300) + hidden_width = cfg.get("hidden_width", 128) + type_to_int = cfg.get("type_to_int", dict()) - def __init__(self, model=True, **cfg): - self.model = False + model = build_nel_encoder(embed_width=embed_width, hidden_width=hidden_width, ner_types=len(type_to_int), **cfg) + return model + + def __init__(self, vocab, **cfg): + self.vocab = vocab + self.model = True + self.kb = None self.cfg = dict(cfg) - self.kb = self.cfg["kb"] + self.sgd_context = None + + def set_kb(self, kb): + self.kb = kb + + def require_model(self): + # Raise an error if the component's model is not initialized. + if getattr(self, "model", None) in (None, True, False): + raise ValueError(Errors.E109.format(name=self.name)) + + def require_kb(self): + # Raise an error if the knowledge base is not initialized. + if getattr(self, "kb", None) in (None, True, False): + raise ValueError(Errors.E139.format(name=self.name)) + + def begin_training(self, get_gold_tuples=lambda: [], pipeline=None, sgd=None, **kwargs): + self.require_kb() + self.cfg["entity_width"] = self.kb.entity_vector_length + + if self.model is True: + self.model = self.Model(**self.cfg) + self.sgd_context = self.create_optimizer() + + if sgd is None: + sgd = self.create_optimizer() + + return sgd + + def update(self, docs, golds, state=None, drop=0.0, sgd=None, losses=None): + self.require_model() + self.require_kb() + + if losses is not None: + losses.setdefault(self.name, 0.0) + + if not docs or not golds: + return 0 + + if len(docs) != len(golds): + raise ValueError(Errors.E077.format(value="EL training", n_docs=len(docs), + n_golds=len(golds))) + + if isinstance(docs, Doc): + docs = [docs] + golds = [golds] + + context_docs = [] + entity_encodings = [] + cats = [] + priors = [] + type_vectors = [] + + type_to_int = self.cfg.get("type_to_int", dict()) + + for doc, gold in zip(docs, golds): + ents_by_offset = dict() + for ent in doc.ents: + ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] = ent + for entity in gold.links: + start, end, gold_kb = entity + mention = doc.text[start:end] + + gold_ent = ents_by_offset[str(ent.start_char) + "_" + str(ent.end_char)] + assert gold_ent is not None + type_vector = [0 for i in range(len(type_to_int))] + if len(type_to_int) > 0: + type_vector[type_to_int[gold_ent.label_]] = 1 + + candidates = self.kb.get_candidates(mention) + random.shuffle(candidates) + nr_neg = 0 + for c in candidates: + kb_id = c.entity_ + entity_encoding = c.entity_vector + entity_encodings.append(entity_encoding) + context_docs.append(doc) + type_vectors.append(type_vector) + + if self.cfg.get("prior_weight", 1) > 0: + priors.append([c.prior_prob]) + else: + priors.append([0]) + + if kb_id == gold_kb: + cats.append([1]) + else: + nr_neg += 1 + cats.append([0]) + + if len(entity_encodings) > 0: + assert len(priors) == len(entity_encodings) == len(context_docs) == len(cats) == len(type_vectors) + + context_encodings, bp_context = self.model.tok2vec.begin_update(context_docs, drop=drop) + entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32") + + mention_encodings = [list(context_encodings[i]) + list(entity_encodings[i]) + priors[i] + type_vectors[i] + for i in range(len(entity_encodings))] + pred, bp_mention = self.model.begin_update(self.model.ops.asarray(mention_encodings, dtype="float32"), drop=drop) + cats = self.model.ops.asarray(cats, dtype="float32") + + loss, d_scores = self.get_loss(prediction=pred, golds=cats, docs=None) + mention_gradient = bp_mention(d_scores, sgd=sgd) + + context_gradients = [list(x[0:self.cfg.get("context_width")]) for x in mention_gradient] + bp_context(self.model.ops.asarray(context_gradients, dtype="float32"), sgd=self.sgd_context) + + if losses is not None: + losses[self.name] += loss + return loss + return 0 + + def get_loss(self, docs, golds, prediction): + d_scores = (prediction - golds) + loss = (d_scores ** 2).sum() + loss = loss / len(golds) + return loss, d_scores + + def get_loss_old(self, docs, golds, scores): + # this loss function assumes we're only using positive examples + loss, gradients = get_cossim_loss(yh=scores, y=golds) + loss = loss / len(golds) + return loss, gradients def __call__(self, doc): - self.set_annotations([doc], scores=None, tensors=None) + entities, kb_ids = self.predict([doc]) + self.set_annotations([doc], entities, kb_ids) return doc def pipe(self, stream, batch_size=128, n_threads=-1): - """Apply the pipe to a stream of documents. - Both __call__ and pipe should delegate to the `predict()` - and `set_annotations()` methods. - """ for docs in util.minibatch(stream, size=batch_size): docs = list(docs) - self.set_annotations(docs, scores=None, tensors=None) + entities, kb_ids = self.predict(docs) + self.set_annotations(docs, entities, kb_ids) yield from docs - def set_annotations(self, docs, scores, tensors=None): - """ - Currently implemented as taking the KB entry with highest prior probability for each named entity - TODO: actually use context etc - """ - for i, doc in enumerate(docs): - for ent in doc.ents: - candidates = self.kb.get_candidates(ent.text) - if candidates: - best_candidate = max(candidates, key=lambda c: c.prior_prob) - for token in ent: - token.ent_kb_id_ = best_candidate.entity_ + def predict(self, docs): + self.require_model() + self.require_kb() - def get_loss(self, docs, golds, scores): - # TODO - pass + final_entities = [] + final_kb_ids = [] + + if not docs: + return final_entities, final_kb_ids + + if isinstance(docs, Doc): + docs = [docs] + + context_encodings = self.model.tok2vec(docs) + xp = get_array_module(context_encodings) + + type_to_int = self.cfg.get("type_to_int", dict()) + + for i, doc in enumerate(docs): + if len(doc) > 0: + context_encoding = context_encodings[i] + for ent in doc.ents: + type_vector = [0 for i in range(len(type_to_int))] + if len(type_to_int) > 0: + type_vector[type_to_int[ent.label_]] = 1 + + candidates = self.kb.get_candidates(ent.text) + if candidates: + random.shuffle(candidates) + + # this will set the prior probabilities to 0 (just like in training) if their weight is 0 + prior_probs = xp.asarray([[c.prior_prob] for c in candidates]) + prior_probs *= self.cfg.get("prior_weight", 1) + scores = prior_probs + + if self.cfg.get("context_weight", 1) > 0: + entity_encodings = xp.asarray([c.entity_vector for c in candidates]) + assert len(entity_encodings) == len(prior_probs) + mention_encodings = [list(context_encoding) + list(entity_encodings[i]) + + list(prior_probs[i]) + type_vector + for i in range(len(entity_encodings))] + scores = self.model(self.model.ops.asarray(mention_encodings, dtype="float32")) + + # TODO: thresholding + best_index = scores.argmax() + best_candidate = candidates[best_index] + final_entities.append(ent) + final_kb_ids.append(best_candidate.entity_) + + return final_entities, final_kb_ids + + def set_annotations(self, docs, entities, kb_ids=None): + for entity, kb_id in zip(entities, kb_ids): + for token in entity: + token.ent_kb_id_ = kb_id + + def to_disk(self, path, exclude=tuple(), **kwargs): + serialize = OrderedDict() + serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) + serialize["vocab"] = lambda p: self.vocab.to_disk(p) + serialize["kb"] = lambda p: self.kb.dump(p) + if self.model not in (None, True, False): + serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes()) + exclude = util.get_serialization_exclude(serialize, exclude, kwargs) + util.to_disk(path, serialize, exclude) + + def from_disk(self, path, exclude=tuple(), **kwargs): + def load_model(p): + if self.model is True: + self.model = self.Model(**self.cfg) + self.model.from_bytes(p.open("rb").read()) + + def load_kb(p): + kb = KnowledgeBase(vocab=self.vocab, entity_vector_length=self.cfg["entity_width"]) + kb.load_bulk(p) + self.set_kb(kb) + + deserialize = OrderedDict() + deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p)) + deserialize["vocab"] = lambda p: self.vocab.from_disk(p) + deserialize["kb"] = load_kb + deserialize["model"] = load_model + exclude = util.get_serialization_exclude(deserialize, exclude, kwargs) + util.from_disk(path, deserialize, exclude) + return self + + def rehearse(self, docs, sgd=None, losses=None, **config): + raise NotImplementedError def add_label(self, label): - # TODO - pass + raise NotImplementedError class Sentencizer(object): diff --git a/spacy/structs.pxd b/spacy/structs.pxd index 154202c0d..e80b1b4d6 100644 --- a/spacy/structs.pxd +++ b/spacy/structs.pxd @@ -3,6 +3,10 @@ from libc.stdint cimport uint8_t, uint32_t, int32_t, uint64_t from .typedefs cimport flags_t, attr_t, hash_t from .parts_of_speech cimport univ_pos_t +from libcpp.vector cimport vector +from libc.stdint cimport int32_t, int64_t + + cdef struct LexemeC: flags_t flags @@ -72,3 +76,32 @@ cdef struct TokenC: attr_t ent_type # TODO: Is there a better way to do this? Multiple sources of truth.. attr_t ent_kb_id hash_t ent_id + + +# Internal struct, for storage and disambiguation of entities. +cdef struct KBEntryC: + + # The hash of this entry's unique ID/name in the kB + hash_t entity_hash + + # Allows retrieval of the entity vector, as an index into a vectors table of the KB. + # Can be expanded later to refer to multiple rows (compositional model to reduce storage footprint). + int32_t vector_index + + # Allows retrieval of a struct of non-vector features. + # This is currently not implemented and set to -1 for the common case where there are no features. + int32_t feats_row + + # log probability of entity, based on corpus frequency + float prob + + +# Each alias struct stores a list of Entry pointers with their prior probabilities +# for this specific mention/alias. +cdef struct AliasC: + + # All entry candidates for this alias + vector[int64_t] entry_indices + + # Prior probability P(entity|alias) - should sum up to (at most) 1. + vector[float] probs diff --git a/spacy/symbols.pxd b/spacy/symbols.pxd index 051b92edb..4501861a2 100644 --- a/spacy/symbols.pxd +++ b/spacy/symbols.pxd @@ -81,6 +81,7 @@ cdef enum symbol_t: DEP ENT_IOB ENT_TYPE + ENT_KB_ID HEAD SENT_START SPACY diff --git a/spacy/symbols.pyx b/spacy/symbols.pyx index 949621820..b65ae9628 100644 --- a/spacy/symbols.pyx +++ b/spacy/symbols.pyx @@ -86,6 +86,7 @@ IDS = { "DEP": DEP, "ENT_IOB": ENT_IOB, "ENT_TYPE": ENT_TYPE, + "ENT_KB_ID": ENT_KB_ID, "HEAD": HEAD, "SENT_START": SENT_START, "SPACY": SPACY, diff --git a/spacy/tests/pipeline/test_el.py b/spacy/tests/pipeline/test_el.py deleted file mode 100644 index 61baece68..000000000 --- a/spacy/tests/pipeline/test_el.py +++ /dev/null @@ -1,91 +0,0 @@ -# coding: utf-8 -from __future__ import unicode_literals - -import pytest - -from spacy.kb import KnowledgeBase -from spacy.lang.en import English - - -@pytest.fixture -def nlp(): - return English() - - -def test_kb_valid_entities(nlp): - """Test the valid construction of a KB with 3 entities and two aliases""" - mykb = KnowledgeBase(nlp.vocab) - - # adding entities - mykb.add_entity(entity=u'Q1', prob=0.9) - mykb.add_entity(entity=u'Q2') - mykb.add_entity(entity=u'Q3', prob=0.5) - - # adding aliases - mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.2]) - mykb.add_alias(alias=u'adam', entities=[u'Q2'], probabilities=[0.9]) - - # test the size of the corresponding KB - assert(mykb.get_size_entities() == 3) - assert(mykb.get_size_aliases() == 2) - - -def test_kb_invalid_entities(nlp): - """Test the invalid construction of a KB with an alias linked to a non-existing entity""" - mykb = KnowledgeBase(nlp.vocab) - - # adding entities - mykb.add_entity(entity=u'Q1', prob=0.9) - mykb.add_entity(entity=u'Q2', prob=0.2) - mykb.add_entity(entity=u'Q3', prob=0.5) - - # adding aliases - should fail because one of the given IDs is not valid - with pytest.raises(ValueError): - mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q342'], probabilities=[0.8, 0.2]) - - -def test_kb_invalid_probabilities(nlp): - """Test the invalid construction of a KB with wrong prior probabilities""" - mykb = KnowledgeBase(nlp.vocab) - - # adding entities - mykb.add_entity(entity=u'Q1', prob=0.9) - mykb.add_entity(entity=u'Q2', prob=0.2) - mykb.add_entity(entity=u'Q3', prob=0.5) - - # adding aliases - should fail because the sum of the probabilities exceeds 1 - with pytest.raises(ValueError): - mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.4]) - - -def test_kb_invalid_combination(nlp): - """Test the invalid construction of a KB with non-matching entity and probability lists""" - mykb = KnowledgeBase(nlp.vocab) - - # adding entities - mykb.add_entity(entity=u'Q1', prob=0.9) - mykb.add_entity(entity=u'Q2', prob=0.2) - mykb.add_entity(entity=u'Q3', prob=0.5) - - # adding aliases - should fail because the entities and probabilities vectors are not of equal length - with pytest.raises(ValueError): - mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.3, 0.4, 0.1]) - - -def test_candidate_generation(nlp): - """Test correct candidate generation""" - mykb = KnowledgeBase(nlp.vocab) - - # adding entities - mykb.add_entity(entity=u'Q1', prob=0.9) - mykb.add_entity(entity=u'Q2', prob=0.2) - mykb.add_entity(entity=u'Q3', prob=0.5) - - # adding aliases - mykb.add_alias(alias=u'douglas', entities=[u'Q2', u'Q3'], probabilities=[0.8, 0.2]) - mykb.add_alias(alias=u'adam', entities=[u'Q2'], probabilities=[0.9]) - - # test the size of the relevant candidates - assert(len(mykb.get_candidates(u'douglas')) == 2) - assert(len(mykb.get_candidates(u'adam')) == 1) - assert(len(mykb.get_candidates(u'shrubbery')) == 0) diff --git a/spacy/tests/pipeline/test_entity_linker.py b/spacy/tests/pipeline/test_entity_linker.py new file mode 100644 index 000000000..cafc380ba --- /dev/null +++ b/spacy/tests/pipeline/test_entity_linker.py @@ -0,0 +1,145 @@ +# coding: utf-8 +from __future__ import unicode_literals + +import pytest + +from spacy.kb import KnowledgeBase +from spacy.lang.en import English +from spacy.pipeline import EntityRuler + + +@pytest.fixture +def nlp(): + return English() + + +def test_kb_valid_entities(nlp): + """Test the valid construction of a KB with 3 entities and two aliases""" + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + + # adding entities + mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) + mykb.add_entity(entity='Q2', prob=0.5, entity_vector=[2]) + mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) + + # adding aliases + mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2]) + mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9]) + + # test the size of the corresponding KB + assert(mykb.get_size_entities() == 3) + assert(mykb.get_size_aliases() == 2) + + +def test_kb_invalid_entities(nlp): + """Test the invalid construction of a KB with an alias linked to a non-existing entity""" + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + + # adding entities + mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) + mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) + mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) + + # adding aliases - should fail because one of the given IDs is not valid + with pytest.raises(ValueError): + mykb.add_alias(alias='douglas', entities=['Q2', 'Q342'], probabilities=[0.8, 0.2]) + + +def test_kb_invalid_probabilities(nlp): + """Test the invalid construction of a KB with wrong prior probabilities""" + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + + # adding entities + mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) + mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) + mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) + + # adding aliases - should fail because the sum of the probabilities exceeds 1 + with pytest.raises(ValueError): + mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.4]) + + +def test_kb_invalid_combination(nlp): + """Test the invalid construction of a KB with non-matching entity and probability lists""" + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + + # adding entities + mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) + mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) + mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) + + # adding aliases - should fail because the entities and probabilities vectors are not of equal length + with pytest.raises(ValueError): + mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.3, 0.4, 0.1]) + + +def test_kb_invalid_entity_vector(nlp): + """Test the invalid construction of a KB with non-matching entity vector lengths""" + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=3) + + # adding entities + mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1, 2, 3]) + + # this should fail because the kb's expected entity vector length is 3 + with pytest.raises(ValueError): + mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) + + +def test_candidate_generation(nlp): + """Test correct candidate generation""" + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + + # adding entities + mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) + mykb.add_entity(entity='Q2', prob=0.2, entity_vector=[2]) + mykb.add_entity(entity='Q3', prob=0.5, entity_vector=[3]) + + # adding aliases + mykb.add_alias(alias='douglas', entities=['Q2', 'Q3'], probabilities=[0.8, 0.2]) + mykb.add_alias(alias='adam', entities=['Q2'], probabilities=[0.9]) + + # test the size of the relevant candidates + assert(len(mykb.get_candidates('douglas')) == 2) + assert(len(mykb.get_candidates('adam')) == 1) + assert(len(mykb.get_candidates('shrubbery')) == 0) + + +def test_preserving_links_asdoc(nlp): + """Test that Span.as_doc preserves the existing entity links""" + mykb = KnowledgeBase(nlp.vocab, entity_vector_length=1) + + # adding entities + mykb.add_entity(entity='Q1', prob=0.9, entity_vector=[1]) + mykb.add_entity(entity='Q2', prob=0.8, entity_vector=[1]) + + # adding aliases + mykb.add_alias(alias='Boston', entities=['Q1'], probabilities=[0.7]) + mykb.add_alias(alias='Denver', entities=['Q2'], probabilities=[0.6]) + + # set up pipeline with NER (Entity Ruler) and NEL (prior probability only, model not trained) + sentencizer = nlp.create_pipe("sentencizer") + nlp.add_pipe(sentencizer) + + ruler = EntityRuler(nlp) + patterns = [{"label": "GPE", "pattern": "Boston"}, + {"label": "GPE", "pattern": "Denver"}] + ruler.add_patterns(patterns) + nlp.add_pipe(ruler) + + el_pipe = nlp.create_pipe(name='entity_linker', config={"context_width": 64}) + el_pipe.set_kb(mykb) + el_pipe.begin_training() + el_pipe.context_weight = 0 + el_pipe.prior_weight = 1 + nlp.add_pipe(el_pipe, last=True) + + # test whether the entity links are preserved by the `as_doc()` function + text = "She lives in Boston. He lives in Denver." + doc = nlp(text) + for ent in doc.ents: + orig_text = ent.text + orig_kb_id = ent.kb_id_ + sent_doc = ent.sent.as_doc() + for s_ent in sent_doc.ents: + if s_ent.text == orig_text: + assert s_ent.kb_id_ == orig_kb_id diff --git a/spacy/tests/serialize/test_serialize_kb.py b/spacy/tests/serialize/test_serialize_kb.py new file mode 100644 index 000000000..fa7253fa1 --- /dev/null +++ b/spacy/tests/serialize/test_serialize_kb.py @@ -0,0 +1,74 @@ +# coding: utf-8 +from __future__ import unicode_literals + +from ..util import make_tempdir +from ...util import ensure_path + +from spacy.kb import KnowledgeBase + + +def test_serialize_kb_disk(en_vocab): + # baseline assertions + kb1 = _get_dummy_kb(en_vocab) + _check_kb(kb1) + + # dumping to file & loading back in + with make_tempdir() as d: + dir_path = ensure_path(d) + if not dir_path.exists(): + dir_path.mkdir() + file_path = dir_path / "kb" + kb1.dump(str(file_path)) + + kb2 = KnowledgeBase(vocab=en_vocab, entity_vector_length=3) + kb2.load_bulk(str(file_path)) + + # final assertions + _check_kb(kb2) + + +def _get_dummy_kb(vocab): + kb = KnowledgeBase(vocab=vocab, entity_vector_length=3) + + kb.add_entity(entity='Q53', prob=0.33, entity_vector=[0, 5, 3]) + kb.add_entity(entity='Q17', prob=0.2, entity_vector=[7, 1, 0]) + kb.add_entity(entity='Q007', prob=0.7, entity_vector=[0, 0, 7]) + kb.add_entity(entity='Q44', prob=0.4, entity_vector=[4, 4, 4]) + + kb.add_alias(alias='double07', entities=['Q17', 'Q007'], probabilities=[0.1, 0.9]) + kb.add_alias(alias='guy', entities=['Q53', 'Q007', 'Q17', 'Q44'], probabilities=[0.3, 0.3, 0.2, 0.1]) + kb.add_alias(alias='random', entities=['Q007'], probabilities=[1.0]) + + return kb + + +def _check_kb(kb): + # check entities + assert kb.get_size_entities() == 4 + for entity_string in ['Q53', 'Q17', 'Q007', 'Q44']: + assert entity_string in kb.get_entity_strings() + for entity_string in ['', 'Q0']: + assert entity_string not in kb.get_entity_strings() + + # check aliases + assert kb.get_size_aliases() == 3 + for alias_string in ['double07', 'guy', 'random']: + assert alias_string in kb.get_alias_strings() + for alias_string in ['nothingness', '', 'randomnoise']: + assert alias_string not in kb.get_alias_strings() + + # check candidates & probabilities + candidates = sorted(kb.get_candidates('double07'), key=lambda x: x.entity_) + assert len(candidates) == 2 + + assert candidates[0].entity_ == 'Q007' + assert 0.6999 < candidates[0].entity_freq < 0.701 + assert candidates[0].entity_vector == [0, 0, 7] + assert candidates[0].alias_ == 'double07' + assert 0.899 < candidates[0].prior_prob < 0.901 + + assert candidates[1].entity_ == 'Q17' + assert 0.199 < candidates[1].entity_freq < 0.201 + assert candidates[1].entity_vector == [7, 1, 0] + assert candidates[1].alias_ == 'double07' + assert 0.099 < candidates[1].prior_prob < 0.101 diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 373771247..a040cdc67 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -22,7 +22,7 @@ from ..lexeme cimport Lexeme, EMPTY_LEXEME from ..typedefs cimport attr_t, flags_t from ..attrs cimport ID, ORTH, NORM, LOWER, SHAPE, PREFIX, SUFFIX, CLUSTER from ..attrs cimport LENGTH, POS, LEMMA, TAG, DEP, HEAD, SPACY, ENT_IOB -from ..attrs cimport ENT_TYPE, SENT_START, attr_id_t +from ..attrs cimport ENT_TYPE, ENT_KB_ID, SENT_START, attr_id_t from ..parts_of_speech cimport CCONJ, PUNCT, NOUN, univ_pos_t from ..attrs import intify_attrs, IDS @@ -64,6 +64,8 @@ cdef attr_t get_token_attr(const TokenC* token, attr_id_t feat_name) nogil: return token.ent_iob elif feat_name == ENT_TYPE: return token.ent_type + elif feat_name == ENT_KB_ID: + return token.ent_kb_id else: return Lexeme.get_struct_attr(token.lex, feat_name) @@ -851,7 +853,7 @@ cdef class Doc: DOCS: https://spacy.io/api/doc#to_bytes """ - array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] + array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] # TODO: ENT_KB_ID ? if self.is_tagged: array_head.append(TAG) # If doc parsed add head and dep attribute @@ -1005,6 +1007,7 @@ cdef class Doc: """ cdef unicode tag, lemma, ent_type deprecation_warning(Warnings.W013.format(obj="Doc")) + # TODO: ENT_KB_ID ? if len(args) == 3: deprecation_warning(Warnings.W003) tag, lemma, ent_type = args diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index 97b6a1adc..3f4f4418b 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -210,7 +210,7 @@ cdef class Span: words = [t.text for t in self] spaces = [bool(t.whitespace_) for t in self] cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces) - array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE] + array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_KB_ID] if self.doc.is_tagged: array_head.append(TAG) # If doc parsed add head and dep attribute diff --git a/spacy/tokens/token.pxd b/spacy/tokens/token.pxd index bb9f7d070..ec5df3fac 100644 --- a/spacy/tokens/token.pxd +++ b/spacy/tokens/token.pxd @@ -53,6 +53,8 @@ cdef class Token: return token.ent_iob elif feat_name == ENT_TYPE: return token.ent_type + elif feat_name == ENT_KB_ID: + return token.ent_kb_id elif feat_name == SENT_START: return token.sent_start else: @@ -79,5 +81,7 @@ cdef class Token: token.ent_iob = value elif feat_name == ENT_TYPE: token.ent_type = value + elif feat_name == ENT_KB_ID: + token.ent_kb_id = value elif feat_name == SENT_START: token.sent_start = value