Merge branch 'feature/ud-script-update' into bugfix/tokenizer-special-cases-matcher

This commit is contained in:
Adriane Boyd 2019-09-16 14:24:33 +02:00
commit e7e7c942c7
29 changed files with 995 additions and 823 deletions

106
.github/contributors/tamuhey.md vendored Normal file
View File

@ -0,0 +1,106 @@
# spaCy contributor agreement
This spaCy Contributor Agreement (**"SCA"**) is based on the
[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf).
The SCA applies to any contribution that you make to any product or project
managed by us (the **"project"**), and sets out the intellectual property rights
you grant to us in the contributed materials. The term **"us"** shall mean
[ExplosionAI GmbH](https://explosion.ai/legal). The term
**"you"** shall mean the person or entity identified below.
If you agree to be bound by these terms, fill in the information requested
below and include the filled-in version with your first pull request, under the
folder [`.github/contributors/`](/.github/contributors/). The name of the file
should be your GitHub username, with the extension `.md`. For example, the user
example_user would create the file `.github/contributors/example_user.md`.
Read this agreement carefully before signing. These terms and conditions
constitute a binding legal agreement.
## Contributor Agreement
1. The term "contribution" or "contributed materials" means any source code,
object code, patch, tool, sample, graphic, specification, manual,
documentation, or any other material posted or submitted by you to the project.
2. With respect to any worldwide copyrights, or copyright applications and
registrations, in your contribution:
* you hereby assign to us joint ownership, and to the extent that such
assignment is or becomes invalid, ineffective or unenforceable, you hereby
grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge,
royalty-free, unrestricted license to exercise all rights under those
copyrights. This includes, at our option, the right to sublicense these same
rights to third parties through multiple levels of sublicensees or other
licensing arrangements;
* you agree that each of us can do all things in relation to your
contribution as if each of us were the sole owners, and if one of us makes
a derivative work of your contribution, the one who makes the derivative
work (or has it made will be the sole owner of that derivative work;
* you agree that you will not assert any moral rights in your contribution
against us, our licensees or transferees;
* you agree that we may register a copyright in your contribution and
exercise all ownership rights associated with it; and
* you agree that neither of us has any duty to consult with, obtain the
consent of, pay or render an accounting to the other for any use or
distribution of your contribution.
3. With respect to any patents you own, or that you can license without payment
to any third party, you hereby grant to us a perpetual, irrevocable,
non-exclusive, worldwide, no-charge, royalty-free license to:
* make, have made, use, sell, offer to sell, import, and otherwise transfer
your contribution in whole or in part, alone or in combination with or
included in any product, work or materials arising out of the project to
which your contribution was submitted, and
* at our option, to sublicense these same rights to third parties through
multiple levels of sublicensees or other licensing arrangements.
4. Except as set out above, you keep all right, title, and interest in your
contribution. The rights that you grant to us under these terms are effective
on the date you first submitted a contribution to us, even if your submission
took place before the date you sign these terms.
5. You covenant, represent, warrant and agree that:
* Each contribution that you submit is and shall be an original work of
authorship and you can legally grant the rights set out in this SCA;
* to the best of your knowledge, each contribution will not violate any
third party's copyrights, trademarks, patents, or other intellectual
property rights; and
* each contribution shall be in compliance with U.S. export control laws and
other applicable export and import laws. You agree to notify us if you
become aware of any circumstance which would make any of the foregoing
representations inaccurate in any respect. We may publicly disclose your
participation in the project, including the fact that you have signed the SCA.
6. This SCA is governed by the laws of the State of California and applicable
U.S. Federal law. Any choice of law rules will not apply.
7. Please place an “x” on one of the applicable statement below. Please do NOT
mark both statements:
* [x] I am signing on behalf of myself as an individual and no other person
or entity, including my employer, has or will have rights with respect to my
contributions.
* [ ] I am signing on behalf of my employer or a legal entity and I have the
actual authority to contractually bind that entity.
## Contributor Details
| Field | Entry |
|------------------------------- | -------------------- |
| Name | Yohei Tamura |
| Company name (if applicable) | PKSHA |
| Title or role (if applicable) | |
| Date | 2019/9/12 |
| GitHub username | tamuhey |
| Website (optional) | |

View File

@ -7,14 +7,16 @@ import datetime
from pathlib import Path
import xml.etree.ElementTree as ET
from spacy.cli.ud import conll17_ud_eval
from spacy.cli.ud.ud_train import write_conllu
import conll17_ud_eval
from ud_train import write_conllu
from spacy.lang.lex_attrs import word_shape
from spacy.util import get_lang_class
# All languages in spaCy - in UD format (note that Norwegian is 'no' instead of 'nb')
ALL_LANGUAGES = "ar, ca, da, de, el, en, es, fa, fi, fr, ga, he, hi, hr, hu, id, " \
"it, ja, no, nl, pl, pt, ro, ru, sv, tr, ur, vi, zh"
ALL_LANGUAGES = ("af, ar, bg, bn, ca, cs, da, de, el, en, es, et, fa, fi, fr,"
"ga, he, hi, hr, hu, id, is, it, ja, kn, ko, lt, lv, mr, no,"
"nl, pl, pt, ro, ru, si, sk, sl, sq, sr, sv, ta, te, th, tl,"
"tr, tt, uk, ur, vi, zh")
# Non-parsing tasks that will be evaluated (works for default models)
EVAL_NO_PARSE = ['Tokens', 'Words', 'Lemmas', 'Sentences', 'Feats']
@ -73,10 +75,10 @@ def _contains_blinded_text(stats_xml):
tree = ET.parse(stats_xml)
root = tree.getroot()
total_tokens = int(root.find('size/total/tokens').text)
unique_lemmas = int(root.find('lemmas').get('unique'))
unique_forms = int(root.find('forms').get('unique'))
# assume the corpus is largely blinded when there are less than 1% unique tokens
return (unique_lemmas / total_tokens) < 0.01
return (unique_forms / total_tokens) < 0.01
def fetch_all_treebanks(ud_dir, languages, corpus, best_per_language):
@ -262,22 +264,26 @@ def main(out_path, ud_dir, check_parse=False, langs=ALL_LANGUAGES, exclude_train
if not exclude_trained_models:
if 'de' in models:
models['de'].append(load_model('de_core_news_sm'))
if 'es' in models:
models['es'].append(load_model('es_core_news_sm'))
models['es'].append(load_model('es_core_news_md'))
if 'pt' in models:
models['pt'].append(load_model('pt_core_news_sm'))
if 'it' in models:
models['it'].append(load_model('it_core_news_sm'))
if 'nl' in models:
models['nl'].append(load_model('nl_core_news_sm'))
models['de'].append(load_model('de_core_news_md'))
if 'el' in models:
models['el'].append(load_model('el_core_news_sm'))
models['el'].append(load_model('el_core_news_md'))
if 'en' in models:
models['en'].append(load_model('en_core_web_sm'))
models['en'].append(load_model('en_core_web_md'))
models['en'].append(load_model('en_core_web_lg'))
if 'es' in models:
models['es'].append(load_model('es_core_news_sm'))
models['es'].append(load_model('es_core_news_md'))
if 'fr' in models:
models['fr'].append(load_model('fr_core_news_sm'))
models['fr'].append(load_model('fr_core_news_md'))
if 'it' in models:
models['it'].append(load_model('it_core_news_sm'))
if 'nl' in models:
models['nl'].append(load_model('nl_core_news_sm'))
if 'pt' in models:
models['pt'].append(load_model('pt_core_news_sm'))
with out_path.open(mode='w', encoding='utf-8') as out_file:
run_all_evals(models, treebanks, out_file, check_parse, print_freq_tasks)

View File

@ -109,15 +109,13 @@ def write_conllu(docs, file_):
merger = Matcher(docs[0].vocab)
merger.add("SUBTOK", None, [{"DEP": "subtok", "op": "+"}])
for i, doc in enumerate(docs):
matches = merger(doc)
matches = []
if doc.is_parsed:
matches = merger(doc)
spans = [doc[start : end + 1] for _, start, end in matches]
with doc.retokenize() as retokenizer:
for span in spans:
retokenizer.merge(span)
# TODO: This shouldn't be necessary? Should be handled in merge
for word in doc:
if word.i == word.head.i:
word.dep_ = "ROOT"
file_.write("# newdoc id = {i}\n".format(i=i))
for j, sent in enumerate(doc.sents):
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))

View File

@ -25,7 +25,7 @@ import itertools
import random
import numpy.random
from . import conll17_ud_eval
import conll17_ud_eval
from spacy import lang
from spacy.lang import zh
@ -214,7 +214,9 @@ def write_conllu(docs, file_):
merger = Matcher(docs[0].vocab)
merger.add("SUBTOK", None, [{"DEP": "subtok", "op": "+"}])
for i, doc in enumerate(docs):
matches = merger(doc)
matches = []
if doc.is_parsed:
matches = merger(doc)
spans = [doc[start : end + 1] for _, start, end in matches]
with doc.retokenize() as retokenizer:
for span in spans:
@ -298,9 +300,9 @@ def get_token_conllu(token, i):
return "\n".join(lines)
Token.set_extension("get_conllu_lines", method=get_token_conllu)
Token.set_extension("begins_fused", default=False)
Token.set_extension("inside_fused", default=False)
Token.set_extension("get_conllu_lines", method=get_token_conllu, force=True)
Token.set_extension("begins_fused", default=False, force=True)
Token.set_extension("inside_fused", default=False, force=True)
##################

View File

@ -0,0 +1,11 @@
TRAINING_DATA_FILE = "gold_entities.jsonl"
KB_FILE = "kb"
KB_MODEL_DIR = "nlp_kb"
OUTPUT_MODEL_DIR = "nlp"
PRIOR_PROB_PATH = "prior_prob.csv"
ENTITY_DEFS_PATH = "entity_defs.csv"
ENTITY_FREQ_PATH = "entity_freq.csv"
ENTITY_DESCR_PATH = "entity_descriptions.csv"
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'

View File

@ -0,0 +1,200 @@
import logging
import random
from collections import defaultdict
logger = logging.getLogger(__name__)
class Metrics(object):
true_pos = 0
false_pos = 0
false_neg = 0
def update_results(self, true_entity, candidate):
candidate_is_correct = true_entity == candidate
# Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
# Therefore, if candidate_is_correct then we have a true positive and never a true negative
self.true_pos += candidate_is_correct
self.false_neg += not candidate_is_correct
if candidate not in {"", "NIL"}:
self.false_pos += not candidate_is_correct
def calculate_precision(self):
if self.true_pos == 0:
return 0.0
else:
return self.true_pos / (self.true_pos + self.false_pos)
def calculate_recall(self):
if self.true_pos == 0:
return 0.0
else:
return self.true_pos / (self.true_pos + self.false_neg)
class EvaluationResults(object):
def __init__(self):
self.metrics = Metrics()
self.metrics_by_label = defaultdict(Metrics)
def update_metrics(self, ent_label, true_entity, candidate):
self.metrics.update_results(true_entity, candidate)
self.metrics_by_label[ent_label].update_results(true_entity, candidate)
def increment_false_negatives(self):
self.metrics.false_neg += 1
def report_metrics(self, model_name):
model_str = model_name.title()
recall = self.metrics.calculate_recall()
precision = self.metrics.calculate_precision()
return ("{}: ".format(model_str) +
"Recall = {} | ".format(round(recall, 3)) +
"Precision = {} | ".format(round(precision, 3)) +
"Precision by label = {}".format({k: v.calculate_precision()
for k, v in self.metrics_by_label.items()}))
class BaselineResults(object):
def __init__(self):
self.random = EvaluationResults()
self.prior = EvaluationResults()
self.oracle = EvaluationResults()
def report_accuracy(self, model):
results = getattr(self, model)
return results.report_metrics(model)
def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate):
self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
self.prior.update_metrics(ent_label, true_entity, prior_candidate)
self.random.update_metrics(ent_label, true_entity, random_candidate)
def measure_performance(dev_data, kb, el_pipe):
baseline_accuracies = measure_baselines(
dev_data, kb
)
logger.info(baseline_accuracies.report_accuracy("random"))
logger.info(baseline_accuracies.report_accuracy("prior"))
logger.info(baseline_accuracies.report_accuracy("oracle"))
# using only context
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = False
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context only"))
# measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context and prior"))
def get_eval_results(data, el_pipe=None):
# If the docs in the data require further processing with an entity linker, set el_pipe
from tqdm import tqdm
docs = []
golds = []
for d, g in tqdm(data, leave=False):
if len(d) > 0:
golds.append(g)
if el_pipe is not None:
docs.append(el_pipe(d))
else:
docs.append(d)
results = EvaluationResults()
for doc, gold in zip(docs, golds):
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
try:
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
# only evaluating on positive examples
for gold_kb, value in kb_dict.items():
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
if offset not in tagged_entries_per_article:
results.increment_false_negatives()
for ent in doc.ents:
ent_label = ent.label_
pred_entity = ent.kb_id_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
results.update_metrics(ent_label, gold_entity, pred_entity)
except Exception as e:
logging.error("Error assessing accuracy " + str(e))
return results
def measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
counts_d = dict()
baseline_results = BaselineResults()
docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0]
for doc, gold in zip(docs, golds):
correct_entries_per_article = dict()
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
for entity, kb_dict in gold.links.items():
start, end = entity
for gold_kb, value in kb_dict.items():
# only evaluating on positive examples
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
if offset not in tagged_entries_per_article:
baseline_results.random.increment_false_negatives()
baseline_results.oracle.increment_false_negatives()
baseline_results.prior.increment_false_negatives()
for ent in doc.ents:
ent_label = ent.label_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
best_candidate = ""
random_candidate = ""
if candidates:
scores = []
for c in candidates:
scores.append(c.prior_prob)
if c.entity_ == gold_entity:
oracle_candidate = c.entity_
best_index = scores.index(max(scores))
best_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_
baseline_results.update_baselines(gold_entity, ent_label,
random_candidate, best_candidate, oracle_candidate)
return baseline_results
def _offset(start, end):
return "{}_{}".format(start, end)

