From af93997993bee4b2f09291009092337a90c0f662 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 11 Sep 2019 13:27:37 +0200 Subject: [PATCH 01/15] Fix conllu converter --- spacy/cli/converters/conllu2json.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/spacy/cli/converters/conllu2json.py b/spacy/cli/converters/conllu2json.py index 3a7a68e4a..8f2900a9b 100644 --- a/spacy/cli/converters/conllu2json.py +++ b/spacy/cli/converters/conllu2json.py @@ -6,7 +6,7 @@ import re from ...gold import iob_to_biluo -def conllu2json(input_data, n_sents=10, use_morphology=False, lang=None): +def conllu2json(input_data, n_sents=10, use_morphology=False, lang=None, **_): """ Convert conllu files into JSON format for use with train cli. use_morphology parameter enables appending morphology to tags, which is From 8ebc3711dc1ec065c39aeb6017d9ace129a28d3f Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Wed, 11 Sep 2019 18:29:35 +0200 Subject: [PATCH 02/15] Fix bug in Parser.labels and add test (#4275) --- spacy/pipeline/pipes.pyx | 2 +- spacy/tests/parser/test_add_label.py | 17 +++++++++++++++++ 2 files changed, 18 insertions(+), 1 deletion(-) diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 90ccc2fbf..095021f00 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1063,7 +1063,7 @@ cdef class DependencyParser(Parser): @property def labels(self): # Get the labels from the model by looking at the available moves - return tuple(set(move.split("-")[1] for move in self.move_names)) + return tuple(set(move.split("-")[1] for move in self.move_names if "-" in move)) cdef class EntityRecognizer(Parser): diff --git a/spacy/tests/parser/test_add_label.py b/spacy/tests/parser/test_add_label.py index 45a51ac8e..4ab9c1e70 100644 --- a/spacy/tests/parser/test_add_label.py +++ b/spacy/tests/parser/test_add_label.py @@ -68,3 +68,20 @@ def test_add_label_deserializes_correctly(): assert ner1.moves.n_moves == ner2.moves.n_moves for i in range(ner1.moves.n_moves): assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i) + + +@pytest.mark.parametrize( + "pipe_cls,n_moves", [(DependencyParser, 5), (EntityRecognizer, 4)] +) +def test_add_label_get_label(pipe_cls, n_moves): + """Test that added labels are returned correctly. This test was added to + test for a bug in DependencyParser.labels that'd cause it to fail when + splitting the move names. + """ + labels = ["A", "B", "C"] + pipe = pipe_cls(Vocab()) + for label in labels: + pipe.add_label(label) + assert len(pipe.move_names) == len(labels) * n_moves + pipe_labels = sorted(list(pipe.labels)) + assert pipe_labels == labels From 71909cdf22f9f590fcd4f37ed43b2058dd8f54c4 Mon Sep 17 00:00:00 2001 From: tamuhey Date: Thu, 12 Sep 2019 17:44:49 +0900 Subject: [PATCH 03/15] Fix iss4278 (#4279) * fix: len(tuple) == 2 * (#4278) add fail test * add contributor's aggreement --- .github/contributors/tamuhey.md | 106 +++++++++++++++++++++++ spacy/pipeline/pipes.pyx | 2 +- spacy/tests/regression/test_issue4278.py | 28 ++++++ 3 files changed, 135 insertions(+), 1 deletion(-) create mode 100644 .github/contributors/tamuhey.md create mode 100644 spacy/tests/regression/test_issue4278.py diff --git a/.github/contributors/tamuhey.md b/.github/contributors/tamuhey.md new file mode 100644 index 000000000..6a63e3b53 --- /dev/null +++ b/.github/contributors/tamuhey.md @@ -0,0 +1,106 @@ +# spaCy contributor agreement + +This spaCy Contributor Agreement (**"SCA"**) is based on the +[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf). +The SCA applies to any contribution that you make to any product or project +managed by us (the **"project"**), and sets out the intellectual property rights +you grant to us in the contributed materials. The term **"us"** shall mean +[ExplosionAI GmbH](https://explosion.ai/legal). The term +**"you"** shall mean the person or entity identified below. + +If you agree to be bound by these terms, fill in the information requested +below and include the filled-in version with your first pull request, under the +folder [`.github/contributors/`](/.github/contributors/). The name of the file +should be your GitHub username, with the extension `.md`. For example, the user +example_user would create the file `.github/contributors/example_user.md`. + +Read this agreement carefully before signing. These terms and conditions +constitute a binding legal agreement. + +## Contributor Agreement + +1. The term "contribution" or "contributed materials" means any source code, +object code, patch, tool, sample, graphic, specification, manual, +documentation, or any other material posted or submitted by you to the project. + +2. With respect to any worldwide copyrights, or copyright applications and +registrations, in your contribution: + + * you hereby assign to us joint ownership, and to the extent that such + assignment is or becomes invalid, ineffective or unenforceable, you hereby + grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge, + royalty-free, unrestricted license to exercise all rights under those + copyrights. This includes, at our option, the right to sublicense these same + rights to third parties through multiple levels of sublicensees or other + licensing arrangements; + + * you agree that each of us can do all things in relation to your + contribution as if each of us were the sole owners, and if one of us makes + a derivative work of your contribution, the one who makes the derivative + work (or has it made will be the sole owner of that derivative work; + + * you agree that you will not assert any moral rights in your contribution + against us, our licensees or transferees; + + * you agree that we may register a copyright in your contribution and + exercise all ownership rights associated with it; and + + * you agree that neither of us has any duty to consult with, obtain the + consent of, pay or render an accounting to the other for any use or + distribution of your contribution. + +3. With respect to any patents you own, or that you can license without payment +to any third party, you hereby grant to us a perpetual, irrevocable, +non-exclusive, worldwide, no-charge, royalty-free license to: + + * make, have made, use, sell, offer to sell, import, and otherwise transfer + your contribution in whole or in part, alone or in combination with or + included in any product, work or materials arising out of the project to + which your contribution was submitted, and + + * at our option, to sublicense these same rights to third parties through + multiple levels of sublicensees or other licensing arrangements. + +4. Except as set out above, you keep all right, title, and interest in your +contribution. The rights that you grant to us under these terms are effective +on the date you first submitted a contribution to us, even if your submission +took place before the date you sign these terms. + +5. You covenant, represent, warrant and agree that: + + * Each contribution that you submit is and shall be an original work of + authorship and you can legally grant the rights set out in this SCA; + + * to the best of your knowledge, each contribution will not violate any + third party's copyrights, trademarks, patents, or other intellectual + property rights; and + + * each contribution shall be in compliance with U.S. export control laws and + other applicable export and import laws. You agree to notify us if you + become aware of any circumstance which would make any of the foregoing + representations inaccurate in any respect. We may publicly disclose your + participation in the project, including the fact that you have signed the SCA. + +6. This SCA is governed by the laws of the State of California and applicable +U.S. Federal law. Any choice of law rules will not apply. + +7. Please place an “x” on one of the applicable statement below. Please do NOT +mark both statements: + + * [x] I am signing on behalf of myself as an individual and no other person + or entity, including my employer, has or will have rights with respect to my + contributions. + + * [ ] I am signing on behalf of my employer or a legal entity and I have the + actual authority to contractually bind that entity. + +## Contributor Details + +| Field | Entry | +|------------------------------- | -------------------- | +| Name | Yohei Tamura | +| Company name (if applicable) | PKSHA | +| Title or role (if applicable) | | +| Date | 2019/9/12 | +| GitHub username | tamuhey | +| Website (optional) | | diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 095021f00..da376c396 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -67,7 +67,7 @@ class Pipe(object): """ self.require_model() predictions = self.predict([doc]) - if isinstance(predictions, tuple) and len(tuple) == 2: + if isinstance(predictions, tuple) and len(predictions) == 2: scores, tensors = predictions self.set_annotations([doc], scores, tensor=tensors) else: diff --git a/spacy/tests/regression/test_issue4278.py b/spacy/tests/regression/test_issue4278.py new file mode 100644 index 000000000..4c85d15c4 --- /dev/null +++ b/spacy/tests/regression/test_issue4278.py @@ -0,0 +1,28 @@ +# coding: utf8 +from __future__ import unicode_literals + +import pytest +from spacy.language import Language +from spacy.pipeline import Pipe + + +class DummyPipe(Pipe): + def __init__(self): + self.model = "dummy_model" + + def predict(self, docs): + return ([1, 2, 3], [4, 5, 6]) + + def set_annotations(self, docs, scores, tensor=None): + return docs + + +@pytest.fixture +def nlp(): + return Language() + + +def test_multiple_predictions(nlp): + doc = nlp.make_doc("foo") + dummy_pipe = DummyPipe() + dummy_pipe(doc) From ac0e27a825c1b26cb016b7107b18f0de1c7969ff Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Thu, 12 Sep 2019 10:56:28 +0200 Subject: [PATCH 04/15] =?UTF-8?q?=F0=9F=92=AB=20Add=20Language.pipe=5Flabe?= =?UTF-8?q?ls=20(#4276)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add Language.pipe_labels * Update spacy/language.py Co-Authored-By: Matthew Honnibal --- spacy/language.py | 12 ++++++++++++ spacy/tests/pipeline/test_pipe_methods.py | 16 ++++++++++++++++ 2 files changed, 28 insertions(+) diff --git a/spacy/language.py b/spacy/language.py index 10381573d..9dc48ca6f 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -248,6 +248,18 @@ class Language(object): """ return [pipe_name for pipe_name, _ in self.pipeline] + @property + def pipe_labels(self): + """Get the labels set by the pipeline components, if available. + + RETURNS (dict): Labels keyed by component name. + """ + labels = OrderedDict() + for name, pipe in self.pipeline: + if hasattr(pipe, "labels"): + labels[name] = list(pipe.labels) + return labels + def get_pipe(self, name): """Get a pipeline component for a given component name. diff --git a/spacy/tests/pipeline/test_pipe_methods.py b/spacy/tests/pipeline/test_pipe_methods.py index 8755cc27a..5f1fa5cfe 100644 --- a/spacy/tests/pipeline/test_pipe_methods.py +++ b/spacy/tests/pipeline/test_pipe_methods.py @@ -128,3 +128,19 @@ def test_pipe_base_class_add_label(nlp, component): assert label in pipe.labels else: assert pipe.labels == (label,) + + +def test_pipe_labels(nlp): + input_labels = { + "ner": ["PERSON", "ORG", "GPE"], + "textcat": ["POSITIVE", "NEGATIVE"], + } + for name, labels in input_labels.items(): + pipe = nlp.create_pipe(name) + for label in labels: + pipe.add_label(label) + assert len(pipe.labels) == len(labels) + nlp.add_pipe(pipe) + assert len(nlp.pipe_labels) == len(input_labels) + for name, labels in nlp.pipe_labels.items(): + assert sorted(input_labels[name]) == sorted(labels) From 4d4b3b0783bdca38493e27dee2939b3ded735c4e Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Thu, 12 Sep 2019 11:34:25 +0200 Subject: [PATCH 05/15] Add "labels" to Language.meta --- spacy/language.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spacy/language.py b/spacy/language.py index 9dc48ca6f..09dd22cf2 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -208,6 +208,7 @@ class Language(object): "name": self.vocab.vectors.name, } self._meta["pipeline"] = self.pipe_names + self._meta["labels"] = self.pipe_labels return self._meta @meta.setter From 0760c41393a9be8d9af35fd6022283a0a65d504e Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Thu, 12 Sep 2019 15:35:01 +0200 Subject: [PATCH 06/15] Change st_ctime to st_mtime --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 1d2aa084b..984de2250 100755 --- a/setup.py +++ b/setup.py @@ -140,7 +140,7 @@ def gzip_language_data(root, source): base = Path(root) / source for jsonfile in base.glob("**/*.json"): outfile = jsonfile.with_suffix(jsonfile.suffix + ".gz") - if outfile.is_file() and outfile.stat().st_ctime > jsonfile.stat().st_ctime: + if outfile.is_file() and outfile.stat().st_mtime > jsonfile.stat().st_mtime: # If the gz is newer it doesn't need updating print("Skipping {}, already compressed".format(jsonfile)) continue From 9be4d1c105877df59e58e6b0cd7dcd13ed6c77ed Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Thu, 12 Sep 2019 17:08:14 +0200 Subject: [PATCH 07/15] Allow copying of user_data in as_doc (#4282) * Allow copying the user_data with as_doc + unit test * add option to docs * add typing * import fix * workaround to avoid bool clashing ... * bint instead of bool --- spacy/tests/doc/test_span.py | 15 +++++++++++++++ spacy/tokens/span.pyx | 6 +++++- website/docs/api/span.md | 7 ++++--- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/spacy/tests/doc/test_span.py b/spacy/tests/doc/test_span.py index 60b711741..c8c809d24 100644 --- a/spacy/tests/doc/test_span.py +++ b/spacy/tests/doc/test_span.py @@ -173,6 +173,21 @@ def test_span_as_doc(doc): assert span_doc[0].idx == 0 +def test_span_as_doc_user_data(doc): + """Test that the user_data can be preserved (but not by default). """ + my_key = "my_info" + my_value = 342 + doc.user_data[my_key] = my_value + + span = doc[4:10] + span_doc_with = span.as_doc(copy_user_data=True) + span_doc_without = span.as_doc() + + assert doc.user_data.get(my_key, None) is my_value + assert span_doc_with.user_data.get(my_key, None) is my_value + assert span_doc_without.user_data.get(my_key, None) is None + + def test_span_string_label_kb_id(doc): span = Span(doc, 0, 1, label="hello", kb_id="Q342") assert span.label_ == "hello" diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index f702133af..9e99392a9 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -200,13 +200,15 @@ cdef class Span: return Underscore(Underscore.span_extensions, self, start=self.start_char, end=self.end_char) - def as_doc(self): + def as_doc(self, bint copy_user_data=False): """Create a `Doc` object with a copy of the `Span`'s data. + copy_user_data (bool): Whether or not to copy the original doc's user data. RETURNS (Doc): The `Doc` copy of the span. DOCS: https://spacy.io/api/span#as_doc """ + # TODO: make copy_user_data a keyword-only argument (Python 3 only) words = [t.text for t in self] spaces = [bool(t.whitespace_) for t in self] cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces) @@ -235,6 +237,8 @@ cdef class Span: cat_start, cat_end, cat_label = key if cat_start == self.start_char and cat_end == self.end_char: doc.cats[cat_label] = value + if copy_user_data: + doc.user_data = self.doc.user_data return doc def _fix_dep_copy(self, attrs, array): diff --git a/website/docs/api/span.md b/website/docs/api/span.md index 0af305b37..c807c7bbf 100644 --- a/website/docs/api/span.md +++ b/website/docs/api/span.md @@ -292,9 +292,10 @@ Create a new `Doc` object corresponding to the `Span`, with a copy of the data. > assert doc2.text == u"New York" > ``` -| Name | Type | Description | -| ----------- | ----- | --------------------------------------- | -| **RETURNS** | `Doc` | A `Doc` object of the `Span`'s content. | +| Name | Type | Description | +| ----------------- | ----- | ---------------------------------------------------- | +| `copy_user_data` | bool | Whether or not to copy the original doc's user data. | +| **RETURNS** | `Doc` | A `Doc` object of the `Span`'s content. | ## Span.root {#root tag="property" model="parser"} From 228bbf506dc5a76b6fdef5d9aa45ad0bf7b577c5 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Thu, 12 Sep 2019 18:02:44 +0200 Subject: [PATCH 08/15] Improve label properties on pipes --- spacy/pipeline/pipes.pyx | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index da376c396..3d799b3da 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1062,8 +1062,15 @@ cdef class DependencyParser(Parser): @property def labels(self): + labels = set() # Get the labels from the model by looking at the available moves - return tuple(set(move.split("-")[1] for move in self.move_names if "-" in move)) + for move in self.move_names: + if "-" in move: + label = move.split("-")[1] + if "||" in label: + label = label.split("||")[1] + labels.add(label) + return tuple(sorted(labels)) cdef class EntityRecognizer(Parser): @@ -1098,8 +1105,9 @@ cdef class EntityRecognizer(Parser): def labels(self): # Get the labels from the model by looking at the available moves, e.g. # B-PERSON, I-PERSON, L-PERSON, U-PERSON - return tuple(set(move.split("-")[1] for move in self.move_names - if move[0] in ("B", "I", "L", "U"))) + labels = set(move.split("-")[1] for move in self.move_names + if move[0] in ("B", "I", "L", "U")) + return tuple(sorted(labels)) class EntityLinker(Pipe): From 2ae5db580edec1c83afdb52db76cc0763bed2d5c Mon Sep 17 00:00:00 2001 From: Sofie Van Landeghem Date: Fri, 13 Sep 2019 16:30:05 +0200 Subject: [PATCH 09/15] dim bugfix when incl_prior is False (#4285) --- spacy/errors.py | 2 ++ spacy/pipeline/pipes.pyx | 4 +++- 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/spacy/errors.py b/spacy/errors.py index b8a8dccba..587a6e700 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -455,6 +455,8 @@ class Errors(object): E158 = ("Can't add table '{name}' to lookups because it already exists.") E159 = ("Can't find table '{name}' in lookups. Available tables: {tables}") E160 = ("Can't find language data file: {path}") + E161 = ("Found an internal inconsistency when predicting entity links. " + "This is likely a bug in spaCy, so feel free to open an issue.") @add_codes diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 3d799b3da..412433565 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1283,7 +1283,7 @@ class EntityLinker(Pipe): # this will set all prior probabilities to 0 if they should be excluded from the model prior_probs = xp.asarray([c.prior_prob for c in candidates]) if not self.cfg.get("incl_prior", True): - prior_probs = xp.asarray([[0.0] for c in candidates]) + prior_probs = xp.asarray([0.0 for c in candidates]) scores = prior_probs # add in similarity from the context @@ -1296,6 +1296,8 @@ class EntityLinker(Pipe): # cosine similarity sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2) + if sims.shape != prior_probs.shape: + raise ValueError(Errors.E161) scores = prior_probs + sims - (prior_probs*sims) # TODO: thresholding From a6830d60e8c4acc6f378a3b4e6c48e851e729408 Mon Sep 17 00:00:00 2001 From: Euan Dowers Date: Fri, 13 Sep 2019 17:03:57 +0200 Subject: [PATCH 10/15] Changes to wiki_entity_linker (#4235) * Changes to wiki_entity_linker * No more f-strings * Make some requested changes * Add back option to get descriptions from wd not wp * Fix logs * Address comments and clean evaluation * Remove type hints * Refactor evaluation, add back metrics by label * Address comments * Log training performance as well as dev --- bin/wiki_entity_linking/__init__.py | 11 + .../entity_linker_evaluation.py | 200 +++++++ bin/wiki_entity_linking/kb_creator.py | 162 +++--- bin/wiki_entity_linking/train_descriptions.py | 19 +- .../training_set_creator.py | 502 ++++++++---------- .../wikidata_pretrain_kb.py | 123 +++-- bin/wiki_entity_linking/wikidata_processor.py | 42 +- .../wikidata_train_entity_linker.py | 407 +++----------- .../wikipedia_processor.py | 14 +- 9 files changed, 709 insertions(+), 771 deletions(-) create mode 100644 bin/wiki_entity_linking/entity_linker_evaluation.py diff --git a/bin/wiki_entity_linking/__init__.py b/bin/wiki_entity_linking/__init__.py index e69de29bb..a604bcc2f 100644 --- a/bin/wiki_entity_linking/__init__.py +++ b/bin/wiki_entity_linking/__init__.py @@ -0,0 +1,11 @@ +TRAINING_DATA_FILE = "gold_entities.jsonl" +KB_FILE = "kb" +KB_MODEL_DIR = "nlp_kb" +OUTPUT_MODEL_DIR = "nlp" + +PRIOR_PROB_PATH = "prior_prob.csv" +ENTITY_DEFS_PATH = "entity_defs.csv" +ENTITY_FREQ_PATH = "entity_freq.csv" +ENTITY_DESCR_PATH = "entity_descriptions.csv" + +LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s' diff --git a/bin/wiki_entity_linking/entity_linker_evaluation.py b/bin/wiki_entity_linking/entity_linker_evaluation.py new file mode 100644 index 000000000..1b1200564 --- /dev/null +++ b/bin/wiki_entity_linking/entity_linker_evaluation.py @@ -0,0 +1,200 @@ +import logging +import random + +from collections import defaultdict + +logger = logging.getLogger(__name__) + + +class Metrics(object): + true_pos = 0 + false_pos = 0 + false_neg = 0 + + def update_results(self, true_entity, candidate): + candidate_is_correct = true_entity == candidate + + # Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL") + # Therefore, if candidate_is_correct then we have a true positive and never a true negative + self.true_pos += candidate_is_correct + self.false_neg += not candidate_is_correct + if candidate not in {"", "NIL"}: + self.false_pos += not candidate_is_correct + + def calculate_precision(self): + if self.true_pos == 0: + return 0.0 + else: + return self.true_pos / (self.true_pos + self.false_pos) + + def calculate_recall(self): + if self.true_pos == 0: + return 0.0 + else: + return self.true_pos / (self.true_pos + self.false_neg) + + +class EvaluationResults(object): + def __init__(self): + self.metrics = Metrics() + self.metrics_by_label = defaultdict(Metrics) + + def update_metrics(self, ent_label, true_entity, candidate): + self.metrics.update_results(true_entity, candidate) + self.metrics_by_label[ent_label].update_results(true_entity, candidate) + + def increment_false_negatives(self): + self.metrics.false_neg += 1 + + def report_metrics(self, model_name): + model_str = model_name.title() + recall = self.metrics.calculate_recall() + precision = self.metrics.calculate_precision() + return ("{}: ".format(model_str) + + "Recall = {} | ".format(round(recall, 3)) + + "Precision = {} | ".format(round(precision, 3)) + + "Precision by label = {}".format({k: v.calculate_precision() + for k, v in self.metrics_by_label.items()})) + + +class BaselineResults(object): + def __init__(self): + self.random = EvaluationResults() + self.prior = EvaluationResults() + self.oracle = EvaluationResults() + + def report_accuracy(self, model): + results = getattr(self, model) + return results.report_metrics(model) + + def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate): + self.oracle.update_metrics(ent_label, true_entity, oracle_candidate) + self.prior.update_metrics(ent_label, true_entity, prior_candidate) + self.random.update_metrics(ent_label, true_entity, random_candidate) + + +def measure_performance(dev_data, kb, el_pipe): + baseline_accuracies = measure_baselines( + dev_data, kb + ) + + logger.info(baseline_accuracies.report_accuracy("random")) + logger.info(baseline_accuracies.report_accuracy("prior")) + logger.info(baseline_accuracies.report_accuracy("oracle")) + + # using only context + el_pipe.cfg["incl_context"] = True + el_pipe.cfg["incl_prior"] = False + results = get_eval_results(dev_data, el_pipe) + logger.info(results.report_metrics("context only")) + + # measuring combined accuracy (prior + context) + el_pipe.cfg["incl_context"] = True + el_pipe.cfg["incl_prior"] = True + results = get_eval_results(dev_data, el_pipe) + logger.info(results.report_metrics("context and prior")) + + +def get_eval_results(data, el_pipe=None): + # If the docs in the data require further processing with an entity linker, set el_pipe + from tqdm import tqdm + + docs = [] + golds = [] + for d, g in tqdm(data, leave=False): + if len(d) > 0: + golds.append(g) + if el_pipe is not None: + docs.append(el_pipe(d)) + else: + docs.append(d) + + results = EvaluationResults() + for doc, gold in zip(docs, golds): + tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents} + try: + correct_entries_per_article = dict() + for entity, kb_dict in gold.links.items(): + start, end = entity + # only evaluating on positive examples + for gold_kb, value in kb_dict.items(): + if value: + offset = _offset(start, end) + correct_entries_per_article[offset] = gold_kb + if offset not in tagged_entries_per_article: + results.increment_false_negatives() + + for ent in doc.ents: + ent_label = ent.label_ + pred_entity = ent.kb_id_ + start = ent.start_char + end = ent.end_char + offset = _offset(start, end) + gold_entity = correct_entries_per_article.get(offset, None) + # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' + if gold_entity is not None: + results.update_metrics(ent_label, gold_entity, pred_entity) + + except Exception as e: + logging.error("Error assessing accuracy " + str(e)) + + return results + + +def measure_baselines(data, kb): + # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound + counts_d = dict() + + baseline_results = BaselineResults() + + docs = [d for d, g in data if len(d) > 0] + golds = [g for d, g in data if len(d) > 0] + + for doc, gold in zip(docs, golds): + correct_entries_per_article = dict() + tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents} + for entity, kb_dict in gold.links.items(): + start, end = entity + for gold_kb, value in kb_dict.items(): + # only evaluating on positive examples + if value: + offset = _offset(start, end) + correct_entries_per_article[offset] = gold_kb + if offset not in tagged_entries_per_article: + baseline_results.random.increment_false_negatives() + baseline_results.oracle.increment_false_negatives() + baseline_results.prior.increment_false_negatives() + + for ent in doc.ents: + ent_label = ent.label_ + start = ent.start_char + end = ent.end_char + offset = _offset(start, end) + gold_entity = correct_entries_per_article.get(offset, None) + + # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' + if gold_entity is not None: + candidates = kb.get_candidates(ent.text) + oracle_candidate = "" + best_candidate = "" + random_candidate = "" + if candidates: + scores = [] + + for c in candidates: + scores.append(c.prior_prob) + if c.entity_ == gold_entity: + oracle_candidate = c.entity_ + + best_index = scores.index(max(scores)) + best_candidate = candidates[best_index].entity_ + random_candidate = random.choice(candidates).entity_ + + baseline_results.update_baselines(gold_entity, ent_label, + random_candidate, best_candidate, oracle_candidate) + + return baseline_results + + +def _offset(start, end): + return "{}_{}".format(start, end) diff --git a/bin/wiki_entity_linking/kb_creator.py b/bin/wiki_entity_linking/kb_creator.py index bd862f536..54ed7815e 100644 --- a/bin/wiki_entity_linking/kb_creator.py +++ b/bin/wiki_entity_linking/kb_creator.py @@ -1,12 +1,20 @@ # coding: utf-8 from __future__ import unicode_literals -from bin.wiki_entity_linking.train_descriptions import EntityEncoder -from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp +import csv +import logging +import spacy +import sys + from spacy.kb import KnowledgeBase -import csv -import datetime +from bin.wiki_entity_linking import wikipedia_processor as wp +from bin.wiki_entity_linking.train_descriptions import EntityEncoder + +csv.field_size_limit(sys.maxsize) + + +logger = logging.getLogger(__name__) def create_kb( @@ -14,52 +22,73 @@ def create_kb( max_entities_per_alias, min_entity_freq, min_occ, - entity_def_output, - entity_descr_output, + entity_def_input, + entity_descr_path, count_input, prior_prob_input, - wikidata_input, entity_vector_length, - limit=None, - read_raw_data=True, ): # Create the knowledge base from Wikidata entries kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length) + # read the mappings from file + title_to_id = get_entity_to_id(entity_def_input) + id_to_descr = get_id_to_description(entity_descr_path) + # check the length of the nlp vectors if "vectors" in nlp.meta and nlp.vocab.vectors.size: input_dim = nlp.vocab.vectors_length - print("Loaded pre-trained vectors of size %s" % input_dim) + logger.info("Loaded pre-trained vectors of size %s" % input_dim) else: raise ValueError( "The `nlp` object should have access to pre-trained word vectors, " " cf. https://spacy.io/usage/models#languages." ) - # disable this part of the pipeline when rerunning the KB generation from preprocessed files - if read_raw_data: - print() - print(now(), " * read wikidata entities:") - title_to_id, id_to_descr = wd.read_wikidata_entities_json( - wikidata_input, limit=limit - ) - - # write the title-ID and ID-description mappings to file - _write_entity_files( - entity_def_output, entity_descr_output, title_to_id, id_to_descr - ) - - else: - # read the mappings from file - title_to_id = get_entity_to_id(entity_def_output) - id_to_descr = get_id_to_description(entity_descr_output) - - print() - print(now(), " * get entity frequencies:") - print() + logger.info("Get entity frequencies") entity_frequencies = wp.get_all_frequencies(count_input=count_input) + logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq)) # filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise + filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities( + title_to_id, + id_to_descr, + entity_frequencies, + min_entity_freq + ) + logger.info("Left with {} entities".format(len(description_list))) + + logger.info("Train entity encoder") + encoder = EntityEncoder(nlp, input_dim, entity_vector_length) + encoder.train(description_list=description_list, to_print=True) + + logger.info("Get entity embeddings:") + embeddings = encoder.apply_encoder(description_list) + + logger.info("Adding {} entities".format(len(entity_list))) + kb.set_entities( + entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings + ) + + logger.info("Adding aliases") + _add_aliases( + kb, + title_to_id=filtered_title_to_id, + max_entities_per_alias=max_entities_per_alias, + min_occ=min_occ, + prior_prob_input=prior_prob_input, + ) + + logger.info("KB size: {} entities, {} aliases".format( + kb.get_size_entities(), + kb.get_size_aliases())) + + logger.info("Done with kb") + return kb + + +def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies, + min_entity_freq: int = 10): filtered_title_to_id = dict() entity_list = [] description_list = [] @@ -72,58 +101,7 @@ def create_kb( description_list.append(desc) frequency_list.append(freq) filtered_title_to_id[title] = entity - - print(len(title_to_id.keys()), "original titles") - kept_nr = len(filtered_title_to_id.keys()) - print("kept", kept_nr, "entities with min. frequency", min_entity_freq) - - print() - print(now(), " * train entity encoder:") - print() - encoder = EntityEncoder(nlp, input_dim, entity_vector_length) - encoder.train(description_list=description_list, to_print=True) - - print() - print(now(), " * get entity embeddings:") - print() - embeddings = encoder.apply_encoder(description_list) - - print(now(), " * adding", len(entity_list), "entities") - kb.set_entities( - entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings - ) - - alias_cnt = _add_aliases( - kb, - title_to_id=filtered_title_to_id, - max_entities_per_alias=max_entities_per_alias, - min_occ=min_occ, - prior_prob_input=prior_prob_input, - ) - print() - print(now(), " * adding", alias_cnt, "aliases") - print() - - print() - print("# of entities in kb:", kb.get_size_entities()) - print("# of aliases in kb:", kb.get_size_aliases()) - - print(now(), "Done with kb") - return kb - - -def _write_entity_files( - entity_def_output, entity_descr_output, title_to_id, id_to_descr -): - with entity_def_output.open("w", encoding="utf8") as id_file: - id_file.write("WP_title" + "|" + "WD_id" + "\n") - for title, qid in title_to_id.items(): - id_file.write(title + "|" + str(qid) + "\n") - - with entity_descr_output.open("w", encoding="utf8") as descr_file: - descr_file.write("WD_id" + "|" + "description" + "\n") - for qid, descr in id_to_descr.items(): - descr_file.write(str(qid) + "|" + descr + "\n") + return filtered_title_to_id, entity_list, description_list, frequency_list def get_entity_to_id(entity_def_output): @@ -137,9 +115,9 @@ def get_entity_to_id(entity_def_output): return entity_to_id -def get_id_to_description(entity_descr_output): +def get_id_to_description(entity_descr_path): id_to_desc = dict() - with entity_descr_output.open("r", encoding="utf8") as csvfile: + with entity_descr_path.open("r", encoding="utf8") as csvfile: csvreader = csv.reader(csvfile, delimiter="|") # skip header next(csvreader) @@ -150,7 +128,6 @@ def get_id_to_description(entity_descr_output): def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input): wp_titles = title_to_id.keys() - cnt = 0 # adding aliases with prior probabilities # we can read this file sequentially, it's sorted by alias, and then by count @@ -187,9 +164,8 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in entities=selected_entities, probabilities=prior_probs, ) - cnt += 1 except ValueError as e: - print(e) + logger.error(e) total_count = 0 counts = [] entities = [] @@ -202,8 +178,12 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in previous_alias = new_alias line = prior_file.readline() - return cnt -def now(): - return datetime.datetime.now() +def read_nlp_kb(model_dir, kb_file): + nlp = spacy.load(model_dir) + kb = KnowledgeBase(vocab=nlp.vocab) + kb.load_bulk(kb_file) + logger.info("kb entities: {}".format(kb.get_size_entities())) + logger.info("kb aliases: {}".format(kb.get_size_aliases())) + return nlp, kb diff --git a/bin/wiki_entity_linking/train_descriptions.py b/bin/wiki_entity_linking/train_descriptions.py index 0663296e4..2cb66909f 100644 --- a/bin/wiki_entity_linking/train_descriptions.py +++ b/bin/wiki_entity_linking/train_descriptions.py @@ -1,6 +1,7 @@ # coding: utf-8 from random import shuffle +import logging import numpy as np from spacy._ml import zero_init, create_default_optimizer @@ -10,6 +11,8 @@ from thinc.v2v import Model from thinc.api import chain from thinc.neural._classes.affine import Affine +logger = logging.getLogger(__name__) + class EntityEncoder: """ @@ -50,21 +53,19 @@ class EntityEncoder: start = start + batch_size stop = min(stop + batch_size, len(description_list)) - print("encoded:", stop, "entities") + logger.info("encoded: {} entities".format(stop)) return encodings def train(self, description_list, to_print=False): processed, loss = self._train_model(description_list) if to_print: - print( - "Trained entity descriptions on", - processed, - "(non-unique) entities across", - self.epochs, - "epochs", + logger.info( + "Trained entity descriptions on {} ".format(processed) + + "(non-unique) entities across {} ".format(self.epochs) + + "epochs" ) - print("Final loss:", loss) + logger.info("Final loss: {}".format(loss)) def _train_model(self, description_list): best_loss = 1.0 @@ -93,7 +94,7 @@ class EntityEncoder: loss = self._update(batch) if batch_nr % 25 == 0: - print("loss:", loss) + logger.info("loss: {} ".format(loss)) processed += len(batch) # in general, continue training if we haven't reached our ideal min yet diff --git a/bin/wiki_entity_linking/training_set_creator.py b/bin/wiki_entity_linking/training_set_creator.py index 7f45d9435..3f42f8bdd 100644 --- a/bin/wiki_entity_linking/training_set_creator.py +++ b/bin/wiki_entity_linking/training_set_creator.py @@ -1,10 +1,13 @@ # coding: utf-8 from __future__ import unicode_literals +import logging import random import re import bz2 -import datetime +import json + +from functools import partial from spacy.gold import GoldParse from bin.wiki_entity_linking import kb_creator @@ -15,18 +18,30 @@ Gold-standard entities are stored in one file in standoff format (by character o """ ENTITY_FILE = "gold_entities.csv" +logger = logging.getLogger(__name__) -def now(): - return datetime.datetime.now() - - -def create_training(wikipedia_input, entity_def_input, training_output, limit=None): +def create_training_examples_and_descriptions(wikipedia_input, + entity_def_input, + description_output, + training_output, + parse_descriptions, + limit=None): wp_to_id = kb_creator.get_entity_to_id(entity_def_input) - _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=limit) + _process_wikipedia_texts(wikipedia_input, + wp_to_id, + description_output, + training_output, + parse_descriptions, + limit) -def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None): +def _process_wikipedia_texts(wikipedia_input, + wp_to_id, + output, + training_output, + parse_descriptions, + limit=None): """ Read the XML wikipedia data to parse out training data: raw text data + positive instances @@ -35,29 +50,21 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N id_regex = re.compile(r"(?<=)\d*(?=)") read_ids = set() - entityfile_loc = training_output / ENTITY_FILE - with entityfile_loc.open("w", encoding="utf8") as entityfile: - # write entity training header file - _write_training_entity( - outputfile=entityfile, - article_id="article_id", - alias="alias", - entity="WD_id", - start="start", - end="end", - ) + with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file: + if parse_descriptions: + _write_training_description(descr_file, "WD_id", "description") with bz2.open(wikipedia_input, mode="rb") as file: - line = file.readline() - cnt = 0 + article_count = 0 article_text = "" article_title = None article_id = None reading_text = False reading_revision = False - while line and (not limit or cnt < limit): - if cnt % 1000000 == 0: - print(now(), "processed", cnt, "lines of Wikipedia dump") + + logger.info("Processed {} articles".format(article_count)) + + for line in file: clean_line = line.strip().decode("utf-8") if clean_line == "": @@ -70,28 +77,32 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N article_text = "" article_title = None article_id = None - # finished reading this page elif clean_line == "": if article_id: - try: - _process_wp_text( - wp_to_id, - entityfile, - article_id, - article_title, - article_text.strip(), - training_output, - ) - except Exception as e: - print( - "Error processing article", article_id, article_title, e - ) - else: - print( - "Done processing a page, but couldn't find an article_id ?", + clean_text, entities = _process_wp_text( article_title, + article_text, + wp_to_id ) + if clean_text is not None and entities is not None: + _write_training_entities(entity_file, + article_id, + clean_text, + entities) + + if article_title in wp_to_id and parse_descriptions: + description = " ".join(clean_text[:1000].split(" ")[:-1]) + _write_training_description( + descr_file, + wp_to_id[article_title], + description + ) + article_count += 1 + if article_count % 10000 == 0: + logger.info("Processed {} articles".format(article_count)) + if limit and article_count >= limit: + break article_text = "" article_title = None article_id = None @@ -115,7 +126,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N if ids: article_id = ids[0] if article_id in read_ids: - print( + logger.info( "Found duplicate article ID", article_id, clean_line ) # This should never happen ... read_ids.add(article_id) @@ -125,115 +136,10 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N titles = title_regex.search(clean_line) if titles: article_title = titles[0].strip() - - line = file.readline() - cnt += 1 - print(now(), "processed", cnt, "lines of Wikipedia dump") + logger.info("Finished. Processed {} articles".format(article_count)) text_regex = re.compile(r"(?<=).*(?= 2: - reading_special_case = True - - if open_read == 2 and reading_text: - reading_text = False - reading_entity = True - reading_mention = False - - # we just finished reading an entity - if open_read == 0 and not reading_text: - if "#" in entity_buffer or entity_buffer.startswith(":"): - reading_special_case = True - # Ignore cases with nested structures like File: handles etc - if not reading_special_case: - if not mention_buffer: - mention_buffer = entity_buffer - start = len(final_text) - end = start + len(mention_buffer) - qid = wp_to_id.get(entity_buffer, None) - if qid: - _write_training_entity( - outputfile=entityfile, - article_id=article_id, - alias=mention_buffer, - entity=qid, - start=start, - end=end, - ) - found_entities = True - final_text += mention_buffer - - entity_buffer = "" - mention_buffer = "" - - reading_text = True - reading_entity = False - reading_mention = False - reading_special_case = False - - if found_entities: - _write_training_article( - article_id=article_id, - clean_text=final_text, - training_output=training_output, - ) - - info_regex = re.compile(r"{[^{]*?}") htlm_regex = re.compile(r"<!--[^-]*-->") category_regex = re.compile(r"\[\[Category:[^\[]*]]") @@ -242,6 +148,29 @@ ref_regex = re.compile(r"<ref.*?>") # non-greedy ref_2_regex = re.compile(r"</ref.*?>") # non-greedy +def _process_wp_text(article_title, article_text, wp_to_id): + # ignore meta Wikipedia pages + if ( + article_title.startswith("Wikipedia:") or + article_title.startswith("Kategori:") + ): + return None, None + + # remove the text tags + text_search = text_regex.search(article_text) + if text_search is None: + return None, None + text = text_search.group(0) + + # stop processing if this is a redirect page + if text.startswith("#REDIRECT"): + return None, None + + # get the raw text without markup etc, keeping only interwiki links + clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id) + return clean_text, entities + + def _get_clean_wp_text(article_text): clean_text = article_text.strip() @@ -300,130 +229,167 @@ def _get_clean_wp_text(article_text): return clean_text.strip() -def _write_training_article(article_id, clean_text, training_output): - file_loc = training_output / "{}.txt".format(article_id) - with file_loc.open("w", encoding="utf8") as outputfile: - outputfile.write(clean_text) +def _remove_links(clean_text, wp_to_id): + # read the text char by char to get the right offsets for the interwiki links + entities = [] + final_text = "" + open_read = 0 + reading_text = True + reading_entity = False + reading_mention = False + reading_special_case = False + entity_buffer = "" + mention_buffer = "" + for index, letter in enumerate(clean_text): + if letter == "[": + open_read += 1 + elif letter == "]": + open_read -= 1 + elif letter == "|": + if reading_text: + final_text += letter + # switch from reading entity to mention in the [[entity|mention]] pattern + elif reading_entity: + reading_text = False + reading_entity = False + reading_mention = True + else: + reading_special_case = True + else: + if reading_entity: + entity_buffer += letter + elif reading_mention: + mention_buffer += letter + elif reading_text: + final_text += letter + else: + raise ValueError("Not sure at point", clean_text[index - 2: index + 2]) + + if open_read > 2: + reading_special_case = True + + if open_read == 2 and reading_text: + reading_text = False + reading_entity = True + reading_mention = False + + # we just finished reading an entity + if open_read == 0 and not reading_text: + if "#" in entity_buffer or entity_buffer.startswith(":"): + reading_special_case = True + # Ignore cases with nested structures like File: handles etc + if not reading_special_case: + if not mention_buffer: + mention_buffer = entity_buffer + start = len(final_text) + end = start + len(mention_buffer) + qid = wp_to_id.get(entity_buffer, None) + if qid: + entities.append((mention_buffer, qid, start, end)) + final_text += mention_buffer + + entity_buffer = "" + mention_buffer = "" + + reading_text = True + reading_entity = False + reading_mention = False + reading_special_case = False + return final_text, entities -def _write_training_entity(outputfile, article_id, alias, entity, start, end): - line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end) +def _write_training_description(outputfile, qid, description): + if description is not None: + line = str(qid) + "|" + description + "\n" + outputfile.write(line) + + +def _write_training_entities(outputfile, article_id, clean_text, entities): + entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities] + line = json.dumps( + { + "article_id": article_id, + "clean_text": clean_text, + "entities": entities_data + }, + ensure_ascii=False) + "\n" outputfile.write(line) +def read_training(nlp, entity_file_path, dev, limit, kb): + """ This method provides training examples that correspond to the entity annotations found by the nlp object. + For training,, it will include negative training examples by using the candidate generator, + and it will only keep positive training examples that can be found by using the candidate generator. + For testing, it will include all positive examples only.""" + + from tqdm import tqdm + data = [] + num_entities = 0 + get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb) + + logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit)) + with entity_file_path.open("r", encoding="utf8") as file: + with tqdm(total=limit, leave=False) as pbar: + for i, line in enumerate(file): + example = json.loads(line) + article_id = example["article_id"] + clean_text = example["clean_text"] + entities = example["entities"] + + if dev != is_dev(article_id) or len(clean_text) >= 30000: + continue + + doc = nlp(clean_text) + gold = get_gold_parse(doc, entities) + if gold and len(gold.links) > 0: + data.append((doc, gold)) + num_entities += len(gold.links) + pbar.update(len(gold.links)) + if limit and num_entities >= limit: + break + logger.info("Read {} entities in {} articles".format(num_entities, len(data))) + return data + + +def _get_gold_parse(doc, entities, dev, kb): + gold_entities = {} + tagged_ent_positions = set( + [(ent.start_char, ent.end_char) for ent in doc.ents] + ) + + for entity in entities: + entity_id = entity["entity"] + alias = entity["alias"] + start = entity["start"] + end = entity["end"] + + candidates = kb.get_candidates(alias) + candidate_ids = [ + c.entity_ for c in candidates + ] + + should_add_ent = ( + dev or + ( + (start, end) in tagged_ent_positions and + entity_id in candidate_ids and + len(candidates) > 1 + ) + ) + + if should_add_ent: + value_by_id = {entity_id: 1.0} + if not dev: + random.shuffle(candidate_ids) + value_by_id.update({ + kb_id: 0.0 + for kb_id in candidate_ids + if kb_id != entity_id + }) + gold_entities[(start, end)] = value_by_id + + return GoldParse(doc, links=gold_entities) + + def is_dev(article_id): return article_id.endswith("3") - - -def read_training(nlp, training_dir, dev, limit, kb=None): - """ This method provides training examples that correspond to the entity annotations found by the nlp object. - When kb is provided (for training), it will include negative training examples by using the candidate generator, - and it will only keep positive training examples that can be found in the KB. - When kb=None (for testing), it will include all positive examples only.""" - entityfile_loc = training_dir / ENTITY_FILE - data = [] - - # assume the data is written sequentially, so we can reuse the article docs - current_article_id = None - current_doc = None - ents_by_offset = dict() - skip_articles = set() - total_entities = 0 - - with entityfile_loc.open("r", encoding="utf8") as file: - for line in file: - if not limit or len(data) < limit: - fields = line.replace("\n", "").split(sep="|") - article_id = fields[0] - alias = fields[1] - wd_id = fields[2] - start = fields[3] - end = fields[4] - - if ( - dev == is_dev(article_id) - and article_id != "article_id" - and article_id not in skip_articles - ): - if not current_doc or (current_article_id != article_id): - # parse the new article text - file_name = article_id + ".txt" - try: - training_file = training_dir / file_name - with training_file.open("r", encoding="utf8") as f: - text = f.read() - # threshold for convenience / speed of processing - if len(text) < 30000: - current_doc = nlp(text) - current_article_id = article_id - ents_by_offset = dict() - for ent in current_doc.ents: - sent_length = len(ent.sent) - # custom filtering to avoid too long or too short sentences - if 5 < sent_length < 100: - offset = "{}_{}".format( - ent.start_char, ent.end_char - ) - ents_by_offset[offset] = ent - else: - skip_articles.add(article_id) - current_doc = None - except Exception as e: - print("Problem parsing article", article_id, e) - skip_articles.add(article_id) - - # repeat checking this condition in case an exception was thrown - if current_doc and (current_article_id == article_id): - offset = "{}_{}".format(start, end) - found_ent = ents_by_offset.get(offset, None) - if found_ent: - if found_ent.text != alias: - skip_articles.add(article_id) - current_doc = None - else: - sent = found_ent.sent.as_doc() - - gold_start = int(start) - found_ent.sent.start_char - gold_end = int(end) - found_ent.sent.start_char - - gold_entities = {} - found_useful = False - for ent in sent.ents: - entry = (ent.start_char, ent.end_char) - gold_entry = (gold_start, gold_end) - if entry == gold_entry: - # add both pos and neg examples (in random order) - # this will exclude examples not in the KB - if kb: - value_by_id = {} - candidates = kb.get_candidates(alias) - candidate_ids = [ - c.entity_ for c in candidates - ] - random.shuffle(candidate_ids) - for kb_id in candidate_ids: - found_useful = True - if kb_id != wd_id: - value_by_id[kb_id] = 0.0 - else: - value_by_id[kb_id] = 1.0 - gold_entities[entry] = value_by_id - # if no KB, keep all positive examples - else: - found_useful = True - value_by_id = {wd_id: 1.0} - - gold_entities[entry] = value_by_id - # currently feeding the gold data one entity per sentence at a time - # setting all other entities to empty gold dictionary - else: - gold_entities[entry] = {} - if found_useful: - gold = GoldParse(doc=sent, links=gold_entities) - data.append((sent, gold)) - total_entities += 1 - if len(data) % 2500 == 0: - print(" -read", total_entities, "entities") - - print(" -read", total_entities, "entities") - return data diff --git a/bin/wiki_entity_linking/wikidata_pretrain_kb.py b/bin/wiki_entity_linking/wikidata_pretrain_kb.py index c5261cada..56107f3a2 100644 --- a/bin/wiki_entity_linking/wikidata_pretrain_kb.py +++ b/bin/wiki_entity_linking/wikidata_pretrain_kb.py @@ -13,27 +13,25 @@ from https://dumps.wikimedia.org/enwiki/latest/ """ from __future__ import unicode_literals -import datetime +import logging from pathlib import Path import plac -from bin.wiki_entity_linking import wikipedia_processor as wp +from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd from bin.wiki_entity_linking import kb_creator - +from bin.wiki_entity_linking import training_set_creator +from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT +from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH import spacy -from spacy import Errors - - -def now(): - return datetime.datetime.now() +logger = logging.getLogger(__name__) @plac.annotations( wd_json=("Path to the downloaded WikiData JSON dump.", "positional", None, Path), wp_xml=("Path to the downloaded Wikipedia XML dump.", "positional", None, Path), output_dir=("Output directory", "positional", None, Path), - model=("Model name, should include pretrained vectors.", "positional", None, str), + model=("Model name or path, should include pretrained vectors.", "positional", None, str), max_per_alias=("Max. # entities per alias (default 10)", "option", "a", int), min_freq=("Min. count of an entity in the corpus (default 20)", "option", "f", int), min_pair=("Min. count of entity-alias pairs (default 5)", "option", "c", int), @@ -41,7 +39,9 @@ def now(): loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path), loc_entity_defs=("Location to file with entity definitions", "option", "d", Path), loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path), + descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"), limit=("Optional threshold to limit lines read from dumps", "option", "l", int), + lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str), ) def main( wd_json, @@ -55,20 +55,29 @@ def main( loc_prior_prob=None, loc_entity_defs=None, loc_entity_desc=None, + descriptions_from_wikipedia=False, limit=None, + lang="en", ): - print(now(), "Creating KB with Wikipedia and WikiData") - print() + + entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH + entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH + entity_freq_path = output_dir / ENTITY_FREQ_PATH + prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH + training_entities_path = output_dir / TRAINING_DATA_FILE + kb_path = output_dir / KB_FILE + + logger.info("Creating KB with Wikipedia and WikiData") if limit is not None: - print("Warning: reading only", limit, "lines of Wikipedia/Wikidata dumps.") + logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit)) # STEP 0: set up IO if not output_dir.exists(): - output_dir.mkdir() + output_dir.mkdir(parents=True) # STEP 1: create the NLP object - print(now(), "STEP 1: loaded model", model) + logger.info("STEP 1: Loading model {}".format(model)) nlp = spacy.load(model) # check the length of the nlp vectors @@ -79,64 +88,68 @@ def main( ) # STEP 2: create prior probabilities from WP - print() - if loc_prior_prob: - print(now(), "STEP 2: reading prior probabilities from", loc_prior_prob) - else: + if not prior_prob_path.exists(): # It takes about 2h to process 1000M lines of Wikipedia XML dump - loc_prior_prob = output_dir / "prior_prob.csv" - print(now(), "STEP 2: writing prior probabilities at", loc_prior_prob) - wp.read_prior_probs(wp_xml, loc_prior_prob, limit=limit) + logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path)) + wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit) + logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path)) # STEP 3: deduce entity frequencies from WP (takes only a few minutes) - print() - print(now(), "STEP 3: calculating entity frequencies") - loc_entity_freq = output_dir / "entity_freq.csv" - wp.write_entity_counts(loc_prior_prob, loc_entity_freq, to_print=False) + logger.info("STEP 3: calculating entity frequencies") + wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False) - loc_kb = output_dir / "kb" - - # STEP 4: reading entity descriptions and definitions from WikiData or from file - print() - if loc_entity_defs and loc_entity_desc: - read_raw = False - print(now(), "STEP 4a: reading entity definitions from", loc_entity_defs) - print(now(), "STEP 4b: reading entity descriptions from", loc_entity_desc) - else: + # STEP 4: reading definitions and (possibly) descriptions from WikiData or from file + message = " and descriptions" if not descriptions_from_wikipedia else "" + if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()): # It takes about 10h to process 55M lines of Wikidata JSON dump - read_raw = True - loc_entity_defs = output_dir / "entity_defs.csv" - loc_entity_desc = output_dir / "entity_descriptions.csv" - print(now(), "STEP 4: parsing wikidata for entity definitions and descriptions") + logger.info("STEP 4: parsing wikidata for entity definitions" + message) + title_to_id, id_to_descr = wd.read_wikidata_entities_json( + wd_json, + limit, + to_print=False, + lang=lang, + parse_descriptions=(not descriptions_from_wikipedia), + ) + wd.write_entity_files(entity_defs_path, title_to_id) + if not descriptions_from_wikipedia: + wd.write_entity_description_files(entity_descr_path, id_to_descr) + logger.info("STEP 4: read entity definitions" + message) - # STEP 5: creating the actual KB + # STEP 5: Getting gold entities from wikipedia + message = " and descriptions" if descriptions_from_wikipedia else "" + if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()): + logger.info("STEP 5: parsing wikipedia for gold entities" + message) + training_set_creator.create_training_examples_and_descriptions( + wp_xml, + entity_defs_path, + entity_descr_path, + training_entities_path, + parse_descriptions=descriptions_from_wikipedia, + limit=limit, + ) + logger.info("STEP 5: read gold entities" + message) + + # STEP 6: creating the actual KB # It takes ca. 30 minutes to pretrain the entity embeddings - print() - print(now(), "STEP 5: creating the KB at", loc_kb) + logger.info("STEP 6: creating the KB at {}".format(kb_path)) kb = kb_creator.create_kb( nlp=nlp, max_entities_per_alias=max_per_alias, min_entity_freq=min_freq, min_occ=min_pair, - entity_def_output=loc_entity_defs, - entity_descr_output=loc_entity_desc, - count_input=loc_entity_freq, - prior_prob_input=loc_prior_prob, - wikidata_input=wd_json, + entity_def_input=entity_defs_path, + entity_descr_path=entity_descr_path, + count_input=entity_freq_path, + prior_prob_input=prior_prob_path, entity_vector_length=entity_vector_length, - limit=limit, - read_raw_data=read_raw, ) - if read_raw: - print(" - wrote entity definitions to", loc_entity_defs) - print(" - wrote writing entity descriptions to", loc_entity_desc) - kb.dump(loc_kb) - nlp.to_disk(output_dir / "nlp") + kb.dump(kb_path) + nlp.to_disk(output_dir / KB_MODEL_DIR) - print() - print(now(), "Done!") + logger.info("Done!") if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) plac.call(main) diff --git a/bin/wiki_entity_linking/wikidata_processor.py b/bin/wiki_entity_linking/wikidata_processor.py index 660eab28e..b4034cb1a 100644 --- a/bin/wiki_entity_linking/wikidata_processor.py +++ b/bin/wiki_entity_linking/wikidata_processor.py @@ -1,17 +1,19 @@ # coding: utf-8 from __future__ import unicode_literals -import bz2 +import gzip import json +import logging import datetime +logger = logging.getLogger(__name__) -def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): + +def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True): # Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines. # get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/ - lang = "en" - site_filter = "enwiki" + site_filter = '{}wiki'.format(lang) # properties filter (currently disabled to get ALL data) prop_filter = dict() @@ -24,18 +26,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): parse_properties = False parse_sitelinks = True parse_labels = False - parse_descriptions = True parse_aliases = False parse_claims = False - with bz2.open(wikidata_file, mode="rb") as file: - line = file.readline() - cnt = 0 - while line and (not limit or cnt < limit): - if cnt % 1000000 == 0: - print( - datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump" - ) + with gzip.open(wikidata_file, mode='rb') as file: + for cnt, line in enumerate(file): + if limit and cnt >= limit: + break + if cnt % 500000 == 0: + logger.info("processed {} lines of WikiData dump".format(cnt)) clean_line = line.strip() if clean_line.endswith(b","): clean_line = clean_line[:-1] @@ -134,8 +133,19 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False): if to_print: print() - line = file.readline() - cnt += 1 - print(datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump") return title_to_id, id_to_descr + + +def write_entity_files(entity_def_output, title_to_id): + with entity_def_output.open("w", encoding="utf8") as id_file: + id_file.write("WP_title" + "|" + "WD_id" + "\n") + for title, qid in title_to_id.items(): + id_file.write(title + "|" + str(qid) + "\n") + + +def write_entity_description_files(entity_descr_output, id_to_descr): + with entity_descr_output.open("w", encoding="utf8") as descr_file: + descr_file.write("WD_id" + "|" + "description" + "\n") + for qid, descr in id_to_descr.items(): + descr_file.write(str(qid) + "|" + descr + "\n") diff --git a/bin/wiki_entity_linking/wikidata_train_entity_linker.py b/bin/wiki_entity_linking/wikidata_train_entity_linker.py index 770919112..d9ed641d6 100644 --- a/bin/wiki_entity_linking/wikidata_train_entity_linker.py +++ b/bin/wiki_entity_linking/wikidata_train_entity_linker.py @@ -11,124 +11,84 @@ from https://dumps.wikimedia.org/enwiki/latest/ from __future__ import unicode_literals import random -import datetime +import logging from pathlib import Path import plac from bin.wiki_entity_linking import training_set_creator +from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR +from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines +from bin.wiki_entity_linking.kb_creator import read_nlp_kb -import spacy -from spacy.kb import KnowledgeBase from spacy.util import minibatch, compounding - -def now(): - return datetime.datetime.now() +logger = logging.getLogger(__name__) @plac.annotations( dir_kb=("Directory with KB, NLP and related files", "positional", None, Path), output_dir=("Output directory", "option", "o", Path), loc_training=("Location to training data", "option", "k", Path), - wp_xml=("Path to the downloaded Wikipedia XML dump.", "option", "w", Path), epochs=("Number of training iterations (default 10)", "option", "e", int), dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float), lr=("Learning rate (default 0.005)", "option", "n", float), l2=("L2 regularization", "option", "r", float), train_inst=("# training instances (default 90% of all)", "option", "t", int), dev_inst=("# test instances (default 10% of all)", "option", "d", int), - limit=("Optional threshold to limit lines read from WP dump", "option", "l", int), ) def main( dir_kb, output_dir=None, loc_training=None, - wp_xml=None, epochs=10, dropout=0.5, lr=0.005, l2=1e-6, train_inst=None, dev_inst=None, - limit=None, ): - print(now(), "Creating Entity Linker with Wikipedia and WikiData") - print() + logger.info("Creating Entity Linker with Wikipedia and WikiData") + + output_dir = Path(output_dir) if output_dir else dir_kb + training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE + nlp_dir = dir_kb / KB_MODEL_DIR + kb_path = output_dir / KB_FILE + nlp_output_dir = output_dir / OUTPUT_MODEL_DIR # STEP 0: set up IO - if output_dir and not output_dir.exists(): + if not output_dir.exists(): output_dir.mkdir() # STEP 1 : load the NLP object - nlp_dir = dir_kb / "nlp" - print(now(), "STEP 1: loading model from", nlp_dir) - nlp = spacy.load(nlp_dir) + logger.info("STEP 1: loading model from {}".format(nlp_dir)) + nlp, kb = read_nlp_kb(nlp_dir, kb_path) # check that there is a NER component in the pipeline if "ner" not in nlp.pipe_names: raise ValueError("The `nlp` object should have a pre-trained `ner` component.") - # STEP 2 : read the KB - print() - print(now(), "STEP 2: reading the KB from", dir_kb / "kb") - kb = KnowledgeBase(vocab=nlp.vocab) - kb.load_bulk(dir_kb / "kb") + # STEP 2: create a training dataset from WP + logger.info("STEP 2: reading training dataset from {}".format(training_path)) - # STEP 3: create a training dataset from WP - print() - if loc_training: - print(now(), "STEP 3: reading training dataset from", loc_training) - else: - if not wp_xml: - raise ValueError( - "Either provide a path to a preprocessed training directory, " - "or to the original Wikipedia XML dump." - ) - - if output_dir: - loc_training = output_dir / "training_data" - else: - loc_training = dir_kb / "training_data" - if not loc_training.exists(): - loc_training.mkdir() - print(now(), "STEP 3: creating training dataset at", loc_training) - - if limit is not None: - print("Warning: reading only", limit, "lines of Wikipedia dump.") - - loc_entity_defs = dir_kb / "entity_defs.csv" - training_set_creator.create_training( - wikipedia_input=wp_xml, - entity_def_input=loc_entity_defs, - training_output=loc_training, - limit=limit, - ) - - # STEP 4: parse the training data - print() - print(now(), "STEP 4: parse the training & evaluation data") - - # for training, get pos & neg instances that correspond to entries in the kb - print("Parsing training data, limit =", train_inst) train_data = training_set_creator.read_training( - nlp=nlp, training_dir=loc_training, dev=False, limit=train_inst, kb=kb + nlp=nlp, + entity_file_path=training_path, + dev=False, + limit=train_inst, + kb=kb, ) - print("Training on", len(train_data), "articles") - print() - - print("Parsing dev testing data, limit =", dev_inst) # for testing, get all pos instances, whether or not they are in the kb dev_data = training_set_creator.read_training( - nlp=nlp, training_dir=loc_training, dev=True, limit=dev_inst, kb=None + nlp=nlp, + entity_file_path=training_path, + dev=True, + limit=dev_inst, + kb=kb, ) - print("Dev testing on", len(dev_data), "articles") - print() - - # STEP 5: create and train the entity linking pipe - print() - print(now(), "STEP 5: training Entity Linking pipe") + # STEP 3: create and train the entity linking pipe + logger.info("STEP 3: training Entity Linking pipe") el_pipe = nlp.create_pipe( name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name} @@ -142,275 +102,70 @@ def main( optimizer.learn_rate = lr optimizer.L2 = l2 - if not train_data: - print("Did not find any training data") - else: - for itn in range(epochs): - random.shuffle(train_data) - losses = {} - batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001)) - batchnr = 0 + logger.info("Training on {} articles".format(len(train_data))) + logger.info("Dev testing on {} articles".format(len(dev_data))) - with nlp.disable_pipes(*other_pipes): - for batch in batches: - try: - docs, golds = zip(*batch) - nlp.update( - docs=docs, - golds=golds, - sgd=optimizer, - drop=dropout, - losses=losses, - ) - batchnr += 1 - except Exception as e: - print("Error updating batch:", e) - - if batchnr > 0: - el_pipe.cfg["incl_context"] = True - el_pipe.cfg["incl_prior"] = True - dev_acc_context, _ = _measure_acc(dev_data, el_pipe) - losses["entity_linker"] = losses["entity_linker"] / batchnr - print( - "Epoch, train loss", - itn, - round(losses["entity_linker"], 2), - " / dev accuracy avg", - round(dev_acc_context, 3), - ) - - # STEP 6: measure the performance of our trained pipe on an independent dev set - print() - if len(dev_data): - print() - print(now(), "STEP 6: performance measurement of Entity Linking pipe") - print() - - counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines( - dev_data, kb - ) - print("dev counts:", sorted(counts.items(), key=lambda x: x[0])) - - oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()] - print("dev accuracy oracle:", round(acc_o, 3), oracle_by_label) - - random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()] - print("dev accuracy random:", round(acc_r, 3), random_by_label) - - prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()] - print("dev accuracy prior:", round(acc_p, 3), prior_by_label) - - # using only context - el_pipe.cfg["incl_context"] = True - el_pipe.cfg["incl_prior"] = False - dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe) - context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()] - print("dev accuracy context:", round(dev_acc_context, 3), context_by_label) - - # measuring combined accuracy (prior + context) - el_pipe.cfg["incl_context"] = True - el_pipe.cfg["incl_prior"] = True - dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe) - combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()] - print("dev accuracy prior+context:", round(dev_acc_combo, 3), combo_by_label) - - # STEP 7: apply the EL pipe on a toy example - print() - print(now(), "STEP 7: applying Entity Linking to toy example") - print() - run_el_toy_example(nlp=nlp) - - # STEP 8: write the NLP pipeline (including entity linker) to file - if output_dir: - print() - nlp_loc = output_dir / "nlp" - print(now(), "STEP 8: Writing trained NLP to", nlp_loc) - nlp.to_disk(nlp_loc) - print() - - print() - print(now(), "Done!") - - -def _measure_acc(data, el_pipe=None, error_analysis=False): - # If the docs in the data require further processing with an entity linker, set el_pipe - correct_by_label = dict() - incorrect_by_label = dict() - - docs = [d for d, g in data if len(d) > 0] - if el_pipe is not None: - docs = list(el_pipe.pipe(docs)) - golds = [g for d, g in data if len(d) > 0] - - for doc, gold in zip(docs, golds): - try: - correct_entries_per_article = dict() - for entity, kb_dict in gold.links.items(): - start, end = entity - # only evaluating on positive examples - for gold_kb, value in kb_dict.items(): - if value: - offset = _offset(start, end) - correct_entries_per_article[offset] = gold_kb - - for ent in doc.ents: - ent_label = ent.label_ - pred_entity = ent.kb_id_ - start = ent.start_char - end = ent.end_char - offset = _offset(start, end) - gold_entity = correct_entries_per_article.get(offset, None) - # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' - if gold_entity is not None: - if gold_entity == pred_entity: - correct = correct_by_label.get(ent_label, 0) - correct_by_label[ent_label] = correct + 1 - else: - incorrect = incorrect_by_label.get(ent_label, 0) - incorrect_by_label[ent_label] = incorrect + 1 - if error_analysis: - print(ent.text, "in", doc) - print( - "Predicted", - pred_entity, - "should have been", - gold_entity, - ) - print() - - except Exception as e: - print("Error assessing accuracy", e) - - acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label) - return acc, acc_by_label - - -def _measure_baselines(data, kb): - # Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound - counts_d = dict() - - random_correct_d = dict() - random_incorrect_d = dict() - - oracle_correct_d = dict() - oracle_incorrect_d = dict() - - prior_correct_d = dict() - prior_incorrect_d = dict() - - docs = [d for d, g in data if len(d) > 0] - golds = [g for d, g in data if len(d) > 0] - - for doc, gold in zip(docs, golds): - try: - correct_entries_per_article = dict() - for entity, kb_dict in gold.links.items(): - start, end = entity - for gold_kb, value in kb_dict.items(): - # only evaluating on positive examples - if value: - offset = _offset(start, end) - correct_entries_per_article[offset] = gold_kb - - for ent in doc.ents: - label = ent.label_ - start = ent.start_char - end = ent.end_char - offset = _offset(start, end) - gold_entity = correct_entries_per_article.get(offset, None) - - # the gold annotations are not complete so we can't evaluate missing annotations as 'wrong' - if gold_entity is not None: - counts_d[label] = counts_d.get(label, 0) + 1 - candidates = kb.get_candidates(ent.text) - oracle_candidate = "" - best_candidate = "" - random_candidate = "" - if candidates: - scores = [] - - for c in candidates: - scores.append(c.prior_prob) - if c.entity_ == gold_entity: - oracle_candidate = c.entity_ - - best_index = scores.index(max(scores)) - best_candidate = candidates[best_index].entity_ - random_candidate = random.choice(candidates).entity_ - - if gold_entity == best_candidate: - prior_correct_d[label] = prior_correct_d.get(label, 0) + 1 - else: - prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1 - - if gold_entity == random_candidate: - random_correct_d[label] = random_correct_d.get(label, 0) + 1 - else: - random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1 - - if gold_entity == oracle_candidate: - oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1 - else: - oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1 - - except Exception as e: - print("Error assessing accuracy", e) - - acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d) - acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d) - acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d) - - return ( - counts_d, - acc_rand, - acc_rand_d, - acc_prior, - acc_prior_d, - acc_oracle, - acc_oracle_d, + dev_baseline_accuracies = measure_baselines( + dev_data, kb ) + logger.info("Dev Baseline Accuracies:") + logger.info(dev_baseline_accuracies.report_accuracy("random")) + logger.info(dev_baseline_accuracies.report_accuracy("prior")) + logger.info(dev_baseline_accuracies.report_accuracy("oracle")) -def _offset(start, end): - return "{}_{}".format(start, end) + for itn in range(epochs): + random.shuffle(train_data) + losses = {} + batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001)) + batchnr = 0 + with nlp.disable_pipes(*other_pipes): + for batch in batches: + try: + docs, golds = zip(*batch) + nlp.update( + docs=docs, + golds=golds, + sgd=optimizer, + drop=dropout, + losses=losses, + ) + batchnr += 1 + except Exception as e: + logger.error("Error updating batch:" + str(e)) + if batchnr > 0: + logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2))) + measure_performance(dev_data, kb, el_pipe) -def calculate_acc(correct_by_label, incorrect_by_label): - acc_by_label = dict() - total_correct = 0 - total_incorrect = 0 - all_keys = set() - all_keys.update(correct_by_label.keys()) - all_keys.update(incorrect_by_label.keys()) - for label in sorted(all_keys): - correct = correct_by_label.get(label, 0) - incorrect = incorrect_by_label.get(label, 0) - total_correct += correct - total_incorrect += incorrect - if correct == incorrect == 0: - acc_by_label[label] = 0 - else: - acc_by_label[label] = correct / (correct + incorrect) - acc = 0 - if not (total_correct == total_incorrect == 0): - acc = total_correct / (total_correct + total_incorrect) - return acc, acc_by_label + # STEP 4: measure the performance of our trained pipe on an independent dev set + logger.info("STEP 4: performance measurement of Entity Linking pipe") + measure_performance(dev_data, kb, el_pipe) + + # STEP 5: apply the EL pipe on a toy example + logger.info("STEP 5: applying Entity Linking to toy example") + run_el_toy_example(nlp=nlp) + + if output_dir: + # STEP 6: write the NLP pipeline (including entity linker) to file + logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir)) + nlp.to_disk(nlp_output_dir) + + logger.info("Done!") def check_kb(kb): for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"): candidates = kb.get_candidates(mention) - print("generating candidates for " + mention + " :") + logger.info("generating candidates for " + mention + " :") for c in candidates: - print( - " ", - c.prior_prob, + logger.info(" ".join[ + str(c.prior_prob), c.alias_, "-->", - c.entity_ + " (freq=" + str(c.entity_freq) + ")", - ) - print() + c.entity_ + " (freq=" + str(c.entity_freq) + ")" + ]) def run_el_toy_example(nlp): @@ -421,11 +176,11 @@ def run_el_toy_example(nlp): "but Dougledydoug doesn't write about George Washington or Homer Simpson." ) doc = nlp(text) - print(text) + logger.info(text) for ent in doc.ents: - print(" ent", ent.text, ent.label_, ent.kb_id_) - print() + logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_])) if __name__ == "__main__": + logging.basicConfig(level=logging.INFO, format=LOG_FORMAT) plac.call(main) diff --git a/bin/wiki_entity_linking/wikipedia_processor.py b/bin/wiki_entity_linking/wikipedia_processor.py index fca600368..8f928723e 100644 --- a/bin/wiki_entity_linking/wikipedia_processor.py +++ b/bin/wiki_entity_linking/wikipedia_processor.py @@ -5,6 +5,9 @@ import re import bz2 import csv import datetime +import logging + +from bin.wiki_entity_linking import LOG_FORMAT """ Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions. @@ -13,6 +16,9 @@ Write these results to file for downstream KB and training data generation. map_alias_to_link = dict() +logger = logging.getLogger(__name__) + + # these will/should be matched ignoring case wiki_namespaces = [ "b", @@ -116,10 +122,6 @@ for ns in wiki_namespaces: ns_regex = re.compile(ns_regex, re.IGNORECASE) -def now(): - return datetime.datetime.now() - - def read_prior_probs(wikipedia_input, prior_prob_output, limit=None): """ Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities. @@ -131,7 +133,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None): cnt = 0 while line and (not limit or cnt < limit): if cnt % 25000000 == 0: - print(now(), "processed", cnt, "lines of Wikipedia XML dump") + logger.info("processed {} lines of Wikipedia XML dump".format(cnt)) clean_line = line.strip().decode("utf-8") aliases, entities, normalizations = get_wp_links(clean_line) @@ -141,7 +143,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None): line = file.readline() cnt += 1 - print(now(), "processed", cnt, "lines of Wikipedia XML dump") + logger.info("processed {} lines of Wikipedia XML dump".format(cnt)) # write all aliases and their entities and count occurrences to file with prior_prob_output.open("w", encoding="utf8") as outputfile: From bee79619278a5e055ac744798a6ea0001951c94b Mon Sep 17 00:00:00 2001 From: adrianeboyd Date: Sat, 14 Sep 2019 14:23:06 +0200 Subject: [PATCH 11/15] Add Kannada, Tamil, and Telugu unicode blocks (#4288) Add Kannada, Tamil, and Telugu unicode blocks to uncased character classes so that period is recognized as a suffix during tokenization. (I'm sure a few symbols in the code blocks should not be ALPHA, but this is mainly relevant for suffix detection and seems to be an improvement in practice.) --- spacy/lang/char_classes.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/spacy/lang/char_classes.py b/spacy/lang/char_classes.py index 9f6c3266e..131bdcd51 100644 --- a/spacy/lang/char_classes.py +++ b/spacy/lang/char_classes.py @@ -11,6 +11,12 @@ _hebrew = r"\u0591-\u05F4\uFB1D-\uFB4F" _hindi = r"\u0900-\u097F" +_kannada = r"\u0C80-\u0CFF" + +_tamil = r"\u0B80-\u0BFF" + +_telugu = r"\u0C00-\u0C7F" + # Latin standard _latin_u_standard = r"A-Z" _latin_l_standard = r"a-z" @@ -195,7 +201,7 @@ _ukrainian = r"а-щюяіїєґА-ЩЮЯІЇЄҐ" _upper = LATIN_UPPER + _russian_upper + _tatar_upper + _greek_upper + _ukrainian_upper _lower = LATIN_LOWER + _russian_lower + _tatar_lower + _greek_lower + _ukrainian_lower -_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi +_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi + _kannada + _tamil + _telugu ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased) ALPHA_LOWER = group_chars(_lower + _uncased) From 6942a6a69b5a50f6864427661bcd59403acfbd72 Mon Sep 17 00:00:00 2001 From: adrianeboyd Date: Sat, 14 Sep 2019 15:25:48 +0200 Subject: [PATCH 12/15] Extend default punct for sentencizer (#4290) Most of these characters are for languages / writing systems that aren't supported by spacy, but I don't think it causes problems to include them. In the UD evals, Hindi and Urdu improve a lot as expected (from 0-10% to 70-80%) and Persian improves a little (90% to 96%). Tamil improves in combination with #4288. The punctuation list is converted to a set internally because of its increased length. Sentence final punctuation generated with: ``` unichars -gas '[\p{Sentence_Break=STerm}\p{Sentence_Break=ATerm}]' '\p{Terminal_Punctuation}' ``` See: https://stackoverflow.com/a/9508766/461847 Fixes #4269. --- spacy/pipeline/pipes.pyx | 24 ++++++++++++++++++------ spacy/tests/pipeline/test_sentencizer.py | 4 ++-- 2 files changed, 20 insertions(+), 8 deletions(-) diff --git a/spacy/pipeline/pipes.pyx b/spacy/pipeline/pipes.pyx index 412433565..190116a2e 100644 --- a/spacy/pipeline/pipes.pyx +++ b/spacy/pipeline/pipes.pyx @@ -1371,7 +1371,16 @@ class Sentencizer(object): """ name = "sentencizer" - default_punct_chars = [".", "!", "?"] + default_punct_chars = ['!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹', + '।', '॥', '၊', '။', '።', '፧', '፨', '᙮', '᜵', '᜶', '᠃', '᠉', '᥄', + '᥅', '᪨', '᪩', '᪪', '᪫', '᭚', '᭛', '᭞', '᭟', '᰻', '᰼', '᱾', '᱿', + '‼', '‽', '⁇', '⁈', '⁉', '⸮', '⸼', '꓿', '꘎', '꘏', '꛳', '꛷', '꡶', + '꡷', '꣎', '꣏', '꤯', '꧈', '꧉', '꩝', '꩞', '꩟', '꫰', '꫱', '꯫', '﹒', + '﹖', '﹗', '!', '.', '?', '𐩖', '𐩗', '𑁇', '𑁈', '𑂾', '𑂿', '𑃀', + '𑃁', '𑅁', '𑅂', '𑅃', '𑇅', '𑇆', '𑇍', '𑇞', '𑇟', '𑈸', '𑈹', '𑈻', '𑈼', + '𑊩', '𑑋', '𑑌', '𑗂', '𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐', + '𑗑', '𑗒', '𑗓', '𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂', + '𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈'] def __init__(self, punct_chars=None, **kwargs): """Initialize the sentencizer. @@ -1382,7 +1391,10 @@ class Sentencizer(object): DOCS: https://spacy.io/api/sentencizer#init """ - self.punct_chars = punct_chars or self.default_punct_chars + if punct_chars: + self.punct_chars = set(punct_chars) + else: + self.punct_chars = set(self.default_punct_chars) def __call__(self, doc): """Apply the sentencizer to a Doc and set Token.is_sent_start. @@ -1414,7 +1426,7 @@ class Sentencizer(object): DOCS: https://spacy.io/api/sentencizer#to_bytes """ - return srsly.msgpack_dumps({"punct_chars": self.punct_chars}) + return srsly.msgpack_dumps({"punct_chars": list(self.punct_chars)}) def from_bytes(self, bytes_data, **kwargs): """Load the sentencizer from a bytestring. @@ -1425,7 +1437,7 @@ class Sentencizer(object): DOCS: https://spacy.io/api/sentencizer#from_bytes """ cfg = srsly.msgpack_loads(bytes_data) - self.punct_chars = cfg.get("punct_chars", self.default_punct_chars) + self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars)) return self def to_disk(self, path, exclude=tuple(), **kwargs): @@ -1435,7 +1447,7 @@ class Sentencizer(object): """ path = util.ensure_path(path) path = path.with_suffix(".json") - srsly.write_json(path, {"punct_chars": self.punct_chars}) + srsly.write_json(path, {"punct_chars": list(self.punct_chars)}) def from_disk(self, path, exclude=tuple(), **kwargs): @@ -1446,7 +1458,7 @@ class Sentencizer(object): path = util.ensure_path(path) path = path.with_suffix(".json") cfg = srsly.read_json(path) - self.punct_chars = cfg.get("punct_chars", self.default_punct_chars) + self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars)) return self diff --git a/spacy/tests/pipeline/test_sentencizer.py b/spacy/tests/pipeline/test_sentencizer.py index c1b3eba45..1e03dc743 100644 --- a/spacy/tests/pipeline/test_sentencizer.py +++ b/spacy/tests/pipeline/test_sentencizer.py @@ -81,7 +81,7 @@ def test_sentencizer_custom_punct(en_vocab, punct_chars, words, sent_starts, n_s def test_sentencizer_serialize_bytes(en_vocab): punct_chars = [".", "~", "+"] sentencizer = Sentencizer(punct_chars=punct_chars) - assert sentencizer.punct_chars == punct_chars + assert sentencizer.punct_chars == set(punct_chars) bytes_data = sentencizer.to_bytes() new_sentencizer = Sentencizer().from_bytes(bytes_data) - assert new_sentencizer.punct_chars == punct_chars + assert new_sentencizer.punct_chars == set(punct_chars) From 76d26a3d5e2cf604d4a2247fb0bb75f5ea110333 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Sat, 14 Sep 2019 16:32:24 +0200 Subject: [PATCH 13/15] Update site.json [ci skip] --- website/meta/site.json | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/website/meta/site.json b/website/meta/site.json index 2b02ef953..edb60ab0c 100644 --- a/website/meta/site.json +++ b/website/meta/site.json @@ -10,10 +10,7 @@ "modelsRepo": "explosion/spacy-models", "social": { "twitter": "spacy_io", - "github": "explosion", - "reddit": "spacynlp", - "codepen": "explosion", - "gitter": "explosion/spaCy" + "github": "explosion" }, "theme": "#09a3d5", "analytics": "UA-58931649-1", @@ -69,6 +66,7 @@ "items": [ { "text": "Twitter", "url": "https://twitter.com/spacy_io" }, { "text": "GitHub", "url": "https://github.com/explosion/spaCy" }, + { "text": "YouTube", "url": "https://youtube.com/c/ExplosionAI" }, { "text": "Blog", "url": "https://explosion.ai/blog" } ] } From 04d36d2471bc48548abbae6c8913b2371b68d3bf Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Sat, 14 Sep 2019 16:41:19 +0200 Subject: [PATCH 14/15] Remove unused link [ci skip] --- website/docs/usage/v2.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/website/docs/usage/v2.md b/website/docs/usage/v2.md index 9e54106c7..a412eeba4 100644 --- a/website/docs/usage/v2.md +++ b/website/docs/usage/v2.md @@ -107,7 +107,7 @@ process. -**Usage:** [Models directory](/models) [Benchmarks](#benchmarks) +**Usage:** [Models directory](/models) From 65854188049339634e868d52645b98bfa7c641ef Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Sun, 15 Sep 2019 20:42:53 +0200 Subject: [PATCH 15/15] Update UD bin scripts * Update imports for `bin/` * Add all currently supported languages * Update subtok merger for new Matcher validation * Modify blinded check to look at tokens instead of lemmas (for corpora with tokens but not lemmas like Telugu) --- bin/ud/run_eval.py | 36 +++++++++++++++++++++--------------- bin/ud/ud_run_test.py | 8 +++----- bin/ud/ud_train.py | 12 +++++++----- 3 files changed, 31 insertions(+), 25 deletions(-) diff --git a/bin/ud/run_eval.py b/bin/ud/run_eval.py index 171687980..2da476721 100644 --- a/bin/ud/run_eval.py +++ b/bin/ud/run_eval.py @@ -7,14 +7,16 @@ import datetime from pathlib import Path import xml.etree.ElementTree as ET -from spacy.cli.ud import conll17_ud_eval -from spacy.cli.ud.ud_train import write_conllu +import conll17_ud_eval +from ud_train import write_conllu from spacy.lang.lex_attrs import word_shape from spacy.util import get_lang_class # All languages in spaCy - in UD format (note that Norwegian is 'no' instead of 'nb') -ALL_LANGUAGES = "ar, ca, da, de, el, en, es, fa, fi, fr, ga, he, hi, hr, hu, id, " \ - "it, ja, no, nl, pl, pt, ro, ru, sv, tr, ur, vi, zh" +ALL_LANGUAGES = ("af, ar, bg, bn, ca, cs, da, de, el, en, es, et, fa, fi, fr," + "ga, he, hi, hr, hu, id, is, it, ja, kn, ko, lt, lv, mr, no," + "nl, pl, pt, ro, ru, si, sk, sl, sq, sr, sv, ta, te, th, tl," + "tr, tt, uk, ur, vi, zh") # Non-parsing tasks that will be evaluated (works for default models) EVAL_NO_PARSE = ['Tokens', 'Words', 'Lemmas', 'Sentences', 'Feats'] @@ -73,10 +75,10 @@ def _contains_blinded_text(stats_xml): tree = ET.parse(stats_xml) root = tree.getroot() total_tokens = int(root.find('size/total/tokens').text) - unique_lemmas = int(root.find('lemmas').get('unique')) + unique_forms = int(root.find('forms').get('unique')) # assume the corpus is largely blinded when there are less than 1% unique tokens - return (unique_lemmas / total_tokens) < 0.01 + return (unique_forms / total_tokens) < 0.01 def fetch_all_treebanks(ud_dir, languages, corpus, best_per_language): @@ -262,22 +264,26 @@ def main(out_path, ud_dir, check_parse=False, langs=ALL_LANGUAGES, exclude_train if not exclude_trained_models: if 'de' in models: models['de'].append(load_model('de_core_news_sm')) - if 'es' in models: - models['es'].append(load_model('es_core_news_sm')) - models['es'].append(load_model('es_core_news_md')) - if 'pt' in models: - models['pt'].append(load_model('pt_core_news_sm')) - if 'it' in models: - models['it'].append(load_model('it_core_news_sm')) - if 'nl' in models: - models['nl'].append(load_model('nl_core_news_sm')) + models['de'].append(load_model('de_core_news_md')) + if 'el' in models: + models['el'].append(load_model('el_core_news_sm')) + models['el'].append(load_model('el_core_news_md')) if 'en' in models: models['en'].append(load_model('en_core_web_sm')) models['en'].append(load_model('en_core_web_md')) models['en'].append(load_model('en_core_web_lg')) + if 'es' in models: + models['es'].append(load_model('es_core_news_sm')) + models['es'].append(load_model('es_core_news_md')) if 'fr' in models: models['fr'].append(load_model('fr_core_news_sm')) models['fr'].append(load_model('fr_core_news_md')) + if 'it' in models: + models['it'].append(load_model('it_core_news_sm')) + if 'nl' in models: + models['nl'].append(load_model('nl_core_news_sm')) + if 'pt' in models: + models['pt'].append(load_model('pt_core_news_sm')) with out_path.open(mode='w', encoding='utf-8') as out_file: run_all_evals(models, treebanks, out_file, check_parse, print_freq_tasks) diff --git a/bin/ud/ud_run_test.py b/bin/ud/ud_run_test.py index 1c529c831..de01cf350 100644 --- a/bin/ud/ud_run_test.py +++ b/bin/ud/ud_run_test.py @@ -109,15 +109,13 @@ def write_conllu(docs, file_): merger = Matcher(docs[0].vocab) merger.add("SUBTOK", None, [{"DEP": "subtok", "op": "+"}]) for i, doc in enumerate(docs): - matches = merger(doc) + matches = [] + if doc.is_parsed: + matches = merger(doc) spans = [doc[start : end + 1] for _, start, end in matches] with doc.retokenize() as retokenizer: for span in spans: retokenizer.merge(span) - # TODO: This shouldn't be necessary? Should be handled in merge - for word in doc: - if word.i == word.head.i: - word.dep_ = "ROOT" file_.write("# newdoc id = {i}\n".format(i=i)) for j, sent in enumerate(doc.sents): file_.write("# sent_id = {i}.{j}\n".format(i=i, j=j)) diff --git a/bin/ud/ud_train.py b/bin/ud/ud_train.py index 8f699db4f..e7fcdf871 100644 --- a/bin/ud/ud_train.py +++ b/bin/ud/ud_train.py @@ -25,7 +25,7 @@ import itertools import random import numpy.random -from . import conll17_ud_eval +import conll17_ud_eval from spacy import lang from spacy.lang import zh @@ -214,7 +214,9 @@ def write_conllu(docs, file_): merger = Matcher(docs[0].vocab) merger.add("SUBTOK", None, [{"DEP": "subtok", "op": "+"}]) for i, doc in enumerate(docs): - matches = merger(doc) + matches = [] + if doc.is_parsed: + matches = merger(doc) spans = [doc[start : end + 1] for _, start, end in matches] with doc.retokenize() as retokenizer: for span in spans: @@ -298,9 +300,9 @@ def get_token_conllu(token, i): return "\n".join(lines) -Token.set_extension("get_conllu_lines", method=get_token_conllu) -Token.set_extension("begins_fused", default=False) -Token.set_extension("inside_fused", default=False) +Token.set_extension("get_conllu_lines", method=get_token_conllu, force=True) +Token.set_extension("begins_fused", default=False, force=True) +Token.set_extension("inside_fused", default=False, force=True) ##################