mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-08 16:26:37 +03:00
Merge branch 'master' into spacy.io
This commit is contained in:
commit
8c43cfc754
4
.flake8
4
.flake8
|
@ -6,9 +6,5 @@ exclude =
|
|||
.env,
|
||||
.git,
|
||||
__pycache__,
|
||||
lemmatizer.py,
|
||||
lookup.py,
|
||||
_tokenizer_exceptions_list.py,
|
||||
spacy/lang/fr/lemmatizer,
|
||||
spacy/lang/nb/lemmatizer
|
||||
spacy/__init__.py
|
||||
|
|
106
.github/contributors/mihaigliga21.md
vendored
Normal file
106
.github/contributors/mihaigliga21.md
vendored
Normal file
|
@ -0,0 +1,106 @@
|
|||
# spaCy contributor agreement
|
||||
|
||||
This spaCy Contributor Agreement (**"SCA"**) is based on the
|
||||
[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf).
|
||||
The SCA applies to any contribution that you make to any product or project
|
||||
managed by us (the **"project"**), and sets out the intellectual property rights
|
||||
you grant to us in the contributed materials. The term **"us"** shall mean
|
||||
[ExplosionAI UG (haftungsbeschränkt)](https://explosion.ai/legal). The term
|
||||
**"you"** shall mean the person or entity identified below.
|
||||
|
||||
If you agree to be bound by these terms, fill in the information requested
|
||||
below and include the filled-in version with your first pull request, under the
|
||||
folder [`.github/contributors/`](/.github/contributors/). The name of the file
|
||||
should be your GitHub username, with the extension `.md`. For example, the user
|
||||
example_user would create the file `.github/contributors/example_user.md`.
|
||||
|
||||
Read this agreement carefully before signing. These terms and conditions
|
||||
constitute a binding legal agreement.
|
||||
|
||||
## Contributor Agreement
|
||||
|
||||
1. The term "contribution" or "contributed materials" means any source code,
|
||||
object code, patch, tool, sample, graphic, specification, manual,
|
||||
documentation, or any other material posted or submitted by you to the project.
|
||||
|
||||
2. With respect to any worldwide copyrights, or copyright applications and
|
||||
registrations, in your contribution:
|
||||
|
||||
* you hereby assign to us joint ownership, and to the extent that such
|
||||
assignment is or becomes invalid, ineffective or unenforceable, you hereby
|
||||
grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge,
|
||||
royalty-free, unrestricted license to exercise all rights under those
|
||||
copyrights. This includes, at our option, the right to sublicense these same
|
||||
rights to third parties through multiple levels of sublicensees or other
|
||||
licensing arrangements;
|
||||
|
||||
* you agree that each of us can do all things in relation to your
|
||||
contribution as if each of us were the sole owners, and if one of us makes
|
||||
a derivative work of your contribution, the one who makes the derivative
|
||||
work (or has it made will be the sole owner of that derivative work;
|
||||
|
||||
* you agree that you will not assert any moral rights in your contribution
|
||||
against us, our licensees or transferees;
|
||||
|
||||
* you agree that we may register a copyright in your contribution and
|
||||
exercise all ownership rights associated with it; and
|
||||
|
||||
* you agree that neither of us has any duty to consult with, obtain the
|
||||
consent of, pay or render an accounting to the other for any use or
|
||||
distribution of your contribution.
|
||||
|
||||
3. With respect to any patents you own, or that you can license without payment
|
||||
to any third party, you hereby grant to us a perpetual, irrevocable,
|
||||
non-exclusive, worldwide, no-charge, royalty-free license to:
|
||||
|
||||
* make, have made, use, sell, offer to sell, import, and otherwise transfer
|
||||
your contribution in whole or in part, alone or in combination with or
|
||||
included in any product, work or materials arising out of the project to
|
||||
which your contribution was submitted, and
|
||||
|
||||
* at our option, to sublicense these same rights to third parties through
|
||||
multiple levels of sublicensees or other licensing arrangements.
|
||||
|
||||
4. Except as set out above, you keep all right, title, and interest in your
|
||||
contribution. The rights that you grant to us under these terms are effective
|
||||
on the date you first submitted a contribution to us, even if your submission
|
||||
took place before the date you sign these terms.
|
||||
|
||||
5. You covenant, represent, warrant and agree that:
|
||||
|
||||
* Each contribution that you submit is and shall be an original work of
|
||||
authorship and you can legally grant the rights set out in this SCA;
|
||||
|
||||
* to the best of your knowledge, each contribution will not violate any
|
||||
third party's copyrights, trademarks, patents, or other intellectual
|
||||
property rights; and
|
||||
|
||||
* each contribution shall be in compliance with U.S. export control laws and
|
||||
other applicable export and import laws. You agree to notify us if you
|
||||
become aware of any circumstance which would make any of the foregoing
|
||||
representations inaccurate in any respect. We may publicly disclose your
|
||||
participation in the project, including the fact that you have signed the SCA.
|
||||
|
||||
6. This SCA is governed by the laws of the State of California and applicable
|
||||
U.S. Federal law. Any choice of law rules will not apply.
|
||||
|
||||
7. Please place an “x” on one of the applicable statement below. Please do NOT
|
||||
mark both statements:
|
||||
|
||||
* [x] I am signing on behalf of myself as an individual and no other person
|
||||
or entity, including my employer, has or will have rights with respect to my
|
||||
contributions.
|
||||
|
||||
* [x] I am signing on behalf of my employer or a legal entity and I have the
|
||||
actual authority to contractually bind that entity.
|
||||
|
||||
## Contributor Details
|
||||
|
||||
| Field | Entry |
|
||||
|------------------------------- | -------------------------- |
|
||||
| Name | Mihai Gliga |
|
||||
| Company name (if applicable) | |
|
||||
| Title or role (if applicable) | |
|
||||
| Date | September 9, 2019 |
|
||||
| GitHub username | mihaigliga21 |
|
||||
| Website (optional) | |
|
106
.github/contributors/tamuhey.md
vendored
Normal file
106
.github/contributors/tamuhey.md
vendored
Normal file
|
@ -0,0 +1,106 @@
|
|||
# spaCy contributor agreement
|
||||
|
||||
This spaCy Contributor Agreement (**"SCA"**) is based on the
|
||||
[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf).
|
||||
The SCA applies to any contribution that you make to any product or project
|
||||
managed by us (the **"project"**), and sets out the intellectual property rights
|
||||
you grant to us in the contributed materials. The term **"us"** shall mean
|
||||
[ExplosionAI GmbH](https://explosion.ai/legal). The term
|
||||
**"you"** shall mean the person or entity identified below.
|
||||
|
||||
If you agree to be bound by these terms, fill in the information requested
|
||||
below and include the filled-in version with your first pull request, under the
|
||||
folder [`.github/contributors/`](/.github/contributors/). The name of the file
|
||||
should be your GitHub username, with the extension `.md`. For example, the user
|
||||
example_user would create the file `.github/contributors/example_user.md`.
|
||||
|
||||
Read this agreement carefully before signing. These terms and conditions
|
||||
constitute a binding legal agreement.
|
||||
|
||||
## Contributor Agreement
|
||||
|
||||
1. The term "contribution" or "contributed materials" means any source code,
|
||||
object code, patch, tool, sample, graphic, specification, manual,
|
||||
documentation, or any other material posted or submitted by you to the project.
|
||||
|
||||
2. With respect to any worldwide copyrights, or copyright applications and
|
||||
registrations, in your contribution:
|
||||
|
||||
* you hereby assign to us joint ownership, and to the extent that such
|
||||
assignment is or becomes invalid, ineffective or unenforceable, you hereby
|
||||
grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge,
|
||||
royalty-free, unrestricted license to exercise all rights under those
|
||||
copyrights. This includes, at our option, the right to sublicense these same
|
||||
rights to third parties through multiple levels of sublicensees or other
|
||||
licensing arrangements;
|
||||
|
||||
* you agree that each of us can do all things in relation to your
|
||||
contribution as if each of us were the sole owners, and if one of us makes
|
||||
a derivative work of your contribution, the one who makes the derivative
|
||||
work (or has it made will be the sole owner of that derivative work;
|
||||
|
||||
* you agree that you will not assert any moral rights in your contribution
|
||||
against us, our licensees or transferees;
|
||||
|
||||
* you agree that we may register a copyright in your contribution and
|
||||
exercise all ownership rights associated with it; and
|
||||
|
||||
* you agree that neither of us has any duty to consult with, obtain the
|
||||
consent of, pay or render an accounting to the other for any use or
|
||||
distribution of your contribution.
|
||||
|
||||
3. With respect to any patents you own, or that you can license without payment
|
||||
to any third party, you hereby grant to us a perpetual, irrevocable,
|
||||
non-exclusive, worldwide, no-charge, royalty-free license to:
|
||||
|
||||
* make, have made, use, sell, offer to sell, import, and otherwise transfer
|
||||
your contribution in whole or in part, alone or in combination with or
|
||||
included in any product, work or materials arising out of the project to
|
||||
which your contribution was submitted, and
|
||||
|
||||
* at our option, to sublicense these same rights to third parties through
|
||||
multiple levels of sublicensees or other licensing arrangements.
|
||||
|
||||
4. Except as set out above, you keep all right, title, and interest in your
|
||||
contribution. The rights that you grant to us under these terms are effective
|
||||
on the date you first submitted a contribution to us, even if your submission
|
||||
took place before the date you sign these terms.
|
||||
|
||||
5. You covenant, represent, warrant and agree that:
|
||||
|
||||
* Each contribution that you submit is and shall be an original work of
|
||||
authorship and you can legally grant the rights set out in this SCA;
|
||||
|
||||
* to the best of your knowledge, each contribution will not violate any
|
||||
third party's copyrights, trademarks, patents, or other intellectual
|
||||
property rights; and
|
||||
|
||||
* each contribution shall be in compliance with U.S. export control laws and
|
||||
other applicable export and import laws. You agree to notify us if you
|
||||
become aware of any circumstance which would make any of the foregoing
|
||||
representations inaccurate in any respect. We may publicly disclose your
|
||||
participation in the project, including the fact that you have signed the SCA.
|
||||
|
||||
6. This SCA is governed by the laws of the State of California and applicable
|
||||
U.S. Federal law. Any choice of law rules will not apply.
|
||||
|
||||
7. Please place an “x” on one of the applicable statement below. Please do NOT
|
||||
mark both statements:
|
||||
|
||||
* [x] I am signing on behalf of myself as an individual and no other person
|
||||
or entity, including my employer, has or will have rights with respect to my
|
||||
contributions.
|
||||
|
||||
* [ ] I am signing on behalf of my employer or a legal entity and I have the
|
||||
actual authority to contractually bind that entity.
|
||||
|
||||
## Contributor Details
|
||||
|
||||
| Field | Entry |
|
||||
|------------------------------- | -------------------- |
|
||||
| Name | Yohei Tamura |
|
||||
| Company name (if applicable) | PKSHA |
|
||||
| Title or role (if applicable) | |
|
||||
| Date | 2019/9/12 |
|
||||
| GitHub username | tamuhey |
|
||||
| Website (optional) | |
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import plac
|
||||
import tqdm
|
||||
from pathlib import Path
|
||||
import re
|
||||
import sys
|
||||
|
|
|
@ -5,7 +5,6 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import plac
|
||||
import tqdm
|
||||
from pathlib import Path
|
||||
import re
|
||||
import sys
|
||||
|
@ -462,6 +461,9 @@ def main(
|
|||
vectors_dir=None,
|
||||
use_oracle_segments=False,
|
||||
):
|
||||
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
|
||||
import tqdm
|
||||
|
||||
spacy.util.fix_random_seed()
|
||||
lang.zh.Chinese.Defaults.use_jieba = False
|
||||
lang.ja.Japanese.Defaults.use_janome = False
|
||||
|
|
|
@ -0,0 +1,11 @@
|
|||
TRAINING_DATA_FILE = "gold_entities.jsonl"
|
||||
KB_FILE = "kb"
|
||||
KB_MODEL_DIR = "nlp_kb"
|
||||
OUTPUT_MODEL_DIR = "nlp"
|
||||
|
||||
PRIOR_PROB_PATH = "prior_prob.csv"
|
||||
ENTITY_DEFS_PATH = "entity_defs.csv"
|
||||
ENTITY_FREQ_PATH = "entity_freq.csv"
|
||||
ENTITY_DESCR_PATH = "entity_descriptions.csv"
|
||||
|
||||
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'
|
200
bin/wiki_entity_linking/entity_linker_evaluation.py
Normal file
200
bin/wiki_entity_linking/entity_linker_evaluation.py
Normal file
|
@ -0,0 +1,200 @@
|
|||
import logging
|
||||
import random
|
||||
|
||||
from collections import defaultdict
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Metrics(object):
|
||||
true_pos = 0
|
||||
false_pos = 0
|
||||
false_neg = 0
|
||||
|
||||
def update_results(self, true_entity, candidate):
|
||||
candidate_is_correct = true_entity == candidate
|
||||
|
||||
# Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
|
||||
# Therefore, if candidate_is_correct then we have a true positive and never a true negative
|
||||
self.true_pos += candidate_is_correct
|
||||
self.false_neg += not candidate_is_correct
|
||||
if candidate not in {"", "NIL"}:
|
||||
self.false_pos += not candidate_is_correct
|
||||
|
||||
def calculate_precision(self):
|
||||
if self.true_pos == 0:
|
||||
return 0.0
|
||||
else:
|
||||
return self.true_pos / (self.true_pos + self.false_pos)
|
||||
|
||||
def calculate_recall(self):
|
||||
if self.true_pos == 0:
|
||||
return 0.0
|
||||
else:
|
||||
return self.true_pos / (self.true_pos + self.false_neg)
|
||||
|
||||
|
||||
class EvaluationResults(object):
|
||||
def __init__(self):
|
||||
self.metrics = Metrics()
|
||||
self.metrics_by_label = defaultdict(Metrics)
|
||||
|
||||
def update_metrics(self, ent_label, true_entity, candidate):
|
||||
self.metrics.update_results(true_entity, candidate)
|
||||
self.metrics_by_label[ent_label].update_results(true_entity, candidate)
|
||||
|
||||
def increment_false_negatives(self):
|
||||
self.metrics.false_neg += 1
|
||||
|
||||
def report_metrics(self, model_name):
|
||||
model_str = model_name.title()
|
||||
recall = self.metrics.calculate_recall()
|
||||
precision = self.metrics.calculate_precision()
|
||||
return ("{}: ".format(model_str) +
|
||||
"Recall = {} | ".format(round(recall, 3)) +
|
||||
"Precision = {} | ".format(round(precision, 3)) +
|
||||
"Precision by label = {}".format({k: v.calculate_precision()
|
||||
for k, v in self.metrics_by_label.items()}))
|
||||
|
||||
|
||||
class BaselineResults(object):
|
||||
def __init__(self):
|
||||
self.random = EvaluationResults()
|
||||
self.prior = EvaluationResults()
|
||||
self.oracle = EvaluationResults()
|
||||
|
||||
def report_accuracy(self, model):
|
||||
results = getattr(self, model)
|
||||
return results.report_metrics(model)
|
||||
|
||||
def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate):
|
||||
self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
|
||||
self.prior.update_metrics(ent_label, true_entity, prior_candidate)
|
||||
self.random.update_metrics(ent_label, true_entity, random_candidate)
|
||||
|
||||
|
||||
def measure_performance(dev_data, kb, el_pipe):
|
||||
baseline_accuracies = measure_baselines(
|
||||
dev_data, kb
|
||||
)
|
||||
|
||||
logger.info(baseline_accuracies.report_accuracy("random"))
|
||||
logger.info(baseline_accuracies.report_accuracy("prior"))
|
||||
logger.info(baseline_accuracies.report_accuracy("oracle"))
|
||||
|
||||
# using only context
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = False
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context only"))
|
||||
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = True
|
||||
results = get_eval_results(dev_data, el_pipe)
|
||||
logger.info(results.report_metrics("context and prior"))
|
||||
|
||||
|
||||
def get_eval_results(data, el_pipe=None):
|
||||
# If the docs in the data require further processing with an entity linker, set el_pipe
|
||||
from tqdm import tqdm
|
||||
|
||||
docs = []
|
||||
golds = []
|
||||
for d, g in tqdm(data, leave=False):
|
||||
if len(d) > 0:
|
||||
golds.append(g)
|
||||
if el_pipe is not None:
|
||||
docs.append(el_pipe(d))
|
||||
else:
|
||||
docs.append(d)
|
||||
|
||||
results = EvaluationResults()
|
||||
for doc, gold in zip(docs, golds):
|
||||
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
# only evaluating on positive examples
|
||||
for gold_kb, value in kb_dict.items():
|
||||
if value:
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
if offset not in tagged_entries_per_article:
|
||||
results.increment_false_negatives()
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
pred_entity = ent.kb_id_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
offset = _offset(start, end)
|
||||
gold_entity = correct_entries_per_article.get(offset, None)
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
results.update_metrics(ent_label, gold_entity, pred_entity)
|
||||
|
||||
except Exception as e:
|
||||
logging.error("Error assessing accuracy " + str(e))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def measure_baselines(data, kb):
|
||||
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
||||
counts_d = dict()
|
||||
|
||||
baseline_results = BaselineResults()
|
||||
|
||||
docs = [d for d, g in data if len(d) > 0]
|
||||
golds = [g for d, g in data if len(d) > 0]
|
||||
|
||||
for doc, gold in zip(docs, golds):
|
||||
correct_entries_per_article = dict()
|
||||
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
for gold_kb, value in kb_dict.items():
|
||||
# only evaluating on positive examples
|
||||
if value:
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
if offset not in tagged_entries_per_article:
|
||||
baseline_results.random.increment_false_negatives()
|
||||
baseline_results.oracle.increment_false_negatives()
|
||||
baseline_results.prior.increment_false_negatives()
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
offset = _offset(start, end)
|
||||
gold_entity = correct_entries_per_article.get(offset, None)
|
||||
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
candidates = kb.get_candidates(ent.text)
|
||||
oracle_candidate = ""
|
||||
best_candidate = ""
|
||||
random_candidate = ""
|
||||
if candidates:
|
||||
scores = []
|
||||
|
||||
for c in candidates:
|
||||
scores.append(c.prior_prob)
|
||||
if c.entity_ == gold_entity:
|
||||
oracle_candidate = c.entity_
|
||||
|
||||
best_index = scores.index(max(scores))
|
||||
best_candidate = candidates[best_index].entity_
|
||||
random_candidate = random.choice(candidates).entity_
|
||||
|
||||
baseline_results.update_baselines(gold_entity, ent_label,
|
||||
random_candidate, best_candidate, oracle_candidate)
|
||||
|
||||
return baseline_results
|
||||
|
||||
|
||||
def _offset(start, end):
|
||||
return "{}_{}".format(start, end)
|
|
@ -1,12 +1,20 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
|
||||
from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp
|
||||
import csv
|
||||
import logging
|
||||
import spacy
|
||||
import sys
|
||||
|
||||
from spacy.kb import KnowledgeBase
|
||||
|
||||
import csv
|
||||
import datetime
|
||||
from bin.wiki_entity_linking import wikipedia_processor as wp
|
||||
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
|
||||
|
||||
csv.field_size_limit(sys.maxsize)
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_kb(
|
||||
|
@ -14,52 +22,73 @@ def create_kb(
|
|||
max_entities_per_alias,
|
||||
min_entity_freq,
|
||||
min_occ,
|
||||
entity_def_output,
|
||||
entity_descr_output,
|
||||
entity_def_input,
|
||||
entity_descr_path,
|
||||
count_input,
|
||||
prior_prob_input,
|
||||
wikidata_input,
|
||||
entity_vector_length,
|
||||
limit=None,
|
||||
read_raw_data=True,
|
||||
):
|
||||
# Create the knowledge base from Wikidata entries
|
||||
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
|
||||
|
||||
# read the mappings from file
|
||||
title_to_id = get_entity_to_id(entity_def_input)
|
||||
id_to_descr = get_id_to_description(entity_descr_path)
|
||||
|
||||
# check the length of the nlp vectors
|
||||
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
|
||||
input_dim = nlp.vocab.vectors_length
|
||||
print("Loaded pre-trained vectors of size %s" % input_dim)
|
||||
logger.info("Loaded pre-trained vectors of size %s" % input_dim)
|
||||
else:
|
||||
raise ValueError(
|
||||
"The `nlp` object should have access to pre-trained word vectors, "
|
||||
" cf. https://spacy.io/usage/models#languages."
|
||||
)
|
||||
|
||||
# disable this part of the pipeline when rerunning the KB generation from preprocessed files
|
||||
if read_raw_data:
|
||||
print()
|
||||
print(now(), " * read wikidata entities:")
|
||||
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
|
||||
wikidata_input, limit=limit
|
||||
)
|
||||
|
||||
# write the title-ID and ID-description mappings to file
|
||||
_write_entity_files(
|
||||
entity_def_output, entity_descr_output, title_to_id, id_to_descr
|
||||
)
|
||||
|
||||
else:
|
||||
# read the mappings from file
|
||||
title_to_id = get_entity_to_id(entity_def_output)
|
||||
id_to_descr = get_id_to_description(entity_descr_output)
|
||||
|
||||
print()
|
||||
print(now(), " * get entity frequencies:")
|
||||
print()
|
||||
logger.info("Get entity frequencies")
|
||||
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
|
||||
|
||||
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
|
||||
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
|
||||
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
|
||||
title_to_id,
|
||||
id_to_descr,
|
||||
entity_frequencies,
|
||||
min_entity_freq
|
||||
)
|
||||
logger.info("Left with {} entities".format(len(description_list)))
|
||||
|
||||
logger.info("Train entity encoder")
|
||||
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
|
||||
encoder.train(description_list=description_list, to_print=True)
|
||||
|
||||
logger.info("Get entity embeddings:")
|
||||
embeddings = encoder.apply_encoder(description_list)
|
||||
|
||||
logger.info("Adding {} entities".format(len(entity_list)))
|
||||
kb.set_entities(
|
||||
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
||||
)
|
||||
|
||||
logger.info("Adding aliases")
|
||||
_add_aliases(
|
||||
kb,
|
||||
title_to_id=filtered_title_to_id,
|
||||
max_entities_per_alias=max_entities_per_alias,
|
||||
min_occ=min_occ,
|
||||
prior_prob_input=prior_prob_input,
|
||||
)
|
||||
|
||||
logger.info("KB size: {} entities, {} aliases".format(
|
||||
kb.get_size_entities(),
|
||||
kb.get_size_aliases()))
|
||||
|
||||
logger.info("Done with kb")
|
||||
return kb
|
||||
|
||||
|
||||
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
|
||||
min_entity_freq: int = 10):
|
||||
filtered_title_to_id = dict()
|
||||
entity_list = []
|
||||
description_list = []
|
||||
|
@ -72,58 +101,7 @@ def create_kb(
|
|||
description_list.append(desc)
|
||||
frequency_list.append(freq)
|
||||
filtered_title_to_id[title] = entity
|
||||
|
||||
print(len(title_to_id.keys()), "original titles")
|
||||
kept_nr = len(filtered_title_to_id.keys())
|
||||
print("kept", kept_nr, "entities with min. frequency", min_entity_freq)
|
||||
|
||||
print()
|
||||
print(now(), " * train entity encoder:")
|
||||
print()
|
||||
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
|
||||
encoder.train(description_list=description_list, to_print=True)
|
||||
|
||||
print()
|
||||
print(now(), " * get entity embeddings:")
|
||||
print()
|
||||
embeddings = encoder.apply_encoder(description_list)
|
||||
|
||||
print(now(), " * adding", len(entity_list), "entities")
|
||||
kb.set_entities(
|
||||
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
|
||||
)
|
||||
|
||||
alias_cnt = _add_aliases(
|
||||
kb,
|
||||
title_to_id=filtered_title_to_id,
|
||||
max_entities_per_alias=max_entities_per_alias,
|
||||
min_occ=min_occ,
|
||||
prior_prob_input=prior_prob_input,
|
||||
)
|
||||
print()
|
||||
print(now(), " * adding", alias_cnt, "aliases")
|
||||
print()
|
||||
|
||||
print()
|
||||
print("# of entities in kb:", kb.get_size_entities())
|
||||
print("# of aliases in kb:", kb.get_size_aliases())
|
||||
|
||||
print(now(), "Done with kb")
|
||||
return kb
|
||||
|
||||
|
||||
def _write_entity_files(
|
||||
entity_def_output, entity_descr_output, title_to_id, id_to_descr
|
||||
):
|
||||
with entity_def_output.open("w", encoding="utf8") as id_file:
|
||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||
for title, qid in title_to_id.items():
|
||||
id_file.write(title + "|" + str(qid) + "\n")
|
||||
|
||||
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||
for qid, descr in id_to_descr.items():
|
||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||
return filtered_title_to_id, entity_list, description_list, frequency_list
|
||||
|
||||
|
||||
def get_entity_to_id(entity_def_output):
|
||||
|
@ -137,9 +115,9 @@ def get_entity_to_id(entity_def_output):
|
|||
return entity_to_id
|
||||
|
||||
|
||||
def get_id_to_description(entity_descr_output):
|
||||
def get_id_to_description(entity_descr_path):
|
||||
id_to_desc = dict()
|
||||
with entity_descr_output.open("r", encoding="utf8") as csvfile:
|
||||
with entity_descr_path.open("r", encoding="utf8") as csvfile:
|
||||
csvreader = csv.reader(csvfile, delimiter="|")
|
||||
# skip header
|
||||
next(csvreader)
|
||||
|
@ -150,7 +128,6 @@ def get_id_to_description(entity_descr_output):
|
|||
|
||||
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
|
||||
wp_titles = title_to_id.keys()
|
||||
cnt = 0
|
||||
|
||||
# adding aliases with prior probabilities
|
||||
# we can read this file sequentially, it's sorted by alias, and then by count
|
||||
|
@ -187,9 +164,8 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
|||
entities=selected_entities,
|
||||
probabilities=prior_probs,
|
||||
)
|
||||
cnt += 1
|
||||
except ValueError as e:
|
||||
print(e)
|
||||
logger.error(e)
|
||||
total_count = 0
|
||||
counts = []
|
||||
entities = []
|
||||
|
@ -202,8 +178,12 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
|
|||
previous_alias = new_alias
|
||||
|
||||
line = prior_file.readline()
|
||||
return cnt
|
||||
|
||||
|
||||
def now():
|
||||
return datetime.datetime.now()
|
||||
def read_nlp_kb(model_dir, kb_file):
|
||||
nlp = spacy.load(model_dir)
|
||||
kb = KnowledgeBase(vocab=nlp.vocab)
|
||||
kb.load_bulk(kb_file)
|
||||
logger.info("kb entities: {}".format(kb.get_size_entities()))
|
||||
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
|
||||
return nlp, kb
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# coding: utf-8
|
||||
from random import shuffle
|
||||
|
||||
import logging
|
||||
import numpy as np
|
||||
|
||||
from spacy._ml import zero_init, create_default_optimizer
|
||||
|
@ -10,6 +11,8 @@ from thinc.v2v import Model
|
|||
from thinc.api import chain
|
||||
from thinc.neural._classes.affine import Affine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EntityEncoder:
|
||||
"""
|
||||
|
@ -50,21 +53,19 @@ class EntityEncoder:
|
|||
|
||||
start = start + batch_size
|
||||
stop = min(stop + batch_size, len(description_list))
|
||||
print("encoded:", stop, "entities")
|
||||
logger.info("encoded: {} entities".format(stop))
|
||||
|
||||
return encodings
|
||||
|
||||
def train(self, description_list, to_print=False):
|
||||
processed, loss = self._train_model(description_list)
|
||||
if to_print:
|
||||
print(
|
||||
"Trained entity descriptions on",
|
||||
processed,
|
||||
"(non-unique) entities across",
|
||||
self.epochs,
|
||||
"epochs",
|
||||
logger.info(
|
||||
"Trained entity descriptions on {} ".format(processed) +
|
||||
"(non-unique) entities across {} ".format(self.epochs) +
|
||||
"epochs"
|
||||
)
|
||||
print("Final loss:", loss)
|
||||
logger.info("Final loss: {}".format(loss))
|
||||
|
||||
def _train_model(self, description_list):
|
||||
best_loss = 1.0
|
||||
|
@ -93,7 +94,7 @@ class EntityEncoder:
|
|||
|
||||
loss = self._update(batch)
|
||||
if batch_nr % 25 == 0:
|
||||
print("loss:", loss)
|
||||
logger.info("loss: {} ".format(loss))
|
||||
processed += len(batch)
|
||||
|
||||
# in general, continue training if we haven't reached our ideal min yet
|
||||
|
|
|
@ -1,10 +1,13 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import logging
|
||||
import random
|
||||
import re
|
||||
import bz2
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from functools import partial
|
||||
|
||||
from spacy.gold import GoldParse
|
||||
from bin.wiki_entity_linking import kb_creator
|
||||
|
@ -15,18 +18,30 @@ Gold-standard entities are stored in one file in standoff format (by character o
|
|||
"""
|
||||
|
||||
ENTITY_FILE = "gold_entities.csv"
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def now():
|
||||
return datetime.datetime.now()
|
||||
|
||||
|
||||
def create_training(wikipedia_input, entity_def_input, training_output, limit=None):
|
||||
def create_training_examples_and_descriptions(wikipedia_input,
|
||||
entity_def_input,
|
||||
description_output,
|
||||
training_output,
|
||||
parse_descriptions,
|
||||
limit=None):
|
||||
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
|
||||
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=limit)
|
||||
_process_wikipedia_texts(wikipedia_input,
|
||||
wp_to_id,
|
||||
description_output,
|
||||
training_output,
|
||||
parse_descriptions,
|
||||
limit)
|
||||
|
||||
|
||||
def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None):
|
||||
def _process_wikipedia_texts(wikipedia_input,
|
||||
wp_to_id,
|
||||
output,
|
||||
training_output,
|
||||
parse_descriptions,
|
||||
limit=None):
|
||||
"""
|
||||
Read the XML wikipedia data to parse out training data:
|
||||
raw text data + positive instances
|
||||
|
@ -35,29 +50,21 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
|
||||
|
||||
read_ids = set()
|
||||
entityfile_loc = training_output / ENTITY_FILE
|
||||
with entityfile_loc.open("w", encoding="utf8") as entityfile:
|
||||
# write entity training header file
|
||||
_write_training_entity(
|
||||
outputfile=entityfile,
|
||||
article_id="article_id",
|
||||
alias="alias",
|
||||
entity="WD_id",
|
||||
start="start",
|
||||
end="end",
|
||||
)
|
||||
|
||||
with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file:
|
||||
if parse_descriptions:
|
||||
_write_training_description(descr_file, "WD_id", "description")
|
||||
with bz2.open(wikipedia_input, mode="rb") as file:
|
||||
line = file.readline()
|
||||
cnt = 0
|
||||
article_count = 0
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
reading_text = False
|
||||
reading_revision = False
|
||||
while line and (not limit or cnt < limit):
|
||||
if cnt % 1000000 == 0:
|
||||
print(now(), "processed", cnt, "lines of Wikipedia dump")
|
||||
|
||||
logger.info("Processed {} articles".format(article_count))
|
||||
|
||||
for line in file:
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
|
||||
if clean_line == "<revision>":
|
||||
|
@ -70,28 +77,32 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
|
||||
# finished reading this page
|
||||
elif clean_line == "</page>":
|
||||
if article_id:
|
||||
try:
|
||||
_process_wp_text(
|
||||
wp_to_id,
|
||||
entityfile,
|
||||
article_id,
|
||||
article_title,
|
||||
article_text.strip(),
|
||||
training_output,
|
||||
)
|
||||
except Exception as e:
|
||||
print(
|
||||
"Error processing article", article_id, article_title, e
|
||||
)
|
||||
else:
|
||||
print(
|
||||
"Done processing a page, but couldn't find an article_id ?",
|
||||
clean_text, entities = _process_wp_text(
|
||||
article_title,
|
||||
article_text,
|
||||
wp_to_id
|
||||
)
|
||||
if clean_text is not None and entities is not None:
|
||||
_write_training_entities(entity_file,
|
||||
article_id,
|
||||
clean_text,
|
||||
entities)
|
||||
|
||||
if article_title in wp_to_id and parse_descriptions:
|
||||
description = " ".join(clean_text[:1000].split(" ")[:-1])
|
||||
_write_training_description(
|
||||
descr_file,
|
||||
wp_to_id[article_title],
|
||||
description
|
||||
)
|
||||
article_count += 1
|
||||
if article_count % 10000 == 0:
|
||||
logger.info("Processed {} articles".format(article_count))
|
||||
if limit and article_count >= limit:
|
||||
break
|
||||
article_text = ""
|
||||
article_title = None
|
||||
article_id = None
|
||||
|
@ -115,7 +126,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
if ids:
|
||||
article_id = ids[0]
|
||||
if article_id in read_ids:
|
||||
print(
|
||||
logger.info(
|
||||
"Found duplicate article ID", article_id, clean_line
|
||||
) # This should never happen ...
|
||||
read_ids.add(article_id)
|
||||
|
@ -125,115 +136,10 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
|
|||
titles = title_regex.search(clean_line)
|
||||
if titles:
|
||||
article_title = titles[0].strip()
|
||||
|
||||
line = file.readline()
|
||||
cnt += 1
|
||||
print(now(), "processed", cnt, "lines of Wikipedia dump")
|
||||
logger.info("Finished. Processed {} articles".format(article_count))
|
||||
|
||||
|
||||
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
|
||||
|
||||
|
||||
def _process_wp_text(
|
||||
wp_to_id, entityfile, article_id, article_title, article_text, training_output
|
||||
):
|
||||
found_entities = False
|
||||
|
||||
# ignore meta Wikipedia pages
|
||||
if article_title.startswith("Wikipedia:"):
|
||||
return
|
||||
|
||||
# remove the text tags
|
||||
text = text_regex.search(article_text).group(0)
|
||||
|
||||
# stop processing if this is a redirect page
|
||||
if text.startswith("#REDIRECT"):
|
||||
return
|
||||
|
||||
# get the raw text without markup etc, keeping only interwiki links
|
||||
clean_text = _get_clean_wp_text(text)
|
||||
|
||||
# read the text char by char to get the right offsets for the interwiki links
|
||||
final_text = ""
|
||||
open_read = 0
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
for index, letter in enumerate(clean_text):
|
||||
if letter == "[":
|
||||
open_read += 1
|
||||
elif letter == "]":
|
||||
open_read -= 1
|
||||
elif letter == "|":
|
||||
if reading_text:
|
||||
final_text += letter
|
||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||
elif reading_entity:
|
||||
reading_text = False
|
||||
reading_entity = False
|
||||
reading_mention = True
|
||||
else:
|
||||
reading_special_case = True
|
||||
else:
|
||||
if reading_entity:
|
||||
entity_buffer += letter
|
||||
elif reading_mention:
|
||||
mention_buffer += letter
|
||||
elif reading_text:
|
||||
final_text += letter
|
||||
else:
|
||||
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
|
||||
|
||||
if open_read > 2:
|
||||
reading_special_case = True
|
||||
|
||||
if open_read == 2 and reading_text:
|
||||
reading_text = False
|
||||
reading_entity = True
|
||||
reading_mention = False
|
||||
|
||||
# we just finished reading an entity
|
||||
if open_read == 0 and not reading_text:
|
||||
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
||||
reading_special_case = True
|
||||
# Ignore cases with nested structures like File: handles etc
|
||||
if not reading_special_case:
|
||||
if not mention_buffer:
|
||||
mention_buffer = entity_buffer
|
||||
start = len(final_text)
|
||||
end = start + len(mention_buffer)
|
||||
qid = wp_to_id.get(entity_buffer, None)
|
||||
if qid:
|
||||
_write_training_entity(
|
||||
outputfile=entityfile,
|
||||
article_id=article_id,
|
||||
alias=mention_buffer,
|
||||
entity=qid,
|
||||
start=start,
|
||||
end=end,
|
||||
)
|
||||
found_entities = True
|
||||
final_text += mention_buffer
|
||||
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
|
||||
if found_entities:
|
||||
_write_training_article(
|
||||
article_id=article_id,
|
||||
clean_text=final_text,
|
||||
training_output=training_output,
|
||||
)
|
||||
|
||||
|
||||
info_regex = re.compile(r"{[^{]*?}")
|
||||
htlm_regex = re.compile(r"<!--[^-]*-->")
|
||||
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
|
||||
|
@ -242,6 +148,29 @@ ref_regex = re.compile(r"<ref.*?>") # non-greedy
|
|||
ref_2_regex = re.compile(r"</ref.*?>") # non-greedy
|
||||
|
||||
|
||||
def _process_wp_text(article_title, article_text, wp_to_id):
|
||||
# ignore meta Wikipedia pages
|
||||
if (
|
||||
article_title.startswith("Wikipedia:") or
|
||||
article_title.startswith("Kategori:")
|
||||
):
|
||||
return None, None
|
||||
|
||||
# remove the text tags
|
||||
text_search = text_regex.search(article_text)
|
||||
if text_search is None:
|
||||
return None, None
|
||||
text = text_search.group(0)
|
||||
|
||||
# stop processing if this is a redirect page
|
||||
if text.startswith("#REDIRECT"):
|
||||
return None, None
|
||||
|
||||
# get the raw text without markup etc, keeping only interwiki links
|
||||
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
|
||||
return clean_text, entities
|
||||
|
||||
|
||||
def _get_clean_wp_text(article_text):
|
||||
clean_text = article_text.strip()
|
||||
|
||||
|
@ -300,130 +229,167 @@ def _get_clean_wp_text(article_text):
|
|||
return clean_text.strip()
|
||||
|
||||
|
||||
def _write_training_article(article_id, clean_text, training_output):
|
||||
file_loc = training_output / "{}.txt".format(article_id)
|
||||
with file_loc.open("w", encoding="utf8") as outputfile:
|
||||
outputfile.write(clean_text)
|
||||
def _remove_links(clean_text, wp_to_id):
|
||||
# read the text char by char to get the right offsets for the interwiki links
|
||||
entities = []
|
||||
final_text = ""
|
||||
open_read = 0
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
for index, letter in enumerate(clean_text):
|
||||
if letter == "[":
|
||||
open_read += 1
|
||||
elif letter == "]":
|
||||
open_read -= 1
|
||||
elif letter == "|":
|
||||
if reading_text:
|
||||
final_text += letter
|
||||
# switch from reading entity to mention in the [[entity|mention]] pattern
|
||||
elif reading_entity:
|
||||
reading_text = False
|
||||
reading_entity = False
|
||||
reading_mention = True
|
||||
else:
|
||||
reading_special_case = True
|
||||
else:
|
||||
if reading_entity:
|
||||
entity_buffer += letter
|
||||
elif reading_mention:
|
||||
mention_buffer += letter
|
||||
elif reading_text:
|
||||
final_text += letter
|
||||
else:
|
||||
raise ValueError("Not sure at point", clean_text[index - 2: index + 2])
|
||||
|
||||
if open_read > 2:
|
||||
reading_special_case = True
|
||||
|
||||
if open_read == 2 and reading_text:
|
||||
reading_text = False
|
||||
reading_entity = True
|
||||
reading_mention = False
|
||||
|
||||
# we just finished reading an entity
|
||||
if open_read == 0 and not reading_text:
|
||||
if "#" in entity_buffer or entity_buffer.startswith(":"):
|
||||
reading_special_case = True
|
||||
# Ignore cases with nested structures like File: handles etc
|
||||
if not reading_special_case:
|
||||
if not mention_buffer:
|
||||
mention_buffer = entity_buffer
|
||||
start = len(final_text)
|
||||
end = start + len(mention_buffer)
|
||||
qid = wp_to_id.get(entity_buffer, None)
|
||||
if qid:
|
||||
entities.append((mention_buffer, qid, start, end))
|
||||
final_text += mention_buffer
|
||||
|
||||
entity_buffer = ""
|
||||
mention_buffer = ""
|
||||
|
||||
reading_text = True
|
||||
reading_entity = False
|
||||
reading_mention = False
|
||||
reading_special_case = False
|
||||
return final_text, entities
|
||||
|
||||
|
||||
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
|
||||
line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end)
|
||||
def _write_training_description(outputfile, qid, description):
|
||||
if description is not None:
|
||||
line = str(qid) + "|" + description + "\n"
|
||||
outputfile.write(line)
|
||||
|
||||
|
||||
def _write_training_entities(outputfile, article_id, clean_text, entities):
|
||||
entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities]
|
||||
line = json.dumps(
|
||||
{
|
||||
"article_id": article_id,
|
||||
"clean_text": clean_text,
|
||||
"entities": entities_data
|
||||
},
|
||||
ensure_ascii=False) + "\n"
|
||||
outputfile.write(line)
|
||||
|
||||
|
||||
def read_training(nlp, entity_file_path, dev, limit, kb):
|
||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||
For training,, it will include negative training examples by using the candidate generator,
|
||||
and it will only keep positive training examples that can be found by using the candidate generator.
|
||||
For testing, it will include all positive examples only."""
|
||||
|
||||
from tqdm import tqdm
|
||||
data = []
|
||||
num_entities = 0
|
||||
get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb)
|
||||
|
||||
logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit))
|
||||
with entity_file_path.open("r", encoding="utf8") as file:
|
||||
with tqdm(total=limit, leave=False) as pbar:
|
||||
for i, line in enumerate(file):
|
||||
example = json.loads(line)
|
||||
article_id = example["article_id"]
|
||||
clean_text = example["clean_text"]
|
||||
entities = example["entities"]
|
||||
|
||||
if dev != is_dev(article_id) or len(clean_text) >= 30000:
|
||||
continue
|
||||
|
||||
doc = nlp(clean_text)
|
||||
gold = get_gold_parse(doc, entities)
|
||||
if gold and len(gold.links) > 0:
|
||||
data.append((doc, gold))
|
||||
num_entities += len(gold.links)
|
||||
pbar.update(len(gold.links))
|
||||
if limit and num_entities >= limit:
|
||||
break
|
||||
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
|
||||
return data
|
||||
|
||||
|
||||
def _get_gold_parse(doc, entities, dev, kb):
|
||||
gold_entities = {}
|
||||
tagged_ent_positions = set(
|
||||
[(ent.start_char, ent.end_char) for ent in doc.ents]
|
||||
)
|
||||
|
||||
for entity in entities:
|
||||
entity_id = entity["entity"]
|
||||
alias = entity["alias"]
|
||||
start = entity["start"]
|
||||
end = entity["end"]
|
||||
|
||||
candidates = kb.get_candidates(alias)
|
||||
candidate_ids = [
|
||||
c.entity_ for c in candidates
|
||||
]
|
||||
|
||||
should_add_ent = (
|
||||
dev or
|
||||
(
|
||||
(start, end) in tagged_ent_positions and
|
||||
entity_id in candidate_ids and
|
||||
len(candidates) > 1
|
||||
)
|
||||
)
|
||||
|
||||
if should_add_ent:
|
||||
value_by_id = {entity_id: 1.0}
|
||||
if not dev:
|
||||
random.shuffle(candidate_ids)
|
||||
value_by_id.update({
|
||||
kb_id: 0.0
|
||||
for kb_id in candidate_ids
|
||||
if kb_id != entity_id
|
||||
})
|
||||
gold_entities[(start, end)] = value_by_id
|
||||
|
||||
return GoldParse(doc, links=gold_entities)
|
||||
|
||||
|
||||
def is_dev(article_id):
|
||||
return article_id.endswith("3")
|
||||
|
||||
|
||||
def read_training(nlp, training_dir, dev, limit, kb=None):
|
||||
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
|
||||
When kb is provided (for training), it will include negative training examples by using the candidate generator,
|
||||
and it will only keep positive training examples that can be found in the KB.
|
||||
When kb=None (for testing), it will include all positive examples only."""
|
||||
entityfile_loc = training_dir / ENTITY_FILE
|
||||
data = []
|
||||
|
||||
# assume the data is written sequentially, so we can reuse the article docs
|
||||
current_article_id = None
|
||||
current_doc = None
|
||||
ents_by_offset = dict()
|
||||
skip_articles = set()
|
||||
total_entities = 0
|
||||
|
||||
with entityfile_loc.open("r", encoding="utf8") as file:
|
||||
for line in file:
|
||||
if not limit or len(data) < limit:
|
||||
fields = line.replace("\n", "").split(sep="|")
|
||||
article_id = fields[0]
|
||||
alias = fields[1]
|
||||
wd_id = fields[2]
|
||||
start = fields[3]
|
||||
end = fields[4]
|
||||
|
||||
if (
|
||||
dev == is_dev(article_id)
|
||||
and article_id != "article_id"
|
||||
and article_id not in skip_articles
|
||||
):
|
||||
if not current_doc or (current_article_id != article_id):
|
||||
# parse the new article text
|
||||
file_name = article_id + ".txt"
|
||||
try:
|
||||
training_file = training_dir / file_name
|
||||
with training_file.open("r", encoding="utf8") as f:
|
||||
text = f.read()
|
||||
# threshold for convenience / speed of processing
|
||||
if len(text) < 30000:
|
||||
current_doc = nlp(text)
|
||||
current_article_id = article_id
|
||||
ents_by_offset = dict()
|
||||
for ent in current_doc.ents:
|
||||
sent_length = len(ent.sent)
|
||||
# custom filtering to avoid too long or too short sentences
|
||||
if 5 < sent_length < 100:
|
||||
offset = "{}_{}".format(
|
||||
ent.start_char, ent.end_char
|
||||
)
|
||||
ents_by_offset[offset] = ent
|
||||
else:
|
||||
skip_articles.add(article_id)
|
||||
current_doc = None
|
||||
except Exception as e:
|
||||
print("Problem parsing article", article_id, e)
|
||||
skip_articles.add(article_id)
|
||||
|
||||
# repeat checking this condition in case an exception was thrown
|
||||
if current_doc and (current_article_id == article_id):
|
||||
offset = "{}_{}".format(start, end)
|
||||
found_ent = ents_by_offset.get(offset, None)
|
||||
if found_ent:
|
||||
if found_ent.text != alias:
|
||||
skip_articles.add(article_id)
|
||||
current_doc = None
|
||||
else:
|
||||
sent = found_ent.sent.as_doc()
|
||||
|
||||
gold_start = int(start) - found_ent.sent.start_char
|
||||
gold_end = int(end) - found_ent.sent.start_char
|
||||
|
||||
gold_entities = {}
|
||||
found_useful = False
|
||||
for ent in sent.ents:
|
||||
entry = (ent.start_char, ent.end_char)
|
||||
gold_entry = (gold_start, gold_end)
|
||||
if entry == gold_entry:
|
||||
# add both pos and neg examples (in random order)
|
||||
# this will exclude examples not in the KB
|
||||
if kb:
|
||||
value_by_id = {}
|
||||
candidates = kb.get_candidates(alias)
|
||||
candidate_ids = [
|
||||
c.entity_ for c in candidates
|
||||
]
|
||||
random.shuffle(candidate_ids)
|
||||
for kb_id in candidate_ids:
|
||||
found_useful = True
|
||||
if kb_id != wd_id:
|
||||
value_by_id[kb_id] = 0.0
|
||||
else:
|
||||
value_by_id[kb_id] = 1.0
|
||||
gold_entities[entry] = value_by_id
|
||||
# if no KB, keep all positive examples
|
||||
else:
|
||||
found_useful = True
|
||||
value_by_id = {wd_id: 1.0}
|
||||
|
||||
gold_entities[entry] = value_by_id
|
||||
# currently feeding the gold data one entity per sentence at a time
|
||||
# setting all other entities to empty gold dictionary
|
||||
else:
|
||||
gold_entities[entry] = {}
|
||||
if found_useful:
|
||||
gold = GoldParse(doc=sent, links=gold_entities)
|
||||
data.append((sent, gold))
|
||||
total_entities += 1
|
||||
if len(data) % 2500 == 0:
|
||||
print(" -read", total_entities, "entities")
|
||||
|
||||
print(" -read", total_entities, "entities")
|
||||
return data
|
||||
|
|
|
@ -13,27 +13,25 @@ from https://dumps.wikimedia.org/enwiki/latest/
|
|||
"""
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import datetime
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import plac
|
||||
|
||||
from bin.wiki_entity_linking import wikipedia_processor as wp
|
||||
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
|
||||
from bin.wiki_entity_linking import kb_creator
|
||||
|
||||
from bin.wiki_entity_linking import training_set_creator
|
||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
|
||||
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH
|
||||
import spacy
|
||||
|
||||
from spacy import Errors
|
||||
|
||||
|
||||
def now():
|
||||
return datetime.datetime.now()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
wd_json=("Path to the downloaded WikiData JSON dump.", "positional", None, Path),
|
||||
wp_xml=("Path to the downloaded Wikipedia XML dump.", "positional", None, Path),
|
||||
output_dir=("Output directory", "positional", None, Path),
|
||||
model=("Model name, should include pretrained vectors.", "positional", None, str),
|
||||
model=("Model name or path, should include pretrained vectors.", "positional", None, str),
|
||||
max_per_alias=("Max. # entities per alias (default 10)", "option", "a", int),
|
||||
min_freq=("Min. count of an entity in the corpus (default 20)", "option", "f", int),
|
||||
min_pair=("Min. count of entity-alias pairs (default 5)", "option", "c", int),
|
||||
|
@ -41,7 +39,9 @@ def now():
|
|||
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
|
||||
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
|
||||
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
|
||||
descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"),
|
||||
limit=("Optional threshold to limit lines read from dumps", "option", "l", int),
|
||||
lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str),
|
||||
)
|
||||
def main(
|
||||
wd_json,
|
||||
|
@ -55,20 +55,29 @@ def main(
|
|||
loc_prior_prob=None,
|
||||
loc_entity_defs=None,
|
||||
loc_entity_desc=None,
|
||||
descriptions_from_wikipedia=False,
|
||||
limit=None,
|
||||
lang="en",
|
||||
):
|
||||
print(now(), "Creating KB with Wikipedia and WikiData")
|
||||
print()
|
||||
|
||||
entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
|
||||
entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
|
||||
entity_freq_path = output_dir / ENTITY_FREQ_PATH
|
||||
prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
|
||||
training_entities_path = output_dir / TRAINING_DATA_FILE
|
||||
kb_path = output_dir / KB_FILE
|
||||
|
||||
logger.info("Creating KB with Wikipedia and WikiData")
|
||||
|
||||
if limit is not None:
|
||||
print("Warning: reading only", limit, "lines of Wikipedia/Wikidata dumps.")
|
||||
logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit))
|
||||
|
||||
# STEP 0: set up IO
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir()
|
||||
output_dir.mkdir(parents=True)
|
||||
|
||||
# STEP 1: create the NLP object
|
||||
print(now(), "STEP 1: loaded model", model)
|
||||
logger.info("STEP 1: Loading model {}".format(model))
|
||||
nlp = spacy.load(model)
|
||||
|
||||
# check the length of the nlp vectors
|
||||
|
@ -79,64 +88,68 @@ def main(
|
|||
)
|
||||
|
||||
# STEP 2: create prior probabilities from WP
|
||||
print()
|
||||
if loc_prior_prob:
|
||||
print(now(), "STEP 2: reading prior probabilities from", loc_prior_prob)
|
||||
else:
|
||||
if not prior_prob_path.exists():
|
||||
# It takes about 2h to process 1000M lines of Wikipedia XML dump
|
||||
loc_prior_prob = output_dir / "prior_prob.csv"
|
||||
print(now(), "STEP 2: writing prior probabilities at", loc_prior_prob)
|
||||
wp.read_prior_probs(wp_xml, loc_prior_prob, limit=limit)
|
||||
logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path))
|
||||
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit)
|
||||
logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path))
|
||||
|
||||
# STEP 3: deduce entity frequencies from WP (takes only a few minutes)
|
||||
print()
|
||||
print(now(), "STEP 3: calculating entity frequencies")
|
||||
loc_entity_freq = output_dir / "entity_freq.csv"
|
||||
wp.write_entity_counts(loc_prior_prob, loc_entity_freq, to_print=False)
|
||||
logger.info("STEP 3: calculating entity frequencies")
|
||||
wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False)
|
||||
|
||||
loc_kb = output_dir / "kb"
|
||||
|
||||
# STEP 4: reading entity descriptions and definitions from WikiData or from file
|
||||
print()
|
||||
if loc_entity_defs and loc_entity_desc:
|
||||
read_raw = False
|
||||
print(now(), "STEP 4a: reading entity definitions from", loc_entity_defs)
|
||||
print(now(), "STEP 4b: reading entity descriptions from", loc_entity_desc)
|
||||
else:
|
||||
# STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
|
||||
message = " and descriptions" if not descriptions_from_wikipedia else ""
|
||||
if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()):
|
||||
# It takes about 10h to process 55M lines of Wikidata JSON dump
|
||||
read_raw = True
|
||||
loc_entity_defs = output_dir / "entity_defs.csv"
|
||||
loc_entity_desc = output_dir / "entity_descriptions.csv"
|
||||
print(now(), "STEP 4: parsing wikidata for entity definitions and descriptions")
|
||||
logger.info("STEP 4: parsing wikidata for entity definitions" + message)
|
||||
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
|
||||
wd_json,
|
||||
limit,
|
||||
to_print=False,
|
||||
lang=lang,
|
||||
parse_descriptions=(not descriptions_from_wikipedia),
|
||||
)
|
||||
wd.write_entity_files(entity_defs_path, title_to_id)
|
||||
if not descriptions_from_wikipedia:
|
||||
wd.write_entity_description_files(entity_descr_path, id_to_descr)
|
||||
logger.info("STEP 4: read entity definitions" + message)
|
||||
|
||||
# STEP 5: creating the actual KB
|
||||
# STEP 5: Getting gold entities from wikipedia
|
||||
message = " and descriptions" if descriptions_from_wikipedia else ""
|
||||
if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()):
|
||||
logger.info("STEP 5: parsing wikipedia for gold entities" + message)
|
||||
training_set_creator.create_training_examples_and_descriptions(
|
||||
wp_xml,
|
||||
entity_defs_path,
|
||||
entity_descr_path,
|
||||
training_entities_path,
|
||||
parse_descriptions=descriptions_from_wikipedia,
|
||||
limit=limit,
|
||||
)
|
||||
logger.info("STEP 5: read gold entities" + message)
|
||||
|
||||
# STEP 6: creating the actual KB
|
||||
# It takes ca. 30 minutes to pretrain the entity embeddings
|
||||
print()
|
||||
print(now(), "STEP 5: creating the KB at", loc_kb)
|
||||
logger.info("STEP 6: creating the KB at {}".format(kb_path))
|
||||
kb = kb_creator.create_kb(
|
||||
nlp=nlp,
|
||||
max_entities_per_alias=max_per_alias,
|
||||
min_entity_freq=min_freq,
|
||||
min_occ=min_pair,
|
||||
entity_def_output=loc_entity_defs,
|
||||
entity_descr_output=loc_entity_desc,
|
||||
count_input=loc_entity_freq,
|
||||
prior_prob_input=loc_prior_prob,
|
||||
wikidata_input=wd_json,
|
||||
entity_def_input=entity_defs_path,
|
||||
entity_descr_path=entity_descr_path,
|
||||
count_input=entity_freq_path,
|
||||
prior_prob_input=prior_prob_path,
|
||||
entity_vector_length=entity_vector_length,
|
||||
limit=limit,
|
||||
read_raw_data=read_raw,
|
||||
)
|
||||
if read_raw:
|
||||
print(" - wrote entity definitions to", loc_entity_defs)
|
||||
print(" - wrote writing entity descriptions to", loc_entity_desc)
|
||||
|
||||
kb.dump(loc_kb)
|
||||
nlp.to_disk(output_dir / "nlp")
|
||||
kb.dump(kb_path)
|
||||
nlp.to_disk(output_dir / KB_MODEL_DIR)
|
||||
|
||||
print()
|
||||
print(now(), "Done!")
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
|
||||
plac.call(main)
|
||||
|
|
|
@ -1,17 +1,19 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import bz2
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import datetime
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
|
||||
|
||||
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True):
|
||||
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
|
||||
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
|
||||
|
||||
lang = "en"
|
||||
site_filter = "enwiki"
|
||||
site_filter = '{}wiki'.format(lang)
|
||||
|
||||
# properties filter (currently disabled to get ALL data)
|
||||
prop_filter = dict()
|
||||
|
@ -24,18 +26,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
|
|||
parse_properties = False
|
||||
parse_sitelinks = True
|
||||
parse_labels = False
|
||||
parse_descriptions = True
|
||||
parse_aliases = False
|
||||
parse_claims = False
|
||||
|
||||
with bz2.open(wikidata_file, mode="rb") as file:
|
||||
line = file.readline()
|
||||
cnt = 0
|
||||
while line and (not limit or cnt < limit):
|
||||
if cnt % 1000000 == 0:
|
||||
print(
|
||||
datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump"
|
||||
)
|
||||
with gzip.open(wikidata_file, mode='rb') as file:
|
||||
for cnt, line in enumerate(file):
|
||||
if limit and cnt >= limit:
|
||||
break
|
||||
if cnt % 500000 == 0:
|
||||
logger.info("processed {} lines of WikiData dump".format(cnt))
|
||||
clean_line = line.strip()
|
||||
if clean_line.endswith(b","):
|
||||
clean_line = clean_line[:-1]
|
||||
|
@ -134,8 +133,19 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
|
|||
|
||||
if to_print:
|
||||
print()
|
||||
line = file.readline()
|
||||
cnt += 1
|
||||
print(datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump")
|
||||
|
||||
return title_to_id, id_to_descr
|
||||
|
||||
|
||||
def write_entity_files(entity_def_output, title_to_id):
|
||||
with entity_def_output.open("w", encoding="utf8") as id_file:
|
||||
id_file.write("WP_title" + "|" + "WD_id" + "\n")
|
||||
for title, qid in title_to_id.items():
|
||||
id_file.write(title + "|" + str(qid) + "\n")
|
||||
|
||||
|
||||
def write_entity_description_files(entity_descr_output, id_to_descr):
|
||||
with entity_descr_output.open("w", encoding="utf8") as descr_file:
|
||||
descr_file.write("WD_id" + "|" + "description" + "\n")
|
||||
for qid, descr in id_to_descr.items():
|
||||
descr_file.write(str(qid) + "|" + descr + "\n")
|
||||
|
|
|
@ -11,124 +11,84 @@ from https://dumps.wikimedia.org/enwiki/latest/
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import random
|
||||
import datetime
|
||||
import logging
|
||||
from pathlib import Path
|
||||
import plac
|
||||
|
||||
from bin.wiki_entity_linking import training_set_creator
|
||||
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
|
||||
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines
|
||||
from bin.wiki_entity_linking.kb_creator import read_nlp_kb
|
||||
|
||||
import spacy
|
||||
from spacy.kb import KnowledgeBase
|
||||
from spacy.util import minibatch, compounding
|
||||
|
||||
|
||||
def now():
|
||||
return datetime.datetime.now()
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@plac.annotations(
|
||||
dir_kb=("Directory with KB, NLP and related files", "positional", None, Path),
|
||||
output_dir=("Output directory", "option", "o", Path),
|
||||
loc_training=("Location to training data", "option", "k", Path),
|
||||
wp_xml=("Path to the downloaded Wikipedia XML dump.", "option", "w", Path),
|
||||
epochs=("Number of training iterations (default 10)", "option", "e", int),
|
||||
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
|
||||
lr=("Learning rate (default 0.005)", "option", "n", float),
|
||||
l2=("L2 regularization", "option", "r", float),
|
||||
train_inst=("# training instances (default 90% of all)", "option", "t", int),
|
||||
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
|
||||
limit=("Optional threshold to limit lines read from WP dump", "option", "l", int),
|
||||
)
|
||||
def main(
|
||||
dir_kb,
|
||||
output_dir=None,
|
||||
loc_training=None,
|
||||
wp_xml=None,
|
||||
epochs=10,
|
||||
dropout=0.5,
|
||||
lr=0.005,
|
||||
l2=1e-6,
|
||||
train_inst=None,
|
||||
dev_inst=None,
|
||||
limit=None,
|
||||
):
|
||||
print(now(), "Creating Entity Linker with Wikipedia and WikiData")
|
||||
print()
|
||||
logger.info("Creating Entity Linker with Wikipedia and WikiData")
|
||||
|
||||
output_dir = Path(output_dir) if output_dir else dir_kb
|
||||
training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE
|
||||
nlp_dir = dir_kb / KB_MODEL_DIR
|
||||
kb_path = output_dir / KB_FILE
|
||||
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
|
||||
|
||||
# STEP 0: set up IO
|
||||
if output_dir and not output_dir.exists():
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir()
|
||||
|
||||
# STEP 1 : load the NLP object
|
||||
nlp_dir = dir_kb / "nlp"
|
||||
print(now(), "STEP 1: loading model from", nlp_dir)
|
||||
nlp = spacy.load(nlp_dir)
|
||||
logger.info("STEP 1: loading model from {}".format(nlp_dir))
|
||||
nlp, kb = read_nlp_kb(nlp_dir, kb_path)
|
||||
|
||||
# check that there is a NER component in the pipeline
|
||||
if "ner" not in nlp.pipe_names:
|
||||
raise ValueError("The `nlp` object should have a pre-trained `ner` component.")
|
||||
|
||||
# STEP 2 : read the KB
|
||||
print()
|
||||
print(now(), "STEP 2: reading the KB from", dir_kb / "kb")
|
||||
kb = KnowledgeBase(vocab=nlp.vocab)
|
||||
kb.load_bulk(dir_kb / "kb")
|
||||
# STEP 2: create a training dataset from WP
|
||||
logger.info("STEP 2: reading training dataset from {}".format(training_path))
|
||||
|
||||
# STEP 3: create a training dataset from WP
|
||||
print()
|
||||
if loc_training:
|
||||
print(now(), "STEP 3: reading training dataset from", loc_training)
|
||||
else:
|
||||
if not wp_xml:
|
||||
raise ValueError(
|
||||
"Either provide a path to a preprocessed training directory, "
|
||||
"or to the original Wikipedia XML dump."
|
||||
)
|
||||
|
||||
if output_dir:
|
||||
loc_training = output_dir / "training_data"
|
||||
else:
|
||||
loc_training = dir_kb / "training_data"
|
||||
if not loc_training.exists():
|
||||
loc_training.mkdir()
|
||||
print(now(), "STEP 3: creating training dataset at", loc_training)
|
||||
|
||||
if limit is not None:
|
||||
print("Warning: reading only", limit, "lines of Wikipedia dump.")
|
||||
|
||||
loc_entity_defs = dir_kb / "entity_defs.csv"
|
||||
training_set_creator.create_training(
|
||||
wikipedia_input=wp_xml,
|
||||
entity_def_input=loc_entity_defs,
|
||||
training_output=loc_training,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# STEP 4: parse the training data
|
||||
print()
|
||||
print(now(), "STEP 4: parse the training & evaluation data")
|
||||
|
||||
# for training, get pos & neg instances that correspond to entries in the kb
|
||||
print("Parsing training data, limit =", train_inst)
|
||||
train_data = training_set_creator.read_training(
|
||||
nlp=nlp, training_dir=loc_training, dev=False, limit=train_inst, kb=kb
|
||||
nlp=nlp,
|
||||
entity_file_path=training_path,
|
||||
dev=False,
|
||||
limit=train_inst,
|
||||
kb=kb,
|
||||
)
|
||||
|
||||
print("Training on", len(train_data), "articles")
|
||||
print()
|
||||
|
||||
print("Parsing dev testing data, limit =", dev_inst)
|
||||
# for testing, get all pos instances, whether or not they are in the kb
|
||||
dev_data = training_set_creator.read_training(
|
||||
nlp=nlp, training_dir=loc_training, dev=True, limit=dev_inst, kb=None
|
||||
nlp=nlp,
|
||||
entity_file_path=training_path,
|
||||
dev=True,
|
||||
limit=dev_inst,
|
||||
kb=kb,
|
||||
)
|
||||
|
||||
print("Dev testing on", len(dev_data), "articles")
|
||||
print()
|
||||
|
||||
# STEP 5: create and train the entity linking pipe
|
||||
print()
|
||||
print(now(), "STEP 5: training Entity Linking pipe")
|
||||
# STEP 3: create and train the entity linking pipe
|
||||
logger.info("STEP 3: training Entity Linking pipe")
|
||||
|
||||
el_pipe = nlp.create_pipe(
|
||||
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
|
||||
|
@ -142,275 +102,70 @@ def main(
|
|||
optimizer.learn_rate = lr
|
||||
optimizer.L2 = l2
|
||||
|
||||
if not train_data:
|
||||
print("Did not find any training data")
|
||||
else:
|
||||
for itn in range(epochs):
|
||||
random.shuffle(train_data)
|
||||
losses = {}
|
||||
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
||||
batchnr = 0
|
||||
logger.info("Training on {} articles".format(len(train_data)))
|
||||
logger.info("Dev testing on {} articles".format(len(dev_data)))
|
||||
|
||||
with nlp.disable_pipes(*other_pipes):
|
||||
for batch in batches:
|
||||
try:
|
||||
docs, golds = zip(*batch)
|
||||
nlp.update(
|
||||
docs=docs,
|
||||
golds=golds,
|
||||
sgd=optimizer,
|
||||
drop=dropout,
|
||||
losses=losses,
|
||||
)
|
||||
batchnr += 1
|
||||
except Exception as e:
|
||||
print("Error updating batch:", e)
|
||||
|
||||
if batchnr > 0:
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = True
|
||||
dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
|
||||
losses["entity_linker"] = losses["entity_linker"] / batchnr
|
||||
print(
|
||||
"Epoch, train loss",
|
||||
itn,
|
||||
round(losses["entity_linker"], 2),
|
||||
" / dev accuracy avg",
|
||||
round(dev_acc_context, 3),
|
||||
)
|
||||
|
||||
# STEP 6: measure the performance of our trained pipe on an independent dev set
|
||||
print()
|
||||
if len(dev_data):
|
||||
print()
|
||||
print(now(), "STEP 6: performance measurement of Entity Linking pipe")
|
||||
print()
|
||||
|
||||
counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
|
||||
dev_data, kb
|
||||
)
|
||||
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
|
||||
|
||||
oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
|
||||
print("dev accuracy oracle:", round(acc_o, 3), oracle_by_label)
|
||||
|
||||
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
|
||||
print("dev accuracy random:", round(acc_r, 3), random_by_label)
|
||||
|
||||
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
|
||||
print("dev accuracy prior:", round(acc_p, 3), prior_by_label)
|
||||
|
||||
# using only context
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = False
|
||||
dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
|
||||
context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
|
||||
print("dev accuracy context:", round(dev_acc_context, 3), context_by_label)
|
||||
|
||||
# measuring combined accuracy (prior + context)
|
||||
el_pipe.cfg["incl_context"] = True
|
||||
el_pipe.cfg["incl_prior"] = True
|
||||
dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
|
||||
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
|
||||
print("dev accuracy prior+context:", round(dev_acc_combo, 3), combo_by_label)
|
||||
|
||||
# STEP 7: apply the EL pipe on a toy example
|
||||
print()
|
||||
print(now(), "STEP 7: applying Entity Linking to toy example")
|
||||
print()
|
||||
run_el_toy_example(nlp=nlp)
|
||||
|
||||
# STEP 8: write the NLP pipeline (including entity linker) to file
|
||||
if output_dir:
|
||||
print()
|
||||
nlp_loc = output_dir / "nlp"
|
||||
print(now(), "STEP 8: Writing trained NLP to", nlp_loc)
|
||||
nlp.to_disk(nlp_loc)
|
||||
print()
|
||||
|
||||
print()
|
||||
print(now(), "Done!")
|
||||
|
||||
|
||||
def _measure_acc(data, el_pipe=None, error_analysis=False):
|
||||
# If the docs in the data require further processing with an entity linker, set el_pipe
|
||||
correct_by_label = dict()
|
||||
incorrect_by_label = dict()
|
||||
|
||||
docs = [d for d, g in data if len(d) > 0]
|
||||
if el_pipe is not None:
|
||||
docs = list(el_pipe.pipe(docs))
|
||||
golds = [g for d, g in data if len(d) > 0]
|
||||
|
||||
for doc, gold in zip(docs, golds):
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
# only evaluating on positive examples
|
||||
for gold_kb, value in kb_dict.items():
|
||||
if value:
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
|
||||
for ent in doc.ents:
|
||||
ent_label = ent.label_
|
||||
pred_entity = ent.kb_id_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
offset = _offset(start, end)
|
||||
gold_entity = correct_entries_per_article.get(offset, None)
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
if gold_entity == pred_entity:
|
||||
correct = correct_by_label.get(ent_label, 0)
|
||||
correct_by_label[ent_label] = correct + 1
|
||||
else:
|
||||
incorrect = incorrect_by_label.get(ent_label, 0)
|
||||
incorrect_by_label[ent_label] = incorrect + 1
|
||||
if error_analysis:
|
||||
print(ent.text, "in", doc)
|
||||
print(
|
||||
"Predicted",
|
||||
pred_entity,
|
||||
"should have been",
|
||||
gold_entity,
|
||||
)
|
||||
print()
|
||||
|
||||
except Exception as e:
|
||||
print("Error assessing accuracy", e)
|
||||
|
||||
acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
|
||||
return acc, acc_by_label
|
||||
|
||||
|
||||
def _measure_baselines(data, kb):
|
||||
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
|
||||
counts_d = dict()
|
||||
|
||||
random_correct_d = dict()
|
||||
random_incorrect_d = dict()
|
||||
|
||||
oracle_correct_d = dict()
|
||||
oracle_incorrect_d = dict()
|
||||
|
||||
prior_correct_d = dict()
|
||||
prior_incorrect_d = dict()
|
||||
|
||||
docs = [d for d, g in data if len(d) > 0]
|
||||
golds = [g for d, g in data if len(d) > 0]
|
||||
|
||||
for doc, gold in zip(docs, golds):
|
||||
try:
|
||||
correct_entries_per_article = dict()
|
||||
for entity, kb_dict in gold.links.items():
|
||||
start, end = entity
|
||||
for gold_kb, value in kb_dict.items():
|
||||
# only evaluating on positive examples
|
||||
if value:
|
||||
offset = _offset(start, end)
|
||||
correct_entries_per_article[offset] = gold_kb
|
||||
|
||||
for ent in doc.ents:
|
||||
label = ent.label_
|
||||
start = ent.start_char
|
||||
end = ent.end_char
|
||||
offset = _offset(start, end)
|
||||
gold_entity = correct_entries_per_article.get(offset, None)
|
||||
|
||||
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
|
||||
if gold_entity is not None:
|
||||
counts_d[label] = counts_d.get(label, 0) + 1
|
||||
candidates = kb.get_candidates(ent.text)
|
||||
oracle_candidate = ""
|
||||
best_candidate = ""
|
||||
random_candidate = ""
|
||||
if candidates:
|
||||
scores = []
|
||||
|
||||
for c in candidates:
|
||||
scores.append(c.prior_prob)
|
||||
if c.entity_ == gold_entity:
|
||||
oracle_candidate = c.entity_
|
||||
|
||||
best_index = scores.index(max(scores))
|
||||
best_candidate = candidates[best_index].entity_
|
||||
random_candidate = random.choice(candidates).entity_
|
||||
|
||||
if gold_entity == best_candidate:
|
||||
prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
|
||||
else:
|
||||
prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
|
||||
|
||||
if gold_entity == random_candidate:
|
||||
random_correct_d[label] = random_correct_d.get(label, 0) + 1
|
||||
else:
|
||||
random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
|
||||
|
||||
if gold_entity == oracle_candidate:
|
||||
oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
|
||||
else:
|
||||
oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
|
||||
|
||||
except Exception as e:
|
||||
print("Error assessing accuracy", e)
|
||||
|
||||
acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
|
||||
acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
|
||||
acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
|
||||
|
||||
return (
|
||||
counts_d,
|
||||
acc_rand,
|
||||
acc_rand_d,
|
||||
acc_prior,
|
||||
acc_prior_d,
|
||||
acc_oracle,
|
||||
acc_oracle_d,
|
||||
dev_baseline_accuracies = measure_baselines(
|
||||
dev_data, kb
|
||||
)
|
||||
|
||||
logger.info("Dev Baseline Accuracies:")
|
||||
logger.info(dev_baseline_accuracies.report_accuracy("random"))
|
||||
logger.info(dev_baseline_accuracies.report_accuracy("prior"))
|
||||
logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
|
||||
|
||||
def _offset(start, end):
|
||||
return "{}_{}".format(start, end)
|
||||
for itn in range(epochs):
|
||||
random.shuffle(train_data)
|
||||
losses = {}
|
||||
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
|
||||
batchnr = 0
|
||||
|
||||
with nlp.disable_pipes(*other_pipes):
|
||||
for batch in batches:
|
||||
try:
|
||||
docs, golds = zip(*batch)
|
||||
nlp.update(
|
||||
docs=docs,
|
||||
golds=golds,
|
||||
sgd=optimizer,
|
||||
drop=dropout,
|
||||
losses=losses,
|
||||
)
|
||||
batchnr += 1
|
||||
except Exception as e:
|
||||
logger.error("Error updating batch:" + str(e))
|
||||
if batchnr > 0:
|
||||
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
|
||||
measure_performance(dev_data, kb, el_pipe)
|
||||
|
||||
def calculate_acc(correct_by_label, incorrect_by_label):
|
||||
acc_by_label = dict()
|
||||
total_correct = 0
|
||||
total_incorrect = 0
|
||||
all_keys = set()
|
||||
all_keys.update(correct_by_label.keys())
|
||||
all_keys.update(incorrect_by_label.keys())
|
||||
for label in sorted(all_keys):
|
||||
correct = correct_by_label.get(label, 0)
|
||||
incorrect = incorrect_by_label.get(label, 0)
|
||||
total_correct += correct
|
||||
total_incorrect += incorrect
|
||||
if correct == incorrect == 0:
|
||||
acc_by_label[label] = 0
|
||||
else:
|
||||
acc_by_label[label] = correct / (correct + incorrect)
|
||||
acc = 0
|
||||
if not (total_correct == total_incorrect == 0):
|
||||
acc = total_correct / (total_correct + total_incorrect)
|
||||
return acc, acc_by_label
|
||||
# STEP 4: measure the performance of our trained pipe on an independent dev set
|
||||
logger.info("STEP 4: performance measurement of Entity Linking pipe")
|
||||
measure_performance(dev_data, kb, el_pipe)
|
||||
|
||||
# STEP 5: apply the EL pipe on a toy example
|
||||
logger.info("STEP 5: applying Entity Linking to toy example")
|
||||
run_el_toy_example(nlp=nlp)
|
||||
|
||||
if output_dir:
|
||||
# STEP 6: write the NLP pipeline (including entity linker) to file
|
||||
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
|
||||
nlp.to_disk(nlp_output_dir)
|
||||
|
||||
logger.info("Done!")
|
||||
|
||||
|
||||
def check_kb(kb):
|
||||
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
|
||||
candidates = kb.get_candidates(mention)
|
||||
|
||||
print("generating candidates for " + mention + " :")
|
||||
logger.info("generating candidates for " + mention + " :")
|
||||
for c in candidates:
|
||||
print(
|
||||
" ",
|
||||
c.prior_prob,
|
||||
logger.info(" ".join[
|
||||
str(c.prior_prob),
|
||||
c.alias_,
|
||||
"-->",
|
||||
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
|
||||
)
|
||||
print()
|
||||
c.entity_ + " (freq=" + str(c.entity_freq) + ")"
|
||||
])
|
||||
|
||||
|
||||
def run_el_toy_example(nlp):
|
||||
|
@ -421,11 +176,11 @@ def run_el_toy_example(nlp):
|
|||
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
|
||||
)
|
||||
doc = nlp(text)
|
||||
print(text)
|
||||
logger.info(text)
|
||||
for ent in doc.ents:
|
||||
print(" ent", ent.text, ent.label_, ent.kb_id_)
|
||||
print()
|
||||
logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_]))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
|
||||
plac.call(main)
|
||||
|
|
|
@ -5,6 +5,9 @@ import re
|
|||
import bz2
|
||||
import csv
|
||||
import datetime
|
||||
import logging
|
||||
|
||||
from bin.wiki_entity_linking import LOG_FORMAT
|
||||
|
||||
"""
|
||||
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
|
||||
|
@ -13,6 +16,9 @@ Write these results to file for downstream KB and training data generation.
|
|||
|
||||
map_alias_to_link = dict()
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# these will/should be matched ignoring case
|
||||
wiki_namespaces = [
|
||||
"b",
|
||||
|
@ -116,10 +122,6 @@ for ns in wiki_namespaces:
|
|||
ns_regex = re.compile(ns_regex, re.IGNORECASE)
|
||||
|
||||
|
||||
def now():
|
||||
return datetime.datetime.now()
|
||||
|
||||
|
||||
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
|
||||
"""
|
||||
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
|
||||
|
@ -131,7 +133,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
|
|||
cnt = 0
|
||||
while line and (not limit or cnt < limit):
|
||||
if cnt % 25000000 == 0:
|
||||
print(now(), "processed", cnt, "lines of Wikipedia XML dump")
|
||||
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
||||
clean_line = line.strip().decode("utf-8")
|
||||
|
||||
aliases, entities, normalizations = get_wp_links(clean_line)
|
||||
|
@ -141,7 +143,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
|
|||
|
||||
line = file.readline()
|
||||
cnt += 1
|
||||
print(now(), "processed", cnt, "lines of Wikipedia XML dump")
|
||||
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
|
||||
|
||||
# write all aliases and their entities and count occurrences to file
|
||||
with prior_prob_output.open("w", encoding="utf8") as outputfile:
|
||||
|
|
|
@ -3,11 +3,9 @@
|
|||
"""
|
||||
from __future__ import unicode_literals
|
||||
import plac
|
||||
import tqdm
|
||||
import attr
|
||||
from pathlib import Path
|
||||
import re
|
||||
import sys
|
||||
import json
|
||||
|
||||
import spacy
|
||||
|
@ -23,7 +21,7 @@ import itertools
|
|||
import random
|
||||
import numpy.random
|
||||
|
||||
import conll17_ud_eval
|
||||
from bin.ud import conll17_ud_eval
|
||||
|
||||
import spacy.lang.zh
|
||||
import spacy.lang.ja
|
||||
|
@ -394,6 +392,9 @@ class TreebankPaths(object):
|
|||
limit=("Size limit", "option", "n", int),
|
||||
)
|
||||
def main(ud_dir, parses_dir, config, corpus, limit=0):
|
||||
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
|
||||
import tqdm
|
||||
|
||||
paths = TreebankPaths(ud_dir, corpus)
|
||||
if not (parses_dir / corpus).exists():
|
||||
(parses_dir / corpus).mkdir()
|
||||
|
|
|
@ -18,7 +18,6 @@ import random
|
|||
import spacy
|
||||
import thinc.extra.datasets
|
||||
from spacy.util import minibatch, use_gpu, compounding
|
||||
import tqdm
|
||||
from spacy._ml import Tok2Vec
|
||||
from spacy.pipeline import TextCategorizer
|
||||
import numpy
|
||||
|
@ -107,6 +106,9 @@ def create_pipeline(width, embed_size, vectors_model):
|
|||
|
||||
|
||||
def train_tensorizer(nlp, texts, dropout, n_iter):
|
||||
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
|
||||
import tqdm
|
||||
|
||||
tensorizer = nlp.create_pipe("tensorizer")
|
||||
nlp.add_pipe(tensorizer)
|
||||
optimizer = nlp.begin_training()
|
||||
|
@ -120,6 +122,9 @@ def train_tensorizer(nlp, texts, dropout, n_iter):
|
|||
|
||||
|
||||
def train_textcat(nlp, n_texts, n_iter=10):
|
||||
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
|
||||
import tqdm
|
||||
|
||||
textcat = nlp.get_pipe("textcat")
|
||||
tok2vec_weights = textcat.model.tok2vec.to_bytes()
|
||||
(train_texts, train_cats), (dev_texts, dev_cats) = load_textcat_data(limit=n_texts)
|
||||
|
|
|
@ -13,7 +13,6 @@ import numpy
|
|||
import plac
|
||||
import spacy
|
||||
import tensorflow as tf
|
||||
import tqdm
|
||||
from tensorflow.contrib.tensorboard.plugins.projector import (
|
||||
visualize_embeddings,
|
||||
ProjectorConfig,
|
||||
|
@ -36,6 +35,9 @@ from tensorflow.contrib.tensorboard.plugins.projector import (
|
|||
),
|
||||
)
|
||||
def main(vectors_loc, out_loc, name="spaCy_vectors"):
|
||||
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
|
||||
import tqdm
|
||||
|
||||
meta_file = "{}.tsv".format(name)
|
||||
out_meta_file = path.join(out_loc, meta_file)
|
||||
|
||||
|
|
2
setup.py
2
setup.py
|
@ -140,7 +140,7 @@ def gzip_language_data(root, source):
|
|||
base = Path(root) / source
|
||||
for jsonfile in base.glob("**/*.json"):
|
||||
outfile = jsonfile.with_suffix(jsonfile.suffix + ".gz")
|
||||
if outfile.is_file() and outfile.stat().st_ctime > jsonfile.stat().st_ctime:
|
||||
if outfile.is_file() and outfile.stat().st_mtime > jsonfile.stat().st_mtime:
|
||||
# If the gz is newer it doesn't need updating
|
||||
print("Skipping {}, already compressed".format(jsonfile))
|
||||
continue
|
||||
|
|
|
@ -6,7 +6,7 @@ import re
|
|||
from ...gold import iob_to_biluo
|
||||
|
||||
|
||||
def conllu2json(input_data, n_sents=10, use_morphology=False, lang=None):
|
||||
def conllu2json(input_data, n_sents=10, use_morphology=False, lang=None, **_):
|
||||
"""
|
||||
Convert conllu files into JSON format for use with train cli.
|
||||
use_morphology parameter enables appending morphology to tags, which is
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import re
|
||||
from wasabi import Printer
|
||||
|
||||
from ...gold import iob_to_biluo
|
||||
|
|
|
@ -3,7 +3,6 @@ from __future__ import unicode_literals
|
|||
|
||||
import plac
|
||||
import math
|
||||
from tqdm import tqdm
|
||||
import numpy
|
||||
from ast import literal_eval
|
||||
from pathlib import Path
|
||||
|
@ -109,6 +108,9 @@ def open_file(loc):
|
|||
|
||||
|
||||
def read_attrs_from_deprecated(freqs_loc, clusters_loc):
|
||||
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
|
||||
from tqdm import tqdm
|
||||
|
||||
if freqs_loc is not None:
|
||||
with msg.loading("Counting frequencies..."):
|
||||
probs, _ = read_freqs(freqs_loc)
|
||||
|
@ -186,6 +188,9 @@ def add_vectors(nlp, vectors_loc, prune_vectors):
|
|||
|
||||
|
||||
def read_vectors(vectors_loc):
|
||||
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
|
||||
from tqdm import tqdm
|
||||
|
||||
f = open_file(vectors_loc)
|
||||
shape = tuple(int(size) for size in next(f).split())
|
||||
vectors_data = numpy.zeros(shape=shape, dtype="f")
|
||||
|
@ -202,6 +207,9 @@ def read_vectors(vectors_loc):
|
|||
|
||||
|
||||
def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50):
|
||||
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
|
||||
from tqdm import tqdm
|
||||
|
||||
counts = PreshCounter()
|
||||
total = 0
|
||||
with freqs_loc.open() as f:
|
||||
|
@ -231,6 +239,9 @@ def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50):
|
|||
|
||||
|
||||
def read_clusters(clusters_loc):
|
||||
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
|
||||
from tqdm import tqdm
|
||||
|
||||
clusters = {}
|
||||
if ftfy is None:
|
||||
user_warning(Warnings.W004)
|
||||
|
|
|
@ -7,7 +7,6 @@ import srsly
|
|||
import cProfile
|
||||
import pstats
|
||||
import sys
|
||||
import tqdm
|
||||
import itertools
|
||||
import thinc.extra.datasets
|
||||
from wasabi import Printer
|
||||
|
@ -48,6 +47,9 @@ def profile(model, inputs=None, n_texts=10000):
|
|||
|
||||
|
||||
def parse_texts(nlp, texts):
|
||||
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
|
||||
import tqdm
|
||||
|
||||
for doc in nlp.pipe(tqdm.tqdm(texts), batch_size=16):
|
||||
pass
|
||||
|
||||
|
|
|
@ -4,7 +4,6 @@ from __future__ import unicode_literals, division, print_function
|
|||
import plac
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tqdm
|
||||
from thinc.neural._classes.model import Model
|
||||
from timeit import default_timer as timer
|
||||
import shutil
|
||||
|
@ -101,6 +100,10 @@ def train(
|
|||
JSON format. To convert data from other formats, use the `spacy convert`
|
||||
command.
|
||||
"""
|
||||
|
||||
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
|
||||
import tqdm
|
||||
|
||||
msg = Printer()
|
||||
util.fix_random_seed()
|
||||
util.set_env_log(verbose)
|
||||
|
@ -390,6 +393,9 @@ def _score_for_model(meta):
|
|||
|
||||
@contextlib.contextmanager
|
||||
def _create_progress_bar(total):
|
||||
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
|
||||
import tqdm
|
||||
|
||||
if int(os.environ.get("LOG_FRIENDLY", 0)):
|
||||
yield
|
||||
else:
|
||||
|
|
|
@ -452,6 +452,11 @@ class Errors(object):
|
|||
"Make sure that you're passing in absolute token indices, not "
|
||||
"relative token offsets.\nstart: {start}, end: {end}, label: "
|
||||
"{label}, direction: {dir}")
|
||||
E158 = ("Can't add table '{name}' to lookups because it already exists.")
|
||||
E159 = ("Can't find table '{name}' in lookups. Available tables: {tables}")
|
||||
E160 = ("Can't find language data file: {path}")
|
||||
E161 = ("Found an internal inconsistency when predicting entity links. "
|
||||
"This is likely a bug in spaCy, so feel free to open an issue.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
|
|
@ -11,6 +11,12 @@ _hebrew = r"\u0591-\u05F4\uFB1D-\uFB4F"
|
|||
|
||||
_hindi = r"\u0900-\u097F"
|
||||
|
||||
_kannada = r"\u0C80-\u0CFF"
|
||||
|
||||
_tamil = r"\u0B80-\u0BFF"
|
||||
|
||||
_telugu = r"\u0C00-\u0C7F"
|
||||
|
||||
# Latin standard
|
||||
_latin_u_standard = r"A-Z"
|
||||
_latin_l_standard = r"a-z"
|
||||
|
@ -195,7 +201,7 @@ _ukrainian = r"а-щюяіїєґА-ЩЮЯІЇЄҐ"
|
|||
_upper = LATIN_UPPER + _russian_upper + _tatar_upper + _greek_upper + _ukrainian_upper
|
||||
_lower = LATIN_LOWER + _russian_lower + _tatar_lower + _greek_lower + _ukrainian_lower
|
||||
|
||||
_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi
|
||||
_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi + _kannada + _tamil + _telugu
|
||||
|
||||
ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased)
|
||||
ALPHA_LOWER = group_chars(_lower + _uncased)
|
||||
|
|
|
@ -46,6 +46,11 @@ class GreekLemmatizer(object):
|
|||
)
|
||||
return lemmas
|
||||
|
||||
def lookup(self, string):
|
||||
if string in self.lookup_table:
|
||||
return self.lookup_table[string]
|
||||
return string
|
||||
|
||||
|
||||
def lemmatize(string, index, exceptions, rules):
|
||||
string = string.lower()
|
||||
|
|
|
@ -18,6 +18,7 @@ class CroatianDefaults(Language.Defaults):
|
|||
)
|
||||
tokenizer_exceptions = update_exc(BASE_EXCEPTIONS)
|
||||
stop_words = STOP_WORDS
|
||||
resources = {"lemma_lookup": "lemma_lookup.json"}
|
||||
|
||||
|
||||
class Croatian(Language):
|
||||
|
|
1313609
spacy/lang/hr/lemma_lookup.json
Normal file
1313609
spacy/lang/hr/lemma_lookup.json
Normal file
File diff suppressed because it is too large
Load Diff
15
spacy/lang/hr/lemma_lookup_license.txt
Normal file
15
spacy/lang/hr/lemma_lookup_license.txt
Normal file
|
@ -0,0 +1,15 @@
|
|||
The list of Croatian lemmas was extracted from the reldi-tagger repository (https://github.com/clarinsi/reldi-tagger).
|
||||
Reldi-tagger is licesned under the Apache 2.0 licence.
|
||||
|
||||
@InProceedings{ljubesic16-new,
|
||||
author = {Nikola Ljubešić and Filip Klubička and Željko Agić and Ivo-Pavao Jazbec},
|
||||
title = {New Inflectional Lexicons and Training Corpora for Improved Morphosyntactic Annotation of Croatian and Serbian},
|
||||
booktitle = {Proceedings of the Tenth International Conference on Language Resources and Evaluation (LREC 2016)},
|
||||
year = {2016},
|
||||
date = {23-28},
|
||||
location = {Portorož, Slovenia},
|
||||
editor = {Nicoletta Calzolari (Conference Chair) and Khalid Choukri and Thierry Declerck and Sara Goggi and Marko Grobelnik and Bente Maegaard and Joseph Mariani and Helene Mazo and Asuncion Moreno and Jan Odijk and Stelios Piperidis},
|
||||
publisher = {European Language Resources Association (ELRA)},
|
||||
address = {Paris, France},
|
||||
isbn = {978-2-9517408-9-1}
|
||||
}
|
|
@ -1,8 +1,6 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
|
||||
from .stop_words import STOP_WORDS
|
||||
from .lex_attrs import LEX_ATTRS
|
||||
|
|
|
@ -3,7 +3,7 @@ from __future__ import unicode_literals
|
|||
|
||||
from .char_classes import LIST_PUNCT, LIST_ELLIPSES, LIST_QUOTES, LIST_CURRENCY
|
||||
from .char_classes import LIST_ICONS, HYPHENS, CURRENCY, UNITS
|
||||
from .char_classes import CONCAT_QUOTES, ALPHA_LOWER, ALPHA_UPPER, ALPHA
|
||||
from .char_classes import CONCAT_QUOTES, ALPHA_LOWER, ALPHA_UPPER, ALPHA, PUNCT
|
||||
|
||||
|
||||
_prefixes = (
|
||||
|
@ -27,8 +27,8 @@ _suffixes = (
|
|||
r"(?<=°[FfCcKk])\.",
|
||||
r"(?<=[0-9])(?:{c})".format(c=CURRENCY),
|
||||
r"(?<=[0-9])(?:{u})".format(u=UNITS),
|
||||
r"(?<=[0-9{al}{e}(?:{q})])\.".format(
|
||||
al=ALPHA_LOWER, e=r"%²\-\+", q=CONCAT_QUOTES
|
||||
r"(?<=[0-9{al}{e}{p}(?:{q})])\.".format(
|
||||
al=ALPHA_LOWER, e=r"%²\-\+", q=CONCAT_QUOTES, p=PUNCT
|
||||
),
|
||||
r"(?<=[{au}][{au}])\.".format(au=ALPHA_UPPER),
|
||||
]
|
||||
|
|
|
@ -9,6 +9,7 @@ from ..norm_exceptions import BASE_NORMS
|
|||
from ...language import Language
|
||||
from ...attrs import LANG, NORM
|
||||
from ...util import update_exc, add_lookups
|
||||
from .tag_map import TAG_MAP
|
||||
|
||||
# Lemma data note:
|
||||
# Original pairs downloaded from http://www.lexiconista.com/datasets/lemmatization/
|
||||
|
@ -24,6 +25,7 @@ class RomanianDefaults(Language.Defaults):
|
|||
tokenizer_exceptions = update_exc(BASE_EXCEPTIONS, TOKENIZER_EXCEPTIONS)
|
||||
stop_words = STOP_WORDS
|
||||
resources = {"lemma_lookup": "lemma_lookup.json"}
|
||||
tag_map = TAG_MAP
|
||||
|
||||
|
||||
class Romanian(Language):
|
||||
|
|
1654
spacy/lang/ro/tag_map.py
Normal file
1654
spacy/lang/ro/tag_map.py
Normal file
File diff suppressed because it is too large
Load Diff
|
@ -21,6 +21,7 @@ class SerbianDefaults(Language.Defaults):
|
|||
)
|
||||
tokenizer_exceptions = update_exc(BASE_EXCEPTIONS, TOKENIZER_EXCEPTIONS)
|
||||
stop_words = STOP_WORDS
|
||||
resources = {"lemma_lookup": "lemma_lookup.json"}
|
||||
|
||||
|
||||
class Serbian(Language):
|
||||
|
|
|
@ -12,13 +12,14 @@ Example sentences to test spaCy and its language models.
|
|||
|
||||
sentences = [
|
||||
# Translations from English
|
||||
"Apple планира куповину америчког стартапа за $1 милијарду."
|
||||
"Apple планира куповину америчког стартапа за $1 милијарду.",
|
||||
"Беспилотни аутомобили пребацују одговорност осигурања на произвођаче.",
|
||||
"Лондон је велики град у Уједињеном Краљевству.",
|
||||
"Где си ти?",
|
||||
"Ко је председник Француске?",
|
||||
# Serbian common and slang
|
||||
"Moj ћале је инжењер!",
|
||||
"Новак Ђоковић је најбољи тенисер света." "У Пироту има добрих кафана!",
|
||||
"Новак Ђоковић је најбољи тенисер света.",
|
||||
"У Пироту има добрих кафана!",
|
||||
"Музеј Николе Тесле се налази у Београду.",
|
||||
]
|
||||
|
|
253316
spacy/lang/sr/lemma_lookup.json
Executable file
253316
spacy/lang/sr/lemma_lookup.json
Executable file
File diff suppressed because it is too large
Load Diff
32
spacy/lang/sr/lemma_lookup_licence.txt
Normal file
32
spacy/lang/sr/lemma_lookup_licence.txt
Normal file
|
@ -0,0 +1,32 @@
|
|||
Copyright @InProceedings{ljubesic16-new,
|
||||
author = {Nikola Ljubešić and Filip Klubička and Željko Agić and Ivo-Pavao Jazbec},
|
||||
title = {New Inflectional Lexicons and Training Corpora for Improved Morphosyntactic Annotation of Croatian and Serbian},
|
||||
booktitle = {Proceedings of the Tenth International Conference on Language Resources and Evaluation (LREC 2016)},
|
||||
year = {2016},
|
||||
date = {23-28},
|
||||
location = {Portorož, Slovenia},
|
||||
editor = {Nicoletta Calzolari (Conference Chair) and Khalid Choukri and Thierry Declerck and Sara Goggi and Marko Grobelnik and Bente Maegaard and Joseph Mariani and Helene Mazo and Asuncion Moreno and Jan Odijk and Stelios Piperidis},
|
||||
publisher = {European Language Resources Association (ELRA)},
|
||||
address = {Paris, France},
|
||||
isbn = {978-2-9517408-9-1}
|
||||
}
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
|
||||
The licence of Serbian lemmas was adopted from Serbian lexicon:
|
||||
- sr.lexicon (https://github.com/clarinsi/reldi-tagger/blob/master/sr.lexicon)
|
||||
|
||||
Changelog:
|
||||
- Lexicon is translated into cyrilic
|
||||
- Word order is sorted
|
|
@ -24,7 +24,7 @@ class UkrainianDefaults(Language.Defaults):
|
|||
stop_words = STOP_WORDS
|
||||
|
||||
@classmethod
|
||||
def create_lemmatizer(cls, nlp=None):
|
||||
def create_lemmatizer(cls, nlp=None, **kwargs):
|
||||
return UkrainianLemmatizer()
|
||||
|
||||
|
||||
|
|
|
@ -208,6 +208,7 @@ class Language(object):
|
|||
"name": self.vocab.vectors.name,
|
||||
}
|
||||
self._meta["pipeline"] = self.pipe_names
|
||||
self._meta["labels"] = self.pipe_labels
|
||||
return self._meta
|
||||
|
||||
@meta.setter
|
||||
|
@ -248,6 +249,18 @@ class Language(object):
|
|||
"""
|
||||
return [pipe_name for pipe_name, _ in self.pipeline]
|
||||
|
||||
@property
|
||||
def pipe_labels(self):
|
||||
"""Get the labels set by the pipeline components, if available.
|
||||
|
||||
RETURNS (dict): Labels keyed by component name.
|
||||
"""
|
||||
labels = OrderedDict()
|
||||
for name, pipe in self.pipeline:
|
||||
if hasattr(pipe, "labels"):
|
||||
labels[name] = list(pipe.labels)
|
||||
return labels
|
||||
|
||||
def get_pipe(self, name):
|
||||
"""Get a pipeline component for a given component name.
|
||||
|
||||
|
|
127
spacy/lookups.py
127
spacy/lookups.py
|
@ -1,52 +1,157 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from .util import SimpleFrozenDict
|
||||
import srsly
|
||||
from collections import OrderedDict
|
||||
|
||||
from .errors import Errors
|
||||
from .util import SimpleFrozenDict, ensure_path
|
||||
|
||||
|
||||
class Lookups(object):
|
||||
"""Container for large lookup tables and dictionaries, e.g. lemmatization
|
||||
data or tokenizer exception lists. Lookups are available via vocab.lookups,
|
||||
so they can be accessed before the pipeline components are applied (e.g.
|
||||
in the tokenizer and lemmatizer), as well as within the pipeline components
|
||||
via doc.vocab.lookups.
|
||||
|
||||
Important note: At the moment, this class only performs a very basic
|
||||
dictionary lookup. We're planning to replace this with a more efficient
|
||||
implementation. See #3971 for details.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self._tables = {}
|
||||
"""Initialize the Lookups object.
|
||||
|
||||
RETURNS (Lookups): The newly created object.
|
||||
"""
|
||||
self._tables = OrderedDict()
|
||||
|
||||
def __contains__(self, name):
|
||||
"""Check if the lookups contain a table of a given name. Delegates to
|
||||
Lookups.has_table.
|
||||
|
||||
name (unicode): Name of the table.
|
||||
RETURNS (bool): Whether a table of that name exists.
|
||||
"""
|
||||
return self.has_table(name)
|
||||
|
||||
def __len__(self):
|
||||
"""RETURNS (int): The number of tables in the lookups."""
|
||||
return len(self._tables)
|
||||
|
||||
@property
|
||||
def tables(self):
|
||||
"""RETURNS (list): Names of all tables in the lookups."""
|
||||
return list(self._tables.keys())
|
||||
|
||||
def add_table(self, name, data=SimpleFrozenDict()):
|
||||
"""Add a new table to the lookups. Raises an error if the table exists.
|
||||
|
||||
name (unicode): Unique name of table.
|
||||
data (dict): Optional data to add to the table.
|
||||
RETURNS (Table): The newly added table.
|
||||
"""
|
||||
if name in self.tables:
|
||||
raise ValueError("Table '{}' already exists".format(name))
|
||||
raise ValueError(Errors.E158.format(name=name))
|
||||
table = Table(name=name)
|
||||
table.update(data)
|
||||
self._tables[name] = table
|
||||
return table
|
||||
|
||||
def get_table(self, name):
|
||||
"""Get a table. Raises an error if the table doesn't exist.
|
||||
|
||||
name (unicode): Name of the table.
|
||||
RETURNS (Table): The table.
|
||||
"""
|
||||
if name not in self._tables:
|
||||
raise KeyError("Can't find table '{}'".format(name))
|
||||
raise KeyError(Errors.E159.format(name=name, tables=self.tables))
|
||||
return self._tables[name]
|
||||
|
||||
def remove_table(self, name):
|
||||
"""Remove a table. Raises an error if the table doesn't exist.
|
||||
|
||||
name (unicode): The name to remove.
|
||||
RETURNS (Table): The removed table.
|
||||
"""
|
||||
if name not in self._tables:
|
||||
raise KeyError(Errors.E159.format(name=name, tables=self.tables))
|
||||
return self._tables.pop(name)
|
||||
|
||||
def has_table(self, name):
|
||||
"""Check if the lookups contain a table of a given name.
|
||||
|
||||
name (unicode): Name of the table.
|
||||
RETURNS (bool): Whether a table of that name exists.
|
||||
"""
|
||||
return name in self._tables
|
||||
|
||||
def to_bytes(self, exclude=tuple(), **kwargs):
|
||||
raise NotImplementedError
|
||||
"""Serialize the lookups to a bytestring.
|
||||
|
||||
exclude (list): String names of serialization fields to exclude.
|
||||
RETURNS (bytes): The serialized Lookups.
|
||||
"""
|
||||
return srsly.msgpack_dumps(self._tables)
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
|
||||
raise NotImplementedError
|
||||
"""Load the lookups from a bytestring.
|
||||
|
||||
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||
raise NotImplementedError
|
||||
exclude (list): String names of serialization fields to exclude.
|
||||
RETURNS (bytes): The loaded Lookups.
|
||||
"""
|
||||
self._tables = OrderedDict()
|
||||
msg = srsly.msgpack_loads(bytes_data)
|
||||
for key, value in msg.items():
|
||||
self._tables[key] = Table.from_dict(value)
|
||||
return self
|
||||
|
||||
def from_disk(self, path, exclude=tuple(), **kwargs):
|
||||
raise NotImplementedError
|
||||
def to_disk(self, path, **kwargs):
|
||||
"""Save the lookups to a directory as lookups.bin.
|
||||
|
||||
path (unicode / Path): The file path.
|
||||
"""
|
||||
if len(self._tables):
|
||||
path = ensure_path(path)
|
||||
filepath = path / "lookups.bin"
|
||||
with filepath.open("wb") as file_:
|
||||
file_.write(self.to_bytes())
|
||||
|
||||
def from_disk(self, path, **kwargs):
|
||||
"""Load lookups from a directory containing a lookups.bin.
|
||||
|
||||
path (unicode / Path): The file path.
|
||||
RETURNS (Lookups): The loaded lookups.
|
||||
"""
|
||||
path = ensure_path(path)
|
||||
filepath = path / "lookups.bin"
|
||||
if filepath.exists():
|
||||
with filepath.open("rb") as file_:
|
||||
data = file_.read()
|
||||
return self.from_bytes(data)
|
||||
return self
|
||||
|
||||
|
||||
class Table(dict):
|
||||
class Table(OrderedDict):
|
||||
"""A table in the lookups. Subclass of builtin dict that implements a
|
||||
slightly more consistent and unified API.
|
||||
"""
|
||||
@classmethod
|
||||
def from_dict(cls, data, name=None):
|
||||
self = cls(name=name)
|
||||
self.update(data)
|
||||
return self
|
||||
|
||||
def __init__(self, name=None):
|
||||
"""Initialize a new table.
|
||||
|
||||
name (unicode): Optional table name for reference.
|
||||
RETURNS (Table): The newly created object.
|
||||
"""
|
||||
OrderedDict.__init__(self)
|
||||
self.name = name
|
||||
|
||||
def set(self, key, value):
|
||||
"""Set new key/value pair. Same as table[key] = value."""
|
||||
self[key] = value
|
||||
|
|
|
@ -67,7 +67,7 @@ class Pipe(object):
|
|||
"""
|
||||
self.require_model()
|
||||
predictions = self.predict([doc])
|
||||
if isinstance(predictions, tuple) and len(tuple) == 2:
|
||||
if isinstance(predictions, tuple) and len(predictions) == 2:
|
||||
scores, tensors = predictions
|
||||
self.set_annotations([doc], scores, tensor=tensors)
|
||||
else:
|
||||
|
@ -1062,8 +1062,15 @@ cdef class DependencyParser(Parser):
|
|||
|
||||
@property
|
||||
def labels(self):
|
||||
labels = set()
|
||||
# Get the labels from the model by looking at the available moves
|
||||
return tuple(set(move.split("-")[1] for move in self.move_names))
|
||||
for move in self.move_names:
|
||||
if "-" in move:
|
||||
label = move.split("-")[1]
|
||||
if "||" in label:
|
||||
label = label.split("||")[1]
|
||||
labels.add(label)
|
||||
return tuple(sorted(labels))
|
||||
|
||||
|
||||
cdef class EntityRecognizer(Parser):
|
||||
|
@ -1098,8 +1105,9 @@ cdef class EntityRecognizer(Parser):
|
|||
def labels(self):
|
||||
# Get the labels from the model by looking at the available moves, e.g.
|
||||
# B-PERSON, I-PERSON, L-PERSON, U-PERSON
|
||||
return tuple(set(move.split("-")[1] for move in self.move_names
|
||||
if move[0] in ("B", "I", "L", "U")))
|
||||
labels = set(move.split("-")[1] for move in self.move_names
|
||||
if move[0] in ("B", "I", "L", "U"))
|
||||
return tuple(sorted(labels))
|
||||
|
||||
|
||||
class EntityLinker(Pipe):
|
||||
|
@ -1275,7 +1283,7 @@ class EntityLinker(Pipe):
|
|||
# this will set all prior probabilities to 0 if they should be excluded from the model
|
||||
prior_probs = xp.asarray([c.prior_prob for c in candidates])
|
||||
if not self.cfg.get("incl_prior", True):
|
||||
prior_probs = xp.asarray([[0.0] for c in candidates])
|
||||
prior_probs = xp.asarray([0.0 for c in candidates])
|
||||
scores = prior_probs
|
||||
|
||||
# add in similarity from the context
|
||||
|
@ -1288,6 +1296,8 @@ class EntityLinker(Pipe):
|
|||
|
||||
# cosine similarity
|
||||
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
|
||||
if sims.shape != prior_probs.shape:
|
||||
raise ValueError(Errors.E161)
|
||||
scores = prior_probs + sims - (prior_probs*sims)
|
||||
|
||||
# TODO: thresholding
|
||||
|
@ -1361,7 +1371,16 @@ class Sentencizer(object):
|
|||
"""
|
||||
|
||||
name = "sentencizer"
|
||||
default_punct_chars = [".", "!", "?"]
|
||||
default_punct_chars = ['!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹',
|
||||
'।', '॥', '၊', '။', '።', '፧', '፨', '᙮', '᜵', '᜶', '᠃', '᠉', '᥄',
|
||||
'᥅', '᪨', '᪩', '᪪', '᪫', '᭚', '᭛', '᭞', '᭟', '᰻', '᰼', '᱾', '᱿',
|
||||
'‼', '‽', '⁇', '⁈', '⁉', '⸮', '⸼', '꓿', '꘎', '꘏', '꛳', '꛷', '꡶',
|
||||
'꡷', '꣎', '꣏', '꤯', '꧈', '꧉', '꩝', '꩞', '꩟', '꫰', '꫱', '꯫', '﹒',
|
||||
'﹖', '﹗', '!', '.', '?', '𐩖', '𐩗', '𑁇', '𑁈', '𑂾', '𑂿', '𑃀',
|
||||
'𑃁', '𑅁', '𑅂', '𑅃', '𑇅', '𑇆', '𑇍', '𑇞', '𑇟', '𑈸', '𑈹', '𑈻', '𑈼',
|
||||
'𑊩', '𑑋', '𑑌', '𑗂', '𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐',
|
||||
'𑗑', '𑗒', '𑗓', '𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂',
|
||||
'𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈']
|
||||
|
||||
def __init__(self, punct_chars=None, **kwargs):
|
||||
"""Initialize the sentencizer.
|
||||
|
@ -1372,7 +1391,10 @@ class Sentencizer(object):
|
|||
|
||||
DOCS: https://spacy.io/api/sentencizer#init
|
||||
"""
|
||||
self.punct_chars = punct_chars or self.default_punct_chars
|
||||
if punct_chars:
|
||||
self.punct_chars = set(punct_chars)
|
||||
else:
|
||||
self.punct_chars = set(self.default_punct_chars)
|
||||
|
||||
def __call__(self, doc):
|
||||
"""Apply the sentencizer to a Doc and set Token.is_sent_start.
|
||||
|
@ -1404,7 +1426,7 @@ class Sentencizer(object):
|
|||
|
||||
DOCS: https://spacy.io/api/sentencizer#to_bytes
|
||||
"""
|
||||
return srsly.msgpack_dumps({"punct_chars": self.punct_chars})
|
||||
return srsly.msgpack_dumps({"punct_chars": list(self.punct_chars)})
|
||||
|
||||
def from_bytes(self, bytes_data, **kwargs):
|
||||
"""Load the sentencizer from a bytestring.
|
||||
|
@ -1415,7 +1437,7 @@ class Sentencizer(object):
|
|||
DOCS: https://spacy.io/api/sentencizer#from_bytes
|
||||
"""
|
||||
cfg = srsly.msgpack_loads(bytes_data)
|
||||
self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
|
||||
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
|
||||
return self
|
||||
|
||||
def to_disk(self, path, exclude=tuple(), **kwargs):
|
||||
|
@ -1425,7 +1447,7 @@ class Sentencizer(object):
|
|||
"""
|
||||
path = util.ensure_path(path)
|
||||
path = path.with_suffix(".json")
|
||||
srsly.write_json(path, {"punct_chars": self.punct_chars})
|
||||
srsly.write_json(path, {"punct_chars": list(self.punct_chars)})
|
||||
|
||||
|
||||
def from_disk(self, path, exclude=tuple(), **kwargs):
|
||||
|
@ -1436,7 +1458,7 @@ class Sentencizer(object):
|
|||
path = util.ensure_path(path)
|
||||
path = path.with_suffix(".json")
|
||||
cfg = srsly.read_json(path)
|
||||
self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
|
||||
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
|
||||
return self
|
||||
|
||||
|
||||
|
|
|
@ -103,6 +103,11 @@ def he_tokenizer():
|
|||
return get_lang_class("he").Defaults.create_tokenizer()
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def hr_tokenizer():
|
||||
return get_lang_class("hr").Defaults.create_tokenizer()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def hu_tokenizer():
|
||||
return get_lang_class("hu").Defaults.create_tokenizer()
|
||||
|
|
|
@ -99,6 +99,53 @@ def test_doc_retokenize_spans_merge_tokens(en_tokenizer):
|
|||
assert doc[0].ent_type_ == "GPE"
|
||||
|
||||
|
||||
def test_doc_retokenize_spans_merge_tokens_default_attrs(en_tokenizer):
|
||||
text = "The players start."
|
||||
heads = [1, 1, 0, -1]
|
||||
tokens = en_tokenizer(text)
|
||||
doc = get_doc(
|
||||
tokens.vocab,
|
||||
words=[t.text for t in tokens],
|
||||
tags=["DT", "NN", "VBZ", "."],
|
||||
pos=["DET", "NOUN", "VERB", "PUNCT"],
|
||||
heads=heads,
|
||||
)
|
||||
assert len(doc) == 4
|
||||
assert doc[0].text == "The"
|
||||
assert doc[0].tag_ == "DT"
|
||||
assert doc[0].pos_ == "DET"
|
||||
with doc.retokenize() as retokenizer:
|
||||
retokenizer.merge(doc[0:2])
|
||||
assert len(doc) == 3
|
||||
assert doc[0].text == "The players"
|
||||
assert doc[0].tag_ == "NN"
|
||||
assert doc[0].pos_ == "NOUN"
|
||||
assert doc[0].lemma_ == "The players"
|
||||
doc = get_doc(
|
||||
tokens.vocab,
|
||||
words=[t.text for t in tokens],
|
||||
tags=["DT", "NN", "VBZ", "."],
|
||||
pos=["DET", "NOUN", "VERB", "PUNCT"],
|
||||
heads=heads,
|
||||
)
|
||||
assert len(doc) == 4
|
||||
assert doc[0].text == "The"
|
||||
assert doc[0].tag_ == "DT"
|
||||
assert doc[0].pos_ == "DET"
|
||||
with doc.retokenize() as retokenizer:
|
||||
retokenizer.merge(doc[0:2])
|
||||
retokenizer.merge(doc[2:4])
|
||||
assert len(doc) == 2
|
||||
assert doc[0].text == "The players"
|
||||
assert doc[0].tag_ == "NN"
|
||||
assert doc[0].pos_ == "NOUN"
|
||||
assert doc[0].lemma_ == "The players"
|
||||
assert doc[1].text == "start ."
|
||||
assert doc[1].tag_ == "VBZ"
|
||||
assert doc[1].pos_ == "VERB"
|
||||
assert doc[1].lemma_ == "start ."
|
||||
|
||||
|
||||
def test_doc_retokenize_spans_merge_heads(en_tokenizer):
|
||||
text = "I found a pilates class near work."
|
||||
heads = [1, 0, 2, 1, -3, -1, -1, -6]
|
||||
|
@ -182,7 +229,7 @@ def test_doc_retokenize_spans_entity_merge(en_tokenizer):
|
|||
assert len(doc) == 15
|
||||
|
||||
|
||||
def test_doc_retokenize_spans_entity_merge_iob():
|
||||
def test_doc_retokenize_spans_entity_merge_iob(en_vocab):
|
||||
# Test entity IOB stays consistent after merging
|
||||
words = ["a", "b", "c", "d", "e"]
|
||||
doc = Doc(Vocab(), words=words)
|
||||
|
@ -195,10 +242,23 @@ def test_doc_retokenize_spans_entity_merge_iob():
|
|||
assert doc[2].ent_iob_ == "I"
|
||||
assert doc[3].ent_iob_ == "B"
|
||||
with doc.retokenize() as retokenizer:
|
||||
retokenizer.merge(doc[0:1])
|
||||
retokenizer.merge(doc[0:2])
|
||||
assert len(doc) == len(words) - 1
|
||||
assert doc[0].ent_iob_ == "B"
|
||||
assert doc[1].ent_iob_ == "I"
|
||||
|
||||
# Test that IOB stays consistent with provided IOB
|
||||
words = ["a", "b", "c", "d", "e"]
|
||||
doc = Doc(Vocab(), words=words)
|
||||
with doc.retokenize() as retokenizer:
|
||||
attrs = {"ent_type": "ent-abc", "ent_iob": 1}
|
||||
retokenizer.merge(doc[0:3], attrs=attrs)
|
||||
retokenizer.merge(doc[3:5], attrs=attrs)
|
||||
assert doc[0].ent_iob_ == "B"
|
||||
assert doc[1].ent_iob_ == "I"
|
||||
|
||||
# if no parse/heads, the first word in the span is the root and provides
|
||||
# default values
|
||||
words = ["a", "b", "c", "d", "e", "f", "g", "h", "i"]
|
||||
doc = Doc(Vocab(), words=words)
|
||||
doc.ents = [
|
||||
|
@ -215,7 +275,47 @@ def test_doc_retokenize_spans_entity_merge_iob():
|
|||
retokenizer.merge(doc[7:9])
|
||||
assert len(doc) == 6
|
||||
assert doc[3].ent_iob_ == "B"
|
||||
assert doc[4].ent_iob_ == "I"
|
||||
assert doc[3].ent_type_ == "ent-de"
|
||||
assert doc[4].ent_iob_ == "B"
|
||||
assert doc[4].ent_type_ == "ent-fg"
|
||||
|
||||
# if there is a parse, span.root provides default values
|
||||
words = ["a", "b", "c", "d", "e", "f", "g", "h", "i"]
|
||||
heads = [0, -1, 1, -3, -4, -5, -1, -7, -8]
|
||||
ents = [(3, 5, "ent-de"), (5, 7, "ent-fg")]
|
||||
deps = ["dep"] * len(words)
|
||||
en_vocab.strings.add("ent-de")
|
||||
en_vocab.strings.add("ent-fg")
|
||||
en_vocab.strings.add("dep")
|
||||
doc = get_doc(en_vocab, words=words, heads=heads, deps=deps, ents=ents)
|
||||
assert doc[2:4].root == doc[3] # root of 'c d' is d
|
||||
assert doc[4:6].root == doc[4] # root is 'e f' is e
|
||||
with doc.retokenize() as retokenizer:
|
||||
retokenizer.merge(doc[2:4])
|
||||
retokenizer.merge(doc[4:6])
|
||||
retokenizer.merge(doc[7:9])
|
||||
assert len(doc) == 6
|
||||
assert doc[2].ent_iob_ == "B"
|
||||
assert doc[2].ent_type_ == "ent-de"
|
||||
assert doc[3].ent_iob_ == "I"
|
||||
assert doc[3].ent_type_ == "ent-de"
|
||||
assert doc[4].ent_iob_ == "B"
|
||||
assert doc[4].ent_type_ == "ent-fg"
|
||||
|
||||
# check that B is preserved if span[start] is B
|
||||
words = ["a", "b", "c", "d", "e", "f", "g", "h", "i"]
|
||||
heads = [0, -1, 1, 1, -4, -5, -1, -7, -8]
|
||||
ents = [(3, 5, "ent-de"), (5, 7, "ent-de")]
|
||||
deps = ["dep"] * len(words)
|
||||
doc = get_doc(en_vocab, words=words, heads=heads, deps=deps, ents=ents)
|
||||
with doc.retokenize() as retokenizer:
|
||||
retokenizer.merge(doc[3:5])
|
||||
retokenizer.merge(doc[5:7])
|
||||
assert len(doc) == 7
|
||||
assert doc[3].ent_iob_ == "B"
|
||||
assert doc[3].ent_type_ == "ent-de"
|
||||
assert doc[4].ent_iob_ == "B"
|
||||
assert doc[4].ent_type_ == "ent-de"
|
||||
|
||||
|
||||
def test_doc_retokenize_spans_sentence_update_after_merge(en_tokenizer):
|
||||
|
|
|
@ -173,6 +173,21 @@ def test_span_as_doc(doc):
|
|||
assert span_doc[0].idx == 0
|
||||
|
||||
|
||||
def test_span_as_doc_user_data(doc):
|
||||
"""Test that the user_data can be preserved (but not by default). """
|
||||
my_key = "my_info"
|
||||
my_value = 342
|
||||
doc.user_data[my_key] = my_value
|
||||
|
||||
span = doc[4:10]
|
||||
span_doc_with = span.as_doc(copy_user_data=True)
|
||||
span_doc_without = span.as_doc()
|
||||
|
||||
assert doc.user_data.get(my_key, None) is my_value
|
||||
assert span_doc_with.user_data.get(my_key, None) is my_value
|
||||
assert span_doc_without.user_data.get(my_key, None) is None
|
||||
|
||||
|
||||
def test_span_string_label_kb_id(doc):
|
||||
span = Span(doc, 0, 1, label="hello", kb_id="Q342")
|
||||
assert span.label_ == "hello"
|
||||
|
|
|
@ -133,3 +133,9 @@ def test_en_tokenizer_splits_em_dash_infix(en_tokenizer):
|
|||
assert tokens[6].text == "Puddleton"
|
||||
assert tokens[7].text == "?"
|
||||
assert tokens[8].text == "\u2014"
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text,length", [("_MATH_", 3), ("_MATH_.", 4)])
|
||||
def test_final_period(en_tokenizer, text, length):
|
||||
tokens = en_tokenizer(text)
|
||||
assert len(tokens) == length
|
||||
|
|
20
spacy/tests/lang/hr/test_lemma.py
Normal file
20
spacy/tests/lang/hr/test_lemma.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string,lemma",
|
||||
[
|
||||
("trčao", "trčati"),
|
||||
("adekvatnim", "adekvatan"),
|
||||
("dekontaminacijama", "dekontaminacija"),
|
||||
("filologovih", "filologov"),
|
||||
("je", "biti"),
|
||||
("se", "sebe"),
|
||||
],
|
||||
)
|
||||
def test_hr_lemmatizer_lookup_assigns(hr_tokenizer, string, lemma):
|
||||
tokens = hr_tokenizer(string)
|
||||
assert tokens[0].lemma_ == lemma
|
20
spacy/tests/lang/sr/test_lemmatizer.py
Normal file
20
spacy/tests/lang/sr/test_lemmatizer.py
Normal file
|
@ -0,0 +1,20 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"string,lemma",
|
||||
[
|
||||
("најадекватнији", "адекватан"),
|
||||
("матурирао", "матурирати"),
|
||||
("планираћемо", "планирати"),
|
||||
("певају", "певати"),
|
||||
("нама", "ми"),
|
||||
("се", "себе"),
|
||||
],
|
||||
)
|
||||
def test_sr_lemmatizer_lookup_assigns(sr_tokenizer, string, lemma):
|
||||
tokens = sr_tokenizer(string)
|
||||
assert tokens[0].lemma_ == lemma
|
|
@ -68,3 +68,20 @@ def test_add_label_deserializes_correctly():
|
|||
assert ner1.moves.n_moves == ner2.moves.n_moves
|
||||
for i in range(ner1.moves.n_moves):
|
||||
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pipe_cls,n_moves", [(DependencyParser, 5), (EntityRecognizer, 4)]
|
||||
)
|
||||
def test_add_label_get_label(pipe_cls, n_moves):
|
||||
"""Test that added labels are returned correctly. This test was added to
|
||||
test for a bug in DependencyParser.labels that'd cause it to fail when
|
||||
splitting the move names.
|
||||
"""
|
||||
labels = ["A", "B", "C"]
|
||||
pipe = pipe_cls(Vocab())
|
||||
for label in labels:
|
||||
pipe.add_label(label)
|
||||
assert len(pipe.move_names) == len(labels) * n_moves
|
||||
pipe_labels = sorted(list(pipe.labels))
|
||||
assert pipe_labels == labels
|
||||
|
|
|
@ -128,3 +128,19 @@ def test_pipe_base_class_add_label(nlp, component):
|
|||
assert label in pipe.labels
|
||||
else:
|
||||
assert pipe.labels == (label,)
|
||||
|
||||
|
||||
def test_pipe_labels(nlp):
|
||||
input_labels = {
|
||||
"ner": ["PERSON", "ORG", "GPE"],
|
||||
"textcat": ["POSITIVE", "NEGATIVE"],
|
||||
}
|
||||
for name, labels in input_labels.items():
|
||||
pipe = nlp.create_pipe(name)
|
||||
for label in labels:
|
||||
pipe.add_label(label)
|
||||
assert len(pipe.labels) == len(labels)
|
||||
nlp.add_pipe(pipe)
|
||||
assert len(nlp.pipe_labels) == len(input_labels)
|
||||
for name, labels in nlp.pipe_labels.items():
|
||||
assert sorted(input_labels[name]) == sorted(labels)
|
||||
|
|
|
@ -81,7 +81,7 @@ def test_sentencizer_custom_punct(en_vocab, punct_chars, words, sent_starts, n_s
|
|||
def test_sentencizer_serialize_bytes(en_vocab):
|
||||
punct_chars = [".", "~", "+"]
|
||||
sentencizer = Sentencizer(punct_chars=punct_chars)
|
||||
assert sentencizer.punct_chars == punct_chars
|
||||
assert sentencizer.punct_chars == set(punct_chars)
|
||||
bytes_data = sentencizer.to_bytes()
|
||||
new_sentencizer = Sentencizer().from_bytes(bytes_data)
|
||||
assert new_sentencizer.punct_chars == punct_chars
|
||||
assert new_sentencizer.punct_chars == set(punct_chars)
|
||||
|
|
|
@ -13,26 +13,25 @@ from spacy.lemmatizer import Lemmatizer
|
|||
from spacy.symbols import ORTH, LEMMA, POS, VERB, VerbForm_part
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
def test_issue1061():
|
||||
'''Test special-case works after tokenizing. Was caching problem.'''
|
||||
text = 'I like _MATH_ even _MATH_ when _MATH_, except when _MATH_ is _MATH_! but not _MATH_.'
|
||||
"""Test special-case works after tokenizing. Was caching problem."""
|
||||
text = "I like _MATH_ even _MATH_ when _MATH_, except when _MATH_ is _MATH_! but not _MATH_."
|
||||
tokenizer = English.Defaults.create_tokenizer()
|
||||
doc = tokenizer(text)
|
||||
assert 'MATH' in [w.text for w in doc]
|
||||
assert '_MATH_' not in [w.text for w in doc]
|
||||
assert "MATH" in [w.text for w in doc]
|
||||
assert "_MATH_" not in [w.text for w in doc]
|
||||
|
||||
tokenizer.add_special_case('_MATH_', [{ORTH: '_MATH_'}])
|
||||
tokenizer.add_special_case("_MATH_", [{ORTH: "_MATH_"}])
|
||||
doc = tokenizer(text)
|
||||
assert '_MATH_' in [w.text for w in doc]
|
||||
assert 'MATH' not in [w.text for w in doc]
|
||||
assert "_MATH_" in [w.text for w in doc]
|
||||
assert "MATH" not in [w.text for w in doc]
|
||||
|
||||
# For sanity, check it works when pipeline is clean.
|
||||
tokenizer = English.Defaults.create_tokenizer()
|
||||
tokenizer.add_special_case('_MATH_', [{ORTH: '_MATH_'}])
|
||||
tokenizer.add_special_case("_MATH_", [{ORTH: "_MATH_"}])
|
||||
doc = tokenizer(text)
|
||||
assert '_MATH_' in [w.text for w in doc]
|
||||
assert 'MATH' not in [w.text for w in doc]
|
||||
assert "_MATH_" in [w.text for w in doc]
|
||||
assert "MATH" not in [w.text for w in doc]
|
||||
|
||||
|
||||
@pytest.mark.xfail(
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy.matcher import Matcher
|
||||
from spacy.tokens import Doc
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy.matcher import Matcher
|
||||
from spacy.tokens import Doc
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy.matcher import PhraseMatcher
|
||||
from spacy.tokens import Doc
|
||||
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy.matcher import Matcher
|
||||
from spacy.tokens import Doc
|
||||
|
||||
|
|
|
@ -2,44 +2,37 @@
|
|||
from __future__ import unicode_literals
|
||||
|
||||
from spacy.lang.en import English
|
||||
|
||||
import spacy
|
||||
from spacy.tokenizer import Tokenizer
|
||||
from spacy import util
|
||||
|
||||
from spacy.tests.util import make_tempdir
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
||||
def test_issue4190():
|
||||
test_string = "Test c."
|
||||
|
||||
# Load default language
|
||||
nlp_1 = English()
|
||||
doc_1a = nlp_1(test_string)
|
||||
result_1a = [token.text for token in doc_1a]
|
||||
|
||||
result_1a = [token.text for token in doc_1a] # noqa: F841
|
||||
# Modify tokenizer
|
||||
customize_tokenizer(nlp_1)
|
||||
doc_1b = nlp_1(test_string)
|
||||
result_1b = [token.text for token in doc_1b]
|
||||
|
||||
# Save and Reload
|
||||
with make_tempdir() as model_dir:
|
||||
nlp_1.to_disk(model_dir)
|
||||
nlp_2 = spacy.load(model_dir)
|
||||
|
||||
nlp_2 = util.load_model(model_dir)
|
||||
# This should be the modified tokenizer
|
||||
doc_2 = nlp_2(test_string)
|
||||
result_2 = [token.text for token in doc_2]
|
||||
|
||||
assert result_1b == result_2
|
||||
|
||||
|
||||
def customize_tokenizer(nlp):
|
||||
prefix_re = spacy.util.compile_prefix_regex(nlp.Defaults.prefixes)
|
||||
suffix_re = spacy.util.compile_suffix_regex(nlp.Defaults.suffixes)
|
||||
infix_re = spacy.util.compile_infix_regex(nlp.Defaults.infixes)
|
||||
|
||||
# remove all exceptions where a single letter is followed by a period (e.g. 'h.')
|
||||
prefix_re = util.compile_prefix_regex(nlp.Defaults.prefixes)
|
||||
suffix_re = util.compile_suffix_regex(nlp.Defaults.suffixes)
|
||||
infix_re = util.compile_infix_regex(nlp.Defaults.infixes)
|
||||
# Remove all exceptions where a single letter is followed by a period (e.g. 'h.')
|
||||
exceptions = {
|
||||
k: v
|
||||
for k, v in dict(nlp.Defaults.tokenizer_exceptions).items()
|
||||
|
@ -53,5 +46,4 @@ def customize_tokenizer(nlp):
|
|||
infix_finditer=infix_re.finditer,
|
||||
token_match=nlp.tokenizer.token_match,
|
||||
)
|
||||
|
||||
nlp.tokenizer = new_tokenizer
|
||||
|
|
12
spacy/tests/regression/test_issue4272.py
Normal file
12
spacy/tests/regression/test_issue4272.py
Normal file
|
@ -0,0 +1,12 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from spacy.lang.el import Greek
|
||||
|
||||
|
||||
def test_issue4272():
|
||||
"""Test that lookup table can be accessed from Token.lemma if no POS tags
|
||||
are available."""
|
||||
nlp = Greek()
|
||||
doc = nlp("Χθες")
|
||||
assert doc[0].lemma_
|
28
spacy/tests/regression/test_issue4278.py
Normal file
28
spacy/tests/regression/test_issue4278.py
Normal file
|
@ -0,0 +1,28 @@
|
|||
# coding: utf8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import pytest
|
||||
from spacy.language import Language
|
||||
from spacy.pipeline import Pipe
|
||||
|
||||
|
||||
class DummyPipe(Pipe):
|
||||
def __init__(self):
|
||||
self.model = "dummy_model"
|
||||
|
||||
def predict(self, docs):
|
||||
return ([1, 2, 3], [4, 5, 6])
|
||||
|
||||
def set_annotations(self, docs, scores, tensor=None):
|
||||
return docs
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def nlp():
|
||||
return Language()
|
||||
|
||||
|
||||
def test_multiple_predictions(nlp):
|
||||
doc = nlp.make_doc("foo")
|
||||
dummy_pipe = DummyPipe()
|
||||
dummy_pipe(doc)
|
|
@ -41,8 +41,8 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
|
|||
parser.model, _ = parser.Model(10)
|
||||
new_parser = Parser(en_vocab)
|
||||
new_parser.model, _ = new_parser.Model(10)
|
||||
new_parser = new_parser.from_bytes(parser.to_bytes())
|
||||
assert new_parser.to_bytes() == parser.to_bytes()
|
||||
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
|
||||
assert new_parser.to_bytes(exclude=["vocab"]) == parser.to_bytes(exclude=["vocab"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("Parser", test_parsers)
|
||||
|
@ -55,8 +55,8 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
|
|||
parser_d = Parser(en_vocab)
|
||||
parser_d.model, _ = parser_d.Model(0)
|
||||
parser_d = parser_d.from_disk(file_path)
|
||||
parser_bytes = parser.to_bytes(exclude=["model"])
|
||||
parser_d_bytes = parser_d.to_bytes(exclude=["model"])
|
||||
parser_bytes = parser.to_bytes(exclude=["model", "vocab"])
|
||||
parser_d_bytes = parser_d.to_bytes(exclude=["model", "vocab"])
|
||||
assert parser_bytes == parser_d_bytes
|
||||
|
||||
|
||||
|
@ -64,7 +64,7 @@ def test_to_from_bytes(parser, blank_parser):
|
|||
assert parser.model is not True
|
||||
assert blank_parser.model is True
|
||||
assert blank_parser.moves.n_moves != parser.moves.n_moves
|
||||
bytes_data = parser.to_bytes()
|
||||
bytes_data = parser.to_bytes(exclude=["vocab"])
|
||||
blank_parser.from_bytes(bytes_data)
|
||||
assert blank_parser.model is not True
|
||||
assert blank_parser.moves.n_moves == parser.moves.n_moves
|
||||
|
@ -97,9 +97,9 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
|
|||
def test_serialize_tensorizer_roundtrip_bytes(en_vocab):
|
||||
tensorizer = Tensorizer(en_vocab)
|
||||
tensorizer.model = tensorizer.Model()
|
||||
tensorizer_b = tensorizer.to_bytes()
|
||||
tensorizer_b = tensorizer.to_bytes(exclude=["vocab"])
|
||||
new_tensorizer = Tensorizer(en_vocab).from_bytes(tensorizer_b)
|
||||
assert new_tensorizer.to_bytes() == tensorizer_b
|
||||
assert new_tensorizer.to_bytes(exclude=["vocab"]) == tensorizer_b
|
||||
|
||||
|
||||
def test_serialize_tensorizer_roundtrip_disk(en_vocab):
|
||||
|
@ -109,13 +109,15 @@ def test_serialize_tensorizer_roundtrip_disk(en_vocab):
|
|||
file_path = d / "tensorizer"
|
||||
tensorizer.to_disk(file_path)
|
||||
tensorizer_d = Tensorizer(en_vocab).from_disk(file_path)
|
||||
assert tensorizer.to_bytes() == tensorizer_d.to_bytes()
|
||||
assert tensorizer.to_bytes(exclude=["vocab"]) == tensorizer_d.to_bytes(
|
||||
exclude=["vocab"]
|
||||
)
|
||||
|
||||
|
||||
def test_serialize_textcat_empty(en_vocab):
|
||||
# See issue #1105
|
||||
textcat = TextCategorizer(en_vocab, labels=["ENTITY", "ACTION", "MODIFIER"])
|
||||
textcat.to_bytes()
|
||||
textcat.to_bytes(exclude=["vocab"])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("Parser", test_parsers)
|
||||
|
@ -128,13 +130,17 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
|
|||
parser = Parser(en_vocab)
|
||||
parser.model, _ = parser.Model(0)
|
||||
parser.cfg["foo"] = "bar"
|
||||
new_parser = get_new_parser().from_bytes(parser.to_bytes())
|
||||
new_parser = get_new_parser().from_bytes(parser.to_bytes(exclude=["vocab"]))
|
||||
assert "foo" in new_parser.cfg
|
||||
new_parser = get_new_parser().from_bytes(parser.to_bytes(), exclude=["cfg"])
|
||||
new_parser = get_new_parser().from_bytes(
|
||||
parser.to_bytes(exclude=["vocab"]), exclude=["cfg"]
|
||||
)
|
||||
assert "foo" not in new_parser.cfg
|
||||
new_parser = get_new_parser().from_bytes(parser.to_bytes(exclude=["cfg"]))
|
||||
new_parser = get_new_parser().from_bytes(
|
||||
parser.to_bytes(exclude=["cfg"]), exclude=["vocab"]
|
||||
)
|
||||
assert "foo" not in new_parser.cfg
|
||||
with pytest.raises(ValueError):
|
||||
parser.to_bytes(cfg=False)
|
||||
parser.to_bytes(cfg=False, exclude=["vocab"])
|
||||
with pytest.raises(ValueError):
|
||||
get_new_parser().from_bytes(parser.to_bytes(), cfg=False)
|
||||
get_new_parser().from_bytes(parser.to_bytes(exclude=["vocab"]), cfg=False)
|
||||
|
|
|
@ -12,12 +12,14 @@ test_strings = [([], []), (["rats", "are", "cute"], ["i", "like", "rats"])]
|
|||
test_strings_attrs = [(["rats", "are", "cute"], "Hello")]
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.parametrize("text", ["rat"])
|
||||
def test_serialize_vocab(en_vocab, text):
|
||||
text_hash = en_vocab.strings.add(text)
|
||||
vocab_bytes = en_vocab.to_bytes()
|
||||
vocab_bytes = en_vocab.to_bytes(exclude=["lookups"])
|
||||
new_vocab = Vocab().from_bytes(vocab_bytes)
|
||||
assert new_vocab.strings[text_hash] == text
|
||||
assert new_vocab.to_bytes(exclude=["lookups"]) == vocab_bytes
|
||||
|
||||
|
||||
@pytest.mark.parametrize("strings1,strings2", test_strings)
|
||||
|
|
|
@ -3,6 +3,9 @@ from __future__ import unicode_literals
|
|||
|
||||
import pytest
|
||||
from spacy.lookups import Lookups
|
||||
from spacy.vocab import Vocab
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
||||
def test_lookups_api():
|
||||
|
@ -10,6 +13,7 @@ def test_lookups_api():
|
|||
data = {"foo": "bar", "hello": "world"}
|
||||
lookups = Lookups()
|
||||
lookups.add_table(table_name, data)
|
||||
assert len(lookups) == 1
|
||||
assert table_name in lookups
|
||||
assert lookups.has_table(table_name)
|
||||
table = lookups.get_table(table_name)
|
||||
|
@ -22,5 +26,91 @@ def test_lookups_api():
|
|||
assert len(table) == 3
|
||||
with pytest.raises(KeyError):
|
||||
lookups.get_table("xyz")
|
||||
# with pytest.raises(ValueError):
|
||||
# lookups.add_table(table_name)
|
||||
with pytest.raises(ValueError):
|
||||
lookups.add_table(table_name)
|
||||
table = lookups.remove_table(table_name)
|
||||
assert table.name == table_name
|
||||
assert len(lookups) == 0
|
||||
assert table_name not in lookups
|
||||
with pytest.raises(KeyError):
|
||||
lookups.get_table(table_name)
|
||||
|
||||
|
||||
# This fails on Python 3.5
|
||||
@pytest.mark.xfail
|
||||
def test_lookups_to_from_bytes():
|
||||
lookups = Lookups()
|
||||
lookups.add_table("table1", {"foo": "bar", "hello": "world"})
|
||||
lookups.add_table("table2", {"a": 1, "b": 2, "c": 3})
|
||||
lookups_bytes = lookups.to_bytes()
|
||||
new_lookups = Lookups()
|
||||
new_lookups.from_bytes(lookups_bytes)
|
||||
assert len(new_lookups) == 2
|
||||
assert "table1" in new_lookups
|
||||
assert "table2" in new_lookups
|
||||
table1 = new_lookups.get_table("table1")
|
||||
assert len(table1) == 2
|
||||
assert table1.get("foo") == "bar"
|
||||
table2 = new_lookups.get_table("table2")
|
||||
assert len(table2) == 3
|
||||
assert table2.get("b") == 2
|
||||
assert new_lookups.to_bytes() == lookups_bytes
|
||||
|
||||
|
||||
# This fails on Python 3.5
|
||||
@pytest.mark.xfail
|
||||
def test_lookups_to_from_disk():
|
||||
lookups = Lookups()
|
||||
lookups.add_table("table1", {"foo": "bar", "hello": "world"})
|
||||
lookups.add_table("table2", {"a": 1, "b": 2, "c": 3})
|
||||
with make_tempdir() as tmpdir:
|
||||
lookups.to_disk(tmpdir)
|
||||
new_lookups = Lookups()
|
||||
new_lookups.from_disk(tmpdir)
|
||||
assert len(new_lookups) == 2
|
||||
assert "table1" in new_lookups
|
||||
assert "table2" in new_lookups
|
||||
table1 = new_lookups.get_table("table1")
|
||||
assert len(table1) == 2
|
||||
assert table1.get("foo") == "bar"
|
||||
table2 = new_lookups.get_table("table2")
|
||||
assert len(table2) == 3
|
||||
assert table2.get("b") == 2
|
||||
|
||||
|
||||
# This fails on Python 3.5
|
||||
@pytest.mark.xfail
|
||||
def test_lookups_to_from_bytes_via_vocab():
|
||||
table_name = "test"
|
||||
vocab = Vocab()
|
||||
vocab.lookups.add_table(table_name, {"foo": "bar", "hello": "world"})
|
||||
assert len(vocab.lookups) == 1
|
||||
assert table_name in vocab.lookups
|
||||
vocab_bytes = vocab.to_bytes()
|
||||
new_vocab = Vocab()
|
||||
new_vocab.from_bytes(vocab_bytes)
|
||||
assert len(new_vocab.lookups) == 1
|
||||
assert table_name in new_vocab.lookups
|
||||
table = new_vocab.lookups.get_table(table_name)
|
||||
assert len(table) == 2
|
||||
assert table.get("hello") == "world"
|
||||
assert new_vocab.to_bytes() == vocab_bytes
|
||||
|
||||
|
||||
# This fails on Python 3.5
|
||||
@pytest.mark.xfail
|
||||
def test_lookups_to_from_disk_via_vocab():
|
||||
table_name = "test"
|
||||
vocab = Vocab()
|
||||
vocab.lookups.add_table(table_name, {"foo": "bar", "hello": "world"})
|
||||
assert len(vocab.lookups) == 1
|
||||
assert table_name in vocab.lookups
|
||||
with make_tempdir() as tmpdir:
|
||||
vocab.to_disk(tmpdir)
|
||||
new_vocab = Vocab()
|
||||
new_vocab.from_disk(tmpdir)
|
||||
assert len(new_vocab.lookups) == 1
|
||||
assert table_name in new_vocab.lookups
|
||||
table = new_vocab.lookups.get_table(table_name)
|
||||
assert len(table) == 2
|
||||
assert table.get("hello") == "world"
|
||||
|
|
|
@ -16,10 +16,10 @@ cdef class Tokenizer:
|
|||
cdef PreshMap _specials
|
||||
cpdef readonly Vocab vocab
|
||||
|
||||
cdef public object token_match
|
||||
cdef public object prefix_search
|
||||
cdef public object suffix_search
|
||||
cdef public object infix_finditer
|
||||
cdef object _token_match
|
||||
cdef object _prefix_search
|
||||
cdef object _suffix_search
|
||||
cdef object _infix_finditer
|
||||
cdef object _rules
|
||||
|
||||
cpdef Doc tokens_from_list(self, list strings)
|
||||
|
|
|
@ -61,6 +61,38 @@ cdef class Tokenizer:
|
|||
for chunk, substrings in sorted(rules.items()):
|
||||
self.add_special_case(chunk, substrings)
|
||||
|
||||
property token_match:
|
||||
def __get__(self):
|
||||
return self._token_match
|
||||
|
||||
def __set__(self, token_match):
|
||||
self._token_match = token_match
|
||||
self._flush_cache()
|
||||
|
||||
property prefix_search:
|
||||
def __get__(self):
|
||||
return self._prefix_search
|
||||
|
||||
def __set__(self, prefix_search):
|
||||
self._prefix_search = prefix_search
|
||||
self._flush_cache()
|
||||
|
||||
property suffix_search:
|
||||
def __get__(self):
|
||||
return self._suffix_search
|
||||
|
||||
def __set__(self, suffix_search):
|
||||
self._suffix_search = suffix_search
|
||||
self._flush_cache()
|
||||
|
||||
property infix_finditer:
|
||||
def __get__(self):
|
||||
return self._infix_finditer
|
||||
|
||||
def __set__(self, infix_finditer):
|
||||
self._infix_finditer = infix_finditer
|
||||
self._flush_cache()
|
||||
|
||||
def __reduce__(self):
|
||||
args = (self.vocab,
|
||||
self._rules,
|
||||
|
@ -141,9 +173,23 @@ cdef class Tokenizer:
|
|||
for text in texts:
|
||||
yield self(text)
|
||||
|
||||
def _flush_cache(self):
|
||||
self._reset_cache([key for key in self._cache if not key in self._specials])
|
||||
|
||||
def _reset_cache(self, keys):
|
||||
for k in keys:
|
||||
del self._cache[k]
|
||||
if not k in self._specials:
|
||||
cached = <_Cached*>self._cache.get(k)
|
||||
if cached is not NULL:
|
||||
self.mem.free(cached)
|
||||
|
||||
def _reset_specials(self):
|
||||
for k in self._specials:
|
||||
cached = <_Cached*>self._specials.get(k)
|
||||
del self._specials[k]
|
||||
if cached is not NULL:
|
||||
self.mem.free(cached)
|
||||
|
||||
cdef int _try_cache(self, hash_t key, Doc tokens) except -1:
|
||||
cached = <_Cached*>self._cache.get(key)
|
||||
|
@ -183,6 +229,9 @@ cdef class Tokenizer:
|
|||
while string and len(string) != last_size:
|
||||
if self.token_match and self.token_match(string):
|
||||
break
|
||||
if self._specials.get(hash_string(string)) != NULL:
|
||||
has_special[0] = 1
|
||||
break
|
||||
last_size = len(string)
|
||||
pre_len = self.find_prefix(string)
|
||||
if pre_len != 0:
|
||||
|
@ -360,8 +409,15 @@ cdef class Tokenizer:
|
|||
cached.is_lex = False
|
||||
cached.data.tokens = self.vocab.make_fused_token(substrings)
|
||||
key = hash_string(string)
|
||||
stale_special = <_Cached*>self._specials.get(key)
|
||||
stale_cached = <_Cached*>self._cache.get(key)
|
||||
self._flush_cache()
|
||||
self._specials.set(key, cached)
|
||||
self._cache.set(key, cached)
|
||||
if stale_special is not NULL:
|
||||
self.mem.free(stale_special)
|
||||
if stale_special != stale_cached and stale_cached is not NULL:
|
||||
self.mem.free(stale_cached)
|
||||
self._rules[string] = substrings
|
||||
|
||||
def to_disk(self, path, **kwargs):
|
||||
|
@ -444,7 +500,10 @@ cdef class Tokenizer:
|
|||
if data.get("rules"):
|
||||
# make sure to hard reset the cache to remove data from the default exceptions
|
||||
self._rules = {}
|
||||
self._reset_cache([key for key in self._cache])
|
||||
self._reset_specials()
|
||||
self._cache = PreshMap()
|
||||
self._specials = PreshMap()
|
||||
for string, substrings in data.get("rules", {}).items():
|
||||
self.add_special_case(string, substrings)
|
||||
|
||||
|
|
|
@ -109,13 +109,8 @@ cdef class Retokenizer:
|
|||
|
||||
def __exit__(self, *args):
|
||||
# Do the actual merging here
|
||||
if len(self.merges) > 1:
|
||||
_bulk_merge(self.doc, self.merges)
|
||||
elif len(self.merges) == 1:
|
||||
(span, attrs) = self.merges[0]
|
||||
start = span.start
|
||||
end = span.end
|
||||
_merge(self.doc, start, end, attrs)
|
||||
if len(self.merges) >= 1:
|
||||
_merge(self.doc, self.merges)
|
||||
# Iterate in order, to keep things simple.
|
||||
for start_char, orths, heads, attrs in sorted(self.splits):
|
||||
# Resolve token index
|
||||
|
@ -140,95 +135,7 @@ cdef class Retokenizer:
|
|||
_split(self.doc, token_index, orths, head_indices, attrs)
|
||||
|
||||
|
||||
def _merge(Doc doc, int start, int end, attributes):
|
||||
"""Retokenize the document, such that the span at
|
||||
`doc.text[start_idx : end_idx]` is merged into a single token. If
|
||||
`start_idx` and `end_idx `do not mark start and end token boundaries,
|
||||
the document remains unchanged.
|
||||
start_idx (int): Character index of the start of the slice to merge.
|
||||
end_idx (int): Character index after the end of the slice to merge.
|
||||
**attributes: Attributes to assign to the merged token. By default,
|
||||
attributes are inherited from the syntactic root of the span.
|
||||
RETURNS (Token): The newly merged token, or `None` if the start and end
|
||||
indices did not fall at token boundaries.
|
||||
"""
|
||||
cdef Span span = doc[start:end]
|
||||
cdef int start_char = span.start_char
|
||||
cdef int end_char = span.end_char
|
||||
# Resize the doc.tensor, if it's set. Let the last row for each token stand
|
||||
# for the merged region. To do this, we create a boolean array indicating
|
||||
# whether the row is to be deleted, then use numpy.delete
|
||||
if doc.tensor is not None and doc.tensor.size != 0:
|
||||
doc.tensor = _resize_tensor(doc.tensor, [(start, end)])
|
||||
# Get LexemeC for newly merged token
|
||||
new_orth = ''.join([t.text_with_ws for t in span])
|
||||
if span[-1].whitespace_:
|
||||
new_orth = new_orth[:-len(span[-1].whitespace_)]
|
||||
cdef const LexemeC* lex = doc.vocab.get(doc.mem, new_orth)
|
||||
# House the new merged token where it starts
|
||||
cdef TokenC* token = &doc.c[start]
|
||||
token.spacy = doc.c[end-1].spacy
|
||||
for attr_name, attr_value in attributes.items():
|
||||
if attr_name == "_": # Set extension attributes
|
||||
for ext_attr_key, ext_attr_value in attr_value.items():
|
||||
doc[start]._.set(ext_attr_key, ext_attr_value)
|
||||
elif attr_name == TAG:
|
||||
doc.vocab.morphology.assign_tag(token, attr_value)
|
||||
else:
|
||||
# Set attributes on both token and lexeme to take care of token
|
||||
# attribute vs. lexical attribute without having to enumerate them.
|
||||
# If an attribute name is not valid, set_struct_attr will ignore it.
|
||||
Token.set_struct_attr(token, attr_name, attr_value)
|
||||
Lexeme.set_struct_attr(<LexemeC*>lex, attr_name, attr_value)
|
||||
# Make sure ent_iob remains consistent
|
||||
if doc.c[end].ent_iob == 1 and token.ent_iob in (0, 2):
|
||||
if token.ent_type == doc.c[end].ent_type:
|
||||
token.ent_iob = 3
|
||||
else:
|
||||
# If they're not the same entity type, let them be two entities
|
||||
doc.c[end].ent_iob = 3
|
||||
# Begin by setting all the head indices to absolute token positions
|
||||
# This is easier to work with for now than the offsets
|
||||
# Before thinking of something simpler, beware the case where a
|
||||
# dependency bridges over the entity. Here the alignment of the
|
||||
# tokens changes.
|
||||
span_root = span.root.i
|
||||
token.dep = span.root.dep
|
||||
# We update token.lex after keeping span root and dep, since
|
||||
# setting token.lex will change span.start and span.end properties
|
||||
# as it modifies the character offsets in the doc
|
||||
token.lex = lex
|
||||
for i in range(doc.length):
|
||||
doc.c[i].head += i
|
||||
# Set the head of the merged token, and its dep relation, from the Span
|
||||
token.head = doc.c[span_root].head
|
||||
# Adjust deps before shrinking tokens
|
||||
# Tokens which point into the merged token should now point to it
|
||||
# Subtract the offset from all tokens which point to >= end
|
||||
offset = (end - start) - 1
|
||||
for i in range(doc.length):
|
||||
head_idx = doc.c[i].head
|
||||
if start <= head_idx < end:
|
||||
doc.c[i].head = start
|
||||
elif head_idx >= end:
|
||||
doc.c[i].head -= offset
|
||||
# Now compress the token array
|
||||
for i in range(end, doc.length):
|
||||
doc.c[i - offset] = doc.c[i]
|
||||
for i in range(doc.length - offset, doc.length):
|
||||
memset(&doc.c[i], 0, sizeof(TokenC))
|
||||
doc.c[i].lex = &EMPTY_LEXEME
|
||||
doc.length -= offset
|
||||
for i in range(doc.length):
|
||||
# ...And, set heads back to a relative position
|
||||
doc.c[i].head -= i
|
||||
# Set the left/right children, left/right edges
|
||||
set_children_from_heads(doc.c, doc.length)
|
||||
# Return the merged Python object
|
||||
return doc[start]
|
||||
|
||||
|
||||
def _bulk_merge(Doc doc, merges):
|
||||
def _merge(Doc doc, merges):
|
||||
"""Retokenize the document, such that the spans described in 'merges'
|
||||
are merged into a single token. This method assumes that the merges
|
||||
are in the same order at which they appear in the doc, and that merges
|
||||
|
@ -256,6 +163,26 @@ def _bulk_merge(Doc doc, merges):
|
|||
spans.append(span)
|
||||
# House the new merged token where it starts
|
||||
token = &doc.c[start]
|
||||
# Initially set attributes to attributes of span root
|
||||
token.tag = doc.c[span.root.i].tag
|
||||
token.pos = doc.c[span.root.i].pos
|
||||
token.morph = doc.c[span.root.i].morph
|
||||
token.ent_iob = doc.c[span.root.i].ent_iob
|
||||
token.ent_type = doc.c[span.root.i].ent_type
|
||||
merged_iob = token.ent_iob
|
||||
# If span root is part of an entity, merged token is B-ENT
|
||||
if token.ent_iob in (1, 3):
|
||||
merged_iob = 3
|
||||
# If start token is I-ENT and previous token is of the same
|
||||
# type, then I-ENT (could check I-ENT from start to span root)
|
||||
if doc.c[start].ent_iob == 1 and start > 0 \
|
||||
and doc.c[start].ent_type == token.ent_type \
|
||||
and doc.c[start - 1].ent_type == token.ent_type:
|
||||
merged_iob = 1
|
||||
token.ent_iob = merged_iob
|
||||
# Unset attributes that don't match new token
|
||||
token.lemma = 0
|
||||
token.norm = 0
|
||||
tokens[merge_index] = token
|
||||
# Resize the doc.tensor, if it's set. Let the last row for each token stand
|
||||
# for the merged region. To do this, we create a boolean array indicating
|
||||
|
@ -351,17 +278,7 @@ def _bulk_merge(Doc doc, merges):
|
|||
# Set the left/right children, left/right edges
|
||||
set_children_from_heads(doc.c, doc.length)
|
||||
# Make sure ent_iob remains consistent
|
||||
for (span, _) in merges:
|
||||
if(span.end < len(offsets)):
|
||||
# If it's not the last span
|
||||
token_after_span_position = offsets[span.end]
|
||||
if doc.c[token_after_span_position].ent_iob == 1\
|
||||
and doc.c[token_after_span_position - 1].ent_iob in (0, 2):
|
||||
if doc.c[token_after_span_position - 1].ent_type == doc.c[token_after_span_position].ent_type:
|
||||
doc.c[token_after_span_position - 1].ent_iob = 3
|
||||
else:
|
||||
# If they're not the same entity type, let them be two entities
|
||||
doc.c[token_after_span_position].ent_iob = 3
|
||||
make_iob_consistent(doc.c, doc.length)
|
||||
# Return the merged Python object
|
||||
return doc[spans[0].start]
|
||||
|
||||
|
@ -480,3 +397,12 @@ def _validate_extensions(extensions):
|
|||
raise ValueError(Errors.E118.format(attr=key))
|
||||
if not is_writable_attr(extension):
|
||||
raise ValueError(Errors.E119.format(attr=key))
|
||||
|
||||
|
||||
cdef make_iob_consistent(TokenC* tokens, int length):
|
||||
cdef int i
|
||||
if tokens[0].ent_iob == 1:
|
||||
tokens[0].ent_iob = 3
|
||||
for i in range(1, length):
|
||||
if tokens[i].ent_iob == 1 and tokens[i - 1].ent_type != tokens[i].ent_type:
|
||||
tokens[i].ent_iob = 3
|
||||
|
|
|
@ -200,13 +200,15 @@ cdef class Span:
|
|||
return Underscore(Underscore.span_extensions, self,
|
||||
start=self.start_char, end=self.end_char)
|
||||
|
||||
def as_doc(self):
|
||||
def as_doc(self, bint copy_user_data=False):
|
||||
"""Create a `Doc` object with a copy of the `Span`'s data.
|
||||
|
||||
copy_user_data (bool): Whether or not to copy the original doc's user data.
|
||||
RETURNS (Doc): The `Doc` copy of the span.
|
||||
|
||||
DOCS: https://spacy.io/api/span#as_doc
|
||||
"""
|
||||
# TODO: make copy_user_data a keyword-only argument (Python 3 only)
|
||||
words = [t.text for t in self]
|
||||
spaces = [bool(t.whitespace_) for t in self]
|
||||
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
|
||||
|
@ -235,6 +237,8 @@ cdef class Span:
|
|||
cat_start, cat_end, cat_label = key
|
||||
if cat_start == self.start_char and cat_end == self.end_char:
|
||||
doc.cats[cat_label] = value
|
||||
if copy_user_data:
|
||||
doc.user_data = self.doc.user_data
|
||||
return doc
|
||||
|
||||
def _fix_dep_copy(self, attrs, array):
|
||||
|
|
|
@ -131,8 +131,7 @@ def load_language_data(path):
|
|||
path = path.with_suffix(path.suffix + ".gz")
|
||||
if path.exists():
|
||||
return srsly.read_gzip_json(path)
|
||||
# TODO: move to spacy.errors
|
||||
raise ValueError("Can't find language data file: {}".format(path2str(path)))
|
||||
raise ValueError(Errors.E160.format(path=path2str(path)))
|
||||
|
||||
|
||||
def get_module_path(module):
|
||||
|
@ -458,6 +457,14 @@ def expand_exc(excs, search, replace):
|
|||
|
||||
|
||||
def get_lemma_tables(lookups):
|
||||
"""Load lemmatizer data from lookups table. Mostly used via
|
||||
Language.Defaults.create_lemmatizer, but available as helper so it can be
|
||||
reused in language classes that implement custom lemmatizers.
|
||||
|
||||
lookups (Lookups): The lookups table.
|
||||
RETURNS (tuple): A (lemma_rules, lemma_index, lemma_exc, lemma_lookup)
|
||||
tuple that can be used to initialize a Lemmatizer.
|
||||
"""
|
||||
lemma_rules = {}
|
||||
lemma_index = {}
|
||||
lemma_exc = {}
|
||||
|
|
|
@ -43,6 +43,7 @@ cdef class Vocab:
|
|||
lemmatizer (object): A lemmatizer. Defaults to `None`.
|
||||
strings (StringStore): StringStore that maps strings to integers, and
|
||||
vice versa.
|
||||
lookups (Lookups): Container for large lookup tables and dictionaries.
|
||||
RETURNS (Vocab): The newly constructed object.
|
||||
"""
|
||||
lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {}
|
||||
|
@ -433,6 +434,8 @@ cdef class Vocab:
|
|||
file_.write(self.lexemes_to_bytes())
|
||||
if "vectors" not in "exclude" and self.vectors is not None:
|
||||
self.vectors.to_disk(path)
|
||||
if "lookups" not in "exclude" and self.lookups is not None:
|
||||
self.lookups.to_disk(path)
|
||||
|
||||
def from_disk(self, path, exclude=tuple(), **kwargs):
|
||||
"""Loads state from a directory. Modifies the object in place and
|
||||
|
@ -457,6 +460,8 @@ cdef class Vocab:
|
|||
self.vectors.from_disk(path, exclude=["strings"])
|
||||
if self.vectors.name is not None:
|
||||
link_vectors_to_models(self)
|
||||
if "lookups" not in exclude:
|
||||
self.lookups.from_disk(path)
|
||||
return self
|
||||
|
||||
def to_bytes(self, exclude=tuple(), **kwargs):
|
||||
|
@ -476,7 +481,8 @@ cdef class Vocab:
|
|||
getters = OrderedDict((
|
||||
("strings", lambda: self.strings.to_bytes()),
|
||||
("lexemes", lambda: self.lexemes_to_bytes()),
|
||||
("vectors", deserialize_vectors)
|
||||
("vectors", deserialize_vectors),
|
||||
("lookups", lambda: self.lookups.to_bytes())
|
||||
))
|
||||
exclude = util.get_serialization_exclude(getters, exclude, kwargs)
|
||||
return util.to_bytes(getters, exclude)
|
||||
|
@ -499,7 +505,8 @@ cdef class Vocab:
|
|||
setters = OrderedDict((
|
||||
("strings", lambda b: self.strings.from_bytes(b)),
|
||||
("lexemes", lambda b: self.lexemes_from_bytes(b)),
|
||||
("vectors", lambda b: serialize_vectors(b))
|
||||
("vectors", lambda b: serialize_vectors(b)),
|
||||
("lookups", lambda b: self.lookups.from_bytes(b))
|
||||
))
|
||||
exclude = util.get_serialization_exclude(setters, exclude, kwargs)
|
||||
util.from_bytes(bytes_data, setters, exclude)
|
||||
|
|
|
@ -292,9 +292,10 @@ Create a new `Doc` object corresponding to the `Span`, with a copy of the data.
|
|||
> assert doc2.text == u"New York"
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----- | --------------------------------------- |
|
||||
| **RETURNS** | `Doc` | A `Doc` object of the `Span`'s content. |
|
||||
| Name | Type | Description |
|
||||
| ----------------- | ----- | ---------------------------------------------------- |
|
||||
| `copy_user_data` | bool | Whether or not to copy the original doc's user data. |
|
||||
| **RETURNS** | `Doc` | A `Doc` object of the `Span`'s content. |
|
||||
|
||||
## Span.root {#root tag="property" model="parser"}
|
||||
|
||||
|
|
|
@ -107,7 +107,7 @@ process.
|
|||
|
||||
<Infobox>
|
||||
|
||||
**Usage:** [Models directory](/models) [Benchmarks](#benchmarks)
|
||||
**Usage:** [Models directory](/models)
|
||||
|
||||
</Infobox>
|
||||
|
||||
|
|
|
@ -10,10 +10,7 @@
|
|||
"modelsRepo": "explosion/spacy-models",
|
||||
"social": {
|
||||
"twitter": "spacy_io",
|
||||
"github": "explosion",
|
||||
"reddit": "spacynlp",
|
||||
"codepen": "explosion",
|
||||
"gitter": "explosion/spaCy"
|
||||
"github": "explosion"
|
||||
},
|
||||
"theme": "#09a3d5",
|
||||
"analytics": "UA-58931649-1",
|
||||
|
@ -69,6 +66,7 @@
|
|||
"items": [
|
||||
{ "text": "Twitter", "url": "https://twitter.com/spacy_io" },
|
||||
{ "text": "GitHub", "url": "https://github.com/explosion/spaCy" },
|
||||
{ "text": "YouTube", "url": "https://youtube.com/c/ExplosionAI" },
|
||||
{ "text": "Blog", "url": "https://explosion.ai/blog" }
|
||||
]
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue
Block a user