View File

@ -1,12 +1,20 @@
# coding: utf-8
from __future__ import unicode_literals
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp
import csv
import logging
import spacy
import sys
from spacy.kb import KnowledgeBase
import csv
import datetime
from bin.wiki_entity_linking import wikipedia_processor as wp
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
csv.field_size_limit(sys.maxsize)
logger = logging.getLogger(__name__)
def create_kb(
@ -14,52 +22,73 @@ def create_kb(
max_entities_per_alias,
min_entity_freq,
min_occ,
entity_def_output,
entity_descr_output,
entity_def_input,
entity_descr_path,
count_input,
prior_prob_input,
wikidata_input,
entity_vector_length,
limit=None,
read_raw_data=True,
):
# Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
# read the mappings from file
title_to_id = get_entity_to_id(entity_def_input)
id_to_descr = get_id_to_description(entity_descr_path)
# check the length of the nlp vectors
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
input_dim = nlp.vocab.vectors_length
print("Loaded pre-trained vectors of size %s" % input_dim)
logger.info("Loaded pre-trained vectors of size %s" % input_dim)
else:
raise ValueError(
"The `nlp` object should have access to pre-trained word vectors, "
" cf. https://spacy.io/usage/models#languages."
)
# disable this part of the pipeline when rerunning the KB generation from preprocessed files
if read_raw_data:
print()
print(now(), " * read wikidata entities:")
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
wikidata_input, limit=limit
)
# write the title-ID and ID-description mappings to file
_write_entity_files(
entity_def_output, entity_descr_output, title_to_id, id_to_descr
)
else:
# read the mappings from file
title_to_id = get_entity_to_id(entity_def_output)
id_to_descr = get_id_to_description(entity_descr_output)
print()
print(now(), " * get entity frequencies:")
print()
logger.info("Get entity frequencies")
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
title_to_id,
id_to_descr,
entity_frequencies,
min_entity_freq
)
logger.info("Left with {} entities".format(len(description_list)))
logger.info("Train entity encoder")
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
encoder.train(description_list=description_list, to_print=True)
logger.info("Get entity embeddings:")
embeddings = encoder.apply_encoder(description_list)
logger.info("Adding {} entities".format(len(entity_list)))
kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
)
logger.info("Adding aliases")
_add_aliases(
kb,
title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ,
prior_prob_input=prior_prob_input,
)
logger.info("KB size: {} entities, {} aliases".format(
kb.get_size_entities(),
kb.get_size_aliases()))
logger.info("Done with kb")
return kb
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
min_entity_freq: int = 10):
filtered_title_to_id = dict()
entity_list = []
description_list = []
@ -72,58 +101,7 @@ def create_kb(
description_list.append(desc)
frequency_list.append(freq)
filtered_title_to_id[title] = entity
print(len(title_to_id.keys()), "original titles")
kept_nr = len(filtered_title_to_id.keys())
print("kept", kept_nr, "entities with min. frequency", min_entity_freq)
print()
print(now(), " * train entity encoder:")
print()
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
encoder.train(description_list=description_list, to_print=True)
print()
print(now(), " * get entity embeddings:")
print()
embeddings = encoder.apply_encoder(description_list)
print(now(), " * adding", len(entity_list), "entities")
kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
)
alias_cnt = _add_aliases(
kb,
title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ,
prior_prob_input=prior_prob_input,
)
print()
print(now(), " * adding", alias_cnt, "aliases")
print()
print()
print("# of entities in kb:", kb.get_size_entities())
print("# of aliases in kb:", kb.get_size_aliases())
print(now(), "Done with kb")
return kb
def _write_entity_files(
entity_def_output, entity_descr_output, title_to_id, id_to_descr
):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")
return filtered_title_to_id, entity_list, description_list, frequency_list
def get_entity_to_id(entity_def_output):
@ -137,9 +115,9 @@ def get_entity_to_id(entity_def_output):
return entity_to_id
def get_id_to_description(entity_descr_output):
def get_id_to_description(entity_descr_path):
id_to_desc = dict()
with entity_descr_output.open("r", encoding="utf8") as csvfile:
with entity_descr_path.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
@ -150,7 +128,6 @@ def get_id_to_description(entity_descr_output):
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
wp_titles = title_to_id.keys()
cnt = 0
# adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count
@ -187,9 +164,8 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
entities=selected_entities,
probabilities=prior_probs,
)
cnt += 1
except ValueError as e:
print(e)
logger.error(e)
total_count = 0
counts = []
entities = []
@ -202,8 +178,12 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
previous_alias = new_alias
line = prior_file.readline()
return cnt
def now():
return datetime.datetime.now()
def read_nlp_kb(model_dir, kb_file):
nlp = spacy.load(model_dir)
kb = KnowledgeBase(vocab=nlp.vocab)
kb.load_bulk(kb_file)
logger.info("kb entities: {}".format(kb.get_size_entities()))
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
return nlp, kb

