diff --git a/.github/contributors/tamuhey.md b/.github/contributors/tamuhey.md
new file mode 100644
index 000000000..6a63e3b53
--- /dev/null
+++ b/.github/contributors/tamuhey.md
@@ -0,0 +1,106 @@
+# spaCy contributor agreement
+
+This spaCy Contributor Agreement (**"SCA"**) is based on the
+[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf).
+The SCA applies to any contribution that you make to any product or project
+managed by us (the **"project"**), and sets out the intellectual property rights
+you grant to us in the contributed materials. The term **"us"** shall mean
+[ExplosionAI GmbH](https://explosion.ai/legal). The term
+**"you"** shall mean the person or entity identified below.
+
+If you agree to be bound by these terms, fill in the information requested
+below and include the filled-in version with your first pull request, under the
+folder [`.github/contributors/`](/.github/contributors/). The name of the file
+should be your GitHub username, with the extension `.md`. For example, the user
+example_user would create the file `.github/contributors/example_user.md`.
+
+Read this agreement carefully before signing. These terms and conditions
+constitute a binding legal agreement.
+
+## Contributor Agreement
+
+1. The term "contribution" or "contributed materials" means any source code,
+object code, patch, tool, sample, graphic, specification, manual,
+documentation, or any other material posted or submitted by you to the project.
+
+2. With respect to any worldwide copyrights, or copyright applications and
+registrations, in your contribution:
+
+ * you hereby assign to us joint ownership, and to the extent that such
+ assignment is or becomes invalid, ineffective or unenforceable, you hereby
+ grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge,
+ royalty-free, unrestricted license to exercise all rights under those
+ copyrights. This includes, at our option, the right to sublicense these same
+ rights to third parties through multiple levels of sublicensees or other
+ licensing arrangements;
+
+ * you agree that each of us can do all things in relation to your
+ contribution as if each of us were the sole owners, and if one of us makes
+ a derivative work of your contribution, the one who makes the derivative
+ work (or has it made will be the sole owner of that derivative work;
+
+ * you agree that you will not assert any moral rights in your contribution
+ against us, our licensees or transferees;
+
+ * you agree that we may register a copyright in your contribution and
+ exercise all ownership rights associated with it; and
+
+ * you agree that neither of us has any duty to consult with, obtain the
+ consent of, pay or render an accounting to the other for any use or
+ distribution of your contribution.
+
+3. With respect to any patents you own, or that you can license without payment
+to any third party, you hereby grant to us a perpetual, irrevocable,
+non-exclusive, worldwide, no-charge, royalty-free license to:
+
+ * make, have made, use, sell, offer to sell, import, and otherwise transfer
+ your contribution in whole or in part, alone or in combination with or
+ included in any product, work or materials arising out of the project to
+ which your contribution was submitted, and
+
+ * at our option, to sublicense these same rights to third parties through
+ multiple levels of sublicensees or other licensing arrangements.
+
+4. Except as set out above, you keep all right, title, and interest in your
+contribution. The rights that you grant to us under these terms are effective
+on the date you first submitted a contribution to us, even if your submission
+took place before the date you sign these terms.
+
+5. You covenant, represent, warrant and agree that:
+
+ * Each contribution that you submit is and shall be an original work of
+ authorship and you can legally grant the rights set out in this SCA;
+
+ * to the best of your knowledge, each contribution will not violate any
+ third party's copyrights, trademarks, patents, or other intellectual
+ property rights; and
+
+ * each contribution shall be in compliance with U.S. export control laws and
+ other applicable export and import laws. You agree to notify us if you
+ become aware of any circumstance which would make any of the foregoing
+ representations inaccurate in any respect. We may publicly disclose your
+ participation in the project, including the fact that you have signed the SCA.
+
+6. This SCA is governed by the laws of the State of California and applicable
+U.S. Federal law. Any choice of law rules will not apply.
+
+7. Please place an “x” on one of the applicable statement below. Please do NOT
+mark both statements:
+
+ * [x] I am signing on behalf of myself as an individual and no other person
+ or entity, including my employer, has or will have rights with respect to my
+ contributions.
+
+ * [ ] I am signing on behalf of my employer or a legal entity and I have the
+ actual authority to contractually bind that entity.
+
+## Contributor Details
+
+| Field | Entry |
+|------------------------------- | -------------------- |
+| Name | Yohei Tamura |
+| Company name (if applicable) | PKSHA |
+| Title or role (if applicable) | |
+| Date | 2019/9/12 |
+| GitHub username | tamuhey |
+| Website (optional) | |
diff --git a/bin/ud/run_eval.py b/bin/ud/run_eval.py
index 171687980..2da476721 100644
--- a/bin/ud/run_eval.py
+++ b/bin/ud/run_eval.py
@@ -7,14 +7,16 @@ import datetime
from pathlib import Path
import xml.etree.ElementTree as ET
-from spacy.cli.ud import conll17_ud_eval
-from spacy.cli.ud.ud_train import write_conllu
+import conll17_ud_eval
+from ud_train import write_conllu
from spacy.lang.lex_attrs import word_shape
from spacy.util import get_lang_class
# All languages in spaCy - in UD format (note that Norwegian is 'no' instead of 'nb')
-ALL_LANGUAGES = "ar, ca, da, de, el, en, es, fa, fi, fr, ga, he, hi, hr, hu, id, " \
- "it, ja, no, nl, pl, pt, ro, ru, sv, tr, ur, vi, zh"
+ALL_LANGUAGES = ("af, ar, bg, bn, ca, cs, da, de, el, en, es, et, fa, fi, fr,"
+ "ga, he, hi, hr, hu, id, is, it, ja, kn, ko, lt, lv, mr, no,"
+ "nl, pl, pt, ro, ru, si, sk, sl, sq, sr, sv, ta, te, th, tl,"
+ "tr, tt, uk, ur, vi, zh")
# Non-parsing tasks that will be evaluated (works for default models)
EVAL_NO_PARSE = ['Tokens', 'Words', 'Lemmas', 'Sentences', 'Feats']
@@ -73,10 +75,10 @@ def _contains_blinded_text(stats_xml):
tree = ET.parse(stats_xml)
root = tree.getroot()
total_tokens = int(root.find('size/total/tokens').text)
- unique_lemmas = int(root.find('lemmas').get('unique'))
+ unique_forms = int(root.find('forms').get('unique'))
# assume the corpus is largely blinded when there are less than 1% unique tokens
- return (unique_lemmas / total_tokens) < 0.01
+ return (unique_forms / total_tokens) < 0.01
def fetch_all_treebanks(ud_dir, languages, corpus, best_per_language):
@@ -262,22 +264,26 @@ def main(out_path, ud_dir, check_parse=False, langs=ALL_LANGUAGES, exclude_train
if not exclude_trained_models:
if 'de' in models:
models['de'].append(load_model('de_core_news_sm'))
- if 'es' in models:
- models['es'].append(load_model('es_core_news_sm'))
- models['es'].append(load_model('es_core_news_md'))
- if 'pt' in models:
- models['pt'].append(load_model('pt_core_news_sm'))
- if 'it' in models:
- models['it'].append(load_model('it_core_news_sm'))
- if 'nl' in models:
- models['nl'].append(load_model('nl_core_news_sm'))
+ models['de'].append(load_model('de_core_news_md'))
+ if 'el' in models:
+ models['el'].append(load_model('el_core_news_sm'))
+ models['el'].append(load_model('el_core_news_md'))
if 'en' in models:
models['en'].append(load_model('en_core_web_sm'))
models['en'].append(load_model('en_core_web_md'))
models['en'].append(load_model('en_core_web_lg'))
+ if 'es' in models:
+ models['es'].append(load_model('es_core_news_sm'))
+ models['es'].append(load_model('es_core_news_md'))
if 'fr' in models:
models['fr'].append(load_model('fr_core_news_sm'))
models['fr'].append(load_model('fr_core_news_md'))
+ if 'it' in models:
+ models['it'].append(load_model('it_core_news_sm'))
+ if 'nl' in models:
+ models['nl'].append(load_model('nl_core_news_sm'))
+ if 'pt' in models:
+ models['pt'].append(load_model('pt_core_news_sm'))
with out_path.open(mode='w', encoding='utf-8') as out_file:
run_all_evals(models, treebanks, out_file, check_parse, print_freq_tasks)
diff --git a/bin/ud/ud_run_test.py b/bin/ud/ud_run_test.py
index 1c529c831..de01cf350 100644
--- a/bin/ud/ud_run_test.py
+++ b/bin/ud/ud_run_test.py
@@ -109,15 +109,13 @@ def write_conllu(docs, file_):
merger = Matcher(docs[0].vocab)
merger.add("SUBTOK", None, [{"DEP": "subtok", "op": "+"}])
for i, doc in enumerate(docs):
- matches = merger(doc)
+ matches = []
+ if doc.is_parsed:
+ matches = merger(doc)
spans = [doc[start : end + 1] for _, start, end in matches]
with doc.retokenize() as retokenizer:
for span in spans:
retokenizer.merge(span)
- # TODO: This shouldn't be necessary? Should be handled in merge
- for word in doc:
- if word.i == word.head.i:
- word.dep_ = "ROOT"
file_.write("# newdoc id = {i}\n".format(i=i))
for j, sent in enumerate(doc.sents):
file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j))
diff --git a/bin/ud/ud_train.py b/bin/ud/ud_train.py
index 8f699db4f..e7fcdf871 100644
--- a/bin/ud/ud_train.py
+++ b/bin/ud/ud_train.py
@@ -25,7 +25,7 @@ import itertools
import random
import numpy.random
-from . import conll17_ud_eval
+import conll17_ud_eval
from spacy import lang
from spacy.lang import zh
@@ -214,7 +214,9 @@ def write_conllu(docs, file_):
merger = Matcher(docs[0].vocab)
merger.add("SUBTOK", None, [{"DEP": "subtok", "op": "+"}])
for i, doc in enumerate(docs):
- matches = merger(doc)
+ matches = []
+ if doc.is_parsed:
+ matches = merger(doc)
spans = [doc[start : end + 1] for _, start, end in matches]
with doc.retokenize() as retokenizer:
for span in spans:
@@ -298,9 +300,9 @@ def get_token_conllu(token, i):
return "\n".join(lines)
-Token.set_extension("get_conllu_lines", method=get_token_conllu)
-Token.set_extension("begins_fused", default=False)
-Token.set_extension("inside_fused", default=False)
+Token.set_extension("get_conllu_lines", method=get_token_conllu, force=True)
+Token.set_extension("begins_fused", default=False, force=True)
+Token.set_extension("inside_fused", default=False, force=True)
##################
diff --git a/bin/wiki_entity_linking/__init__.py b/bin/wiki_entity_linking/__init__.py
index e69de29bb..a604bcc2f 100644
--- a/bin/wiki_entity_linking/__init__.py
+++ b/bin/wiki_entity_linking/__init__.py
@@ -0,0 +1,11 @@
+TRAINING_DATA_FILE = "gold_entities.jsonl"
+KB_FILE = "kb"
+KB_MODEL_DIR = "nlp_kb"
+OUTPUT_MODEL_DIR = "nlp"
+
+PRIOR_PROB_PATH = "prior_prob.csv"
+ENTITY_DEFS_PATH = "entity_defs.csv"
+ENTITY_FREQ_PATH = "entity_freq.csv"
+ENTITY_DESCR_PATH = "entity_descriptions.csv"
+
+LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
diff --git a/bin/wiki_entity_linking/entity_linker_evaluation.py b/bin/wiki_entity_linking/entity_linker_evaluation.py
new file mode 100644
index 000000000..1b1200564
--- /dev/null
+++ b/bin/wiki_entity_linking/entity_linker_evaluation.py
@@ -0,0 +1,200 @@
+import logging
+import random
+
+from collections import defaultdict
+
+logger = logging.getLogger(__name__)
+
+
+class Metrics(object):
+ true_pos = 0
+ false_pos = 0
+ false_neg = 0
+
+ def update_results(self, true_entity, candidate):
+ candidate_is_correct = true_entity == candidate
+
+ # Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
+ # Therefore, if candidate_is_correct then we have a true positive and never a true negative
+ self.true_pos += candidate_is_correct
+ self.false_neg += not candidate_is_correct
+ if candidate not in {"", "NIL"}:
+ self.false_pos += not candidate_is_correct
+
+ def calculate_precision(self):
+ if self.true_pos == 0:
+ return 0.0
+ else:
+ return self.true_pos / (self.true_pos + self.false_pos)
+
+ def calculate_recall(self):
+ if self.true_pos == 0:
+ return 0.0
+ else:
+ return self.true_pos / (self.true_pos + self.false_neg)
+
+
+class EvaluationResults(object):
+ def __init__(self):
+ self.metrics = Metrics()
+ self.metrics_by_label = defaultdict(Metrics)
+
+ def update_metrics(self, ent_label, true_entity, candidate):
+ self.metrics.update_results(true_entity, candidate)
+ self.metrics_by_label[ent_label].update_results(true_entity, candidate)
+
+ def increment_false_negatives(self):
+ self.metrics.false_neg += 1
+
+ def report_metrics(self, model_name):
+ model_str = model_name.title()
+ recall = self.metrics.calculate_recall()
+ precision = self.metrics.calculate_precision()
+ return ("{}: ".format(model_str) +
+ "Recall = {} | ".format(round(recall, 3)) +
+ "Precision = {} | ".format(round(precision, 3)) +
+ "Precision by label = {}".format({k: v.calculate_precision()
+ for k, v in self.metrics_by_label.items()}))
+
+
+class BaselineResults(object):
+ def __init__(self):
+ self.random = EvaluationResults()
+ self.prior = EvaluationResults()
+ self.oracle = EvaluationResults()
+
+ def report_accuracy(self, model):
+ results = getattr(self, model)
+ return results.report_metrics(model)
+
+ def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate):
+ self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
+ self.prior.update_metrics(ent_label, true_entity, prior_candidate)
+ self.random.update_metrics(ent_label, true_entity, random_candidate)
+
+
+def measure_performance(dev_data, kb, el_pipe):
+ baseline_accuracies = measure_baselines(
+ dev_data, kb
+ )
+
+ logger.info(baseline_accuracies.report_accuracy("random"))
+ logger.info(baseline_accuracies.report_accuracy("prior"))
+ logger.info(baseline_accuracies.report_accuracy("oracle"))
+
+ # using only context
+ el_pipe.cfg["incl_context"] = True
+ el_pipe.cfg["incl_prior"] = False
+ results = get_eval_results(dev_data, el_pipe)
+ logger.info(results.report_metrics("context only"))
+
+ # measuring combined accuracy (prior + context)
+ el_pipe.cfg["incl_context"] = True
+ el_pipe.cfg["incl_prior"] = True
+ results = get_eval_results(dev_data, el_pipe)
+ logger.info(results.report_metrics("context and prior"))
+
+
+def get_eval_results(data, el_pipe=None):
+ # If the docs in the data require further processing with an entity linker, set el_pipe
+ from tqdm import tqdm
+
+ docs = []
+ golds = []
+ for d, g in tqdm(data, leave=False):
+ if len(d) > 0:
+ golds.append(g)
+ if el_pipe is not None:
+ docs.append(el_pipe(d))
+ else:
+ docs.append(d)
+
+ results = EvaluationResults()
+ for doc, gold in zip(docs, golds):
+ tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
+ try:
+ correct_entries_per_article = dict()
+ for entity, kb_dict in gold.links.items():
+ start, end = entity
+ # only evaluating on positive examples
+ for gold_kb, value in kb_dict.items():
+ if value:
+ offset = _offset(start, end)
+ correct_entries_per_article[offset] = gold_kb
+ if offset not in tagged_entries_per_article:
+ results.increment_false_negatives()
+
+ for ent in doc.ents:
+ ent_label = ent.label_
+ pred_entity = ent.kb_id_
+ start = ent.start_char
+ end = ent.end_char
+ offset = _offset(start, end)
+ gold_entity = correct_entries_per_article.get(offset, None)
+ # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
+ if gold_entity is not None:
+ results.update_metrics(ent_label, gold_entity, pred_entity)
+
+ except Exception as e:
+ logging.error("Error assessing accuracy " + str(e))
+
+ return results
+
+
+def measure_baselines(data, kb):
+ # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
+ counts_d = dict()
+
+ baseline_results = BaselineResults()
+
+ docs = [d for d, g in data if len(d) > 0]
+ golds = [g for d, g in data if len(d) > 0]
+
+ for doc, gold in zip(docs, golds):
+ correct_entries_per_article = dict()
+ tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
+ for entity, kb_dict in gold.links.items():
+ start, end = entity
+ for gold_kb, value in kb_dict.items():
+ # only evaluating on positive examples
+ if value:
+ offset = _offset(start, end)
+ correct_entries_per_article[offset] = gold_kb
+ if offset not in tagged_entries_per_article:
+ baseline_results.random.increment_false_negatives()
+ baseline_results.oracle.increment_false_negatives()
+ baseline_results.prior.increment_false_negatives()
+
+ for ent in doc.ents:
+ ent_label = ent.label_
+ start = ent.start_char
+ end = ent.end_char
+ offset = _offset(start, end)
+ gold_entity = correct_entries_per_article.get(offset, None)
+
+ # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
+ if gold_entity is not None:
+ candidates = kb.get_candidates(ent.text)
+ oracle_candidate = ""
+ best_candidate = ""
+ random_candidate = ""
+ if candidates:
+ scores = []
+
+ for c in candidates:
+ scores.append(c.prior_prob)
+ if c.entity_ == gold_entity:
+ oracle_candidate = c.entity_
+
+ best_index = scores.index(max(scores))
+ best_candidate = candidates[best_index].entity_
+ random_candidate = random.choice(candidates).entity_
+
+ baseline_results.update_baselines(gold_entity, ent_label,
+ random_candidate, best_candidate, oracle_candidate)
+
+ return baseline_results
+
+
+def _offset(start, end):
+ return "{}_{}".format(start, end)
diff --git a/bin/wiki_entity_linking/kb_creator.py b/bin/wiki_entity_linking/kb_creator.py
index bd862f536..54ed7815e 100644
--- a/bin/wiki_entity_linking/kb_creator.py
+++ b/bin/wiki_entity_linking/kb_creator.py
@@ -1,12 +1,20 @@
# coding: utf-8
from __future__ import unicode_literals
-from bin.wiki_entity_linking.train_descriptions import EntityEncoder
-from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp
+import csv
+import logging
+import spacy
+import sys
+
from spacy.kb import KnowledgeBase
-import csv
-import datetime
+from bin.wiki_entity_linking import wikipedia_processor as wp
+from bin.wiki_entity_linking.train_descriptions import EntityEncoder
+
+csv.field_size_limit(sys.maxsize)
+
+
+logger = logging.getLogger(__name__)
def create_kb(
@@ -14,52 +22,73 @@ def create_kb(
max_entities_per_alias,
min_entity_freq,
min_occ,
- entity_def_output,
- entity_descr_output,
+ entity_def_input,
+ entity_descr_path,
count_input,
prior_prob_input,
- wikidata_input,
entity_vector_length,
- limit=None,
- read_raw_data=True,
):
# Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
+ # read the mappings from file
+ title_to_id = get_entity_to_id(entity_def_input)
+ id_to_descr = get_id_to_description(entity_descr_path)
+
# check the length of the nlp vectors
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
input_dim = nlp.vocab.vectors_length
- print("Loaded pre-trained vectors of size %s" % input_dim)
+ logger.info("Loaded pre-trained vectors of size %s" % input_dim)
else:
raise ValueError(
"The `nlp` object should have access to pre-trained word vectors, "
" cf. https://spacy.io/usage/models#languages."
)
- # disable this part of the pipeline when rerunning the KB generation from preprocessed files
- if read_raw_data:
- print()
- print(now(), " * read wikidata entities:")
- title_to_id, id_to_descr = wd.read_wikidata_entities_json(
- wikidata_input, limit=limit
- )
-
- # write the title-ID and ID-description mappings to file
- _write_entity_files(
- entity_def_output, entity_descr_output, title_to_id, id_to_descr
- )
-
- else:
- # read the mappings from file
- title_to_id = get_entity_to_id(entity_def_output)
- id_to_descr = get_id_to_description(entity_descr_output)
-
- print()
- print(now(), " * get entity frequencies:")
- print()
+ logger.info("Get entity frequencies")
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
+ logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
+ filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
+ title_to_id,
+ id_to_descr,
+ entity_frequencies,
+ min_entity_freq
+ )
+ logger.info("Left with {} entities".format(len(description_list)))
+
+ logger.info("Train entity encoder")
+ encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
+ encoder.train(description_list=description_list, to_print=True)
+
+ logger.info("Get entity embeddings:")
+ embeddings = encoder.apply_encoder(description_list)
+
+ logger.info("Adding {} entities".format(len(entity_list)))
+ kb.set_entities(
+ entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
+ )
+
+ logger.info("Adding aliases")
+ _add_aliases(
+ kb,
+ title_to_id=filtered_title_to_id,
+ max_entities_per_alias=max_entities_per_alias,
+ min_occ=min_occ,
+ prior_prob_input=prior_prob_input,
+ )
+
+ logger.info("KB size: {} entities, {} aliases".format(
+ kb.get_size_entities(),
+ kb.get_size_aliases()))
+
+ logger.info("Done with kb")
+ return kb
+
+
+def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
+ min_entity_freq: int = 10):
filtered_title_to_id = dict()
entity_list = []
description_list = []
@@ -72,58 +101,7 @@ def create_kb(
description_list.append(desc)
frequency_list.append(freq)
filtered_title_to_id[title] = entity
-
- print(len(title_to_id.keys()), "original titles")
- kept_nr = len(filtered_title_to_id.keys())
- print("kept", kept_nr, "entities with min. frequency", min_entity_freq)
-
- print()
- print(now(), " * train entity encoder:")
- print()
- encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
- encoder.train(description_list=description_list, to_print=True)
-
- print()
- print(now(), " * get entity embeddings:")
- print()
- embeddings = encoder.apply_encoder(description_list)
-
- print(now(), " * adding", len(entity_list), "entities")
- kb.set_entities(
- entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
- )
-
- alias_cnt = _add_aliases(
- kb,
- title_to_id=filtered_title_to_id,
- max_entities_per_alias=max_entities_per_alias,
- min_occ=min_occ,
- prior_prob_input=prior_prob_input,
- )
- print()
- print(now(), " * adding", alias_cnt, "aliases")
- print()
-
- print()
- print("# of entities in kb:", kb.get_size_entities())
- print("# of aliases in kb:", kb.get_size_aliases())
-
- print(now(), "Done with kb")
- return kb
-
-
-def _write_entity_files(
- entity_def_output, entity_descr_output, title_to_id, id_to_descr
-):
- with entity_def_output.open("w", encoding="utf8") as id_file:
- id_file.write("WP_title" + "|" + "WD_id" + "\n")
- for title, qid in title_to_id.items():
- id_file.write(title + "|" + str(qid) + "\n")
-
- with entity_descr_output.open("w", encoding="utf8") as descr_file:
- descr_file.write("WD_id" + "|" + "description" + "\n")
- for qid, descr in id_to_descr.items():
- descr_file.write(str(qid) + "|" + descr + "\n")
+ return filtered_title_to_id, entity_list, description_list, frequency_list
def get_entity_to_id(entity_def_output):
@@ -137,9 +115,9 @@ def get_entity_to_id(entity_def_output):
return entity_to_id
-def get_id_to_description(entity_descr_output):
+def get_id_to_description(entity_descr_path):
id_to_desc = dict()
- with entity_descr_output.open("r", encoding="utf8") as csvfile:
+ with entity_descr_path.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
@@ -150,7 +128,6 @@ def get_id_to_description(entity_descr_output):
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
wp_titles = title_to_id.keys()
- cnt = 0
# adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count
@@ -187,9 +164,8 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
entities=selected_entities,
probabilities=prior_probs,
)
- cnt += 1
except ValueError as e:
- print(e)
+ logger.error(e)
total_count = 0
counts = []
entities = []
@@ -202,8 +178,12 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
previous_alias = new_alias
line = prior_file.readline()
- return cnt
-def now():
- return datetime.datetime.now()
+def read_nlp_kb(model_dir, kb_file):
+ nlp = spacy.load(model_dir)
+ kb = KnowledgeBase(vocab=nlp.vocab)
+ kb.load_bulk(kb_file)
+ logger.info("kb entities: {}".format(kb.get_size_entities()))
+ logger.info("kb aliases: {}".format(kb.get_size_aliases()))
+ return nlp, kb
diff --git a/bin/wiki_entity_linking/train_descriptions.py b/bin/wiki_entity_linking/train_descriptions.py
index 0663296e4..2cb66909f 100644
--- a/bin/wiki_entity_linking/train_descriptions.py
+++ b/bin/wiki_entity_linking/train_descriptions.py
@@ -1,6 +1,7 @@
# coding: utf-8
from random import shuffle
+import logging
import numpy as np
from spacy._ml import zero_init, create_default_optimizer
@@ -10,6 +11,8 @@ from thinc.v2v import Model
from thinc.api import chain
from thinc.neural._classes.affine import Affine
+logger = logging.getLogger(__name__)
+
class EntityEncoder:
"""
@@ -50,21 +53,19 @@ class EntityEncoder:
start = start + batch_size
stop = min(stop + batch_size, len(description_list))
- print("encoded:", stop, "entities")
+ logger.info("encoded: {} entities".format(stop))
return encodings
def train(self, description_list, to_print=False):
processed, loss = self._train_model(description_list)
if to_print:
- print(
- "Trained entity descriptions on",
- processed,
- "(non-unique) entities across",
- self.epochs,
- "epochs",
+ logger.info(
+ "Trained entity descriptions on {} ".format(processed) +
+ "(non-unique) entities across {} ".format(self.epochs) +
+ "epochs"
)
- print("Final loss:", loss)
+ logger.info("Final loss: {}".format(loss))
def _train_model(self, description_list):
best_loss = 1.0
@@ -93,7 +94,7 @@ class EntityEncoder:
loss = self._update(batch)
if batch_nr % 25 == 0:
- print("loss:", loss)
+ logger.info("loss: {} ".format(loss))
processed += len(batch)
# in general, continue training if we haven't reached our ideal min yet
diff --git a/bin/wiki_entity_linking/training_set_creator.py b/bin/wiki_entity_linking/training_set_creator.py
index 7f45d9435..3f42f8bdd 100644
--- a/bin/wiki_entity_linking/training_set_creator.py
+++ b/bin/wiki_entity_linking/training_set_creator.py
@@ -1,10 +1,13 @@
# coding: utf-8
from __future__ import unicode_literals
+import logging
import random
import re
import bz2
-import datetime
+import json
+
+from functools import partial
from spacy.gold import GoldParse
from bin.wiki_entity_linking import kb_creator
@@ -15,18 +18,30 @@ Gold-standard entities are stored in one file in standoff format (by character o
"""
ENTITY_FILE = "gold_entities.csv"
+logger = logging.getLogger(__name__)
-def now():
- return datetime.datetime.now()
-
-
-def create_training(wikipedia_input, entity_def_input, training_output, limit=None):
+def create_training_examples_and_descriptions(wikipedia_input,
+ entity_def_input,
+ description_output,
+ training_output,
+ parse_descriptions,
+ limit=None):
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
- _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=limit)
+ _process_wikipedia_texts(wikipedia_input,
+ wp_to_id,
+ description_output,
+ training_output,
+ parse_descriptions,
+ limit)
-def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None):
+def _process_wikipedia_texts(wikipedia_input,
+ wp_to_id,
+ output,
+ training_output,
+ parse_descriptions,
+ limit=None):
"""
Read the XML wikipedia data to parse out training data:
raw text data + positive instances
@@ -35,29 +50,21 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
id_regex = re.compile(r"(?<=)\d*(?=)")
read_ids = set()
- entityfile_loc = training_output / ENTITY_FILE
- with entityfile_loc.open("w", encoding="utf8") as entityfile:
- # write entity training header file
- _write_training_entity(
- outputfile=entityfile,
- article_id="article_id",
- alias="alias",
- entity="WD_id",
- start="start",
- end="end",
- )
+ with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file:
+ if parse_descriptions:
+ _write_training_description(descr_file, "WD_id", "description")
with bz2.open(wikipedia_input, mode="rb") as file:
- line = file.readline()
- cnt = 0
+ article_count = 0
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
- while line and (not limit or cnt < limit):
- if cnt % 1000000 == 0:
- print(now(), "processed", cnt, "lines of Wikipedia dump")
+
+ logger.info("Processed {} articles".format(article_count))
+
+ for line in file:
clean_line = line.strip().decode("utf-8")
if clean_line == "":
@@ -70,28 +77,32 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
article_text = ""
article_title = None
article_id = None
-
# finished reading this page
elif clean_line == "":
if article_id:
- try:
- _process_wp_text(
- wp_to_id,
- entityfile,
- article_id,
- article_title,
- article_text.strip(),
- training_output,
- )
- except Exception as e:
- print(
- "Error processing article", article_id, article_title, e
- )
- else:
- print(
- "Done processing a page, but couldn't find an article_id ?",
+ clean_text, entities = _process_wp_text(
article_title,
+ article_text,
+ wp_to_id
)
+ if clean_text is not None and entities is not None:
+ _write_training_entities(entity_file,
+ article_id,
+ clean_text,
+ entities)
+
+ if article_title in wp_to_id and parse_descriptions:
+ description = " ".join(clean_text[:1000].split(" ")[:-1])
+ _write_training_description(
+ descr_file,
+ wp_to_id[article_title],
+ description
+ )
+ article_count += 1
+ if article_count % 10000 == 0:
+ logger.info("Processed {} articles".format(article_count))
+ if limit and article_count >= limit:
+ break
article_text = ""
article_title = None
article_id = None
@@ -115,7 +126,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
if ids:
article_id = ids[0]
if article_id in read_ids:
- print(
+ logger.info(
"Found duplicate article ID", article_id, clean_line
) # This should never happen ...
read_ids.add(article_id)
@@ -125,115 +136,10 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
titles = title_regex.search(clean_line)
if titles:
article_title = titles[0].strip()
-
- line = file.readline()
- cnt += 1
- print(now(), "processed", cnt, "lines of Wikipedia dump")
+ logger.info("Finished. Processed {} articles".format(article_count))
text_regex = re.compile(r"(?<=).*(?= 2:
- reading_special_case = True
-
- if open_read == 2 and reading_text:
- reading_text = False
- reading_entity = True
- reading_mention = False
-
- # we just finished reading an entity
- if open_read == 0 and not reading_text:
- if "#" in entity_buffer or entity_buffer.startswith(":"):
- reading_special_case = True
- # Ignore cases with nested structures like File: handles etc
- if not reading_special_case:
- if not mention_buffer:
- mention_buffer = entity_buffer
- start = len(final_text)
- end = start + len(mention_buffer)
- qid = wp_to_id.get(entity_buffer, None)
- if qid:
- _write_training_entity(
- outputfile=entityfile,
- article_id=article_id,
- alias=mention_buffer,
- entity=qid,
- start=start,
- end=end,
- )
- found_entities = True
- final_text += mention_buffer
-
- entity_buffer = ""
- mention_buffer = ""
-
- reading_text = True
- reading_entity = False
- reading_mention = False
- reading_special_case = False
-
- if found_entities:
- _write_training_article(
- article_id=article_id,
- clean_text=final_text,
- training_output=training_output,
- )
-
-
info_regex = re.compile(r"{[^{]*?}")
htlm_regex = re.compile(r"<!--[^-]*-->")
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
@@ -242,6 +148,29 @@ ref_regex = re.compile(r"<ref.*?>") # non-greedy
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
+def _process_wp_text(article_title, article_text, wp_to_id):
+ # ignore meta Wikipedia pages
+ if (
+ article_title.startswith("Wikipedia:") or
+ article_title.startswith("Kategori:")
+ ):
+ return None, None
+
+ # remove the text tags
+ text_search = text_regex.search(article_text)
+ if text_search is None:
+ return None, None
+ text = text_search.group(0)
+
+ # stop processing if this is a redirect page
+ if text.startswith("#REDIRECT"):
+ return None, None
+
+ # get the raw text without markup etc, keeping only interwiki links
+ clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
+ return clean_text, entities
+
+
def _get_clean_wp_text(article_text):
clean_text = article_text.strip()
@@ -300,130 +229,167 @@ def _get_clean_wp_text(article_text):
return clean_text.strip()
-def _write_training_article(article_id, clean_text, training_output):
- file_loc = training_output / "{}.txt".format(article_id)
- with file_loc.open("w", encoding="utf8") as outputfile:
- outputfile.write(clean_text)
+def _remove_links(clean_text, wp_to_id):
+ # read the text char by char to get the right offsets for the interwiki links
+ entities = []
+ final_text = ""
+ open_read = 0
+ reading_text = True
+ reading_entity = False
+ reading_mention = False
+ reading_special_case = False
+ entity_buffer = ""
+ mention_buffer = ""
+ for index, letter in enumerate(clean_text):
+ if letter == "[":
+ open_read += 1
+ elif letter == "]":
+ open_read -= 1
+ elif letter == "|":
+ if reading_text:
+ final_text += letter
+ # switch from reading entity to mention in the [[entity|mention]] pattern
+ elif reading_entity:
+ reading_text = False
+ reading_entity = False
+ reading_mention = True
+ else:
+ reading_special_case = True
+ else:
+ if reading_entity:
+ entity_buffer += letter
+ elif reading_mention:
+ mention_buffer += letter
+ elif reading_text:
+ final_text += letter
+ else:
+ raise ValueError("Not sure at point", clean_text[index - 2: index + 2])
+
+ if open_read > 2:
+ reading_special_case = True
+
+ if open_read == 2 and reading_text:
+ reading_text = False
+ reading_entity = True
+ reading_mention = False
+
+ # we just finished reading an entity
+ if open_read == 0 and not reading_text:
+ if "#" in entity_buffer or entity_buffer.startswith(":"):
+ reading_special_case = True
+ # Ignore cases with nested structures like File: handles etc
+ if not reading_special_case:
+ if not mention_buffer:
+ mention_buffer = entity_buffer
+ start = len(final_text)
+ end = start + len(mention_buffer)
+ qid = wp_to_id.get(entity_buffer, None)
+ if qid:
+ entities.append((mention_buffer, qid, start, end))
+ final_text += mention_buffer
+
+ entity_buffer = ""
+ mention_buffer = ""
+
+ reading_text = True
+ reading_entity = False
+ reading_mention = False
+ reading_special_case = False
+ return final_text, entities
-def _write_training_entity(outputfile, article_id, alias, entity, start, end):
- line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end)
+def _write_training_description(outputfile, qid, description):
+ if description is not None:
+ line = str(qid) + "|" + description + "\n"
+ outputfile.write(line)
+
+
+def _write_training_entities(outputfile, article_id, clean_text, entities):
+ entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities]
+ line = json.dumps(
+ {
+ "article_id": article_id,
+ "clean_text": clean_text,
+ "entities": entities_data
+ },
+ ensure_ascii=False) + "\n"
outputfile.write(line)
+def read_training(nlp, entity_file_path, dev, limit, kb):
+ """ This method provides training examples that correspond to the entity annotations found by the nlp object.
+ For training,, it will include negative training examples by using the candidate generator,
+ and it will only keep positive training examples that can be found by using the candidate generator.
+ For testing, it will include all positive examples only."""
+
+ from tqdm import tqdm
+ data = []
+ num_entities = 0
+ get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb)
+
+ logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit))
+ with entity_file_path.open("r", encoding="utf8") as file:
+ with tqdm(total=limit, leave=False) as pbar:
+ for i, line in enumerate(file):
+ example = json.loads(line)
+ article_id = example["article_id"]
+ clean_text = example["clean_text"]
+ entities = example["entities"]
+
+ if dev != is_dev(article_id) or len(clean_text) >= 30000:
+ continue
+
+ doc = nlp(clean_text)
+ gold = get_gold_parse(doc, entities)
+ if gold and len(gold.links) > 0:
+ data.append((doc, gold))
+ num_entities += len(gold.links)
+ pbar.update(len(gold.links))
+ if limit and num_entities >= limit:
+ break
+ logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
+ return data
+
+
+def _get_gold_parse(doc, entities, dev, kb):
+ gold_entities = {}
+ tagged_ent_positions = set(
+ [(ent.start_char, ent.end_char) for ent in doc.ents]
+ )
+
+ for entity in entities:
+ entity_id = entity["entity"]
+ alias = entity["alias"]
+ start = entity["start"]
+ end = entity["end"]
+
+ candidates = kb.get_candidates(alias)
+ candidate_ids = [
+ c.entity_ for c in candidates
+ ]
+
+ should_add_ent = (
+ dev or
+ (
+ (start, end) in tagged_ent_positions and
+ entity_id in candidate_ids and
+ len(candidates) > 1
+ )
+ )
+
+ if should_add_ent:
+ value_by_id = {entity_id: 1.0}
+ if not dev:
+ random.shuffle(candidate_ids)
+ value_by_id.update({
+ kb_id: 0.0
+ for kb_id in candidate_ids
+ if kb_id != entity_id
+ })
+ gold_entities[(start, end)] = value_by_id
+
+ return GoldParse(doc, links=gold_entities)
+
+
def is_dev(article_id):
return article_id.endswith("3")
-
-
-def read_training(nlp, training_dir, dev, limit, kb=None):
- """ This method provides training examples that correspond to the entity annotations found by the nlp object.
- When kb is provided (for training), it will include negative training examples by using the candidate generator,
- and it will only keep positive training examples that can be found in the KB.
- When kb=None (for testing), it will include all positive examples only."""
- entityfile_loc = training_dir / ENTITY_FILE
- data = []
-
- # assume the data is written sequentially, so we can reuse the article docs
- current_article_id = None
- current_doc = None
- ents_by_offset = dict()
- skip_articles = set()
- total_entities = 0
-
- with entityfile_loc.open("r", encoding="utf8") as file:
- for line in file:
- if not limit or len(data) < limit:
- fields = line.replace("\n", "").split(sep="|")
- article_id = fields[0]
- alias = fields[1]
- wd_id = fields[2]
- start = fields[3]
- end = fields[4]
-
- if (
- dev == is_dev(article_id)
- and article_id != "article_id"
- and article_id not in skip_articles
- ):
- if not current_doc or (current_article_id != article_id):
- # parse the new article text
- file_name = article_id + ".txt"
- try:
- training_file = training_dir / file_name
- with training_file.open("r", encoding="utf8") as f:
- text = f.read()
- # threshold for convenience / speed of processing
- if len(text) < 30000:
- current_doc = nlp(text)
- current_article_id = article_id
- ents_by_offset = dict()
- for ent in current_doc.ents:
- sent_length = len(ent.sent)
- # custom filtering to avoid too long or too short sentences
- if 5 < sent_length < 100:
- offset = "{}_{}".format(
- ent.start_char, ent.end_char
- )
- ents_by_offset[offset] = ent
- else:
- skip_articles.add(article_id)
- current_doc = None
- except Exception as e:
- print("Problem parsing article", article_id, e)
- skip_articles.add(article_id)
-
- # repeat checking this condition in case an exception was thrown
- if current_doc and (current_article_id == article_id):
- offset = "{}_{}".format(start, end)
- found_ent = ents_by_offset.get(offset, None)
- if found_ent:
- if found_ent.text != alias:
- skip_articles.add(article_id)
- current_doc = None
- else:
- sent = found_ent.sent.as_doc()
-
- gold_start = int(start) - found_ent.sent.start_char
- gold_end = int(end) - found_ent.sent.start_char
-
- gold_entities = {}
- found_useful = False
- for ent in sent.ents:
- entry = (ent.start_char, ent.end_char)
- gold_entry = (gold_start, gold_end)
- if entry == gold_entry:
- # add both pos and neg examples (in random order)
- # this will exclude examples not in the KB
- if kb:
- value_by_id = {}
- candidates = kb.get_candidates(alias)
- candidate_ids = [
- c.entity_ for c in candidates
- ]
- random.shuffle(candidate_ids)
- for kb_id in candidate_ids:
- found_useful = True
- if kb_id != wd_id:
- value_by_id[kb_id] = 0.0
- else:
- value_by_id[kb_id] = 1.0
- gold_entities[entry] = value_by_id
- # if no KB, keep all positive examples
- else:
- found_useful = True
- value_by_id = {wd_id: 1.0}
-
- gold_entities[entry] = value_by_id
- # currently feeding the gold data one entity per sentence at a time
- # setting all other entities to empty gold dictionary
- else:
- gold_entities[entry] = {}
- if found_useful:
- gold = GoldParse(doc=sent, links=gold_entities)
- data.append((sent, gold))
- total_entities += 1
- if len(data) % 2500 == 0:
- print(" -read", total_entities, "entities")
-
- print(" -read", total_entities, "entities")
- return data
diff --git a/bin/wiki_entity_linking/wikidata_pretrain_kb.py b/bin/wiki_entity_linking/wikidata_pretrain_kb.py
index c5261cada..56107f3a2 100644
--- a/bin/wiki_entity_linking/wikidata_pretrain_kb.py
+++ b/bin/wiki_entity_linking/wikidata_pretrain_kb.py
@@ -13,27 +13,25 @@ from https://dumps.wikimedia.org/enwiki/latest/
"""
from __future__ import unicode_literals
-import datetime
+import logging
from pathlib import Path
import plac
-from bin.wiki_entity_linking import wikipedia_processor as wp
+from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
from bin.wiki_entity_linking import kb_creator
-
+from bin.wiki_entity_linking import training_set_creator
+from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
+from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH
import spacy
-from spacy import Errors
-
-
-def now():
- return datetime.datetime.now()
+logger = logging.getLogger(__name__)
@plac.annotations(
wd_json=("Path to the downloaded WikiData JSON dump.", "positional", None, Path),
wp_xml=("Path to the downloaded Wikipedia XML dump.", "positional", None, Path),
output_dir=("Output directory", "positional", None, Path),
- model=("Model name, should include pretrained vectors.", "positional", None, str),
+ model=("Model name or path, should include pretrained vectors.", "positional", None, str),
max_per_alias=("Max. # entities per alias (default 10)", "option", "a", int),
min_freq=("Min. count of an entity in the corpus (default 20)", "option", "f", int),
min_pair=("Min. count of entity-alias pairs (default 5)", "option", "c", int),
@@ -41,7 +39,9 @@ def now():
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
+ descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"),
limit=("Optional threshold to limit lines read from dumps", "option", "l", int),
+ lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str),
)
def main(
wd_json,
@@ -55,20 +55,29 @@ def main(
loc_prior_prob=None,
loc_entity_defs=None,
loc_entity_desc=None,
+ descriptions_from_wikipedia=False,
limit=None,
+ lang="en",
):
- print(now(), "Creating KB with Wikipedia and WikiData")
- print()
+
+ entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
+ entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
+ entity_freq_path = output_dir / ENTITY_FREQ_PATH
+ prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
+ training_entities_path = output_dir / TRAINING_DATA_FILE
+ kb_path = output_dir / KB_FILE
+
+ logger.info("Creating KB with Wikipedia and WikiData")
if limit is not None:
- print("Warning: reading only", limit, "lines of Wikipedia/Wikidata dumps.")
+ logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit))
# STEP 0: set up IO
if not output_dir.exists():
- output_dir.mkdir()
+ output_dir.mkdir(parents=True)
# STEP 1: create the NLP object
- print(now(), "STEP 1: loaded model", model)
+ logger.info("STEP 1: Loading model {}".format(model))
nlp = spacy.load(model)
# check the length of the nlp vectors
@@ -79,64 +88,68 @@ def main(
)
# STEP 2: create prior probabilities from WP
- print()
- if loc_prior_prob:
- print(now(), "STEP 2: reading prior probabilities from", loc_prior_prob)
- else:
+ if not prior_prob_path.exists():
# It takes about 2h to process 1000M lines of Wikipedia XML dump
- loc_prior_prob = output_dir / "prior_prob.csv"
- print(now(), "STEP 2: writing prior probabilities at", loc_prior_prob)
- wp.read_prior_probs(wp_xml, loc_prior_prob, limit=limit)
+ logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path))
+ wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit)
+ logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path))
# STEP 3: deduce entity frequencies from WP (takes only a few minutes)
- print()
- print(now(), "STEP 3: calculating entity frequencies")
- loc_entity_freq = output_dir / "entity_freq.csv"
- wp.write_entity_counts(loc_prior_prob, loc_entity_freq, to_print=False)
+ logger.info("STEP 3: calculating entity frequencies")
+ wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False)
- loc_kb = output_dir / "kb"
-
- # STEP 4: reading entity descriptions and definitions from WikiData or from file
- print()
- if loc_entity_defs and loc_entity_desc:
- read_raw = False
- print(now(), "STEP 4a: reading entity definitions from", loc_entity_defs)
- print(now(), "STEP 4b: reading entity descriptions from", loc_entity_desc)
- else:
+ # STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
+ message = " and descriptions" if not descriptions_from_wikipedia else ""
+ if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()):
# It takes about 10h to process 55M lines of Wikidata JSON dump
- read_raw = True
- loc_entity_defs = output_dir / "entity_defs.csv"
- loc_entity_desc = output_dir / "entity_descriptions.csv"
- print(now(), "STEP 4: parsing wikidata for entity definitions and descriptions")
+ logger.info("STEP 4: parsing wikidata for entity definitions" + message)
+ title_to_id, id_to_descr = wd.read_wikidata_entities_json(
+ wd_json,
+ limit,
+ to_print=False,
+ lang=lang,
+ parse_descriptions=(not descriptions_from_wikipedia),
+ )
+ wd.write_entity_files(entity_defs_path, title_to_id)
+ if not descriptions_from_wikipedia:
+ wd.write_entity_description_files(entity_descr_path, id_to_descr)
+ logger.info("STEP 4: read entity definitions" + message)
- # STEP 5: creating the actual KB
+ # STEP 5: Getting gold entities from wikipedia
+ message = " and descriptions" if descriptions_from_wikipedia else ""
+ if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()):
+ logger.info("STEP 5: parsing wikipedia for gold entities" + message)
+ training_set_creator.create_training_examples_and_descriptions(
+ wp_xml,
+ entity_defs_path,
+ entity_descr_path,
+ training_entities_path,
+ parse_descriptions=descriptions_from_wikipedia,
+ limit=limit,
+ )
+ logger.info("STEP 5: read gold entities" + message)
+
+ # STEP 6: creating the actual KB
# It takes ca. 30 minutes to pretrain the entity embeddings
- print()
- print(now(), "STEP 5: creating the KB at", loc_kb)
+ logger.info("STEP 6: creating the KB at {}".format(kb_path))
kb = kb_creator.create_kb(
nlp=nlp,
max_entities_per_alias=max_per_alias,
min_entity_freq=min_freq,
min_occ=min_pair,
- entity_def_output=loc_entity_defs,
- entity_descr_output=loc_entity_desc,
- count_input=loc_entity_freq,
- prior_prob_input=loc_prior_prob,
- wikidata_input=wd_json,
+ entity_def_input=entity_defs_path,
+ entity_descr_path=entity_descr_path,
+ count_input=entity_freq_path,
+ prior_prob_input=prior_prob_path,
entity_vector_length=entity_vector_length,
- limit=limit,
- read_raw_data=read_raw,
)
- if read_raw:
- print(" - wrote entity definitions to", loc_entity_defs)
- print(" - wrote writing entity descriptions to", loc_entity_desc)
- kb.dump(loc_kb)
- nlp.to_disk(output_dir / "nlp")
+ kb.dump(kb_path)
+ nlp.to_disk(output_dir / KB_MODEL_DIR)
- print()
- print(now(), "Done!")
+ logger.info("Done!")
if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
plac.call(main)
diff --git a/bin/wiki_entity_linking/wikidata_processor.py b/bin/wiki_entity_linking/wikidata_processor.py
index 660eab28e..b4034cb1a 100644
--- a/bin/wiki_entity_linking/wikidata_processor.py
+++ b/bin/wiki_entity_linking/wikidata_processor.py
@@ -1,17 +1,19 @@
# coding: utf-8
from __future__ import unicode_literals
-import bz2
+import gzip
import json
+import logging
import datetime
+logger = logging.getLogger(__name__)
-def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
+
+def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True):
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
- lang = "en"
- site_filter = "enwiki"
+ site_filter = '{}wiki'.format(lang)
# properties filter (currently disabled to get ALL data)
prop_filter = dict()
@@ -24,18 +26,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
parse_properties = False
parse_sitelinks = True
parse_labels = False
- parse_descriptions = True
parse_aliases = False
parse_claims = False
- with bz2.open(wikidata_file, mode="rb") as file:
- line = file.readline()
- cnt = 0
- while line and (not limit or cnt < limit):
- if cnt % 1000000 == 0:
- print(
- datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump"
- )
+ with gzip.open(wikidata_file, mode='rb') as file:
+ for cnt, line in enumerate(file):
+ if limit and cnt >= limit:
+ break
+ if cnt % 500000 == 0:
+ logger.info("processed {} lines of WikiData dump".format(cnt))
clean_line = line.strip()
if clean_line.endswith(b","):
clean_line = clean_line[:-1]
@@ -134,8 +133,19 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
if to_print:
print()
- line = file.readline()
- cnt += 1
- print(datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump")
return title_to_id, id_to_descr
+
+
+def write_entity_files(entity_def_output, title_to_id):
+ with entity_def_output.open("w", encoding="utf8") as id_file:
+ id_file.write("WP_title" + "|" + "WD_id" + "\n")
+ for title, qid in title_to_id.items():
+ id_file.write(title + "|" + str(qid) + "\n")
+
+
+def write_entity_description_files(entity_descr_output, id_to_descr):
+ with entity_descr_output.open("w", encoding="utf8") as descr_file:
+ descr_file.write("WD_id" + "|" + "description" + "\n")
+ for qid, descr in id_to_descr.items():
+ descr_file.write(str(qid) + "|" + descr + "\n")
diff --git a/bin/wiki_entity_linking/wikidata_train_entity_linker.py b/bin/wiki_entity_linking/wikidata_train_entity_linker.py
index 770919112..d9ed641d6 100644
--- a/bin/wiki_entity_linking/wikidata_train_entity_linker.py
+++ b/bin/wiki_entity_linking/wikidata_train_entity_linker.py
@@ -11,124 +11,84 @@ from https://dumps.wikimedia.org/enwiki/latest/
from __future__ import unicode_literals
import random
-import datetime
+import logging
from pathlib import Path
import plac
from bin.wiki_entity_linking import training_set_creator
+from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
+from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines
+from bin.wiki_entity_linking.kb_creator import read_nlp_kb
-import spacy
-from spacy.kb import KnowledgeBase
from spacy.util import minibatch, compounding
-
-def now():
- return datetime.datetime.now()
+logger = logging.getLogger(__name__)
@plac.annotations(
dir_kb=("Directory with KB, NLP and related files", "positional", None, Path),
output_dir=("Output directory", "option", "o", Path),
loc_training=("Location to training data", "option", "k", Path),
- wp_xml=("Path to the downloaded Wikipedia XML dump.", "option", "w", Path),
epochs=("Number of training iterations (default 10)", "option", "e", int),
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
lr=("Learning rate (default 0.005)", "option", "n", float),
l2=("L2 regularization", "option", "r", float),
train_inst=("# training instances (default 90% of all)", "option", "t", int),
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
- limit=("Optional threshold to limit lines read from WP dump", "option", "l", int),
)
def main(
dir_kb,
output_dir=None,
loc_training=None,
- wp_xml=None,
epochs=10,
dropout=0.5,
lr=0.005,
l2=1e-6,
train_inst=None,
dev_inst=None,
- limit=None,
):
- print(now(), "Creating Entity Linker with Wikipedia and WikiData")
- print()
+ logger.info("Creating Entity Linker with Wikipedia and WikiData")
+
+ output_dir = Path(output_dir) if output_dir else dir_kb
+ training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE
+ nlp_dir = dir_kb / KB_MODEL_DIR
+ kb_path = output_dir / KB_FILE
+ nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
# STEP 0: set up IO
- if output_dir and not output_dir.exists():
+ if not output_dir.exists():
output_dir.mkdir()
# STEP 1 : load the NLP object
- nlp_dir = dir_kb / "nlp"
- print(now(), "STEP 1: loading model from", nlp_dir)
- nlp = spacy.load(nlp_dir)
+ logger.info("STEP 1: loading model from {}".format(nlp_dir))
+ nlp, kb = read_nlp_kb(nlp_dir, kb_path)
# check that there is a NER component in the pipeline
if "ner" not in nlp.pipe_names:
raise ValueError("The `nlp` object should have a pre-trained `ner` component.")
- # STEP 2 : read the KB
- print()
- print(now(), "STEP 2: reading the KB from", dir_kb / "kb")
- kb = KnowledgeBase(vocab=nlp.vocab)
- kb.load_bulk(dir_kb / "kb")
+ # STEP 2: create a training dataset from WP
+ logger.info("STEP 2: reading training dataset from {}".format(training_path))
- # STEP 3: create a training dataset from WP
- print()
- if loc_training:
- print(now(), "STEP 3: reading training dataset from", loc_training)
- else:
- if not wp_xml:
- raise ValueError(
- "Either provide a path to a preprocessed training directory, "
- "or to the original Wikipedia XML dump."
- )
-
- if output_dir:
- loc_training = output_dir / "training_data"
- else:
- loc_training = dir_kb / "training_data"
- if not loc_training.exists():
- loc_training.mkdir()
- print(now(), "STEP 3: creating training dataset at", loc_training)
-
- if limit is not None:
- print("Warning: reading only", limit, "lines of Wikipedia dump.")
-
- loc_entity_defs = dir_kb / "entity_defs.csv"
- training_set_creator.create_training(
- wikipedia_input=wp_xml,
- entity_def_input=loc_entity_defs,
- training_output=loc_training,
- limit=limit,
- )
-
- # STEP 4: parse the training data
- print()
- print(now(), "STEP 4: parse the training & evaluation data")
-
- # for training, get pos & neg instances that correspond to entries in the kb
- print("Parsing training data, limit =", train_inst)
train_data = training_set_creator.read_training(
- nlp=nlp, training_dir=loc_training, dev=False, limit=train_inst, kb=kb
+ nlp=nlp,
+ entity_file_path=training_path,
+ dev=False,
+ limit=train_inst,
+ kb=kb,
)
- print("Training on", len(train_data), "articles")
- print()
-
- print("Parsing dev testing data, limit =", dev_inst)
# for testing, get all pos instances, whether or not they are in the kb
dev_data = training_set_creator.read_training(
- nlp=nlp, training_dir=loc_training, dev=True, limit=dev_inst, kb=None
+ nlp=nlp,
+ entity_file_path=training_path,
+ dev=True,
+ limit=dev_inst,
+ kb=kb,
)
- print("Dev testing on", len(dev_data), "articles")
- print()
-
- # STEP 5: create and train the entity linking pipe
- print()
- print(now(), "STEP 5: training Entity Linking pipe")
+ # STEP 3: create and train the entity linking pipe
+ logger.info("STEP 3: training Entity Linking pipe")
el_pipe = nlp.create_pipe(
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
@@ -142,275 +102,70 @@ def main(
optimizer.learn_rate = lr
optimizer.L2 = l2
- if not train_data:
- print("Did not find any training data")
- else:
- for itn in range(epochs):
- random.shuffle(train_data)
- losses = {}
- batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
- batchnr = 0
+ logger.info("Training on {} articles".format(len(train_data)))
+ logger.info("Dev testing on {} articles".format(len(dev_data)))
- with nlp.disable_pipes(*other_pipes):
- for batch in batches:
- try:
- docs, golds = zip(*batch)
- nlp.update(
- docs=docs,
- golds=golds,
- sgd=optimizer,
- drop=dropout,
- losses=losses,
- )
- batchnr += 1
- except Exception as e:
- print("Error updating batch:", e)
-
- if batchnr > 0:
- el_pipe.cfg["incl_context"] = True
- el_pipe.cfg["incl_prior"] = True
- dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
- losses["entity_linker"] = losses["entity_linker"] / batchnr
- print(
- "Epoch, train loss",
- itn,
- round(losses["entity_linker"], 2),
- " / dev accuracy avg",
- round(dev_acc_context, 3),
- )
-
- # STEP 6: measure the performance of our trained pipe on an independent dev set
- print()
- if len(dev_data):
- print()
- print(now(), "STEP 6: performance measurement of Entity Linking pipe")
- print()
-
- counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
- dev_data, kb
- )
- print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
-
- oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
- print("dev accuracy oracle:", round(acc_o, 3), oracle_by_label)
-
- random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
- print("dev accuracy random:", round(acc_r, 3), random_by_label)
-
- prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
- print("dev accuracy prior:", round(acc_p, 3), prior_by_label)
-
- # using only context
- el_pipe.cfg["incl_context"] = True
- el_pipe.cfg["incl_prior"] = False
- dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
- context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
- print("dev accuracy context:", round(dev_acc_context, 3), context_by_label)
-
- # measuring combined accuracy (prior + context)
- el_pipe.cfg["incl_context"] = True
- el_pipe.cfg["incl_prior"] = True
- dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
- combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
- print("dev accuracy prior+context:", round(dev_acc_combo, 3), combo_by_label)
-
- # STEP 7: apply the EL pipe on a toy example
- print()
- print(now(), "STEP 7: applying Entity Linking to toy example")
- print()
- run_el_toy_example(nlp=nlp)
-
- # STEP 8: write the NLP pipeline (including entity linker) to file
- if output_dir:
- print()
- nlp_loc = output_dir / "nlp"
- print(now(), "STEP 8: Writing trained NLP to", nlp_loc)
- nlp.to_disk(nlp_loc)
- print()
-
- print()
- print(now(), "Done!")
-
-
-def _measure_acc(data, el_pipe=None, error_analysis=False):
- # If the docs in the data require further processing with an entity linker, set el_pipe
- correct_by_label = dict()
- incorrect_by_label = dict()
-
- docs = [d for d, g in data if len(d) > 0]
- if el_pipe is not None:
- docs = list(el_pipe.pipe(docs))
- golds = [g for d, g in data if len(d) > 0]
-
- for doc, gold in zip(docs, golds):
- try:
- correct_entries_per_article = dict()
- for entity, kb_dict in gold.links.items():
- start, end = entity
- # only evaluating on positive examples
- for gold_kb, value in kb_dict.items():
- if value:
- offset = _offset(start, end)
- correct_entries_per_article[offset] = gold_kb
-
- for ent in doc.ents:
- ent_label = ent.label_
- pred_entity = ent.kb_id_
- start = ent.start_char
- end = ent.end_char
- offset = _offset(start, end)
- gold_entity = correct_entries_per_article.get(offset, None)
- # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
- if gold_entity is not None:
- if gold_entity == pred_entity:
- correct = correct_by_label.get(ent_label, 0)
- correct_by_label[ent_label] = correct + 1
- else:
- incorrect = incorrect_by_label.get(ent_label, 0)
- incorrect_by_label[ent_label] = incorrect + 1
- if error_analysis:
- print(ent.text, "in", doc)
- print(
- "Predicted",
- pred_entity,
- "should have been",
- gold_entity,
- )
- print()
-
- except Exception as e:
- print("Error assessing accuracy", e)
-
- acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
- return acc, acc_by_label
-
-
-def _measure_baselines(data, kb):
- # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
- counts_d = dict()
-
- random_correct_d = dict()
- random_incorrect_d = dict()
-
- oracle_correct_d = dict()
- oracle_incorrect_d = dict()
-
- prior_correct_d = dict()
- prior_incorrect_d = dict()
-
- docs = [d for d, g in data if len(d) > 0]
- golds = [g for d, g in data if len(d) > 0]
-
- for doc, gold in zip(docs, golds):
- try:
- correct_entries_per_article = dict()
- for entity, kb_dict in gold.links.items():
- start, end = entity
- for gold_kb, value in kb_dict.items():
- # only evaluating on positive examples
- if value:
- offset = _offset(start, end)
- correct_entries_per_article[offset] = gold_kb
-
- for ent in doc.ents:
- label = ent.label_
- start = ent.start_char
- end = ent.end_char
- offset = _offset(start, end)
- gold_entity = correct_entries_per_article.get(offset, None)
-
- # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
- if gold_entity is not None:
- counts_d[label] = counts_d.get(label, 0) + 1
- candidates = kb.get_candidates(ent.text)
- oracle_candidate = ""
- best_candidate = ""
- random_candidate = ""
- if candidates:
- scores = []
-
- for c in candidates:
- scores.append(c.prior_prob)
- if c.entity_ == gold_entity:
- oracle_candidate = c.entity_
-
- best_index = scores.index(max(scores))
- best_candidate = candidates[best_index].entity_
- random_candidate = random.choice(candidates).entity_
-
- if gold_entity == best_candidate:
- prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
- else:
- prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
-
- if gold_entity == random_candidate:
- random_correct_d[label] = random_correct_d.get(label, 0) + 1
- else:
- random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
-
- if gold_entity == oracle_candidate:
- oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
- else:
- oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
-
- except Exception as e:
- print("Error assessing accuracy", e)
-
- acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
- acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
- acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
-
- return (
- counts_d,
- acc_rand,
- acc_rand_d,
- acc_prior,
- acc_prior_d,
- acc_oracle,
- acc_oracle_d,
+ dev_baseline_accuracies = measure_baselines(
+ dev_data, kb
)
+ logger.info("Dev Baseline Accuracies:")
+ logger.info(dev_baseline_accuracies.report_accuracy("random"))
+ logger.info(dev_baseline_accuracies.report_accuracy("prior"))
+ logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
-def _offset(start, end):
- return "{}_{}".format(start, end)
+ for itn in range(epochs):
+ random.shuffle(train_data)
+ losses = {}
+ batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
+ batchnr = 0
+ with nlp.disable_pipes(*other_pipes):
+ for batch in batches:
+ try:
+ docs, golds = zip(*batch)
+ nlp.update(
+ docs=docs,
+ golds=golds,
+ sgd=optimizer,
+ drop=dropout,
+ losses=losses,
+ )
+ batchnr += 1
+ except Exception as e:
+ logger.error("Error updating batch:" + str(e))
+ if batchnr > 0:
+ logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
+ measure_performance(dev_data, kb, el_pipe)
-def calculate_acc(correct_by_label, incorrect_by_label):
- acc_by_label = dict()
- total_correct = 0
- total_incorrect = 0
- all_keys = set()
- all_keys.update(correct_by_label.keys())
- all_keys.update(incorrect_by_label.keys())
- for label in sorted(all_keys):
- correct = correct_by_label.get(label, 0)
- incorrect = incorrect_by_label.get(label, 0)
- total_correct += correct
- total_incorrect += incorrect
- if correct == incorrect == 0:
- acc_by_label[label] = 0
- else:
- acc_by_label[label] = correct / (correct + incorrect)
- acc = 0
- if not (total_correct == total_incorrect == 0):
- acc = total_correct / (total_correct + total_incorrect)
- return acc, acc_by_label
+ # STEP 4: measure the performance of our trained pipe on an independent dev set
+ logger.info("STEP 4: performance measurement of Entity Linking pipe")
+ measure_performance(dev_data, kb, el_pipe)
+
+ # STEP 5: apply the EL pipe on a toy example
+ logger.info("STEP 5: applying Entity Linking to toy example")
+ run_el_toy_example(nlp=nlp)
+
+ if output_dir:
+ # STEP 6: write the NLP pipeline (including entity linker) to file
+ logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
+ nlp.to_disk(nlp_output_dir)
+
+ logger.info("Done!")
def check_kb(kb):
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
candidates = kb.get_candidates(mention)
- print("generating candidates for " + mention + " :")
+ logger.info("generating candidates for " + mention + " :")
for c in candidates:
- print(
- " ",
- c.prior_prob,
+ logger.info(" ".join[
+ str(c.prior_prob),
c.alias_,
"-->",
- c.entity_ + " (freq=" + str(c.entity_freq) + ")",
- )
- print()
+ c.entity_ + " (freq=" + str(c.entity_freq) + ")"
+ ])
def run_el_toy_example(nlp):
@@ -421,11 +176,11 @@ def run_el_toy_example(nlp):
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
)
doc = nlp(text)
- print(text)
+ logger.info(text)
for ent in doc.ents:
- print(" ent", ent.text, ent.label_, ent.kb_id_)
- print()
+ logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_]))
if __name__ == "__main__":
+ logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
plac.call(main)
diff --git a/bin/wiki_entity_linking/wikipedia_processor.py b/bin/wiki_entity_linking/wikipedia_processor.py
index fca600368..8f928723e 100644
--- a/bin/wiki_entity_linking/wikipedia_processor.py
+++ b/bin/wiki_entity_linking/wikipedia_processor.py
@@ -5,6 +5,9 @@ import re
import bz2
import csv
import datetime
+import logging
+
+from bin.wiki_entity_linking import LOG_FORMAT
"""
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
@@ -13,6 +16,9 @@ Write these results to file for downstream KB and training data generation.
map_alias_to_link = dict()
+logger = logging.getLogger(__name__)
+
+
# these will/should be matched ignoring case
wiki_namespaces = [
"b",
@@ -116,10 +122,6 @@ for ns in wiki_namespaces:
ns_regex = re.compile(ns_regex, re.IGNORECASE)
-def now():
- return datetime.datetime.now()
-
-
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
"""
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
@@ -131,7 +133,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
cnt = 0
while line and (not limit or cnt < limit):
if cnt % 25000000 == 0:
- print(now(), "processed", cnt, "lines of Wikipedia XML dump")
+ logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
clean_line = line.strip().decode("utf-8")
aliases, entities, normalizations = get_wp_links(clean_line)
@@ -141,7 +143,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
line = file.readline()
cnt += 1
- print(now(), "processed", cnt, "lines of Wikipedia XML dump")
+ logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
# write all aliases and their entities and count occurrences to file
with prior_prob_output.open("w", encoding="utf8") as outputfile:
diff --git a/setup.py b/setup.py
index 1d2aa084b..984de2250 100755
--- a/setup.py
+++ b/setup.py
@@ -140,7 +140,7 @@ def gzip_language_data(root, source):
base = Path(root) / source
for jsonfile in base.glob("**/*.json"):
outfile = jsonfile.with_suffix(jsonfile.suffix + ".gz")
- if outfile.is_file() and outfile.stat().st_ctime > jsonfile.stat().st_ctime:
+ if outfile.is_file() and outfile.stat().st_mtime > jsonfile.stat().st_mtime:
# If the gz is newer it doesn't need updating
print("Skipping {}, already compressed".format(jsonfile))
continue
diff --git a/spacy/cli/converters/conllu2json.py b/spacy/cli/converters/conllu2json.py
index 3a7a68e4a..8f2900a9b 100644
--- a/spacy/cli/converters/conllu2json.py
+++ b/spacy/cli/converters/conllu2json.py
@@ -6,7 +6,7 @@ import re
from ...gold import iob_to_biluo
-def conllu2json(input_data, n_sents=10, use_morphology=False, lang=None):
+def conllu2json(input_data, n_sents=10, use_morphology=False, lang=None, **_):
"""
Convert conllu files into JSON format for use with train cli.
use_morphology parameter enables appending morphology to tags, which is
diff --git a/spacy/errors.py b/spacy/errors.py
index 1be6dd6df..963024c32 100644
--- a/spacy/errors.py
+++ b/spacy/errors.py
@@ -455,7 +455,9 @@ class Errors(object):
E158 = ("Can't add table '{name}' to lookups because it already exists.")
E159 = ("Can't find table '{name}' in lookups. Available tables: {tables}")
E160 = ("Can't find language data file: {path}")
- E161 = ("Tokenizer special cases are not allowed to modify the text. "
+ E161 = ("Found an internal inconsistency when predicting entity links. "
+ "This is likely a bug in spaCy, so feel free to open an issue.")
+ E162 = ("Tokenizer special cases are not allowed to modify the text. "
"This would map '{chunk}' to '{orth}' given token attributes "
"'{token_attrs}'.")
diff --git a/spacy/lang/char_classes.py b/spacy/lang/char_classes.py
index 9f6c3266e..131bdcd51 100644
--- a/spacy/lang/char_classes.py
+++ b/spacy/lang/char_classes.py
@@ -11,6 +11,12 @@ _hebrew = r"\u0591-\u05F4\uFB1D-\uFB4F"
_hindi = r"\u0900-\u097F"
+_kannada = r"\u0C80-\u0CFF"
+
+_tamil = r"\u0B80-\u0BFF"
+
+_telugu = r"\u0C00-\u0C7F"
+
# Latin standard
_latin_u_standard = r"A-Z"
_latin_l_standard = r"a-z"
@@ -195,7 +201,7 @@ _ukrainian = r"а-щюяіїєґА-ЩЮЯІЇЄҐ"
_upper = LATIN_UPPER + _russian_upper + _tatar_upper + _greek_upper + _ukrainian_upper
_lower = LATIN_LOWER + _russian_lower + _tatar_lower + _greek_lower + _ukrainian_lower
-_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi
+_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi + _kannada + _tamil + _telugu
ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased)
ALPHA_LOWER = group_chars(_lower + _uncased)
diff --git a/spacy/language.py b/spacy/language.py
index 10381573d..09dd22cf2 100644
--- a/spacy/language.py
+++ b/spacy/language.py
@@ -208,6 +208,7 @@ class Language(object):
"name": self.vocab.vectors.name,
}
self._meta["pipeline"] = self.pipe_names
+ self._meta["labels"] = self.pipe_labels
return self._meta
@meta.setter
@@ -248,6 +249,18 @@ class Language(object):
"""
return [pipe_name for pipe_name, _ in self.pipeline]
+ @property
+ def pipe_labels(self):
+ """Get the labels set by the pipeline components, if available.
+
+ RETURNS (dict): Labels keyed by component name.
+ """
+ labels = OrderedDict()
+ for name, pipe in self.pipeline:
+ if hasattr(pipe, "labels"):
+ labels[name] = list(pipe.labels)
+ return labels
+
def get_pipe(self, name):
"""Get a pipeline component for a given component name.
diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx
index 90ccc2fbf..190116a2e 100644
--- a/spacy/pipeline/pipes.pyx
+++ b/spacy/pipeline/pipes.pyx
@@ -67,7 +67,7 @@ class Pipe(object):
"""
self.require_model()
predictions = self.predict([doc])
- if isinstance(predictions, tuple) and len(tuple) == 2:
+ if isinstance(predictions, tuple) and len(predictions) == 2:
scores, tensors = predictions
self.set_annotations([doc], scores, tensor=tensors)
else:
@@ -1062,8 +1062,15 @@ cdef class DependencyParser(Parser):
@property
def labels(self):
+ labels = set()
# Get the labels from the model by looking at the available moves
- return tuple(set(move.split("-")[1] for move in self.move_names))
+ for move in self.move_names:
+ if "-" in move:
+ label = move.split("-")[1]
+ if "||" in label:
+ label = label.split("||")[1]
+ labels.add(label)
+ return tuple(sorted(labels))
cdef class EntityRecognizer(Parser):
@@ -1098,8 +1105,9 @@ cdef class EntityRecognizer(Parser):
def labels(self):
# Get the labels from the model by looking at the available moves, e.g.
# B-PERSON, I-PERSON, L-PERSON, U-PERSON
- return tuple(set(move.split("-")[1] for move in self.move_names
- if move[0] in ("B", "I", "L", "U")))
+ labels = set(move.split("-")[1] for move in self.move_names
+ if move[0] in ("B", "I", "L", "U"))
+ return tuple(sorted(labels))
class EntityLinker(Pipe):
@@ -1275,7 +1283,7 @@ class EntityLinker(Pipe):
# this will set all prior probabilities to 0 if they should be excluded from the model
prior_probs = xp.asarray([c.prior_prob for c in candidates])
if not self.cfg.get("incl_prior", True):
- prior_probs = xp.asarray([[0.0] for c in candidates])
+ prior_probs = xp.asarray([0.0 for c in candidates])
scores = prior_probs
# add in similarity from the context
@@ -1288,6 +1296,8 @@ class EntityLinker(Pipe):
# cosine similarity
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
+ if sims.shape != prior_probs.shape:
+ raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs*sims)
# TODO: thresholding
@@ -1361,7 +1371,16 @@ class Sentencizer(object):
"""
name = "sentencizer"
- default_punct_chars = [".", "!", "?"]
+ default_punct_chars = ['!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹',
+ '।', '॥', '၊', '။', '።', '፧', '፨', '᙮', '᜵', '᜶', '᠃', '᠉', '᥄',
+ '᥅', '᪨', '᪩', '᪪', '᪫', '᭚', '᭛', '᭞', '᭟', '᰻', '᰼', '᱾', '᱿',
+ '‼', '‽', '⁇', '⁈', '⁉', '⸮', '⸼', '꓿', '꘎', '꘏', '꛳', '꛷', '꡶',
+ '꡷', '꣎', '꣏', '꤯', '꧈', '꧉', '꩝', '꩞', '꩟', '꫰', '꫱', '꯫', '﹒',
+ '﹖', '﹗', '!', '.', '?', '𐩖', '𐩗', '𑁇', '𑁈', '𑂾', '𑂿', '𑃀',
+ '𑃁', '𑅁', '𑅂', '𑅃', '𑇅', '𑇆', '𑇍', '𑇞', '𑇟', '𑈸', '𑈹', '𑈻', '𑈼',
+ '𑊩', '𑑋', '𑑌', '𑗂', '𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐',
+ '𑗑', '𑗒', '𑗓', '𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂',
+ '𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈']
def __init__(self, punct_chars=None, **kwargs):
"""Initialize the sentencizer.
@@ -1372,7 +1391,10 @@ class Sentencizer(object):
DOCS: https://spacy.io/api/sentencizer#init
"""
- self.punct_chars = punct_chars or self.default_punct_chars
+ if punct_chars:
+ self.punct_chars = set(punct_chars)
+ else:
+ self.punct_chars = set(self.default_punct_chars)
def __call__(self, doc):
"""Apply the sentencizer to a Doc and set Token.is_sent_start.
@@ -1404,7 +1426,7 @@ class Sentencizer(object):
DOCS: https://spacy.io/api/sentencizer#to_bytes
"""
- return srsly.msgpack_dumps({"punct_chars": self.punct_chars})
+ return srsly.msgpack_dumps({"punct_chars": list(self.punct_chars)})
def from_bytes(self, bytes_data, **kwargs):
"""Load the sentencizer from a bytestring.
@@ -1415,7 +1437,7 @@ class Sentencizer(object):
DOCS: https://spacy.io/api/sentencizer#from_bytes
"""
cfg = srsly.msgpack_loads(bytes_data)
- self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
+ self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
return self
def to_disk(self, path, exclude=tuple(), **kwargs):
@@ -1425,7 +1447,7 @@ class Sentencizer(object):
"""
path = util.ensure_path(path)
path = path.with_suffix(".json")
- srsly.write_json(path, {"punct_chars": self.punct_chars})
+ srsly.write_json(path, {"punct_chars": list(self.punct_chars)})
def from_disk(self, path, exclude=tuple(), **kwargs):
@@ -1436,7 +1458,7 @@ class Sentencizer(object):
path = util.ensure_path(path)
path = path.with_suffix(".json")
cfg = srsly.read_json(path)
- self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
+ self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
return self
diff --git a/spacy/tests/doc/test_span.py b/spacy/tests/doc/test_span.py
index 60b711741..c8c809d24 100644
--- a/spacy/tests/doc/test_span.py
+++ b/spacy/tests/doc/test_span.py
@@ -173,6 +173,21 @@ def test_span_as_doc(doc):
assert span_doc[0].idx == 0
+def test_span_as_doc_user_data(doc):
+ """Test that the user_data can be preserved (but not by default). """
+ my_key = "my_info"
+ my_value = 342
+ doc.user_data[my_key] = my_value
+
+ span = doc[4:10]
+ span_doc_with = span.as_doc(copy_user_data=True)
+ span_doc_without = span.as_doc()
+
+ assert doc.user_data.get(my_key, None) is my_value
+ assert span_doc_with.user_data.get(my_key, None) is my_value
+ assert span_doc_without.user_data.get(my_key, None) is None
+
+
def test_span_string_label_kb_id(doc):
span = Span(doc, 0, 1, label="hello", kb_id="Q342")
assert span.label_ == "hello"
diff --git a/spacy/tests/parser/test_add_label.py b/spacy/tests/parser/test_add_label.py
index 45a51ac8e..4ab9c1e70 100644
--- a/spacy/tests/parser/test_add_label.py
+++ b/spacy/tests/parser/test_add_label.py
@@ -68,3 +68,20 @@ def test_add_label_deserializes_correctly():
assert ner1.moves.n_moves == ner2.moves.n_moves
for i in range(ner1.moves.n_moves):
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)
+
+
+@pytest.mark.parametrize(
+ "pipe_cls,n_moves", [(DependencyParser, 5), (EntityRecognizer, 4)]
+)
+def test_add_label_get_label(pipe_cls, n_moves):
+ """Test that added labels are returned correctly. This test was added to
+ test for a bug in DependencyParser.labels that'd cause it to fail when
+ splitting the move names.
+ """
+ labels = ["A", "B", "C"]
+ pipe = pipe_cls(Vocab())
+ for label in labels:
+ pipe.add_label(label)
+ assert len(pipe.move_names) == len(labels) * n_moves
+ pipe_labels = sorted(list(pipe.labels))
+ assert pipe_labels == labels
diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py
index 8755cc27a..5f1fa5cfe 100644
--- a/spacy/tests/pipeline/test_pipe_methods.py
+++ b/spacy/tests/pipeline/test_pipe_methods.py
@@ -128,3 +128,19 @@ def test_pipe_base_class_add_label(nlp, component):
assert label in pipe.labels
else:
assert pipe.labels == (label,)
+
+
+def test_pipe_labels(nlp):
+ input_labels = {
+ "ner": ["PERSON", "ORG", "GPE"],
+ "textcat": ["POSITIVE", "NEGATIVE"],
+ }
+ for name, labels in input_labels.items():
+ pipe = nlp.create_pipe(name)
+ for label in labels:
+ pipe.add_label(label)
+ assert len(pipe.labels) == len(labels)
+ nlp.add_pipe(pipe)
+ assert len(nlp.pipe_labels) == len(input_labels)
+ for name, labels in nlp.pipe_labels.items():
+ assert sorted(input_labels[name]) == sorted(labels)
diff --git a/spacy/tests/pipeline/test_sentencizer.py b/spacy/tests/pipeline/test_sentencizer.py
index c1b3eba45..1e03dc743 100644
--- a/spacy/tests/pipeline/test_sentencizer.py
+++ b/spacy/tests/pipeline/test_sentencizer.py
@@ -81,7 +81,7 @@ def test_sentencizer_custom_punct(en_vocab, punct_chars, words, sent_starts, n_s
def test_sentencizer_serialize_bytes(en_vocab):
punct_chars = [".", "~", "+"]
sentencizer = Sentencizer(punct_chars=punct_chars)
- assert sentencizer.punct_chars == punct_chars
+ assert sentencizer.punct_chars == set(punct_chars)
bytes_data = sentencizer.to_bytes()
new_sentencizer = Sentencizer().from_bytes(bytes_data)
- assert new_sentencizer.punct_chars == punct_chars
+ assert new_sentencizer.punct_chars == set(punct_chars)
diff --git a/spacy/tests/regression/test_issue4278.py b/spacy/tests/regression/test_issue4278.py
new file mode 100644
index 000000000..4c85d15c4
--- /dev/null
+++ b/spacy/tests/regression/test_issue4278.py
@@ -0,0 +1,28 @@
+# coding: utf8
+from __future__ import unicode_literals
+
+import pytest
+from spacy.language import Language
+from spacy.pipeline import Pipe
+
+
+class DummyPipe(Pipe):
+ def __init__(self):
+ self.model = "dummy_model"
+
+ def predict(self, docs):
+ return ([1, 2, 3], [4, 5, 6])
+
+ def set_annotations(self, docs, scores, tensor=None):
+ return docs
+
+
+@pytest.fixture
+def nlp():
+ return Language()
+
+
+def test_multiple_predictions(nlp):
+ doc = nlp.make_doc("foo")
+ dummy_pipe = DummyPipe()
+ dummy_pipe(doc)
diff --git a/spacy/tokenizer.pyx b/spacy/tokenizer.pyx
index 0ce29ae6b..0aff68d7a 100644
--- a/spacy/tokenizer.pyx
+++ b/spacy/tokenizer.pyx
@@ -495,7 +495,7 @@ cdef class Tokenizer:
attrs = [intify_attrs(spec, _do_deprecated=True) for spec in substrings]
orth = "".join([spec[ORTH] for spec in attrs])
if chunk != orth:
- raise ValueError(Errors.E161.format(chunk=chunk, orth=orth, token_attrs=substrings))
+ raise ValueError(Errors.E162.format(chunk=chunk, orth=orth, token_attrs=substrings))
def add_special_case(self, unicode string, substrings):
"""Add a special-case tokenization rule.
diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx
index f702133af..9e99392a9 100644
--- a/spacy/tokens/span.pyx
+++ b/spacy/tokens/span.pyx
@@ -200,13 +200,15 @@ cdef class Span:
return Underscore(Underscore.span_extensions, self,
start=self.start_char, end=self.end_char)
- def as_doc(self):
+ def as_doc(self, bint copy_user_data=False):
"""Create a `Doc` object with a copy of the `Span`'s data.
+ copy_user_data (bool): Whether or not to copy the original doc's user data.
RETURNS (Doc): The `Doc` copy of the span.
DOCS: https://spacy.io/api/span#as_doc
"""
+ # TODO: make copy_user_data a keyword-only argument (Python 3 only)
words = [t.text for t in self]
spaces = [bool(t.whitespace_) for t in self]
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
@@ -235,6 +237,8 @@ cdef class Span:
cat_start, cat_end, cat_label = key
if cat_start == self.start_char and cat_end == self.end_char:
doc.cats[cat_label] = value
+ if copy_user_data:
+ doc.user_data = self.doc.user_data
return doc
def _fix_dep_copy(self, attrs, array):
diff --git a/website/docs/api/span.md b/website/docs/api/span.md
index 0af305b37..c807c7bbf 100644
--- a/website/docs/api/span.md
+++ b/website/docs/api/span.md
@@ -292,9 +292,10 @@ Create a new `Doc` object corresponding to the `Span`, with a copy of the data.
> assert doc2.text == u"New York"
> ```
-| Name | Type | Description |
-| ----------- | ----- | --------------------------------------- |
-| **RETURNS** | `Doc` | A `Doc` object of the `Span`'s content. |
+| Name | Type | Description |
+| ----------------- | ----- | ---------------------------------------------------- |
+| `copy_user_data` | bool | Whether or not to copy the original doc's user data. |
+| **RETURNS** | `Doc` | A `Doc` object of the `Span`'s content. |
## Span.root {#root tag="property" model="parser"}
diff --git a/website/docs/usage/v2.md b/website/docs/usage/v2.md
index 9e54106c7..a412eeba4 100644
--- a/website/docs/usage/v2.md
+++ b/website/docs/usage/v2.md
@@ -107,7 +107,7 @@ process.
-**Usage:** [Models directory](/models) [Benchmarks](#benchmarks)
+**Usage:** [Models directory](/models)
diff --git a/website/meta/site.json b/website/meta/site.json
index 2b02ef953..edb60ab0c 100644
--- a/website/meta/site.json
+++ b/website/meta/site.json
@@ -10,10 +10,7 @@
"modelsRepo": "explosion/spacy-models",
"social": {
"twitter": "spacy_io",
- "github": "explosion",
- "reddit": "spacynlp",
- "codepen": "explosion",
- "gitter": "explosion/spaCy"
+ "github": "explosion"
},
"theme": "#09a3d5",
"analytics": "UA-58931649-1",
@@ -69,6 +66,7 @@
"items": [
{ "text": "Twitter", "url": "https://twitter.com/spacy_io" },
{ "text": "GitHub", "url": "https://github.com/explosion/spaCy" },
+ { "text": "YouTube", "url": "https://youtube.com/c/ExplosionAI" },
{ "text": "Blog", "url": "https://explosion.ai/blog" }
]
}