View File

@ -1,6 +1,7 @@
# coding: utf-8
from random import shuffle
import logging
import numpy as np
from spacy._ml import zero_init, create_default_optimizer
@ -10,6 +11,8 @@ from thinc.v2v import Model
from thinc.api import chain
from thinc.neural._classes.affine import Affine
logger = logging.getLogger(__name__)
class EntityEncoder:
"""
@ -50,21 +53,19 @@ class EntityEncoder:
start = start + batch_size
stop = min(stop + batch_size, len(description_list))
print("encoded:", stop, "entities")
logger.info("encoded: {} entities".format(stop))
return encodings
def train(self, description_list, to_print=False):
processed, loss = self._train_model(description_list)
if to_print:
print(
"Trained entity descriptions on",
processed,
"(non-unique) entities across",
self.epochs,
"epochs",
logger.info(
"Trained entity descriptions on {} ".format(processed) +
"(non-unique) entities across {} ".format(self.epochs) +
"epochs"
)
print("Final loss:", loss)
logger.info("Final loss: {}".format(loss))
def _train_model(self, description_list):
best_loss = 1.0
@ -93,7 +94,7 @@ class EntityEncoder:
loss = self._update(batch)
if batch_nr % 25 == 0:
print("loss:", loss)
logger.info("loss: {} ".format(loss))
processed += len(batch)
# in general, continue training if we haven't reached our ideal min yet

View File

@ -1,10 +1,13 @@
# coding: utf-8
from __future__ import unicode_literals
import logging
import random
import re
import bz2
import datetime
import json
from functools import partial
from spacy.gold import GoldParse
from bin.wiki_entity_linking import kb_creator
@ -15,18 +18,30 @@ Gold-standard entities are stored in one file in standoff format (by character o
"""
ENTITY_FILE = "gold_entities.csv"
logger = logging.getLogger(__name__)
def now():
return datetime.datetime.now()
def create_training(wikipedia_input, entity_def_input, training_output, limit=None):
def create_training_examples_and_descriptions(wikipedia_input,
entity_def_input,
description_output,
training_output,
parse_descriptions,
limit=None):
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=limit)
_process_wikipedia_texts(wikipedia_input,
wp_to_id,
description_output,
training_output,
parse_descriptions,
limit)
def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None):
def _process_wikipedia_texts(wikipedia_input,
wp_to_id,
output,
training_output,
parse_descriptions,
limit=None):
"""
Read the XML wikipedia data to parse out training data:
raw text data + positive instances
@ -35,29 +50,21 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
read_ids = set()
entityfile_loc = training_output / ENTITY_FILE
with entityfile_loc.open("w", encoding="utf8") as entityfile:
# write entity training header file
_write_training_entity(
outputfile=entityfile,
article_id="article_id",
alias="alias",
entity="WD_id",
start="start",
end="end",
)
with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file:
if parse_descriptions:
_write_training_description(descr_file, "WD_id", "description")
with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline()
cnt = 0
article_count = 0
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
while line and (not limit or cnt < limit):
if cnt % 1000000 == 0:
print(now(), "processed", cnt, "lines of Wikipedia dump")
logger.info("Processed {} articles".format(article_count))
for line in file:
clean_line = line.strip().decode("utf-8")
if clean_line == "<revision>":
@ -70,28 +77,32 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
article_text = ""
article_title = None
article_id = None
# finished reading this page
elif clean_line == "</page>":
if article_id:
try:
_process_wp_text(
wp_to_id,
entityfile,
article_id,
article_title,
article_text.strip(),
training_output,
)
except Exception as e:
print(
"Error processing article", article_id, article_title, e
)
else:
print(
"Done processing a page, but couldn't find an article_id ?",
clean_text, entities = _process_wp_text(
article_title,
article_text,
wp_to_id
)
if clean_text is not None and entities is not None:
_write_training_entities(entity_file,
article_id,
clean_text,
entities)
if article_title in wp_to_id and parse_descriptions:
description = " ".join(clean_text[:1000].split(" ")[:-1])
_write_training_description(
descr_file,
wp_to_id[article_title],
description
)
article_count += 1
if article_count % 10000 == 0:
logger.info("Processed {} articles".format(article_count))
if limit and article_count >= limit:
break
article_text = ""
article_title = None
article_id = None
@ -115,7 +126,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
if ids:
article_id = ids[0]
if article_id in read_ids:
print(
logger.info(
"Found duplicate article ID", article_id, clean_line
) # This should never happen ...
read_ids.add(article_id)
@ -125,115 +136,10 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
titles = title_regex.search(clean_line)
if titles:
article_title = titles[0].strip()
line = file.readline()
cnt += 1
print(now(), "processed", cnt, "lines of Wikipedia dump")
logger.info("Finished. Processed {} articles".format(article_count))
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
def _process_wp_text(
wp_to_id, entityfile, article_id, article_title, article_text, training_output
):
found_entities = False
# ignore meta Wikipedia pages
if article_title.startswith("Wikipedia:"):
return
# remove the text tags
text = text_regex.search(article_text).group(0)
# stop processing if this is a redirect page
if text.startswith("#REDIRECT"):
return
# get the raw text without markup etc, keeping only interwiki links
clean_text = _get_clean_wp_text(text)
# read the text char by char to get the right offsets for the interwiki links
final_text = ""
open_read = 0
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
entity_buffer = ""
mention_buffer = ""
for index, letter in enumerate(clean_text):
if letter == "[":
open_read += 1
elif letter == "]":
open_read -= 1
elif letter == "|":
if reading_text:
final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern
elif reading_entity:
reading_text = False
reading_entity = False
reading_mention = True
else:
reading_special_case = True
else:
if reading_entity:
entity_buffer += letter
elif reading_mention:
mention_buffer += letter
elif reading_text:
final_text += letter
else:
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
if open_read > 2:
reading_special_case = True
if open_read == 2 and reading_text:
reading_text = False
reading_entity = True
reading_mention = False
# we just finished reading an entity
if open_read == 0 and not reading_text:
if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
if not mention_buffer:
mention_buffer = entity_buffer
start = len(final_text)
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
_write_training_entity(
outputfile=entityfile,
article_id=article_id,
alias=mention_buffer,
entity=qid,
start=start,
end=end,
)
found_entities = True
final_text += mention_buffer
entity_buffer = ""
mention_buffer = ""
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
if found_entities:
_write_training_article(
article_id=article_id,
clean_text=final_text,
training_output=training_output,
)
info_regex = re.compile(r"{[^{]*?}")
htlm_regex = re.compile(r"&lt;!--[^-]*--&gt;")
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
@ -242,6 +148,29 @@ ref_regex = re.compile(r"&lt;ref.*?&gt;") # non-greedy
ref_2_regex = re.compile(r"&lt;/ref.*?&gt;") # non-greedy
def _process_wp_text(article_title, article_text, wp_to_id):
# ignore meta Wikipedia pages
if (
article_title.startswith("Wikipedia:") or
article_title.startswith("Kategori:")
):
return None, None
# remove the text tags
text_search = text_regex.search(article_text)
if text_search is None:
return None, None
text = text_search.group(0)
# stop processing if this is a redirect page
if text.startswith("#REDIRECT"):
return None, None
# get the raw text without markup etc, keeping only interwiki links
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
return clean_text, entities
def _get_clean_wp_text(article_text):
clean_text = article_text.strip()
@ -300,130 +229,167 @@ def _get_clean_wp_text(article_text):
return clean_text.strip()
def _write_training_article(article_id, clean_text, training_output):
file_loc = training_output / "{}.txt".format(article_id)
with file_loc.open("w", encoding="utf8") as outputfile:
outputfile.write(clean_text)
def _remove_links(clean_text, wp_to_id):
# read the text char by char to get the right offsets for the interwiki links
entities = []
final_text = ""
open_read = 0
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
entity_buffer = ""
mention_buffer = ""
for index, letter in enumerate(clean_text):
if letter == "[":
open_read += 1
elif letter == "]":
open_read -= 1
elif letter == "|":
if reading_text:
final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern
elif reading_entity:
reading_text = False
reading_entity = False
reading_mention = True
else:
reading_special_case = True
else:
if reading_entity:
entity_buffer += letter
elif reading_mention:
mention_buffer += letter
elif reading_text:
final_text += letter
else:
raise ValueError("Not sure at point", clean_text[index - 2: index + 2])
if open_read > 2:
reading_special_case = True
if open_read == 2 and reading_text:
reading_text = False
reading_entity = True
reading_mention = False
# we just finished reading an entity
if open_read == 0 and not reading_text:
if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
if not mention_buffer:
mention_buffer = entity_buffer
start = len(final_text)
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
entities.append((mention_buffer, qid, start, end))
final_text += mention_buffer
entity_buffer = ""
mention_buffer = ""
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
return final_text, entities
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end)
def _write_training_description(outputfile, qid, description):
if description is not None:
line = str(qid) + "|" + description + "\n"
outputfile.write(line)
def _write_training_entities(outputfile, article_id, clean_text, entities):
entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities]
line = json.dumps(
{
"article_id": article_id,
"clean_text": clean_text,
"entities": entities_data
},
ensure_ascii=False) + "\n"
outputfile.write(line)
def read_training(nlp, entity_file_path, dev, limit, kb):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
For training,, it will include negative training examples by using the candidate generator,
and it will only keep positive training examples that can be found by using the candidate generator.
For testing, it will include all positive examples only."""
from tqdm import tqdm
data = []
num_entities = 0
get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb)
logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit))
with entity_file_path.open("r", encoding="utf8") as file:
with tqdm(total=limit, leave=False) as pbar:
for i, line in enumerate(file):
example = json.loads(line)
article_id = example["article_id"]
clean_text = example["clean_text"]
entities = example["entities"]
if dev != is_dev(article_id) or len(clean_text) >= 30000:
continue
doc = nlp(clean_text)
gold = get_gold_parse(doc, entities)
if gold and len(gold.links) > 0:
data.append((doc, gold))
num_entities += len(gold.links)
pbar.update(len(gold.links))
if limit and num_entities >= limit:
break
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
return data
def _get_gold_parse(doc, entities, dev, kb):
gold_entities = {}
tagged_ent_positions = set(
[(ent.start_char, ent.end_char) for ent in doc.ents]
)
for entity in entities:
entity_id = entity["entity"]
alias = entity["alias"]
start = entity["start"]
end = entity["end"]
candidates = kb.get_candidates(alias)
candidate_ids = [
c.entity_ for c in candidates
]
should_add_ent = (
dev or
(
(start, end) in tagged_ent_positions and
entity_id in candidate_ids and
len(candidates) > 1
)
)
if should_add_ent:
value_by_id = {entity_id: 1.0}
if not dev:
random.shuffle(candidate_ids)
value_by_id.update({
kb_id: 0.0
for kb_id in candidate_ids
if kb_id != entity_id
})
gold_entities[(start, end)] = value_by_id
return GoldParse(doc, links=gold_entities)
def is_dev(article_id):
return article_id.endswith("3")
def read_training(nlp, training_dir, dev, limit, kb=None):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
When kb is provided (for training), it will include negative training examples by using the candidate generator,
and it will only keep positive training examples that can be found in the KB.
When kb=None (for testing), it will include all positive examples only."""
entityfile_loc = training_dir / ENTITY_FILE
data = []
# assume the data is written sequentially, so we can reuse the article docs
current_article_id = None
current_doc = None
ents_by_offset = dict()
skip_articles = set()
total_entities = 0
with entityfile_loc.open("r", encoding="utf8") as file:
for line in file:
if not limit or len(data) < limit:
fields = line.replace("\n", "").split(sep="|")
article_id = fields[0]
alias = fields[1]
wd_id = fields[2]
start = fields[3]
end = fields[4]
if (
dev == is_dev(article_id)
and article_id != "article_id"
and article_id not in skip_articles
):
if not current_doc or (current_article_id != article_id):
# parse the new article text
file_name = article_id + ".txt"
try:
training_file = training_dir / file_name
with training_file.open("r", encoding="utf8") as f:
text = f.read()
# threshold for convenience / speed of processing
if len(text) < 30000:
current_doc = nlp(text)
current_article_id = article_id
ents_by_offset = dict()
for ent in current_doc.ents:
sent_length = len(ent.sent)
# custom filtering to avoid too long or too short sentences
if 5 < sent_length < 100:
offset = "{}_{}".format(
ent.start_char, ent.end_char
)
ents_by_offset[offset] = ent
else:
skip_articles.add(article_id)
current_doc = None
except Exception as e:
print("Problem parsing article", article_id, e)
skip_articles.add(article_id)
# repeat checking this condition in case an exception was thrown
if current_doc and (current_article_id == article_id):
offset = "{}_{}".format(start, end)
found_ent = ents_by_offset.get(offset, None)
if found_ent:
if found_ent.text != alias:
skip_articles.add(article_id)
current_doc = None
else:
sent = found_ent.sent.as_doc()
gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char
gold_entities = {}
found_useful = False
for ent in sent.ents:
entry = (ent.start_char, ent.end_char)
gold_entry = (gold_start, gold_end)
if entry == gold_entry:
# add both pos and neg examples (in random order)
# this will exclude examples not in the KB
if kb:
value_by_id = {}
candidates = kb.get_candidates(alias)
candidate_ids = [
c.entity_ for c in candidates
]
random.shuffle(candidate_ids)
for kb_id in candidate_ids:
found_useful = True
if kb_id != wd_id:
value_by_id[kb_id] = 0.0
else:
value_by_id[kb_id] = 1.0
gold_entities[entry] = value_by_id
# if no KB, keep all positive examples
else:
found_useful = True
value_by_id = {wd_id: 1.0}
gold_entities[entry] = value_by_id
# currently feeding the gold data one entity per sentence at a time
# setting all other entities to empty gold dictionary
else:
gold_entities[entry] = {}
if found_useful:
gold = GoldParse(doc=sent, links=gold_entities)
data.append((sent, gold))
total_entities += 1
if len(data) % 2500 == 0:
print(" -read", total_entities, "entities")
print(" -read", total_entities, "entities")
return data

View File

@ -13,27 +13,25 @@ from https://dumps.wikimedia.org/enwiki/latest/
"""
from __future__ import unicode_literals
import datetime
import logging
from pathlib import Path
import plac
from bin.wiki_entity_linking import wikipedia_processor as wp
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
from bin.wiki_entity_linking import kb_creator
from bin.wiki_entity_linking import training_set_creator
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH
import spacy
from spacy import Errors
def now():
return datetime.datetime.now()
logger = logging.getLogger(__name__)
@plac.annotations(
wd_json=("Path to the downloaded WikiData JSON dump.", "positional", None, Path),
wp_xml=("Path to the downloaded Wikipedia XML dump.", "positional", None, Path),
output_dir=("Output directory", "positional", None, Path),
model=("Model name, should include pretrained vectors.", "positional", None, str),
model=("Model name or path, should include pretrained vectors.", "positional", None, str),
max_per_alias=("Max. # entities per alias (default 10)", "option", "a", int),
min_freq=("Min. count of an entity in the corpus (default 20)", "option", "f", int),
min_pair=("Min. count of entity-alias pairs (default 5)", "option", "c", int),
@ -41,7 +39,9 @@ def now():
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"),
limit=("Optional threshold to limit lines read from dumps", "option", "l", int),
lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str),
)
def main(
wd_json,
@ -55,20 +55,29 @@ def main(
loc_prior_prob=None,
loc_entity_defs=None,
loc_entity_desc=None,
descriptions_from_wikipedia=False,
limit=None,
lang="en",
):
print(now(), "Creating KB with Wikipedia and WikiData")
print()
entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
entity_freq_path = output_dir / ENTITY_FREQ_PATH
prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
training_entities_path = output_dir / TRAINING_DATA_FILE
kb_path = output_dir / KB_FILE
logger.info("Creating KB with Wikipedia and WikiData")
if limit is not None:
print("Warning: reading only", limit, "lines of Wikipedia/Wikidata dumps.")
logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit))
# STEP 0: set up IO
if not output_dir.exists():
output_dir.mkdir()
output_dir.mkdir(parents=True)
# STEP 1: create the NLP object
print(now(), "STEP 1: loaded model", model)
logger.info("STEP 1: Loading model {}".format(model))
nlp = spacy.load(model)
# check the length of the nlp vectors
@ -79,64 +88,68 @@ def main(
)
# STEP 2: create prior probabilities from WP
print()
if loc_prior_prob:
print(now(), "STEP 2: reading prior probabilities from", loc_prior_prob)
else:
if not prior_prob_path.exists():
# It takes about 2h to process 1000M lines of Wikipedia XML dump
loc_prior_prob = output_dir / "prior_prob.csv"
print(now(), "STEP 2: writing prior probabilities at", loc_prior_prob)
wp.read_prior_probs(wp_xml, loc_prior_prob, limit=limit)
logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path))
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit)
logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path))
# STEP 3: deduce entity frequencies from WP (takes only a few minutes)
print()
print(now(), "STEP 3: calculating entity frequencies")
loc_entity_freq = output_dir / "entity_freq.csv"
wp.write_entity_counts(loc_prior_prob, loc_entity_freq, to_print=False)
logger.info("STEP 3: calculating entity frequencies")
wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False)
loc_kb = output_dir / "kb"
# STEP 4: reading entity descriptions and definitions from WikiData or from file
print()
if loc_entity_defs and loc_entity_desc:
read_raw = False
print(now(), "STEP 4a: reading entity definitions from", loc_entity_defs)
print(now(), "STEP 4b: reading entity descriptions from", loc_entity_desc)
else:
# STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
message = " and descriptions" if not descriptions_from_wikipedia else ""
if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()):
# It takes about 10h to process 55M lines of Wikidata JSON dump
read_raw = True
loc_entity_defs = output_dir / "entity_defs.csv"
loc_entity_desc = output_dir / "entity_descriptions.csv"
print(now(), "STEP 4: parsing wikidata for entity definitions and descriptions")
logger.info("STEP 4: parsing wikidata for entity definitions" + message)
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
wd_json,
limit,
to_print=False,
lang=lang,
parse_descriptions=(not descriptions_from_wikipedia),
)
wd.write_entity_files(entity_defs_path, title_to_id)
if not descriptions_from_wikipedia:
wd.write_entity_description_files(entity_descr_path, id_to_descr)
logger.info("STEP 4: read entity definitions" + message)
# STEP 5: creating the actual KB
# STEP 5: Getting gold entities from wikipedia
message = " and descriptions" if descriptions_from_wikipedia else ""
if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()):
logger.info("STEP 5: parsing wikipedia for gold entities" + message)
training_set_creator.create_training_examples_and_descriptions(
wp_xml,
entity_defs_path,
entity_descr_path,
training_entities_path,
parse_descriptions=descriptions_from_wikipedia,
limit=limit,
)
logger.info("STEP 5: read gold entities" + message)
# STEP 6: creating the actual KB
# It takes ca. 30 minutes to pretrain the entity embeddings
print()
print(now(), "STEP 5: creating the KB at", loc_kb)
logger.info("STEP 6: creating the KB at {}".format(kb_path))
kb = kb_creator.create_kb(
nlp=nlp,
max_entities_per_alias=max_per_alias,
min_entity_freq=min_freq,
min_occ=min_pair,
entity_def_output=loc_entity_defs,
entity_descr_output=loc_entity_desc,
count_input=loc_entity_freq,
prior_prob_input=loc_prior_prob,
wikidata_input=wd_json,
entity_def_input=entity_defs_path,
entity_descr_path=entity_descr_path,
count_input=entity_freq_path,
prior_prob_input=prior_prob_path,
entity_vector_length=entity_vector_length,
limit=limit,
read_raw_data=read_raw,
)
if read_raw:
print(" - wrote entity definitions to", loc_entity_defs)
print(" - wrote writing entity descriptions to", loc_entity_desc)
kb.dump(loc_kb)
nlp.to_disk(output_dir / "nlp")
kb.dump(kb_path)
nlp.to_disk(output_dir / KB_MODEL_DIR)
print()
print(now(), "Done!")
logger.info("Done!")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
plac.call(main)

View File

@ -1,17 +1,19 @@
# coding: utf-8
from __future__ import unicode_literals
import bz2
import gzip
import json
import logging
import datetime
logger = logging.getLogger(__name__)
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True):
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
lang = "en"
site_filter = "enwiki"
site_filter = '{}wiki'.format(lang)
# properties filter (currently disabled to get ALL data)
prop_filter = dict()
@ -24,18 +26,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
parse_properties = False
parse_sitelinks = True
parse_labels = False
parse_descriptions = True
parse_aliases = False
parse_claims = False
with bz2.open(wikidata_file, mode="rb") as file:
line = file.readline()
cnt = 0
while line and (not limit or cnt < limit):
if cnt % 1000000 == 0:
print(
datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump"
)
with gzip.open(wikidata_file, mode='rb') as file:
for cnt, line in enumerate(file):
if limit and cnt >= limit:
break
if cnt % 500000 == 0:
logger.info("processed {} lines of WikiData dump".format(cnt))
clean_line = line.strip()
if clean_line.endswith(b","):
clean_line = clean_line[:-1]
@ -134,8 +133,19 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
if to_print:
print()
line = file.readline()
cnt += 1
print(datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump")
return title_to_id, id_to_descr
def write_entity_files(entity_def_output, title_to_id):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
def write_entity_description_files(entity_descr_output, id_to_descr):
with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")

View File

@ -11,124 +11,84 @@ from https://dumps.wikimedia.org/enwiki/latest/
from __future__ import unicode_literals
import random
import datetime
import logging
from pathlib import Path
import plac
from bin.wiki_entity_linking import training_set_creator
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines
from bin.wiki_entity_linking.kb_creator import read_nlp_kb
import spacy
from spacy.kb import KnowledgeBase
from spacy.util import minibatch, compounding
def now():
return datetime.datetime.now()
logger = logging.getLogger(__name__)
@plac.annotations(
dir_kb=("Directory with KB, NLP and related files", "positional", None, Path),
output_dir=("Output directory", "option", "o", Path),
loc_training=("Location to training data", "option", "k", Path),
wp_xml=("Path to the downloaded Wikipedia XML dump.", "option", "w", Path),
epochs=("Number of training iterations (default 10)", "option", "e", int),
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
lr=("Learning rate (default 0.005)", "option", "n", float),
l2=("L2 regularization", "option", "r", float),
train_inst=("# training instances (default 90% of all)", "option", "t", int),
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
limit=("Optional threshold to limit lines read from WP dump", "option", "l", int),
)
def main(
dir_kb,
output_dir=None,
loc_training=None,
wp_xml=None,
epochs=10,
dropout=0.5,
lr=0.005,
l2=1e-6,
train_inst=None,
dev_inst=None,
limit=None,
):
print(now(), "Creating Entity Linker with Wikipedia and WikiData")
print()
logger.info("Creating Entity Linker with Wikipedia and WikiData")
output_dir = Path(output_dir) if output_dir else dir_kb
training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE
nlp_dir = dir_kb / KB_MODEL_DIR
kb_path = output_dir / KB_FILE
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
# STEP 0: set up IO
if output_dir and not output_dir.exists():
if not output_dir.exists():
output_dir.mkdir()
# STEP 1 : load the NLP object
nlp_dir = dir_kb / "nlp"
print(now(), "STEP 1: loading model from", nlp_dir)
nlp = spacy.load(nlp_dir)
logger.info("STEP 1: loading model from {}".format(nlp_dir))
nlp, kb = read_nlp_kb(nlp_dir, kb_path)
# check that there is a NER component in the pipeline
if "ner" not in nlp.pipe_names:
raise ValueError("The `nlp` object should have a pre-trained `ner` component.")
# STEP 2 : read the KB
print()
print(now(), "STEP 2: reading the KB from", dir_kb / "kb")
kb = KnowledgeBase(vocab=nlp.vocab)
kb.load_bulk(dir_kb / "kb")
# STEP 2: create a training dataset from WP
logger.info("STEP 2: reading training dataset from {}".format(training_path))
# STEP 3: create a training dataset from WP
print()
if loc_training:
print(now(), "STEP 3: reading training dataset from", loc_training)
else:
if not wp_xml:
raise ValueError(
"Either provide a path to a preprocessed training directory, "
"or to the original Wikipedia XML dump."
)
if output_dir:
loc_training = output_dir / "training_data"
else:
loc_training = dir_kb / "training_data"
if not loc_training.exists():
loc_training.mkdir()
print(now(), "STEP 3: creating training dataset at", loc_training)
if limit is not None:
print("Warning: reading only", limit, "lines of Wikipedia dump.")
loc_entity_defs = dir_kb / "entity_defs.csv"
training_set_creator.create_training(
wikipedia_input=wp_xml,
entity_def_input=loc_entity_defs,
training_output=loc_training,
limit=limit,
)
# STEP 4: parse the training data
print()
print(now(), "STEP 4: parse the training & evaluation data")
# for training, get pos & neg instances that correspond to entries in the kb
print("Parsing training data, limit =", train_inst)
train_data = training_set_creator.read_training(
nlp=nlp, training_dir=loc_training, dev=False, limit=train_inst, kb=kb
nlp=nlp,
entity_file_path=training_path,
dev=False,
limit=train_inst,
kb=kb,
)
print("Training on", len(train_data), "articles")
print()
print("Parsing dev testing data, limit =", dev_inst)
# for testing, get all pos instances, whether or not they are in the kb
dev_data = training_set_creator.read_training(
nlp=nlp, training_dir=loc_training, dev=True, limit=dev_inst, kb=None
nlp=nlp,
entity_file_path=training_path,
dev=True,
limit=dev_inst,
kb=kb,
)
print("Dev testing on", len(dev_data), "articles")
print()
# STEP 5: create and train the entity linking pipe
print()
print(now(), "STEP 5: training Entity Linking pipe")
# STEP 3: create and train the entity linking pipe
logger.info("STEP 3: training Entity Linking pipe")
el_pipe = nlp.create_pipe(
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
@ -142,275 +102,70 @@ def main(
optimizer.learn_rate = lr
optimizer.L2 = l2
if not train_data:
print("Did not find any training data")
else:
for itn in range(epochs):
random.shuffle(train_data)
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
batchnr = 0
logger.info("Training on {} articles".format(len(train_data)))
logger.info("Dev testing on {} articles".format(len(dev_data)))
with nlp.disable_pipes(*other_pipes):
for batch in batches:
try:
docs, golds = zip(*batch)
nlp.update(
docs=docs,
golds=golds,
sgd=optimizer,
drop=dropout,
losses=losses,
)
batchnr += 1
except Exception as e:
print("Error updating batch:", e)
if batchnr > 0:
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
losses["entity_linker"] = losses["entity_linker"] / batchnr
print(
"Epoch, train loss",
itn,
round(losses["entity_linker"], 2),
" / dev accuracy avg",
round(dev_acc_context, 3),
)
# STEP 6: measure the performance of our trained pipe on an independent dev set
print()
if len(dev_data):
print()
print(now(), "STEP 6: performance measurement of Entity Linking pipe")
print()
counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
dev_data, kb
)
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
print("dev accuracy oracle:", round(acc_o, 3), oracle_by_label)
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
print("dev accuracy random:", round(acc_r, 3), random_by_label)
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
print("dev accuracy prior:", round(acc_p, 3), prior_by_label)
# using only context
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = False
dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
print("dev accuracy context:", round(dev_acc_context, 3), context_by_label)
# measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
print("dev accuracy prior+context:", round(dev_acc_combo, 3), combo_by_label)
# STEP 7: apply the EL pipe on a toy example
print()
print(now(), "STEP 7: applying Entity Linking to toy example")
print()
run_el_toy_example(nlp=nlp)
# STEP 8: write the NLP pipeline (including entity linker) to file
if output_dir:
print()
nlp_loc = output_dir / "nlp"
print(now(), "STEP 8: Writing trained NLP to", nlp_loc)
nlp.to_disk(nlp_loc)
print()
print()
print(now(), "Done!")
def _measure_acc(data, el_pipe=None, error_analysis=False):
# If the docs in the data require further processing with an entity linker, set el_pipe
correct_by_label = dict()
incorrect_by_label = dict()
docs = [d for d, g in data if len(d) > 0]
if el_pipe is not None:
docs = list(el_pipe.pipe(docs))
golds = [g for d, g in data if len(d) > 0]
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
# only evaluating on positive examples
for gold_kb, value in kb_dict.items():
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
for ent in doc.ents:
ent_label = ent.label_
pred_entity = ent.kb_id_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
if gold_entity == pred_entity:
correct = correct_by_label.get(ent_label, 0)
correct_by_label[ent_label] = correct + 1
else:
incorrect = incorrect_by_label.get(ent_label, 0)
incorrect_by_label[ent_label] = incorrect + 1
if error_analysis:
print(ent.text, "in", doc)
print(
"Predicted",
pred_entity,
"should have been",
gold_entity,
)
print()
except Exception as e:
print("Error assessing accuracy", e)
acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
return acc, acc_by_label
def _measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
counts_d = dict()
random_correct_d = dict()
random_incorrect_d = dict()
oracle_correct_d = dict()
oracle_incorrect_d = dict()
prior_correct_d = dict()
prior_incorrect_d = dict()
docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0]
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
for gold_kb, value in kb_dict.items():
# only evaluating on positive examples
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
for ent in doc.ents:
label = ent.label_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
counts_d[label] = counts_d.get(label, 0) + 1
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
best_candidate = ""
random_candidate = ""
if candidates:
scores = []
for c in candidates:
scores.append(c.prior_prob)
if c.entity_ == gold_entity:
oracle_candidate = c.entity_
best_index = scores.index(max(scores))
best_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_
if gold_entity == best_candidate:
prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
else:
prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
if gold_entity == random_candidate:
random_correct_d[label] = random_correct_d.get(label, 0) + 1
else:
random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
if gold_entity == oracle_candidate:
oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
else:
oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
except Exception as e:
print("Error assessing accuracy", e)
acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
return (
counts_d,
acc_rand,
acc_rand_d,
acc_prior,
acc_prior_d,
acc_oracle,
acc_oracle_d,
dev_baseline_accuracies = measure_baselines(
dev_data, kb
)
logger.info("Dev Baseline Accuracies:")
logger.info(dev_baseline_accuracies.report_accuracy("random"))
logger.info(dev_baseline_accuracies.report_accuracy("prior"))
logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
def _offset(start, end):
return "{}_{}".format(start, end)
for itn in range(epochs):
random.shuffle(train_data)
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
batchnr = 0
with nlp.disable_pipes(*other_pipes):
for batch in batches:
try:
docs, golds = zip(*batch)
nlp.update(
docs=docs,
golds=golds,
sgd=optimizer,
drop=dropout,
losses=losses,
)
batchnr += 1
except Exception as e:
logger.error("Error updating batch:" + str(e))
if batchnr > 0:
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
measure_performance(dev_data, kb, el_pipe)
def calculate_acc(correct_by_label, incorrect_by_label):
acc_by_label = dict()
total_correct = 0
total_incorrect = 0
all_keys = set()
all_keys.update(correct_by_label.keys())
all_keys.update(incorrect_by_label.keys())
for label in sorted(all_keys):
correct = correct_by_label.get(label, 0)
incorrect = incorrect_by_label.get(label, 0)
total_correct += correct
total_incorrect += incorrect
if correct == incorrect == 0:
acc_by_label[label] = 0
else:
acc_by_label[label] = correct / (correct + incorrect)
acc = 0
if not (total_correct == total_incorrect == 0):
acc = total_correct / (total_correct + total_incorrect)
return acc, acc_by_label
# STEP 4: measure the performance of our trained pipe on an independent dev set
logger.info("STEP 4: performance measurement of Entity Linking pipe")
measure_performance(dev_data, kb, el_pipe)
# STEP 5: apply the EL pipe on a toy example
logger.info("STEP 5: applying Entity Linking to toy example")
run_el_toy_example(nlp=nlp)
if output_dir:
# STEP 6: write the NLP pipeline (including entity linker) to file
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
nlp.to_disk(nlp_output_dir)
logger.info("Done!")
def check_kb(kb):
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
candidates = kb.get_candidates(mention)
print("generating candidates for " + mention + " :")
logger.info("generating candidates for " + mention + " :")
for c in candidates:
print(
" ",
c.prior_prob,
logger.info(" ".join[
str(c.prior_prob),
c.alias_,
"-->",
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
)
print()
c.entity_ + " (freq=" + str(c.entity_freq) + ")"
])
def run_el_toy_example(nlp):
@ -421,11 +176,11 @@ def run_el_toy_example(nlp):
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
)
doc = nlp(text)
print(text)
logger.info(text)
for ent in doc.ents:
print(" ent", ent.text, ent.label_, ent.kb_id_)
print()
logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_]))
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
plac.call(main)

View File

@ -5,6 +5,9 @@ import re
import bz2
import csv
import datetime
import logging
from bin.wiki_entity_linking import LOG_FORMAT
"""
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
@ -13,6 +16,9 @@ Write these results to file for downstream KB and training data generation.
map_alias_to_link = dict()
logger = logging.getLogger(__name__)
# these will/should be matched ignoring case
wiki_namespaces = [
"b",
@ -116,10 +122,6 @@ for ns in wiki_namespaces:
ns_regex = re.compile(ns_regex, re.IGNORECASE)
def now():
return datetime.datetime.now()
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
"""
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
@ -131,7 +133,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
cnt = 0
while line and (not limit or cnt < limit):
if cnt % 25000000 == 0:
print(now(), "processed", cnt, "lines of Wikipedia XML dump")
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
clean_line = line.strip().decode("utf-8")
aliases, entities, normalizations = get_wp_links(clean_line)
@ -141,7 +143,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
line = file.readline()
cnt += 1
print(now(), "processed", cnt, "lines of Wikipedia XML dump")
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
# write all aliases and their entities and count occurrences to file
with prior_prob_output.open("w", encoding="utf8") as outputfile:

View File

@ -140,7 +140,7 @@ def gzip_language_data(root, source):
base = Path(root) / source
for jsonfile in base.glob("**/*.json"):
outfile = jsonfile.with_suffix(jsonfile.suffix + ".gz")
if outfile.is_file() and outfile.stat().st_ctime > jsonfile.stat().st_ctime:
if outfile.is_file() and outfile.stat().st_mtime > jsonfile.stat().st_mtime:
# If the gz is newer it doesn't need updating
print("Skipping {}, already compressed".format(jsonfile))
continue

View File

@ -6,7 +6,7 @@ import re
from ...gold import iob_to_biluo
def conllu2json(input_data, n_sents=10, use_morphology=False, lang=None):
def conllu2json(input_data, n_sents=10, use_morphology=False, lang=None, **_):
"""
Convert conllu files into JSON format for use with train cli.
use_morphology parameter enables appending morphology to tags, which is

View File

@ -455,7 +455,9 @@ class Errors(object):
E158 = ("Can't add table '{name}' to lookups because it already exists.")
E159 = ("Can't find table '{name}' in lookups. Available tables: {tables}")
E160 = ("Can't find language data file: {path}")
E161 = ("Tokenizer special cases are not allowed to modify the text. "
E161 = ("Found an internal inconsistency when predicting entity links. "
"This is likely a bug in spaCy, so feel free to open an issue.")
E162 = ("Tokenizer special cases are not allowed to modify the text. "
"This would map '{chunk}' to '{orth}' given token attributes "
"'{token_attrs}'.")

View File

@ -11,6 +11,12 @@ _hebrew = r"\u0591-\u05F4\uFB1D-\uFB4F"
_hindi = r"\u0900-\u097F"
_kannada = r"\u0C80-\u0CFF"
_tamil = r"\u0B80-\u0BFF"
_telugu = r"\u0C00-\u0C7F"
# Latin standard
_latin_u_standard = r"A-Z"
_latin_l_standard = r"a-z"
@ -195,7 +201,7 @@ _ukrainian = r"а-щюяіїєґА-ЩЮЯІЇЄҐ"
_upper = LATIN_UPPER + _russian_upper + _tatar_upper + _greek_upper + _ukrainian_upper
_lower = LATIN_LOWER + _russian_lower + _tatar_lower + _greek_lower + _ukrainian_lower
_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi
_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi + _kannada + _tamil + _telugu
ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased)
ALPHA_LOWER = group_chars(_lower + _uncased)

View File

@ -208,6 +208,7 @@ class Language(object):
"name": self.vocab.vectors.name,
}
self._meta["pipeline"] = self.pipe_names
self._meta["labels"] = self.pipe_labels
return self._meta
@meta.setter
@ -248,6 +249,18 @@ class Language(object):
"""
return [pipe_name for pipe_name, _ in self.pipeline]
@property
def pipe_labels(self):
"""Get the labels set by the pipeline components, if available.
RETURNS (dict): Labels keyed by component name.
"""
labels = OrderedDict()
for name, pipe in self.pipeline:
if hasattr(pipe, "labels"):
labels[name] = list(pipe.labels)
return labels
def get_pipe(self, name):
"""Get a pipeline component for a given component name.

View File

@ -67,7 +67,7 @@ class Pipe(object):
"""
self.require_model()
predictions = self.predict([doc])
if isinstance(predictions, tuple) and len(tuple) == 2:
if isinstance(predictions, tuple) and len(predictions) == 2:
scores, tensors = predictions
self.set_annotations([doc], scores, tensor=tensors)
else:
@ -1062,8 +1062,15 @@ cdef class DependencyParser(Parser):
@property
def labels(self):
labels = set()
# Get the labels from the model by looking at the available moves
return tuple(set(move.split("-")[1] for move in self.move_names))
for move in self.move_names:
if "-" in move:
label = move.split("-")[1]
if "||" in label:
label = label.split("||")[1]
labels.add(label)
return tuple(sorted(labels))
cdef class EntityRecognizer(Parser):
@ -1098,8 +1105,9 @@ cdef class EntityRecognizer(Parser):
def labels(self):
# Get the labels from the model by looking at the available moves, e.g.
# B-PERSON, I-PERSON, L-PERSON, U-PERSON
return tuple(set(move.split("-")[1] for move in self.move_names
if move[0] in ("B", "I", "L", "U")))
labels = set(move.split("-")[1] for move in self.move_names
if move[0] in ("B", "I", "L", "U"))
return tuple(sorted(labels))
class EntityLinker(Pipe):
@ -1275,7 +1283,7 @@ class EntityLinker(Pipe):
# this will set all prior probabilities to 0 if they should be excluded from the model
prior_probs = xp.asarray([c.prior_prob for c in candidates])
if not self.cfg.get("incl_prior", True):
prior_probs = xp.asarray([[0.0] for c in candidates])
prior_probs = xp.asarray([0.0 for c in candidates])
scores = prior_probs
# add in similarity from the context
@ -1288,6 +1296,8 @@ class EntityLinker(Pipe):
# cosine similarity
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
if sims.shape != prior_probs.shape:
raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs*sims)
# TODO: thresholding
@ -1361,7 +1371,16 @@ class Sentencizer(object):
"""
name = "sentencizer"
default_punct_chars = [".", "!", "?"]
default_punct_chars = ['!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '᱿',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '𐩖', '𐩗', '𑁇', '𑁈', '𑂾', '𑂿', '𑃀',
'𑃁', '𑅁', '𑅂', '𑅃', '𑇅', '𑇆', '𑇍', '𑇞', '𑇟', '𑈸', '𑈹', '𑈻', '𑈼',
'𑊩', '𑑋', '𑑌', '𑗂', '𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐',
'𑗑', '𑗒', '𑗓', '𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂',
'𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈']
def __init__(self, punct_chars=None, **kwargs):
"""Initialize the sentencizer.
@ -1372,7 +1391,10 @@ class Sentencizer(object):
DOCS: https://spacy.io/api/sentencizer#init
"""
self.punct_chars = punct_chars or self.default_punct_chars
if punct_chars:
self.punct_chars = set(punct_chars)
else:
self.punct_chars = set(self.default_punct_chars)
def __call__(self, doc):
"""Apply the sentencizer to a Doc and set Token.is_sent_start.
@ -1404,7 +1426,7 @@ class Sentencizer(object):
DOCS: https://spacy.io/api/sentencizer#to_bytes
"""
return srsly.msgpack_dumps({"punct_chars": self.punct_chars})
return srsly.msgpack_dumps({"punct_chars": list(self.punct_chars)})
def from_bytes(self, bytes_data, **kwargs):
"""Load the sentencizer from a bytestring.
@ -1415,7 +1437,7 @@ class Sentencizer(object):
DOCS: https://spacy.io/api/sentencizer#from_bytes
"""
cfg = srsly.msgpack_loads(bytes_data)
self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
return self
def to_disk(self, path, exclude=tuple(), **kwargs):
@ -1425,7 +1447,7 @@ class Sentencizer(object):
"""
path = util.ensure_path(path)
path = path.with_suffix(".json")
srsly.write_json(path, {"punct_chars": self.punct_chars})
srsly.write_json(path, {"punct_chars": list(self.punct_chars)})
def from_disk(self, path, exclude=tuple(), **kwargs):
@ -1436,7 +1458,7 @@ class Sentencizer(object):
path = util.ensure_path(path)
path = path.with_suffix(".json")
cfg = srsly.read_json(path)
self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
return self

View File

@ -173,6 +173,21 @@ def test_span_as_doc(doc):
assert span_doc[0].idx == 0
def test_span_as_doc_user_data(doc):
"""Test that the user_data can be preserved (but not by default). """
my_key = "my_info"
my_value = 342
doc.user_data[my_key] = my_value
span = doc[4:10]
span_doc_with = span.as_doc(copy_user_data=True)
span_doc_without = span.as_doc()
assert doc.user_data.get(my_key, None) is my_value
assert span_doc_with.user_data.get(my_key, None) is my_value
assert span_doc_without.user_data.get(my_key, None) is None
def test_span_string_label_kb_id(doc):
span = Span(doc, 0, 1, label="hello", kb_id="Q342")
assert span.label_ == "hello"

View File

@ -68,3 +68,20 @@ def test_add_label_deserializes_correctly():
assert ner1.moves.n_moves == ner2.moves.n_moves
for i in range(ner1.moves.n_moves):
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)
@pytest.mark.parametrize(
"pipe_cls,n_moves", [(DependencyParser, 5), (EntityRecognizer, 4)]
)
def test_add_label_get_label(pipe_cls, n_moves):
"""Test that added labels are returned correctly. This test was added to
test for a bug in DependencyParser.labels that'd cause it to fail when
splitting the move names.
"""
labels = ["A", "B", "C"]
pipe = pipe_cls(Vocab())
for label in labels:
pipe.add_label(label)
assert len(pipe.move_names) == len(labels) * n_moves
pipe_labels = sorted(list(pipe.labels))
assert pipe_labels == labels

View File

@ -128,3 +128,19 @@ def test_pipe_base_class_add_label(nlp, component):
assert label in pipe.labels
else:
assert pipe.labels == (label,)
def test_pipe_labels(nlp):
input_labels = {
"ner": ["PERSON", "ORG", "GPE"],
"textcat": ["POSITIVE", "NEGATIVE"],
}
for name, labels in input_labels.items():
pipe = nlp.create_pipe(name)
for label in labels:
pipe.add_label(label)
assert len(pipe.labels) == len(labels)
nlp.add_pipe(pipe)
assert len(nlp.pipe_labels) == len(input_labels)
for name, labels in nlp.pipe_labels.items():
assert sorted(input_labels[name]) == sorted(labels)

View File

@ -81,7 +81,7 @@ def test_sentencizer_custom_punct(en_vocab, punct_chars, words, sent_starts, n_s
def test_sentencizer_serialize_bytes(en_vocab):
punct_chars = [".", "~", "+"]
sentencizer = Sentencizer(punct_chars=punct_chars)
assert sentencizer.punct_chars == punct_chars
assert sentencizer.punct_chars == set(punct_chars)
bytes_data = sentencizer.to_bytes()
new_sentencizer = Sentencizer().from_bytes(bytes_data)
assert new_sentencizer.punct_chars == punct_chars
assert new_sentencizer.punct_chars == set(punct_chars)

View File

@ -0,0 +1,28 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.language import Language
from spacy.pipeline import Pipe
class DummyPipe(Pipe):
def __init__(self):
self.model = "dummy_model"
def predict(self, docs):
return ([1, 2, 3], [4, 5, 6])
def set_annotations(self, docs, scores, tensor=None):
return docs
@pytest.fixture
def nlp():
return Language()
def test_multiple_predictions(nlp):
doc = nlp.make_doc("foo")
dummy_pipe = DummyPipe()
dummy_pipe(doc)

View File

@ -495,7 +495,7 @@ cdef class Tokenizer:
attrs = [intify_attrs(spec, _do_deprecated=True) for spec in substrings]
orth = "".join([spec[ORTH] for spec in attrs])
if chunk != orth:
raise ValueError(Errors.E161.format(chunk=chunk, orth=orth, token_attrs=substrings))
raise ValueError(Errors.E162.format(chunk=chunk, orth=orth, token_attrs=substrings))
def add_special_case(self, unicode string, substrings):
"""Add a special-case tokenization rule.

View File

@ -200,13 +200,15 @@ cdef class Span:
return Underscore(Underscore.span_extensions, self,
start=self.start_char, end=self.end_char)
def as_doc(self):
def as_doc(self, bint copy_user_data=False):
"""Create a `Doc` object with a copy of the `Span`'s data.
copy_user_data (bool): Whether or not to copy the original doc's user data.
RETURNS (Doc): The `Doc` copy of the span.
DOCS: https://spacy.io/api/span#as_doc
"""
# TODO: make copy_user_data a keyword-only argument (Python 3 only)
words = [t.text for t in self]
spaces = [bool(t.whitespace_) for t in self]
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
@ -235,6 +237,8 @@ cdef class Span:
cat_start, cat_end, cat_label = key
if cat_start == self.start_char and cat_end == self.end_char:
doc.cats[cat_label] = value
if copy_user_data:
doc.user_data = self.doc.user_data
return doc
def _fix_dep_copy(self, attrs, array):

View File

@ -292,9 +292,10 @@ Create a new `Doc` object corresponding to the `Span`, with a copy of the data.
> assert doc2.text == u"New York"
> ```
| Name | Type | Description |
| ----------- | ----- | --------------------------------------- |
| **RETURNS** | `Doc` | A `Doc` object of the `Span`'s content. |
| Name | Type | Description |
| ----------------- | ----- | ---------------------------------------------------- |
| `copy_user_data` | bool | Whether or not to copy the original doc's user data. |
| **RETURNS** | `Doc` | A `Doc` object of the `Span`'s content. |
## Span.root {#root tag="property" model="parser"}

View File

@ -107,7 +107,7 @@ process.
<Infobox>
**Usage:** [Models directory](/models) [Benchmarks](#benchmarks)
**Usage:** [Models directory](/models)
</Infobox>

View File

@ -10,10 +10,7 @@
"modelsRepo": "explosion/spacy-models",
"social": {
"twitter": "spacy_io",
"github": "explosion",
"reddit": "spacynlp",
"codepen": "explosion",
"gitter": "explosion/spaCy"
"github": "explosion"
},
"theme": "#09a3d5",
"analytics": "UA-58931649-1",
@ -69,6 +66,7 @@
"items": [
{ "text": "Twitter", "url": "https://twitter.com/spacy_io" },
{ "text": "GitHub", "url": "https://github.com/explosion/spaCy" },
{ "text": "YouTube", "url": "https://youtube.com/c/ExplosionAI" },
{ "text": "Blog", "url": "https://explosion.ai/blog" }
]
}