Merge branch 'master' into spacy.io

This commit is contained in:
Ines Montani 2019-09-14 16:42:18 +02:00
commit 8c43cfc754
70 changed files with 1570280 additions and 996 deletions

View File

@ -6,9 +6,5 @@ exclude =
.env,
.git,
__pycache__,
lemmatizer.py,
lookup.py,
_tokenizer_exceptions_list.py,
spacy/lang/fr/lemmatizer,
spacy/lang/nb/lemmatizer
spacy/__init__.py

106
.github/contributors/mihaigliga21.md vendored Normal file
View File

@ -0,0 +1,106 @@
# spaCy contributor agreement
This spaCy Contributor Agreement (**"SCA"**) is based on the
[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf).
The SCA applies to any contribution that you make to any product or project
managed by us (the **"project"**), and sets out the intellectual property rights
you grant to us in the contributed materials. The term **"us"** shall mean
[ExplosionAI UG (haftungsbeschränkt)](https://explosion.ai/legal). The term
**"you"** shall mean the person or entity identified below.
If you agree to be bound by these terms, fill in the information requested
below and include the filled-in version with your first pull request, under the
folder [`.github/contributors/`](/.github/contributors/). The name of the file
should be your GitHub username, with the extension `.md`. For example, the user
example_user would create the file `.github/contributors/example_user.md`.
Read this agreement carefully before signing. These terms and conditions
constitute a binding legal agreement.
## Contributor Agreement
1. The term "contribution" or "contributed materials" means any source code,
object code, patch, tool, sample, graphic, specification, manual,
documentation, or any other material posted or submitted by you to the project.
2. With respect to any worldwide copyrights, or copyright applications and
registrations, in your contribution:
* you hereby assign to us joint ownership, and to the extent that such
assignment is or becomes invalid, ineffective or unenforceable, you hereby
grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge,
royalty-free, unrestricted license to exercise all rights under those
copyrights. This includes, at our option, the right to sublicense these same
rights to third parties through multiple levels of sublicensees or other
licensing arrangements;
* you agree that each of us can do all things in relation to your
contribution as if each of us were the sole owners, and if one of us makes
a derivative work of your contribution, the one who makes the derivative
work (or has it made will be the sole owner of that derivative work;
* you agree that you will not assert any moral rights in your contribution
against us, our licensees or transferees;
* you agree that we may register a copyright in your contribution and
exercise all ownership rights associated with it; and
* you agree that neither of us has any duty to consult with, obtain the
consent of, pay or render an accounting to the other for any use or
distribution of your contribution.
3. With respect to any patents you own, or that you can license without payment
to any third party, you hereby grant to us a perpetual, irrevocable,
non-exclusive, worldwide, no-charge, royalty-free license to:
* make, have made, use, sell, offer to sell, import, and otherwise transfer
your contribution in whole or in part, alone or in combination with or
included in any product, work or materials arising out of the project to
which your contribution was submitted, and
* at our option, to sublicense these same rights to third parties through
multiple levels of sublicensees or other licensing arrangements.
4. Except as set out above, you keep all right, title, and interest in your
contribution. The rights that you grant to us under these terms are effective
on the date you first submitted a contribution to us, even if your submission
took place before the date you sign these terms.
5. You covenant, represent, warrant and agree that:
* Each contribution that you submit is and shall be an original work of
authorship and you can legally grant the rights set out in this SCA;
* to the best of your knowledge, each contribution will not violate any
third party's copyrights, trademarks, patents, or other intellectual
property rights; and
* each contribution shall be in compliance with U.S. export control laws and
other applicable export and import laws. You agree to notify us if you
become aware of any circumstance which would make any of the foregoing
representations inaccurate in any respect. We may publicly disclose your
participation in the project, including the fact that you have signed the SCA.
6. This SCA is governed by the laws of the State of California and applicable
U.S. Federal law. Any choice of law rules will not apply.
7. Please place an “x” on one of the applicable statement below. Please do NOT
mark both statements:
* [x] I am signing on behalf of myself as an individual and no other person
or entity, including my employer, has or will have rights with respect to my
contributions.
* [x] I am signing on behalf of my employer or a legal entity and I have the
actual authority to contractually bind that entity.
## Contributor Details
| Field | Entry |
|------------------------------- | -------------------------- |
| Name | Mihai Gliga |
| Company name (if applicable) | |
| Title or role (if applicable) | |
| Date | September 9, 2019 |
| GitHub username | mihaigliga21 |
| Website (optional) | |

106
.github/contributors/tamuhey.md vendored Normal file
View File

@ -0,0 +1,106 @@
# spaCy contributor agreement
This spaCy Contributor Agreement (**"SCA"**) is based on the
[Oracle Contributor Agreement](http://www.oracle.com/technetwork/oca-405177.pdf).
The SCA applies to any contribution that you make to any product or project
managed by us (the **"project"**), and sets out the intellectual property rights
you grant to us in the contributed materials. The term **"us"** shall mean
[ExplosionAI GmbH](https://explosion.ai/legal). The term
**"you"** shall mean the person or entity identified below.
If you agree to be bound by these terms, fill in the information requested
below and include the filled-in version with your first pull request, under the
folder [`.github/contributors/`](/.github/contributors/). The name of the file
should be your GitHub username, with the extension `.md`. For example, the user
example_user would create the file `.github/contributors/example_user.md`.
Read this agreement carefully before signing. These terms and conditions
constitute a binding legal agreement.
## Contributor Agreement
1. The term "contribution" or "contributed materials" means any source code,
object code, patch, tool, sample, graphic, specification, manual,
documentation, or any other material posted or submitted by you to the project.
2. With respect to any worldwide copyrights, or copyright applications and
registrations, in your contribution:
* you hereby assign to us joint ownership, and to the extent that such
assignment is or becomes invalid, ineffective or unenforceable, you hereby
grant to us a perpetual, irrevocable, non-exclusive, worldwide, no-charge,
royalty-free, unrestricted license to exercise all rights under those
copyrights. This includes, at our option, the right to sublicense these same
rights to third parties through multiple levels of sublicensees or other
licensing arrangements;
* you agree that each of us can do all things in relation to your
contribution as if each of us were the sole owners, and if one of us makes
a derivative work of your contribution, the one who makes the derivative
work (or has it made will be the sole owner of that derivative work;
* you agree that you will not assert any moral rights in your contribution
against us, our licensees or transferees;
* you agree that we may register a copyright in your contribution and
exercise all ownership rights associated with it; and
* you agree that neither of us has any duty to consult with, obtain the
consent of, pay or render an accounting to the other for any use or
distribution of your contribution.
3. With respect to any patents you own, or that you can license without payment
to any third party, you hereby grant to us a perpetual, irrevocable,
non-exclusive, worldwide, no-charge, royalty-free license to:
* make, have made, use, sell, offer to sell, import, and otherwise transfer
your contribution in whole or in part, alone or in combination with or
included in any product, work or materials arising out of the project to
which your contribution was submitted, and
* at our option, to sublicense these same rights to third parties through
multiple levels of sublicensees or other licensing arrangements.
4. Except as set out above, you keep all right, title, and interest in your
contribution. The rights that you grant to us under these terms are effective
on the date you first submitted a contribution to us, even if your submission
took place before the date you sign these terms.
5. You covenant, represent, warrant and agree that:
* Each contribution that you submit is and shall be an original work of
authorship and you can legally grant the rights set out in this SCA;
* to the best of your knowledge, each contribution will not violate any
third party's copyrights, trademarks, patents, or other intellectual
property rights; and
* each contribution shall be in compliance with U.S. export control laws and
other applicable export and import laws. You agree to notify us if you
become aware of any circumstance which would make any of the foregoing
representations inaccurate in any respect. We may publicly disclose your
participation in the project, including the fact that you have signed the SCA.
6. This SCA is governed by the laws of the State of California and applicable
U.S. Federal law. Any choice of law rules will not apply.
7. Please place an “x” on one of the applicable statement below. Please do NOT
mark both statements:
* [x] I am signing on behalf of myself as an individual and no other person
or entity, including my employer, has or will have rights with respect to my
contributions.
* [ ] I am signing on behalf of my employer or a legal entity and I have the
actual authority to contractually bind that entity.
## Contributor Details
| Field | Entry |
|------------------------------- | -------------------- |
| Name | Yohei Tamura |
| Company name (if applicable) | PKSHA |
| Title or role (if applicable) | |
| Date | 2019/9/12 |
| GitHub username | tamuhey |
| Website (optional) | |

View File

@ -5,7 +5,6 @@
from __future__ import unicode_literals
import plac
import tqdm
from pathlib import Path
import re
import sys

View File

@ -5,7 +5,6 @@
from __future__ import unicode_literals
import plac
import tqdm
from pathlib import Path
import re
import sys
@ -462,6 +461,9 @@ def main(
vectors_dir=None,
use_oracle_segments=False,
):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
spacy.util.fix_random_seed()
lang.zh.Chinese.Defaults.use_jieba = False
lang.ja.Japanese.Defaults.use_janome = False

View File

@ -0,0 +1,11 @@
TRAINING_DATA_FILE = "gold_entities.jsonl"
KB_FILE = "kb"
KB_MODEL_DIR = "nlp_kb"
OUTPUT_MODEL_DIR = "nlp"
PRIOR_PROB_PATH = "prior_prob.csv"
ENTITY_DEFS_PATH = "entity_defs.csv"
ENTITY_FREQ_PATH = "entity_freq.csv"
ENTITY_DESCR_PATH = "entity_descriptions.csv"
LOG_FORMAT = '%(asctime)s - %(levelname)s - %(name)s - %(message)s'

View File

@ -0,0 +1,200 @@
import logging
import random
from collections import defaultdict
logger = logging.getLogger(__name__)
class Metrics(object):
true_pos = 0
false_pos = 0
false_neg = 0
def update_results(self, true_entity, candidate):
candidate_is_correct = true_entity == candidate
# Assume that we have no labeled negatives in the data (i.e. cases where true_entity is "NIL")
# Therefore, if candidate_is_correct then we have a true positive and never a true negative
self.true_pos += candidate_is_correct
self.false_neg += not candidate_is_correct
if candidate not in {"", "NIL"}:
self.false_pos += not candidate_is_correct
def calculate_precision(self):
if self.true_pos == 0:
return 0.0
else:
return self.true_pos / (self.true_pos + self.false_pos)
def calculate_recall(self):
if self.true_pos == 0:
return 0.0
else:
return self.true_pos / (self.true_pos + self.false_neg)
class EvaluationResults(object):
def __init__(self):
self.metrics = Metrics()
self.metrics_by_label = defaultdict(Metrics)
def update_metrics(self, ent_label, true_entity, candidate):
self.metrics.update_results(true_entity, candidate)
self.metrics_by_label[ent_label].update_results(true_entity, candidate)
def increment_false_negatives(self):
self.metrics.false_neg += 1
def report_metrics(self, model_name):
model_str = model_name.title()
recall = self.metrics.calculate_recall()
precision = self.metrics.calculate_precision()
return ("{}: ".format(model_str) +
"Recall = {} | ".format(round(recall, 3)) +
"Precision = {} | ".format(round(precision, 3)) +
"Precision by label = {}".format({k: v.calculate_precision()
for k, v in self.metrics_by_label.items()}))
class BaselineResults(object):
def __init__(self):
self.random = EvaluationResults()
self.prior = EvaluationResults()
self.oracle = EvaluationResults()
def report_accuracy(self, model):
results = getattr(self, model)
return results.report_metrics(model)
def update_baselines(self, true_entity, ent_label, random_candidate, prior_candidate, oracle_candidate):
self.oracle.update_metrics(ent_label, true_entity, oracle_candidate)
self.prior.update_metrics(ent_label, true_entity, prior_candidate)
self.random.update_metrics(ent_label, true_entity, random_candidate)
def measure_performance(dev_data, kb, el_pipe):
baseline_accuracies = measure_baselines(
dev_data, kb
)
logger.info(baseline_accuracies.report_accuracy("random"))
logger.info(baseline_accuracies.report_accuracy("prior"))
logger.info(baseline_accuracies.report_accuracy("oracle"))
# using only context
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = False
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context only"))
# measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
results = get_eval_results(dev_data, el_pipe)
logger.info(results.report_metrics("context and prior"))
def get_eval_results(data, el_pipe=None):
# If the docs in the data require further processing with an entity linker, set el_pipe
from tqdm import tqdm
docs = []
golds = []
for d, g in tqdm(data, leave=False):
if len(d) > 0:
golds.append(g)
if el_pipe is not None:
docs.append(el_pipe(d))
else:
docs.append(d)
results = EvaluationResults()
for doc, gold in zip(docs, golds):
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
try:
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
# only evaluating on positive examples
for gold_kb, value in kb_dict.items():
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
if offset not in tagged_entries_per_article:
results.increment_false_negatives()
for ent in doc.ents:
ent_label = ent.label_
pred_entity = ent.kb_id_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
results.update_metrics(ent_label, gold_entity, pred_entity)
except Exception as e:
logging.error("Error assessing accuracy " + str(e))
return results
def measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
counts_d = dict()
baseline_results = BaselineResults()
docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0]
for doc, gold in zip(docs, golds):
correct_entries_per_article = dict()
tagged_entries_per_article = {_offset(ent.start_char, ent.end_char): ent for ent in doc.ents}
for entity, kb_dict in gold.links.items():
start, end = entity
for gold_kb, value in kb_dict.items():
# only evaluating on positive examples
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
if offset not in tagged_entries_per_article:
baseline_results.random.increment_false_negatives()
baseline_results.oracle.increment_false_negatives()
baseline_results.prior.increment_false_negatives()
for ent in doc.ents:
ent_label = ent.label_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
best_candidate = ""
random_candidate = ""
if candidates:
scores = []
for c in candidates:
scores.append(c.prior_prob)
if c.entity_ == gold_entity:
oracle_candidate = c.entity_
best_index = scores.index(max(scores))
best_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_
baseline_results.update_baselines(gold_entity, ent_label,
random_candidate, best_candidate, oracle_candidate)
return baseline_results
def _offset(start, end):
return "{}_{}".format(start, end)

View File

@ -1,12 +1,20 @@
# coding: utf-8
from __future__ import unicode_literals
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
from bin.wiki_entity_linking import wikidata_processor as wd, wikipedia_processor as wp
import csv
import logging
import spacy
import sys
from spacy.kb import KnowledgeBase
import csv
import datetime
from bin.wiki_entity_linking import wikipedia_processor as wp
from bin.wiki_entity_linking.train_descriptions import EntityEncoder
csv.field_size_limit(sys.maxsize)
logger = logging.getLogger(__name__)
def create_kb(
@ -14,52 +22,73 @@ def create_kb(
max_entities_per_alias,
min_entity_freq,
min_occ,
entity_def_output,
entity_descr_output,
entity_def_input,
entity_descr_path,
count_input,
prior_prob_input,
wikidata_input,
entity_vector_length,
limit=None,
read_raw_data=True,
):
# Create the knowledge base from Wikidata entries
kb = KnowledgeBase(vocab=nlp.vocab, entity_vector_length=entity_vector_length)
# read the mappings from file
title_to_id = get_entity_to_id(entity_def_input)
id_to_descr = get_id_to_description(entity_descr_path)
# check the length of the nlp vectors
if "vectors" in nlp.meta and nlp.vocab.vectors.size:
input_dim = nlp.vocab.vectors_length
print("Loaded pre-trained vectors of size %s" % input_dim)
logger.info("Loaded pre-trained vectors of size %s" % input_dim)
else:
raise ValueError(
"The `nlp` object should have access to pre-trained word vectors, "
" cf. https://spacy.io/usage/models#languages."
)
# disable this part of the pipeline when rerunning the KB generation from preprocessed files
if read_raw_data:
print()
print(now(), " * read wikidata entities:")
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
wikidata_input, limit=limit
)
# write the title-ID and ID-description mappings to file
_write_entity_files(
entity_def_output, entity_descr_output, title_to_id, id_to_descr
)
else:
# read the mappings from file
title_to_id = get_entity_to_id(entity_def_output)
id_to_descr = get_id_to_description(entity_descr_output)
print()
print(now(), " * get entity frequencies:")
print()
logger.info("Get entity frequencies")
entity_frequencies = wp.get_all_frequencies(count_input=count_input)
logger.info("Filtering entities with fewer than {} mentions".format(min_entity_freq))
# filter the entities for in the KB by frequency, because there's just too much data (8M entities) otherwise
filtered_title_to_id, entity_list, description_list, frequency_list = get_filtered_entities(
title_to_id,
id_to_descr,
entity_frequencies,
min_entity_freq
)
logger.info("Left with {} entities".format(len(description_list)))
logger.info("Train entity encoder")
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
encoder.train(description_list=description_list, to_print=True)
logger.info("Get entity embeddings:")
embeddings = encoder.apply_encoder(description_list)
logger.info("Adding {} entities".format(len(entity_list)))
kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
)
logger.info("Adding aliases")
_add_aliases(
kb,
title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ,
prior_prob_input=prior_prob_input,
)
logger.info("KB size: {} entities, {} aliases".format(
kb.get_size_entities(),
kb.get_size_aliases()))
logger.info("Done with kb")
return kb
def get_filtered_entities(title_to_id, id_to_descr, entity_frequencies,
min_entity_freq: int = 10):
filtered_title_to_id = dict()
entity_list = []
description_list = []
@ -72,58 +101,7 @@ def create_kb(
description_list.append(desc)
frequency_list.append(freq)
filtered_title_to_id[title] = entity
print(len(title_to_id.keys()), "original titles")
kept_nr = len(filtered_title_to_id.keys())
print("kept", kept_nr, "entities with min. frequency", min_entity_freq)
print()
print(now(), " * train entity encoder:")
print()
encoder = EntityEncoder(nlp, input_dim, entity_vector_length)
encoder.train(description_list=description_list, to_print=True)
print()
print(now(), " * get entity embeddings:")
print()
embeddings = encoder.apply_encoder(description_list)
print(now(), " * adding", len(entity_list), "entities")
kb.set_entities(
entity_list=entity_list, freq_list=frequency_list, vector_list=embeddings
)
alias_cnt = _add_aliases(
kb,
title_to_id=filtered_title_to_id,
max_entities_per_alias=max_entities_per_alias,
min_occ=min_occ,
prior_prob_input=prior_prob_input,
)
print()
print(now(), " * adding", alias_cnt, "aliases")
print()
print()
print("# of entities in kb:", kb.get_size_entities())
print("# of aliases in kb:", kb.get_size_aliases())
print(now(), "Done with kb")
return kb
def _write_entity_files(
entity_def_output, entity_descr_output, title_to_id, id_to_descr
):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")
return filtered_title_to_id, entity_list, description_list, frequency_list
def get_entity_to_id(entity_def_output):
@ -137,9 +115,9 @@ def get_entity_to_id(entity_def_output):
return entity_to_id
def get_id_to_description(entity_descr_output):
def get_id_to_description(entity_descr_path):
id_to_desc = dict()
with entity_descr_output.open("r", encoding="utf8") as csvfile:
with entity_descr_path.open("r", encoding="utf8") as csvfile:
csvreader = csv.reader(csvfile, delimiter="|")
# skip header
next(csvreader)
@ -150,7 +128,6 @@ def get_id_to_description(entity_descr_output):
def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_input):
wp_titles = title_to_id.keys()
cnt = 0
# adding aliases with prior probabilities
# we can read this file sequentially, it's sorted by alias, and then by count
@ -187,9 +164,8 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
entities=selected_entities,
probabilities=prior_probs,
)
cnt += 1
except ValueError as e:
print(e)
logger.error(e)
total_count = 0
counts = []
entities = []
@ -202,8 +178,12 @@ def _add_aliases(kb, title_to_id, max_entities_per_alias, min_occ, prior_prob_in
previous_alias = new_alias
line = prior_file.readline()
return cnt
def now():
return datetime.datetime.now()
def read_nlp_kb(model_dir, kb_file):
nlp = spacy.load(model_dir)
kb = KnowledgeBase(vocab=nlp.vocab)
kb.load_bulk(kb_file)
logger.info("kb entities: {}".format(kb.get_size_entities()))
logger.info("kb aliases: {}".format(kb.get_size_aliases()))
return nlp, kb

View File

@ -1,6 +1,7 @@
# coding: utf-8
from random import shuffle
import logging
import numpy as np
from spacy._ml import zero_init, create_default_optimizer
@ -10,6 +11,8 @@ from thinc.v2v import Model
from thinc.api import chain
from thinc.neural._classes.affine import Affine
logger = logging.getLogger(__name__)
class EntityEncoder:
"""
@ -50,21 +53,19 @@ class EntityEncoder:
start = start + batch_size
stop = min(stop + batch_size, len(description_list))
print("encoded:", stop, "entities")
logger.info("encoded: {} entities".format(stop))
return encodings
def train(self, description_list, to_print=False):
processed, loss = self._train_model(description_list)
if to_print:
print(
"Trained entity descriptions on",
processed,
"(non-unique) entities across",
self.epochs,
"epochs",
logger.info(
"Trained entity descriptions on {} ".format(processed) +
"(non-unique) entities across {} ".format(self.epochs) +
"epochs"
)
print("Final loss:", loss)
logger.info("Final loss: {}".format(loss))
def _train_model(self, description_list):
best_loss = 1.0
@ -93,7 +94,7 @@ class EntityEncoder:
loss = self._update(batch)
if batch_nr % 25 == 0:
print("loss:", loss)
logger.info("loss: {} ".format(loss))
processed += len(batch)
# in general, continue training if we haven't reached our ideal min yet

View File

@ -1,10 +1,13 @@
# coding: utf-8
from __future__ import unicode_literals
import logging
import random
import re
import bz2
import datetime
import json
from functools import partial
from spacy.gold import GoldParse
from bin.wiki_entity_linking import kb_creator
@ -15,18 +18,30 @@ Gold-standard entities are stored in one file in standoff format (by character o
"""
ENTITY_FILE = "gold_entities.csv"
logger = logging.getLogger(__name__)
def now():
return datetime.datetime.now()
def create_training(wikipedia_input, entity_def_input, training_output, limit=None):
def create_training_examples_and_descriptions(wikipedia_input,
entity_def_input,
description_output,
training_output,
parse_descriptions,
limit=None):
wp_to_id = kb_creator.get_entity_to_id(entity_def_input)
_process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=limit)
_process_wikipedia_texts(wikipedia_input,
wp_to_id,
description_output,
training_output,
parse_descriptions,
limit)
def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=None):
def _process_wikipedia_texts(wikipedia_input,
wp_to_id,
output,
training_output,
parse_descriptions,
limit=None):
"""
Read the XML wikipedia data to parse out training data:
raw text data + positive instances
@ -35,29 +50,21 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
id_regex = re.compile(r"(?<=<id>)\d*(?=</id>)")
read_ids = set()
entityfile_loc = training_output / ENTITY_FILE
with entityfile_loc.open("w", encoding="utf8") as entityfile:
# write entity training header file
_write_training_entity(
outputfile=entityfile,
article_id="article_id",
alias="alias",
entity="WD_id",
start="start",
end="end",
)
with output.open("a", encoding="utf8") as descr_file, training_output.open("w", encoding="utf8") as entity_file:
if parse_descriptions:
_write_training_description(descr_file, "WD_id", "description")
with bz2.open(wikipedia_input, mode="rb") as file:
line = file.readline()
cnt = 0
article_count = 0
article_text = ""
article_title = None
article_id = None
reading_text = False
reading_revision = False
while line and (not limit or cnt < limit):
if cnt % 1000000 == 0:
print(now(), "processed", cnt, "lines of Wikipedia dump")
logger.info("Processed {} articles".format(article_count))
for line in file:
clean_line = line.strip().decode("utf-8")
if clean_line == "<revision>":
@ -70,28 +77,32 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
article_text = ""
article_title = None
article_id = None
# finished reading this page
elif clean_line == "</page>":
if article_id:
try:
_process_wp_text(
wp_to_id,
entityfile,
article_id,
article_title,
article_text.strip(),
training_output,
)
except Exception as e:
print(
"Error processing article", article_id, article_title, e
)
else:
print(
"Done processing a page, but couldn't find an article_id ?",
clean_text, entities = _process_wp_text(
article_title,
article_text,
wp_to_id
)
if clean_text is not None and entities is not None:
_write_training_entities(entity_file,
article_id,
clean_text,
entities)
if article_title in wp_to_id and parse_descriptions:
description = " ".join(clean_text[:1000].split(" ")[:-1])
_write_training_description(
descr_file,
wp_to_id[article_title],
description
)
article_count += 1
if article_count % 10000 == 0:
logger.info("Processed {} articles".format(article_count))
if limit and article_count >= limit:
break
article_text = ""
article_title = None
article_id = None
@ -115,7 +126,7 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
if ids:
article_id = ids[0]
if article_id in read_ids:
print(
logger.info(
"Found duplicate article ID", article_id, clean_line
) # This should never happen ...
read_ids.add(article_id)
@ -125,115 +136,10 @@ def _process_wikipedia_texts(wikipedia_input, wp_to_id, training_output, limit=N
titles = title_regex.search(clean_line)
if titles:
article_title = titles[0].strip()
line = file.readline()
cnt += 1
print(now(), "processed", cnt, "lines of Wikipedia dump")
logger.info("Finished. Processed {} articles".format(article_count))
text_regex = re.compile(r"(?<=<text xml:space=\"preserve\">).*(?=</text)")
def _process_wp_text(
wp_to_id, entityfile, article_id, article_title, article_text, training_output
):
found_entities = False
# ignore meta Wikipedia pages
if article_title.startswith("Wikipedia:"):
return
# remove the text tags
text = text_regex.search(article_text).group(0)
# stop processing if this is a redirect page
if text.startswith("#REDIRECT"):
return
# get the raw text without markup etc, keeping only interwiki links
clean_text = _get_clean_wp_text(text)
# read the text char by char to get the right offsets for the interwiki links
final_text = ""
open_read = 0
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
entity_buffer = ""
mention_buffer = ""
for index, letter in enumerate(clean_text):
if letter == "[":
open_read += 1
elif letter == "]":
open_read -= 1
elif letter == "|":
if reading_text:
final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern
elif reading_entity:
reading_text = False
reading_entity = False
reading_mention = True
else:
reading_special_case = True
else:
if reading_entity:
entity_buffer += letter
elif reading_mention:
mention_buffer += letter
elif reading_text:
final_text += letter
else:
raise ValueError("Not sure at point", clean_text[index - 2 : index + 2])
if open_read > 2:
reading_special_case = True
if open_read == 2 and reading_text:
reading_text = False
reading_entity = True
reading_mention = False
# we just finished reading an entity
if open_read == 0 and not reading_text:
if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
if not mention_buffer:
mention_buffer = entity_buffer
start = len(final_text)
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
_write_training_entity(
outputfile=entityfile,
article_id=article_id,
alias=mention_buffer,
entity=qid,
start=start,
end=end,
)
found_entities = True
final_text += mention_buffer
entity_buffer = ""
mention_buffer = ""
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
if found_entities:
_write_training_article(
article_id=article_id,
clean_text=final_text,
training_output=training_output,
)
info_regex = re.compile(r"{[^{]*?}")
htlm_regex = re.compile(r"&lt;!--[^-]*--&gt;")
category_regex = re.compile(r"\[\[Category:[^\[]*]]")
@ -242,6 +148,29 @@ ref_regex = re.compile(r"&lt;ref.*?&gt;") # non-greedy
ref_2_regex = re.compile(r"&lt;/ref.*?&gt;") # non-greedy
def _process_wp_text(article_title, article_text, wp_to_id):
# ignore meta Wikipedia pages
if (
article_title.startswith("Wikipedia:") or
article_title.startswith("Kategori:")
):
return None, None
# remove the text tags
text_search = text_regex.search(article_text)
if text_search is None:
return None, None
text = text_search.group(0)
# stop processing if this is a redirect page
if text.startswith("#REDIRECT"):
return None, None
# get the raw text without markup etc, keeping only interwiki links
clean_text, entities = _remove_links(_get_clean_wp_text(text), wp_to_id)
return clean_text, entities
def _get_clean_wp_text(article_text):
clean_text = article_text.strip()
@ -300,130 +229,167 @@ def _get_clean_wp_text(article_text):
return clean_text.strip()
def _write_training_article(article_id, clean_text, training_output):
file_loc = training_output / "{}.txt".format(article_id)
with file_loc.open("w", encoding="utf8") as outputfile:
outputfile.write(clean_text)
def _remove_links(clean_text, wp_to_id):
# read the text char by char to get the right offsets for the interwiki links
entities = []
final_text = ""
open_read = 0
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
entity_buffer = ""
mention_buffer = ""
for index, letter in enumerate(clean_text):
if letter == "[":
open_read += 1
elif letter == "]":
open_read -= 1
elif letter == "|":
if reading_text:
final_text += letter
# switch from reading entity to mention in the [[entity|mention]] pattern
elif reading_entity:
reading_text = False
reading_entity = False
reading_mention = True
else:
reading_special_case = True
else:
if reading_entity:
entity_buffer += letter
elif reading_mention:
mention_buffer += letter
elif reading_text:
final_text += letter
else:
raise ValueError("Not sure at point", clean_text[index - 2: index + 2])
if open_read > 2:
reading_special_case = True
if open_read == 2 and reading_text:
reading_text = False
reading_entity = True
reading_mention = False
# we just finished reading an entity
if open_read == 0 and not reading_text:
if "#" in entity_buffer or entity_buffer.startswith(":"):
reading_special_case = True
# Ignore cases with nested structures like File: handles etc
if not reading_special_case:
if not mention_buffer:
mention_buffer = entity_buffer
start = len(final_text)
end = start + len(mention_buffer)
qid = wp_to_id.get(entity_buffer, None)
if qid:
entities.append((mention_buffer, qid, start, end))
final_text += mention_buffer
entity_buffer = ""
mention_buffer = ""
reading_text = True
reading_entity = False
reading_mention = False
reading_special_case = False
return final_text, entities
def _write_training_entity(outputfile, article_id, alias, entity, start, end):
line = "{}|{}|{}|{}|{}\n".format(article_id, alias, entity, start, end)
def _write_training_description(outputfile, qid, description):
if description is not None:
line = str(qid) + "|" + description + "\n"
outputfile.write(line)
def _write_training_entities(outputfile, article_id, clean_text, entities):
entities_data = [{"alias": ent[0], "entity": ent[1], "start": ent[2], "end": ent[3]} for ent in entities]
line = json.dumps(
{
"article_id": article_id,
"clean_text": clean_text,
"entities": entities_data
},
ensure_ascii=False) + "\n"
outputfile.write(line)
def read_training(nlp, entity_file_path, dev, limit, kb):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
For training,, it will include negative training examples by using the candidate generator,
and it will only keep positive training examples that can be found by using the candidate generator.
For testing, it will include all positive examples only."""
from tqdm import tqdm
data = []
num_entities = 0
get_gold_parse = partial(_get_gold_parse, dev=dev, kb=kb)
logger.info("Reading {} data with limit {}".format('dev' if dev else 'train', limit))
with entity_file_path.open("r", encoding="utf8") as file:
with tqdm(total=limit, leave=False) as pbar:
for i, line in enumerate(file):
example = json.loads(line)
article_id = example["article_id"]
clean_text = example["clean_text"]
entities = example["entities"]
if dev != is_dev(article_id) or len(clean_text) >= 30000:
continue
doc = nlp(clean_text)
gold = get_gold_parse(doc, entities)
if gold and len(gold.links) > 0:
data.append((doc, gold))
num_entities += len(gold.links)
pbar.update(len(gold.links))
if limit and num_entities >= limit:
break
logger.info("Read {} entities in {} articles".format(num_entities, len(data)))
return data
def _get_gold_parse(doc, entities, dev, kb):
gold_entities = {}
tagged_ent_positions = set(
[(ent.start_char, ent.end_char) for ent in doc.ents]
)
for entity in entities:
entity_id = entity["entity"]
alias = entity["alias"]
start = entity["start"]
end = entity["end"]
candidates = kb.get_candidates(alias)
candidate_ids = [
c.entity_ for c in candidates
]
should_add_ent = (
dev or
(
(start, end) in tagged_ent_positions and
entity_id in candidate_ids and
len(candidates) > 1
)
)
if should_add_ent:
value_by_id = {entity_id: 1.0}
if not dev:
random.shuffle(candidate_ids)
value_by_id.update({
kb_id: 0.0
for kb_id in candidate_ids
if kb_id != entity_id
})
gold_entities[(start, end)] = value_by_id
return GoldParse(doc, links=gold_entities)
def is_dev(article_id):
return article_id.endswith("3")
def read_training(nlp, training_dir, dev, limit, kb=None):
""" This method provides training examples that correspond to the entity annotations found by the nlp object.
When kb is provided (for training), it will include negative training examples by using the candidate generator,
and it will only keep positive training examples that can be found in the KB.
When kb=None (for testing), it will include all positive examples only."""
entityfile_loc = training_dir / ENTITY_FILE
data = []
# assume the data is written sequentially, so we can reuse the article docs
current_article_id = None
current_doc = None
ents_by_offset = dict()
skip_articles = set()
total_entities = 0
with entityfile_loc.open("r", encoding="utf8") as file:
for line in file:
if not limit or len(data) < limit:
fields = line.replace("\n", "").split(sep="|")
article_id = fields[0]
alias = fields[1]
wd_id = fields[2]
start = fields[3]
end = fields[4]
if (
dev == is_dev(article_id)
and article_id != "article_id"
and article_id not in skip_articles
):
if not current_doc or (current_article_id != article_id):
# parse the new article text
file_name = article_id + ".txt"
try:
training_file = training_dir / file_name
with training_file.open("r", encoding="utf8") as f:
text = f.read()
# threshold for convenience / speed of processing
if len(text) < 30000:
current_doc = nlp(text)
current_article_id = article_id
ents_by_offset = dict()
for ent in current_doc.ents:
sent_length = len(ent.sent)
# custom filtering to avoid too long or too short sentences
if 5 < sent_length < 100:
offset = "{}_{}".format(
ent.start_char, ent.end_char
)
ents_by_offset[offset] = ent
else:
skip_articles.add(article_id)
current_doc = None
except Exception as e:
print("Problem parsing article", article_id, e)
skip_articles.add(article_id)
# repeat checking this condition in case an exception was thrown
if current_doc and (current_article_id == article_id):
offset = "{}_{}".format(start, end)
found_ent = ents_by_offset.get(offset, None)
if found_ent:
if found_ent.text != alias:
skip_articles.add(article_id)
current_doc = None
else:
sent = found_ent.sent.as_doc()
gold_start = int(start) - found_ent.sent.start_char
gold_end = int(end) - found_ent.sent.start_char
gold_entities = {}
found_useful = False
for ent in sent.ents:
entry = (ent.start_char, ent.end_char)
gold_entry = (gold_start, gold_end)
if entry == gold_entry:
# add both pos and neg examples (in random order)
# this will exclude examples not in the KB
if kb:
value_by_id = {}
candidates = kb.get_candidates(alias)
candidate_ids = [
c.entity_ for c in candidates
]
random.shuffle(candidate_ids)
for kb_id in candidate_ids:
found_useful = True
if kb_id != wd_id:
value_by_id[kb_id] = 0.0
else:
value_by_id[kb_id] = 1.0
gold_entities[entry] = value_by_id
# if no KB, keep all positive examples
else:
found_useful = True
value_by_id = {wd_id: 1.0}
gold_entities[entry] = value_by_id
# currently feeding the gold data one entity per sentence at a time
# setting all other entities to empty gold dictionary
else:
gold_entities[entry] = {}
if found_useful:
gold = GoldParse(doc=sent, links=gold_entities)
data.append((sent, gold))
total_entities += 1
if len(data) % 2500 == 0:
print(" -read", total_entities, "entities")
print(" -read", total_entities, "entities")
return data

View File

@ -13,27 +13,25 @@ from https://dumps.wikimedia.org/enwiki/latest/
"""
from __future__ import unicode_literals
import datetime
import logging
from pathlib import Path
import plac
from bin.wiki_entity_linking import wikipedia_processor as wp
from bin.wiki_entity_linking import wikipedia_processor as wp, wikidata_processor as wd
from bin.wiki_entity_linking import kb_creator
from bin.wiki_entity_linking import training_set_creator
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_FILE, ENTITY_DESCR_PATH, KB_MODEL_DIR, LOG_FORMAT
from bin.wiki_entity_linking import ENTITY_FREQ_PATH, PRIOR_PROB_PATH, ENTITY_DEFS_PATH
import spacy
from spacy import Errors
def now():
return datetime.datetime.now()
logger = logging.getLogger(__name__)
@plac.annotations(
wd_json=("Path to the downloaded WikiData JSON dump.", "positional", None, Path),
wp_xml=("Path to the downloaded Wikipedia XML dump.", "positional", None, Path),
output_dir=("Output directory", "positional", None, Path),
model=("Model name, should include pretrained vectors.", "positional", None, str),
model=("Model name or path, should include pretrained vectors.", "positional", None, str),
max_per_alias=("Max. # entities per alias (default 10)", "option", "a", int),
min_freq=("Min. count of an entity in the corpus (default 20)", "option", "f", int),
min_pair=("Min. count of entity-alias pairs (default 5)", "option", "c", int),
@ -41,7 +39,9 @@ def now():
loc_prior_prob=("Location to file with prior probabilities", "option", "p", Path),
loc_entity_defs=("Location to file with entity definitions", "option", "d", Path),
loc_entity_desc=("Location to file with entity descriptions", "option", "s", Path),
descriptions_from_wikipedia=("Flag for using wp descriptions not wd", "flag", "wp"),
limit=("Optional threshold to limit lines read from dumps", "option", "l", int),
lang=("Optional language for which to get wikidata titles. Defaults to 'en'", "option", "la", str),
)
def main(
wd_json,
@ -55,20 +55,29 @@ def main(
loc_prior_prob=None,
loc_entity_defs=None,
loc_entity_desc=None,
descriptions_from_wikipedia=False,
limit=None,
lang="en",
):
print(now(), "Creating KB with Wikipedia and WikiData")
print()
entity_defs_path = loc_entity_defs if loc_entity_defs else output_dir / ENTITY_DEFS_PATH
entity_descr_path = loc_entity_desc if loc_entity_desc else output_dir / ENTITY_DESCR_PATH
entity_freq_path = output_dir / ENTITY_FREQ_PATH
prior_prob_path = loc_prior_prob if loc_prior_prob else output_dir / PRIOR_PROB_PATH
training_entities_path = output_dir / TRAINING_DATA_FILE
kb_path = output_dir / KB_FILE
logger.info("Creating KB with Wikipedia and WikiData")
if limit is not None:
print("Warning: reading only", limit, "lines of Wikipedia/Wikidata dumps.")
logger.warning("Warning: reading only {} lines of Wikipedia/Wikidata dumps.".format(limit))
# STEP 0: set up IO
if not output_dir.exists():
output_dir.mkdir()
output_dir.mkdir(parents=True)
# STEP 1: create the NLP object
print(now(), "STEP 1: loaded model", model)
logger.info("STEP 1: Loading model {}".format(model))
nlp = spacy.load(model)
# check the length of the nlp vectors
@ -79,64 +88,68 @@ def main(
)
# STEP 2: create prior probabilities from WP
print()
if loc_prior_prob:
print(now(), "STEP 2: reading prior probabilities from", loc_prior_prob)
else:
if not prior_prob_path.exists():
# It takes about 2h to process 1000M lines of Wikipedia XML dump
loc_prior_prob = output_dir / "prior_prob.csv"
print(now(), "STEP 2: writing prior probabilities at", loc_prior_prob)
wp.read_prior_probs(wp_xml, loc_prior_prob, limit=limit)
logger.info("STEP 2: writing prior probabilities to {}".format(prior_prob_path))
wp.read_prior_probs(wp_xml, prior_prob_path, limit=limit)
logger.info("STEP 2: reading prior probabilities from {}".format(prior_prob_path))
# STEP 3: deduce entity frequencies from WP (takes only a few minutes)
print()
print(now(), "STEP 3: calculating entity frequencies")
loc_entity_freq = output_dir / "entity_freq.csv"
wp.write_entity_counts(loc_prior_prob, loc_entity_freq, to_print=False)
logger.info("STEP 3: calculating entity frequencies")
wp.write_entity_counts(prior_prob_path, entity_freq_path, to_print=False)
loc_kb = output_dir / "kb"
# STEP 4: reading entity descriptions and definitions from WikiData or from file
print()
if loc_entity_defs and loc_entity_desc:
read_raw = False
print(now(), "STEP 4a: reading entity definitions from", loc_entity_defs)
print(now(), "STEP 4b: reading entity descriptions from", loc_entity_desc)
else:
# STEP 4: reading definitions and (possibly) descriptions from WikiData or from file
message = " and descriptions" if not descriptions_from_wikipedia else ""
if (not entity_defs_path.exists()) or (not descriptions_from_wikipedia and not entity_descr_path.exists()):
# It takes about 10h to process 55M lines of Wikidata JSON dump
read_raw = True
loc_entity_defs = output_dir / "entity_defs.csv"
loc_entity_desc = output_dir / "entity_descriptions.csv"
print(now(), "STEP 4: parsing wikidata for entity definitions and descriptions")
logger.info("STEP 4: parsing wikidata for entity definitions" + message)
title_to_id, id_to_descr = wd.read_wikidata_entities_json(
wd_json,
limit,
to_print=False,
lang=lang,
parse_descriptions=(not descriptions_from_wikipedia),
)
wd.write_entity_files(entity_defs_path, title_to_id)
if not descriptions_from_wikipedia:
wd.write_entity_description_files(entity_descr_path, id_to_descr)
logger.info("STEP 4: read entity definitions" + message)
# STEP 5: creating the actual KB
# STEP 5: Getting gold entities from wikipedia
message = " and descriptions" if descriptions_from_wikipedia else ""
if (not training_entities_path.exists()) or (descriptions_from_wikipedia and not entity_descr_path.exists()):
logger.info("STEP 5: parsing wikipedia for gold entities" + message)
training_set_creator.create_training_examples_and_descriptions(
wp_xml,
entity_defs_path,
entity_descr_path,
training_entities_path,
parse_descriptions=descriptions_from_wikipedia,
limit=limit,
)
logger.info("STEP 5: read gold entities" + message)
# STEP 6: creating the actual KB
# It takes ca. 30 minutes to pretrain the entity embeddings
print()
print(now(), "STEP 5: creating the KB at", loc_kb)
logger.info("STEP 6: creating the KB at {}".format(kb_path))
kb = kb_creator.create_kb(
nlp=nlp,
max_entities_per_alias=max_per_alias,
min_entity_freq=min_freq,
min_occ=min_pair,
entity_def_output=loc_entity_defs,
entity_descr_output=loc_entity_desc,
count_input=loc_entity_freq,
prior_prob_input=loc_prior_prob,
wikidata_input=wd_json,
entity_def_input=entity_defs_path,
entity_descr_path=entity_descr_path,
count_input=entity_freq_path,
prior_prob_input=prior_prob_path,
entity_vector_length=entity_vector_length,
limit=limit,
read_raw_data=read_raw,
)
if read_raw:
print(" - wrote entity definitions to", loc_entity_defs)
print(" - wrote writing entity descriptions to", loc_entity_desc)
kb.dump(loc_kb)
nlp.to_disk(output_dir / "nlp")
kb.dump(kb_path)
nlp.to_disk(output_dir / KB_MODEL_DIR)
print()
print(now(), "Done!")
logger.info("Done!")
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
plac.call(main)

View File

@ -1,17 +1,19 @@
# coding: utf-8
from __future__ import unicode_literals
import bz2
import gzip
import json
import logging
import datetime
logger = logging.getLogger(__name__)
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False, lang="en", parse_descriptions=True):
# Read the JSON wiki data and parse out the entities. Takes about 7u30 to parse 55M lines.
# get latest-all.json.bz2 from https://dumps.wikimedia.org/wikidatawiki/entities/
lang = "en"
site_filter = "enwiki"
site_filter = '{}wiki'.format(lang)
# properties filter (currently disabled to get ALL data)
prop_filter = dict()
@ -24,18 +26,15 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
parse_properties = False
parse_sitelinks = True
parse_labels = False
parse_descriptions = True
parse_aliases = False
parse_claims = False
with bz2.open(wikidata_file, mode="rb") as file:
line = file.readline()
cnt = 0
while line and (not limit or cnt < limit):
if cnt % 1000000 == 0:
print(
datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump"
)
with gzip.open(wikidata_file, mode='rb') as file:
for cnt, line in enumerate(file):
if limit and cnt >= limit:
break
if cnt % 500000 == 0:
logger.info("processed {} lines of WikiData dump".format(cnt))
clean_line = line.strip()
if clean_line.endswith(b","):
clean_line = clean_line[:-1]
@ -134,8 +133,19 @@ def read_wikidata_entities_json(wikidata_file, limit=None, to_print=False):
if to_print:
print()
line = file.readline()
cnt += 1
print(datetime.datetime.now(), "processed", cnt, "lines of WikiData JSON dump")
return title_to_id, id_to_descr
def write_entity_files(entity_def_output, title_to_id):
with entity_def_output.open("w", encoding="utf8") as id_file:
id_file.write("WP_title" + "|" + "WD_id" + "\n")
for title, qid in title_to_id.items():
id_file.write(title + "|" + str(qid) + "\n")
def write_entity_description_files(entity_descr_output, id_to_descr):
with entity_descr_output.open("w", encoding="utf8") as descr_file:
descr_file.write("WD_id" + "|" + "description" + "\n")
for qid, descr in id_to_descr.items():
descr_file.write(str(qid) + "|" + descr + "\n")

View File

@ -11,124 +11,84 @@ from https://dumps.wikimedia.org/enwiki/latest/
from __future__ import unicode_literals
import random
import datetime
import logging
from pathlib import Path
import plac
from bin.wiki_entity_linking import training_set_creator
from bin.wiki_entity_linking import TRAINING_DATA_FILE, KB_MODEL_DIR, KB_FILE, LOG_FORMAT, OUTPUT_MODEL_DIR
from bin.wiki_entity_linking.entity_linker_evaluation import measure_performance, measure_baselines
from bin.wiki_entity_linking.kb_creator import read_nlp_kb
import spacy
from spacy.kb import KnowledgeBase
from spacy.util import minibatch, compounding
def now():
return datetime.datetime.now()
logger = logging.getLogger(__name__)
@plac.annotations(
dir_kb=("Directory with KB, NLP and related files", "positional", None, Path),
output_dir=("Output directory", "option", "o", Path),
loc_training=("Location to training data", "option", "k", Path),
wp_xml=("Path to the downloaded Wikipedia XML dump.", "option", "w", Path),
epochs=("Number of training iterations (default 10)", "option", "e", int),
dropout=("Dropout to prevent overfitting (default 0.5)", "option", "p", float),
lr=("Learning rate (default 0.005)", "option", "n", float),
l2=("L2 regularization", "option", "r", float),
train_inst=("# training instances (default 90% of all)", "option", "t", int),
dev_inst=("# test instances (default 10% of all)", "option", "d", int),
limit=("Optional threshold to limit lines read from WP dump", "option", "l", int),
)
def main(
dir_kb,
output_dir=None,
loc_training=None,
wp_xml=None,
epochs=10,
dropout=0.5,
lr=0.005,
l2=1e-6,
train_inst=None,
dev_inst=None,
limit=None,
):
print(now(), "Creating Entity Linker with Wikipedia and WikiData")
print()
logger.info("Creating Entity Linker with Wikipedia and WikiData")
output_dir = Path(output_dir) if output_dir else dir_kb
training_path = loc_training if loc_training else output_dir / TRAINING_DATA_FILE
nlp_dir = dir_kb / KB_MODEL_DIR
kb_path = output_dir / KB_FILE
nlp_output_dir = output_dir / OUTPUT_MODEL_DIR
# STEP 0: set up IO
if output_dir and not output_dir.exists():
if not output_dir.exists():
output_dir.mkdir()
# STEP 1 : load the NLP object
nlp_dir = dir_kb / "nlp"
print(now(), "STEP 1: loading model from", nlp_dir)
nlp = spacy.load(nlp_dir)
logger.info("STEP 1: loading model from {}".format(nlp_dir))
nlp, kb = read_nlp_kb(nlp_dir, kb_path)
# check that there is a NER component in the pipeline
if "ner" not in nlp.pipe_names:
raise ValueError("The `nlp` object should have a pre-trained `ner` component.")
# STEP 2 : read the KB
print()
print(now(), "STEP 2: reading the KB from", dir_kb / "kb")
kb = KnowledgeBase(vocab=nlp.vocab)
kb.load_bulk(dir_kb / "kb")
# STEP 2: create a training dataset from WP
logger.info("STEP 2: reading training dataset from {}".format(training_path))
# STEP 3: create a training dataset from WP
print()
if loc_training:
print(now(), "STEP 3: reading training dataset from", loc_training)
else:
if not wp_xml:
raise ValueError(
"Either provide a path to a preprocessed training directory, "
"or to the original Wikipedia XML dump."
)
if output_dir:
loc_training = output_dir / "training_data"
else:
loc_training = dir_kb / "training_data"
if not loc_training.exists():
loc_training.mkdir()
print(now(), "STEP 3: creating training dataset at", loc_training)
if limit is not None:
print("Warning: reading only", limit, "lines of Wikipedia dump.")
loc_entity_defs = dir_kb / "entity_defs.csv"
training_set_creator.create_training(
wikipedia_input=wp_xml,
entity_def_input=loc_entity_defs,
training_output=loc_training,
limit=limit,
)
# STEP 4: parse the training data
print()
print(now(), "STEP 4: parse the training & evaluation data")
# for training, get pos & neg instances that correspond to entries in the kb
print("Parsing training data, limit =", train_inst)
train_data = training_set_creator.read_training(
nlp=nlp, training_dir=loc_training, dev=False, limit=train_inst, kb=kb
nlp=nlp,
entity_file_path=training_path,
dev=False,
limit=train_inst,
kb=kb,
)
print("Training on", len(train_data), "articles")
print()
print("Parsing dev testing data, limit =", dev_inst)
# for testing, get all pos instances, whether or not they are in the kb
dev_data = training_set_creator.read_training(
nlp=nlp, training_dir=loc_training, dev=True, limit=dev_inst, kb=None
nlp=nlp,
entity_file_path=training_path,
dev=True,
limit=dev_inst,
kb=kb,
)
print("Dev testing on", len(dev_data), "articles")
print()
# STEP 5: create and train the entity linking pipe
print()
print(now(), "STEP 5: training Entity Linking pipe")
# STEP 3: create and train the entity linking pipe
logger.info("STEP 3: training Entity Linking pipe")
el_pipe = nlp.create_pipe(
name="entity_linker", config={"pretrained_vectors": nlp.vocab.vectors.name}
@ -142,275 +102,70 @@ def main(
optimizer.learn_rate = lr
optimizer.L2 = l2
if not train_data:
print("Did not find any training data")
else:
for itn in range(epochs):
random.shuffle(train_data)
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
batchnr = 0
logger.info("Training on {} articles".format(len(train_data)))
logger.info("Dev testing on {} articles".format(len(dev_data)))
with nlp.disable_pipes(*other_pipes):
for batch in batches:
try:
docs, golds = zip(*batch)
nlp.update(
docs=docs,
golds=golds,
sgd=optimizer,
drop=dropout,
losses=losses,
)
batchnr += 1
except Exception as e:
print("Error updating batch:", e)
if batchnr > 0:
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
dev_acc_context, _ = _measure_acc(dev_data, el_pipe)
losses["entity_linker"] = losses["entity_linker"] / batchnr
print(
"Epoch, train loss",
itn,
round(losses["entity_linker"], 2),
" / dev accuracy avg",
round(dev_acc_context, 3),
)
# STEP 6: measure the performance of our trained pipe on an independent dev set
print()
if len(dev_data):
print()
print(now(), "STEP 6: performance measurement of Entity Linking pipe")
print()
counts, acc_r, acc_r_d, acc_p, acc_p_d, acc_o, acc_o_d = _measure_baselines(
dev_data, kb
)
print("dev counts:", sorted(counts.items(), key=lambda x: x[0]))
oracle_by_label = [(x, round(y, 3)) for x, y in acc_o_d.items()]
print("dev accuracy oracle:", round(acc_o, 3), oracle_by_label)
random_by_label = [(x, round(y, 3)) for x, y in acc_r_d.items()]
print("dev accuracy random:", round(acc_r, 3), random_by_label)
prior_by_label = [(x, round(y, 3)) for x, y in acc_p_d.items()]
print("dev accuracy prior:", round(acc_p, 3), prior_by_label)
# using only context
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = False
dev_acc_context, dev_acc_cont_d = _measure_acc(dev_data, el_pipe)
context_by_label = [(x, round(y, 3)) for x, y in dev_acc_cont_d.items()]
print("dev accuracy context:", round(dev_acc_context, 3), context_by_label)
# measuring combined accuracy (prior + context)
el_pipe.cfg["incl_context"] = True
el_pipe.cfg["incl_prior"] = True
dev_acc_combo, dev_acc_combo_d = _measure_acc(dev_data, el_pipe)
combo_by_label = [(x, round(y, 3)) for x, y in dev_acc_combo_d.items()]
print("dev accuracy prior+context:", round(dev_acc_combo, 3), combo_by_label)
# STEP 7: apply the EL pipe on a toy example
print()
print(now(), "STEP 7: applying Entity Linking to toy example")
print()
run_el_toy_example(nlp=nlp)
# STEP 8: write the NLP pipeline (including entity linker) to file
if output_dir:
print()
nlp_loc = output_dir / "nlp"
print(now(), "STEP 8: Writing trained NLP to", nlp_loc)
nlp.to_disk(nlp_loc)
print()
print()
print(now(), "Done!")
def _measure_acc(data, el_pipe=None, error_analysis=False):
# If the docs in the data require further processing with an entity linker, set el_pipe
correct_by_label = dict()
incorrect_by_label = dict()
docs = [d for d, g in data if len(d) > 0]
if el_pipe is not None:
docs = list(el_pipe.pipe(docs))
golds = [g for d, g in data if len(d) > 0]
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
# only evaluating on positive examples
for gold_kb, value in kb_dict.items():
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
for ent in doc.ents:
ent_label = ent.label_
pred_entity = ent.kb_id_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
if gold_entity == pred_entity:
correct = correct_by_label.get(ent_label, 0)
correct_by_label[ent_label] = correct + 1
else:
incorrect = incorrect_by_label.get(ent_label, 0)
incorrect_by_label[ent_label] = incorrect + 1
if error_analysis:
print(ent.text, "in", doc)
print(
"Predicted",
pred_entity,
"should have been",
gold_entity,
)
print()
except Exception as e:
print("Error assessing accuracy", e)
acc, acc_by_label = calculate_acc(correct_by_label, incorrect_by_label)
return acc, acc_by_label
def _measure_baselines(data, kb):
# Measure 3 performance baselines: random selection, prior probabilities, and 'oracle' prediction for upper bound
counts_d = dict()
random_correct_d = dict()
random_incorrect_d = dict()
oracle_correct_d = dict()
oracle_incorrect_d = dict()
prior_correct_d = dict()
prior_incorrect_d = dict()
docs = [d for d, g in data if len(d) > 0]
golds = [g for d, g in data if len(d) > 0]
for doc, gold in zip(docs, golds):
try:
correct_entries_per_article = dict()
for entity, kb_dict in gold.links.items():
start, end = entity
for gold_kb, value in kb_dict.items():
# only evaluating on positive examples
if value:
offset = _offset(start, end)
correct_entries_per_article[offset] = gold_kb
for ent in doc.ents:
label = ent.label_
start = ent.start_char
end = ent.end_char
offset = _offset(start, end)
gold_entity = correct_entries_per_article.get(offset, None)
# the gold annotations are not complete so we can't evaluate missing annotations as 'wrong'
if gold_entity is not None:
counts_d[label] = counts_d.get(label, 0) + 1
candidates = kb.get_candidates(ent.text)
oracle_candidate = ""
best_candidate = ""
random_candidate = ""
if candidates:
scores = []
for c in candidates:
scores.append(c.prior_prob)
if c.entity_ == gold_entity:
oracle_candidate = c.entity_
best_index = scores.index(max(scores))
best_candidate = candidates[best_index].entity_
random_candidate = random.choice(candidates).entity_
if gold_entity == best_candidate:
prior_correct_d[label] = prior_correct_d.get(label, 0) + 1
else:
prior_incorrect_d[label] = prior_incorrect_d.get(label, 0) + 1
if gold_entity == random_candidate:
random_correct_d[label] = random_correct_d.get(label, 0) + 1
else:
random_incorrect_d[label] = random_incorrect_d.get(label, 0) + 1
if gold_entity == oracle_candidate:
oracle_correct_d[label] = oracle_correct_d.get(label, 0) + 1
else:
oracle_incorrect_d[label] = oracle_incorrect_d.get(label, 0) + 1
except Exception as e:
print("Error assessing accuracy", e)
acc_prior, acc_prior_d = calculate_acc(prior_correct_d, prior_incorrect_d)
acc_rand, acc_rand_d = calculate_acc(random_correct_d, random_incorrect_d)
acc_oracle, acc_oracle_d = calculate_acc(oracle_correct_d, oracle_incorrect_d)
return (
counts_d,
acc_rand,
acc_rand_d,
acc_prior,
acc_prior_d,
acc_oracle,
acc_oracle_d,
dev_baseline_accuracies = measure_baselines(
dev_data, kb
)
logger.info("Dev Baseline Accuracies:")
logger.info(dev_baseline_accuracies.report_accuracy("random"))
logger.info(dev_baseline_accuracies.report_accuracy("prior"))
logger.info(dev_baseline_accuracies.report_accuracy("oracle"))
def _offset(start, end):
return "{}_{}".format(start, end)
for itn in range(epochs):
random.shuffle(train_data)
losses = {}
batches = minibatch(train_data, size=compounding(4.0, 128.0, 1.001))
batchnr = 0
with nlp.disable_pipes(*other_pipes):
for batch in batches:
try:
docs, golds = zip(*batch)
nlp.update(
docs=docs,
golds=golds,
sgd=optimizer,
drop=dropout,
losses=losses,
)
batchnr += 1
except Exception as e:
logger.error("Error updating batch:" + str(e))
if batchnr > 0:
logging.info("Epoch {}, train loss {}".format(itn, round(losses["entity_linker"] / batchnr, 2)))
measure_performance(dev_data, kb, el_pipe)
def calculate_acc(correct_by_label, incorrect_by_label):
acc_by_label = dict()
total_correct = 0
total_incorrect = 0
all_keys = set()
all_keys.update(correct_by_label.keys())
all_keys.update(incorrect_by_label.keys())
for label in sorted(all_keys):
correct = correct_by_label.get(label, 0)
incorrect = incorrect_by_label.get(label, 0)
total_correct += correct
total_incorrect += incorrect
if correct == incorrect == 0:
acc_by_label[label] = 0
else:
acc_by_label[label] = correct / (correct + incorrect)
acc = 0
if not (total_correct == total_incorrect == 0):
acc = total_correct / (total_correct + total_incorrect)
return acc, acc_by_label
# STEP 4: measure the performance of our trained pipe on an independent dev set
logger.info("STEP 4: performance measurement of Entity Linking pipe")
measure_performance(dev_data, kb, el_pipe)
# STEP 5: apply the EL pipe on a toy example
logger.info("STEP 5: applying Entity Linking to toy example")
run_el_toy_example(nlp=nlp)
if output_dir:
# STEP 6: write the NLP pipeline (including entity linker) to file
logger.info("STEP 6: Writing trained NLP to {}".format(nlp_output_dir))
nlp.to_disk(nlp_output_dir)
logger.info("Done!")
def check_kb(kb):
for mention in ("Bush", "Douglas Adams", "Homer", "Brazil", "China"):
candidates = kb.get_candidates(mention)
print("generating candidates for " + mention + " :")
logger.info("generating candidates for " + mention + " :")
for c in candidates:
print(
" ",
c.prior_prob,
logger.info(" ".join[
str(c.prior_prob),
c.alias_,
"-->",
c.entity_ + " (freq=" + str(c.entity_freq) + ")",
)
print()
c.entity_ + " (freq=" + str(c.entity_freq) + ")"
])
def run_el_toy_example(nlp):
@ -421,11 +176,11 @@ def run_el_toy_example(nlp):
"but Dougledydoug doesn't write about George Washington or Homer Simpson."
)
doc = nlp(text)
print(text)
logger.info(text)
for ent in doc.ents:
print(" ent", ent.text, ent.label_, ent.kb_id_)
print()
logger.info(" ".join(["ent", ent.text, ent.label_, ent.kb_id_]))
if __name__ == "__main__":
logging.basicConfig(level=logging.INFO, format=LOG_FORMAT)
plac.call(main)

View File

@ -5,6 +5,9 @@ import re
import bz2
import csv
import datetime
import logging
from bin.wiki_entity_linking import LOG_FORMAT
"""
Process a Wikipedia dump to calculate entity frequencies and prior probabilities in combination with certain mentions.
@ -13,6 +16,9 @@ Write these results to file for downstream KB and training data generation.
map_alias_to_link = dict()
logger = logging.getLogger(__name__)
# these will/should be matched ignoring case
wiki_namespaces = [
"b",
@ -116,10 +122,6 @@ for ns in wiki_namespaces:
ns_regex = re.compile(ns_regex, re.IGNORECASE)
def now():
return datetime.datetime.now()
def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
"""
Read the XML wikipedia data and parse out intra-wiki links to estimate prior probabilities.
@ -131,7 +133,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
cnt = 0
while line and (not limit or cnt < limit):
if cnt % 25000000 == 0:
print(now(), "processed", cnt, "lines of Wikipedia XML dump")
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
clean_line = line.strip().decode("utf-8")
aliases, entities, normalizations = get_wp_links(clean_line)
@ -141,7 +143,7 @@ def read_prior_probs(wikipedia_input, prior_prob_output, limit=None):
line = file.readline()
cnt += 1
print(now(), "processed", cnt, "lines of Wikipedia XML dump")
logger.info("processed {} lines of Wikipedia XML dump".format(cnt))
# write all aliases and their entities and count occurrences to file
with prior_prob_output.open("w", encoding="utf8") as outputfile:

View File

@ -3,11 +3,9 @@
"""
from __future__ import unicode_literals
import plac
import tqdm
import attr
from pathlib import Path
import re
import sys
import json
import spacy
@ -23,7 +21,7 @@ import itertools
import random
import numpy.random
import conll17_ud_eval
from bin.ud import conll17_ud_eval
import spacy.lang.zh
import spacy.lang.ja
@ -394,6 +392,9 @@ class TreebankPaths(object):
limit=("Size limit", "option", "n", int),
)
def main(ud_dir, parses_dir, config, corpus, limit=0):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
paths = TreebankPaths(ud_dir, corpus)
if not (parses_dir / corpus).exists():
(parses_dir / corpus).mkdir()

View File

@ -18,7 +18,6 @@ import random
import spacy
import thinc.extra.datasets
from spacy.util import minibatch, use_gpu, compounding
import tqdm
from spacy._ml import Tok2Vec
from spacy.pipeline import TextCategorizer
import numpy
@ -107,6 +106,9 @@ def create_pipeline(width, embed_size, vectors_model):
def train_tensorizer(nlp, texts, dropout, n_iter):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
tensorizer = nlp.create_pipe("tensorizer")
nlp.add_pipe(tensorizer)
optimizer = nlp.begin_training()
@ -120,6 +122,9 @@ def train_tensorizer(nlp, texts, dropout, n_iter):
def train_textcat(nlp, n_texts, n_iter=10):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
textcat = nlp.get_pipe("textcat")
tok2vec_weights = textcat.model.tok2vec.to_bytes()
(train_texts, train_cats), (dev_texts, dev_cats) = load_textcat_data(limit=n_texts)

View File

@ -13,7 +13,6 @@ import numpy
import plac
import spacy
import tensorflow as tf
import tqdm
from tensorflow.contrib.tensorboard.plugins.projector import (
visualize_embeddings,
ProjectorConfig,
@ -36,6 +35,9 @@ from tensorflow.contrib.tensorboard.plugins.projector import (
),
)
def main(vectors_loc, out_loc, name="spaCy_vectors"):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
meta_file = "{}.tsv".format(name)
out_meta_file = path.join(out_loc, meta_file)

View File

@ -140,7 +140,7 @@ def gzip_language_data(root, source):
base = Path(root) / source
for jsonfile in base.glob("**/*.json"):
outfile = jsonfile.with_suffix(jsonfile.suffix + ".gz")
if outfile.is_file() and outfile.stat().st_ctime > jsonfile.stat().st_ctime:
if outfile.is_file() and outfile.stat().st_mtime > jsonfile.stat().st_mtime:
# If the gz is newer it doesn't need updating
print("Skipping {}, already compressed".format(jsonfile))
continue

View File

@ -6,7 +6,7 @@ import re
from ...gold import iob_to_biluo
def conllu2json(input_data, n_sents=10, use_morphology=False, lang=None):
def conllu2json(input_data, n_sents=10, use_morphology=False, lang=None, **_):
"""
Convert conllu files into JSON format for use with train cli.
use_morphology parameter enables appending morphology to tags, which is

View File

@ -1,7 +1,6 @@
# coding: utf8
from __future__ import unicode_literals
import re
from wasabi import Printer
from ...gold import iob_to_biluo

View File

@ -3,7 +3,6 @@ from __future__ import unicode_literals
import plac
import math
from tqdm import tqdm
import numpy
from ast import literal_eval
from pathlib import Path
@ -109,6 +108,9 @@ def open_file(loc):
def read_attrs_from_deprecated(freqs_loc, clusters_loc):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
from tqdm import tqdm
if freqs_loc is not None:
with msg.loading("Counting frequencies..."):
probs, _ = read_freqs(freqs_loc)
@ -186,6 +188,9 @@ def add_vectors(nlp, vectors_loc, prune_vectors):
def read_vectors(vectors_loc):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
from tqdm import tqdm
f = open_file(vectors_loc)
shape = tuple(int(size) for size in next(f).split())
vectors_data = numpy.zeros(shape=shape, dtype="f")
@ -202,6 +207,9 @@ def read_vectors(vectors_loc):
def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
from tqdm import tqdm
counts = PreshCounter()
total = 0
with freqs_loc.open() as f:
@ -231,6 +239,9 @@ def read_freqs(freqs_loc, max_length=100, min_doc_freq=5, min_freq=50):
def read_clusters(clusters_loc):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
from tqdm import tqdm
clusters = {}
if ftfy is None:
user_warning(Warnings.W004)

View File

@ -7,7 +7,6 @@ import srsly
import cProfile
import pstats
import sys
import tqdm
import itertools
import thinc.extra.datasets
from wasabi import Printer
@ -48,6 +47,9 @@ def profile(model, inputs=None, n_texts=10000):
def parse_texts(nlp, texts):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
for doc in nlp.pipe(tqdm.tqdm(texts), batch_size=16):
pass

View File

@ -4,7 +4,6 @@ from __future__ import unicode_literals, division, print_function
import plac
import os
from pathlib import Path
import tqdm
from thinc.neural._classes.model import Model
from timeit import default_timer as timer
import shutil
@ -101,6 +100,10 @@ def train(
JSON format. To convert data from other formats, use the `spacy convert`
command.
"""
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
msg = Printer()
util.fix_random_seed()
util.set_env_log(verbose)
@ -390,6 +393,9 @@ def _score_for_model(meta):
@contextlib.contextmanager
def _create_progress_bar(total):
# temp fix to avoid import issues cf https://github.com/explosion/spaCy/issues/4200
import tqdm
if int(os.environ.get("LOG_FRIENDLY", 0)):
yield
else:

View File

@ -452,6 +452,11 @@ class Errors(object):
"Make sure that you're passing in absolute token indices, not "
"relative token offsets.\nstart: {start}, end: {end}, label: "
"{label}, direction: {dir}")
E158 = ("Can't add table '{name}' to lookups because it already exists.")
E159 = ("Can't find table '{name}' in lookups. Available tables: {tables}")
E160 = ("Can't find language data file: {path}")
E161 = ("Found an internal inconsistency when predicting entity links. "
"This is likely a bug in spaCy, so feel free to open an issue.")
@add_codes

View File

@ -11,6 +11,12 @@ _hebrew = r"\u0591-\u05F4\uFB1D-\uFB4F"
_hindi = r"\u0900-\u097F"
_kannada = r"\u0C80-\u0CFF"
_tamil = r"\u0B80-\u0BFF"
_telugu = r"\u0C00-\u0C7F"
# Latin standard
_latin_u_standard = r"A-Z"
_latin_l_standard = r"a-z"
@ -195,7 +201,7 @@ _ukrainian = r"а-щюяіїєґА-ЩЮЯІЇЄҐ"
_upper = LATIN_UPPER + _russian_upper + _tatar_upper + _greek_upper + _ukrainian_upper
_lower = LATIN_LOWER + _russian_lower + _tatar_lower + _greek_lower + _ukrainian_lower
_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi
_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi + _kannada + _tamil + _telugu
ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased)
ALPHA_LOWER = group_chars(_lower + _uncased)

View File

@ -46,6 +46,11 @@ class GreekLemmatizer(object):
)
return lemmas
def lookup(self, string):
if string in self.lookup_table:
return self.lookup_table[string]
return string
def lemmatize(string, index, exceptions, rules):
string = string.lower()

View File

@ -18,6 +18,7 @@ class CroatianDefaults(Language.Defaults):
)
tokenizer_exceptions = update_exc(BASE_EXCEPTIONS)
stop_words = STOP_WORDS
resources = {"lemma_lookup": "lemma_lookup.json"}
class Croatian(Language):

1313609
spacy/lang/hr/lemma_lookup.json Normal file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,15 @@
The list of Croatian lemmas was extracted from the reldi-tagger repository (https://github.com/clarinsi/reldi-tagger).
Reldi-tagger is licesned under the Apache 2.0 licence.
@InProceedings{ljubesic16-new,
author = {Nikola Ljubešić and Filip Klubička and Željko Agić and Ivo-Pavao Jazbec},
title = {New Inflectional Lexicons and Training Corpora for Improved Morphosyntactic Annotation of Croatian and Serbian},
booktitle = {Proceedings of the Tenth International Conference on Language Resources and Evaluation (LREC 2016)},
year = {2016},
date = {23-28},
location = {Portorož, Slovenia},
editor = {Nicoletta Calzolari (Conference Chair) and Khalid Choukri and Thierry Declerck and Sara Goggi and Marko Grobelnik and Bente Maegaard and Joseph Mariani and Helene Mazo and Asuncion Moreno and Jan Odijk and Stelios Piperidis},
publisher = {European Language Resources Association (ELRA)},
address = {Paris, France},
isbn = {978-2-9517408-9-1}
}

View File

@ -1,8 +1,6 @@
# coding: utf8
from __future__ import unicode_literals
from pathlib import Path
from .tokenizer_exceptions import TOKENIZER_EXCEPTIONS
from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS

View File

@ -3,7 +3,7 @@ from __future__ import unicode_literals
from .char_classes import LIST_PUNCT, LIST_ELLIPSES, LIST_QUOTES, LIST_CURRENCY
from .char_classes import LIST_ICONS, HYPHENS, CURRENCY, UNITS
from .char_classes import CONCAT_QUOTES, ALPHA_LOWER, ALPHA_UPPER, ALPHA
from .char_classes import CONCAT_QUOTES, ALPHA_LOWER, ALPHA_UPPER, ALPHA, PUNCT
_prefixes = (
@ -27,8 +27,8 @@ _suffixes = (
r"(?<=°[FfCcKk])\.",
r"(?<=[0-9])(?:{c})".format(c=CURRENCY),
r"(?<=[0-9])(?:{u})".format(u=UNITS),
r"(?<=[0-9{al}{e}(?:{q})])\.".format(
al=ALPHA_LOWER, e=r"%²\-\+", q=CONCAT_QUOTES
r"(?<=[0-9{al}{e}{p}(?:{q})])\.".format(
al=ALPHA_LOWER, e=r"%²\-\+", q=CONCAT_QUOTES, p=PUNCT
),
r"(?<=[{au}][{au}])\.".format(au=ALPHA_UPPER),
]

View File

@ -9,6 +9,7 @@ from ..norm_exceptions import BASE_NORMS
from ...language import Language
from ...attrs import LANG, NORM
from ...util import update_exc, add_lookups
from .tag_map import TAG_MAP
# Lemma data note:
# Original pairs downloaded from http://www.lexiconista.com/datasets/lemmatization/
@ -24,6 +25,7 @@ class RomanianDefaults(Language.Defaults):
tokenizer_exceptions = update_exc(BASE_EXCEPTIONS, TOKENIZER_EXCEPTIONS)
stop_words = STOP_WORDS
resources = {"lemma_lookup": "lemma_lookup.json"}
tag_map = TAG_MAP
class Romanian(Language):

1654
spacy/lang/ro/tag_map.py Normal file

File diff suppressed because it is too large Load Diff

View File

@ -21,6 +21,7 @@ class SerbianDefaults(Language.Defaults):
)
tokenizer_exceptions = update_exc(BASE_EXCEPTIONS, TOKENIZER_EXCEPTIONS)
stop_words = STOP_WORDS
resources = {"lemma_lookup": "lemma_lookup.json"}
class Serbian(Language):

View File

@ -12,13 +12,14 @@ Example sentences to test spaCy and its language models.
sentences = [
# Translations from English
"Apple планира куповину америчког стартапа за $1 милијарду."
"Apple планира куповину америчког стартапа за $1 милијарду.",
"Беспилотни аутомобили пребацују одговорност осигурања на произвођаче.",
"Лондон је велики град у Уједињеном Краљевству.",
"Где си ти?",
"Ко је председник Француске?",
# Serbian common and slang
"Moj ћале је инжењер!",
"Новак Ђоковић је најбољи тенисер света." "У Пироту има добрих кафана!",
"Новак Ђоковић је најбољи тенисер света.",
"У Пироту има добрих кафана!",
"Музеј Николе Тесле се налази у Београду.",
]

253316
spacy/lang/sr/lemma_lookup.json Executable file

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,32 @@
Copyright @InProceedings{ljubesic16-new,
author = {Nikola Ljubešić and Filip Klubička and Željko Agić and Ivo-Pavao Jazbec},
title = {New Inflectional Lexicons and Training Corpora for Improved Morphosyntactic Annotation of Croatian and Serbian},
booktitle = {Proceedings of the Tenth International Conference on Language Resources and Evaluation (LREC 2016)},
year = {2016},
date = {23-28},
location = {Portorož, Slovenia},
editor = {Nicoletta Calzolari (Conference Chair) and Khalid Choukri and Thierry Declerck and Sara Goggi and Marko Grobelnik and Bente Maegaard and Joseph Mariani and Helene Mazo and Asuncion Moreno and Jan Odijk and Stelios Piperidis},
publisher = {European Language Resources Association (ELRA)},
address = {Paris, France},
isbn = {978-2-9517408-9-1}
}
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
The licence of Serbian lemmas was adopted from Serbian lexicon:
- sr.lexicon (https://github.com/clarinsi/reldi-tagger/blob/master/sr.lexicon)
Changelog:
- Lexicon is translated into cyrilic
- Word order is sorted

View File

@ -24,7 +24,7 @@ class UkrainianDefaults(Language.Defaults):
stop_words = STOP_WORDS
@classmethod
def create_lemmatizer(cls, nlp=None):
def create_lemmatizer(cls, nlp=None, **kwargs):
return UkrainianLemmatizer()

View File

@ -208,6 +208,7 @@ class Language(object):
"name": self.vocab.vectors.name,
}
self._meta["pipeline"] = self.pipe_names
self._meta["labels"] = self.pipe_labels
return self._meta
@meta.setter
@ -248,6 +249,18 @@ class Language(object):
"""
return [pipe_name for pipe_name, _ in self.pipeline]
@property
def pipe_labels(self):
"""Get the labels set by the pipeline components, if available.
RETURNS (dict): Labels keyed by component name.
"""
labels = OrderedDict()
for name, pipe in self.pipeline:
if hasattr(pipe, "labels"):
labels[name] = list(pipe.labels)
return labels
def get_pipe(self, name):
"""Get a pipeline component for a given component name.

View File

@ -1,52 +1,157 @@
# coding: utf8
from __future__ import unicode_literals
from .util import SimpleFrozenDict
import srsly
from collections import OrderedDict
from .errors import Errors
from .util import SimpleFrozenDict, ensure_path
class Lookups(object):
"""Container for large lookup tables and dictionaries, e.g. lemmatization
data or tokenizer exception lists. Lookups are available via vocab.lookups,
so they can be accessed before the pipeline components are applied (e.g.
in the tokenizer and lemmatizer), as well as within the pipeline components
via doc.vocab.lookups.
Important note: At the moment, this class only performs a very basic
dictionary lookup. We're planning to replace this with a more efficient
implementation. See #3971 for details.
"""
def __init__(self):
self._tables = {}
"""Initialize the Lookups object.
RETURNS (Lookups): The newly created object.
"""
self._tables = OrderedDict()
def __contains__(self, name):
"""Check if the lookups contain a table of a given name. Delegates to
Lookups.has_table.
name (unicode): Name of the table.
RETURNS (bool): Whether a table of that name exists.
"""
return self.has_table(name)
def __len__(self):
"""RETURNS (int): The number of tables in the lookups."""
return len(self._tables)
@property
def tables(self):
"""RETURNS (list): Names of all tables in the lookups."""
return list(self._tables.keys())
def add_table(self, name, data=SimpleFrozenDict()):
"""Add a new table to the lookups. Raises an error if the table exists.
name (unicode): Unique name of table.
data (dict): Optional data to add to the table.
RETURNS (Table): The newly added table.
"""
if name in self.tables:
raise ValueError("Table '{}' already exists".format(name))
raise ValueError(Errors.E158.format(name=name))
table = Table(name=name)
table.update(data)
self._tables[name] = table
return table
def get_table(self, name):
"""Get a table. Raises an error if the table doesn't exist.
name (unicode): Name of the table.
RETURNS (Table): The table.
"""
if name not in self._tables:
raise KeyError("Can't find table '{}'".format(name))
raise KeyError(Errors.E159.format(name=name, tables=self.tables))
return self._tables[name]
def remove_table(self, name):
"""Remove a table. Raises an error if the table doesn't exist.
name (unicode): The name to remove.
RETURNS (Table): The removed table.
"""
if name not in self._tables:
raise KeyError(Errors.E159.format(name=name, tables=self.tables))
return self._tables.pop(name)
def has_table(self, name):
"""Check if the lookups contain a table of a given name.
name (unicode): Name of the table.
RETURNS (bool): Whether a table of that name exists.
"""
return name in self._tables
def to_bytes(self, exclude=tuple(), **kwargs):
raise NotImplementedError
"""Serialize the lookups to a bytestring.
exclude (list): String names of serialization fields to exclude.
RETURNS (bytes): The serialized Lookups.
"""
return srsly.msgpack_dumps(self._tables)
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
raise NotImplementedError
"""Load the lookups from a bytestring.
def to_disk(self, path, exclude=tuple(), **kwargs):
raise NotImplementedError
exclude (list): String names of serialization fields to exclude.
RETURNS (bytes): The loaded Lookups.
"""
self._tables = OrderedDict()
msg = srsly.msgpack_loads(bytes_data)
for key, value in msg.items():
self._tables[key] = Table.from_dict(value)
return self
def from_disk(self, path, exclude=tuple(), **kwargs):
raise NotImplementedError
def to_disk(self, path, **kwargs):
"""Save the lookups to a directory as lookups.bin.
path (unicode / Path): The file path.
"""
if len(self._tables):
path = ensure_path(path)
filepath = path / "lookups.bin"
with filepath.open("wb") as file_:
file_.write(self.to_bytes())
def from_disk(self, path, **kwargs):
"""Load lookups from a directory containing a lookups.bin.
path (unicode / Path): The file path.
RETURNS (Lookups): The loaded lookups.
"""
path = ensure_path(path)
filepath = path / "lookups.bin"
if filepath.exists():
with filepath.open("rb") as file_:
data = file_.read()
return self.from_bytes(data)
return self
class Table(dict):
class Table(OrderedDict):
"""A table in the lookups. Subclass of builtin dict that implements a
slightly more consistent and unified API.
"""
@classmethod
def from_dict(cls, data, name=None):
self = cls(name=name)
self.update(data)
return self
def __init__(self, name=None):
"""Initialize a new table.
name (unicode): Optional table name for reference.
RETURNS (Table): The newly created object.
"""
OrderedDict.__init__(self)
self.name = name
def set(self, key, value):
"""Set new key/value pair. Same as table[key] = value."""
self[key] = value

View File

@ -67,7 +67,7 @@ class Pipe(object):
"""
self.require_model()
predictions = self.predict([doc])
if isinstance(predictions, tuple) and len(tuple) == 2:
if isinstance(predictions, tuple) and len(predictions) == 2:
scores, tensors = predictions
self.set_annotations([doc], scores, tensor=tensors)
else:
@ -1062,8 +1062,15 @@ cdef class DependencyParser(Parser):
@property
def labels(self):
labels = set()
# Get the labels from the model by looking at the available moves
return tuple(set(move.split("-")[1] for move in self.move_names))
for move in self.move_names:
if "-" in move:
label = move.split("-")[1]
if "||" in label:
label = label.split("||")[1]
labels.add(label)
return tuple(sorted(labels))
cdef class EntityRecognizer(Parser):
@ -1098,8 +1105,9 @@ cdef class EntityRecognizer(Parser):
def labels(self):
# Get the labels from the model by looking at the available moves, e.g.
# B-PERSON, I-PERSON, L-PERSON, U-PERSON
return tuple(set(move.split("-")[1] for move in self.move_names
if move[0] in ("B", "I", "L", "U")))
labels = set(move.split("-")[1] for move in self.move_names
if move[0] in ("B", "I", "L", "U"))
return tuple(sorted(labels))
class EntityLinker(Pipe):
@ -1275,7 +1283,7 @@ class EntityLinker(Pipe):
# this will set all prior probabilities to 0 if they should be excluded from the model
prior_probs = xp.asarray([c.prior_prob for c in candidates])
if not self.cfg.get("incl_prior", True):
prior_probs = xp.asarray([[0.0] for c in candidates])
prior_probs = xp.asarray([0.0 for c in candidates])
scores = prior_probs
# add in similarity from the context
@ -1288,6 +1296,8 @@ class EntityLinker(Pipe):
# cosine similarity
sims = xp.dot(entity_encodings, context_enc_t) / (norm_1 * norm_2)
if sims.shape != prior_probs.shape:
raise ValueError(Errors.E161)
scores = prior_probs + sims - (prior_probs*sims)
# TODO: thresholding
@ -1361,7 +1371,16 @@ class Sentencizer(object):
"""
name = "sentencizer"
default_punct_chars = [".", "!", "?"]
default_punct_chars = ['!', '.', '?', '։', '؟', '۔', '܀', '܁', '܂', '߹',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '᱿',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '', '', '', '', '', '', '', '',
'', '', '', '', '', '𐩖', '𐩗', '𑁇', '𑁈', '𑂾', '𑂿', '𑃀',
'𑃁', '𑅁', '𑅂', '𑅃', '𑇅', '𑇆', '𑇍', '𑇞', '𑇟', '𑈸', '𑈹', '𑈻', '𑈼',
'𑊩', '𑑋', '𑑌', '𑗂', '𑗃', '𑗉', '𑗊', '𑗋', '𑗌', '𑗍', '𑗎', '𑗏', '𑗐',
'𑗑', '𑗒', '𑗓', '𑗔', '𑗕', '𑗖', '𑗗', '𑙁', '𑙂', '𑜼', '𑜽', '𑜾', '𑩂',
'𑩃', '𑪛', '𑪜', '𑱁', '𑱂', '𖩮', '𖩯', '𖫵', '𖬷', '𖬸', '𖭄', '𛲟', '𝪈']
def __init__(self, punct_chars=None, **kwargs):
"""Initialize the sentencizer.
@ -1372,7 +1391,10 @@ class Sentencizer(object):
DOCS: https://spacy.io/api/sentencizer#init
"""
self.punct_chars = punct_chars or self.default_punct_chars
if punct_chars:
self.punct_chars = set(punct_chars)
else:
self.punct_chars = set(self.default_punct_chars)
def __call__(self, doc):
"""Apply the sentencizer to a Doc and set Token.is_sent_start.
@ -1404,7 +1426,7 @@ class Sentencizer(object):
DOCS: https://spacy.io/api/sentencizer#to_bytes
"""
return srsly.msgpack_dumps({"punct_chars": self.punct_chars})
return srsly.msgpack_dumps({"punct_chars": list(self.punct_chars)})
def from_bytes(self, bytes_data, **kwargs):
"""Load the sentencizer from a bytestring.
@ -1415,7 +1437,7 @@ class Sentencizer(object):
DOCS: https://spacy.io/api/sentencizer#from_bytes
"""
cfg = srsly.msgpack_loads(bytes_data)
self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
return self
def to_disk(self, path, exclude=tuple(), **kwargs):
@ -1425,7 +1447,7 @@ class Sentencizer(object):
"""
path = util.ensure_path(path)
path = path.with_suffix(".json")
srsly.write_json(path, {"punct_chars": self.punct_chars})
srsly.write_json(path, {"punct_chars": list(self.punct_chars)})
def from_disk(self, path, exclude=tuple(), **kwargs):
@ -1436,7 +1458,7 @@ class Sentencizer(object):
path = util.ensure_path(path)
path = path.with_suffix(".json")
cfg = srsly.read_json(path)
self.punct_chars = cfg.get("punct_chars", self.default_punct_chars)
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
return self

View File

@ -103,6 +103,11 @@ def he_tokenizer():
return get_lang_class("he").Defaults.create_tokenizer()
@pytest.fixture(scope="session")
def hr_tokenizer():
return get_lang_class("hr").Defaults.create_tokenizer()
@pytest.fixture
def hu_tokenizer():
return get_lang_class("hu").Defaults.create_tokenizer()

View File

@ -99,6 +99,53 @@ def test_doc_retokenize_spans_merge_tokens(en_tokenizer):
assert doc[0].ent_type_ == "GPE"
def test_doc_retokenize_spans_merge_tokens_default_attrs(en_tokenizer):
text = "The players start."
heads = [1, 1, 0, -1]
tokens = en_tokenizer(text)
doc = get_doc(
tokens.vocab,
words=[t.text for t in tokens],
tags=["DT", "NN", "VBZ", "."],
pos=["DET", "NOUN", "VERB", "PUNCT"],
heads=heads,
)
assert len(doc) == 4
assert doc[0].text == "The"
assert doc[0].tag_ == "DT"
assert doc[0].pos_ == "DET"
with doc.retokenize() as retokenizer:
retokenizer.merge(doc[0:2])
assert len(doc) == 3
assert doc[0].text == "The players"
assert doc[0].tag_ == "NN"
assert doc[0].pos_ == "NOUN"
assert doc[0].lemma_ == "The players"
doc = get_doc(
tokens.vocab,
words=[t.text for t in tokens],
tags=["DT", "NN", "VBZ", "."],
pos=["DET", "NOUN", "VERB", "PUNCT"],
heads=heads,
)
assert len(doc) == 4
assert doc[0].text == "The"
assert doc[0].tag_ == "DT"
assert doc[0].pos_ == "DET"
with doc.retokenize() as retokenizer:
retokenizer.merge(doc[0:2])
retokenizer.merge(doc[2:4])
assert len(doc) == 2
assert doc[0].text == "The players"
assert doc[0].tag_ == "NN"
assert doc[0].pos_ == "NOUN"
assert doc[0].lemma_ == "The players"
assert doc[1].text == "start ."
assert doc[1].tag_ == "VBZ"
assert doc[1].pos_ == "VERB"
assert doc[1].lemma_ == "start ."
def test_doc_retokenize_spans_merge_heads(en_tokenizer):
text = "I found a pilates class near work."
heads = [1, 0, 2, 1, -3, -1, -1, -6]
@ -182,7 +229,7 @@ def test_doc_retokenize_spans_entity_merge(en_tokenizer):
assert len(doc) == 15
def test_doc_retokenize_spans_entity_merge_iob():
def test_doc_retokenize_spans_entity_merge_iob(en_vocab):
# Test entity IOB stays consistent after merging
words = ["a", "b", "c", "d", "e"]
doc = Doc(Vocab(), words=words)
@ -195,10 +242,23 @@ def test_doc_retokenize_spans_entity_merge_iob():
assert doc[2].ent_iob_ == "I"
assert doc[3].ent_iob_ == "B"
with doc.retokenize() as retokenizer:
retokenizer.merge(doc[0:1])
retokenizer.merge(doc[0:2])
assert len(doc) == len(words) - 1
assert doc[0].ent_iob_ == "B"
assert doc[1].ent_iob_ == "I"
# Test that IOB stays consistent with provided IOB
words = ["a", "b", "c", "d", "e"]
doc = Doc(Vocab(), words=words)
with doc.retokenize() as retokenizer:
attrs = {"ent_type": "ent-abc", "ent_iob": 1}
retokenizer.merge(doc[0:3], attrs=attrs)
retokenizer.merge(doc[3:5], attrs=attrs)
assert doc[0].ent_iob_ == "B"
assert doc[1].ent_iob_ == "I"
# if no parse/heads, the first word in the span is the root and provides
# default values
words = ["a", "b", "c", "d", "e", "f", "g", "h", "i"]
doc = Doc(Vocab(), words=words)
doc.ents = [
@ -215,7 +275,47 @@ def test_doc_retokenize_spans_entity_merge_iob():
retokenizer.merge(doc[7:9])
assert len(doc) == 6
assert doc[3].ent_iob_ == "B"
assert doc[4].ent_iob_ == "I"
assert doc[3].ent_type_ == "ent-de"
assert doc[4].ent_iob_ == "B"
assert doc[4].ent_type_ == "ent-fg"
# if there is a parse, span.root provides default values
words = ["a", "b", "c", "d", "e", "f", "g", "h", "i"]
heads = [0, -1, 1, -3, -4, -5, -1, -7, -8]
ents = [(3, 5, "ent-de"), (5, 7, "ent-fg")]
deps = ["dep"] * len(words)
en_vocab.strings.add("ent-de")
en_vocab.strings.add("ent-fg")
en_vocab.strings.add("dep")
doc = get_doc(en_vocab, words=words, heads=heads, deps=deps, ents=ents)
assert doc[2:4].root == doc[3] # root of 'c d' is d
assert doc[4:6].root == doc[4] # root is 'e f' is e
with doc.retokenize() as retokenizer:
retokenizer.merge(doc[2:4])
retokenizer.merge(doc[4:6])
retokenizer.merge(doc[7:9])
assert len(doc) == 6
assert doc[2].ent_iob_ == "B"
assert doc[2].ent_type_ == "ent-de"
assert doc[3].ent_iob_ == "I"
assert doc[3].ent_type_ == "ent-de"
assert doc[4].ent_iob_ == "B"
assert doc[4].ent_type_ == "ent-fg"
# check that B is preserved if span[start] is B
words = ["a", "b", "c", "d", "e", "f", "g", "h", "i"]
heads = [0, -1, 1, 1, -4, -5, -1, -7, -8]
ents = [(3, 5, "ent-de"), (5, 7, "ent-de")]
deps = ["dep"] * len(words)
doc = get_doc(en_vocab, words=words, heads=heads, deps=deps, ents=ents)
with doc.retokenize() as retokenizer:
retokenizer.merge(doc[3:5])
retokenizer.merge(doc[5:7])
assert len(doc) == 7
assert doc[3].ent_iob_ == "B"
assert doc[3].ent_type_ == "ent-de"
assert doc[4].ent_iob_ == "B"
assert doc[4].ent_type_ == "ent-de"
def test_doc_retokenize_spans_sentence_update_after_merge(en_tokenizer):

View File

@ -173,6 +173,21 @@ def test_span_as_doc(doc):
assert span_doc[0].idx == 0
def test_span_as_doc_user_data(doc):
"""Test that the user_data can be preserved (but not by default). """
my_key = "my_info"
my_value = 342
doc.user_data[my_key] = my_value
span = doc[4:10]
span_doc_with = span.as_doc(copy_user_data=True)
span_doc_without = span.as_doc()
assert doc.user_data.get(my_key, None) is my_value
assert span_doc_with.user_data.get(my_key, None) is my_value
assert span_doc_without.user_data.get(my_key, None) is None
def test_span_string_label_kb_id(doc):
span = Span(doc, 0, 1, label="hello", kb_id="Q342")
assert span.label_ == "hello"

View File

@ -133,3 +133,9 @@ def test_en_tokenizer_splits_em_dash_infix(en_tokenizer):
assert tokens[6].text == "Puddleton"
assert tokens[7].text == "?"
assert tokens[8].text == "\u2014"
@pytest.mark.parametrize("text,length", [("_MATH_", 3), ("_MATH_.", 4)])
def test_final_period(en_tokenizer, text, length):
tokens = en_tokenizer(text)
assert len(tokens) == length

View File

@ -0,0 +1,20 @@
# coding: utf-8
from __future__ import unicode_literals
import pytest
@pytest.mark.parametrize(
"string,lemma",
[
("trčao", "trčati"),
("adekvatnim", "adekvatan"),
("dekontaminacijama", "dekontaminacija"),
("filologovih", "filologov"),
("je", "biti"),
("se", "sebe"),
],
)
def test_hr_lemmatizer_lookup_assigns(hr_tokenizer, string, lemma):
tokens = hr_tokenizer(string)
assert tokens[0].lemma_ == lemma

View File

@ -0,0 +1,20 @@
# coding: utf-8
from __future__ import unicode_literals
import pytest
@pytest.mark.parametrize(
"string,lemma",
[
("најадекватнији", "адекватан"),
("матурирао", "матурирати"),
("планираћемо", "планирати"),
("певају", "певати"),
("нама", "ми"),
("се", "себе"),
],
)
def test_sr_lemmatizer_lookup_assigns(sr_tokenizer, string, lemma):
tokens = sr_tokenizer(string)
assert tokens[0].lemma_ == lemma

View File

@ -68,3 +68,20 @@ def test_add_label_deserializes_correctly():
assert ner1.moves.n_moves == ner2.moves.n_moves
for i in range(ner1.moves.n_moves):
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)
@pytest.mark.parametrize(
"pipe_cls,n_moves", [(DependencyParser, 5), (EntityRecognizer, 4)]
)
def test_add_label_get_label(pipe_cls, n_moves):
"""Test that added labels are returned correctly. This test was added to
test for a bug in DependencyParser.labels that'd cause it to fail when
splitting the move names.
"""
labels = ["A", "B", "C"]
pipe = pipe_cls(Vocab())
for label in labels:
pipe.add_label(label)
assert len(pipe.move_names) == len(labels) * n_moves
pipe_labels = sorted(list(pipe.labels))
assert pipe_labels == labels

View File

@ -128,3 +128,19 @@ def test_pipe_base_class_add_label(nlp, component):
assert label in pipe.labels
else:
assert pipe.labels == (label,)
def test_pipe_labels(nlp):
input_labels = {
"ner": ["PERSON", "ORG", "GPE"],
"textcat": ["POSITIVE", "NEGATIVE"],
}
for name, labels in input_labels.items():
pipe = nlp.create_pipe(name)
for label in labels:
pipe.add_label(label)
assert len(pipe.labels) == len(labels)
nlp.add_pipe(pipe)
assert len(nlp.pipe_labels) == len(input_labels)
for name, labels in nlp.pipe_labels.items():
assert sorted(input_labels[name]) == sorted(labels)

View File

@ -81,7 +81,7 @@ def test_sentencizer_custom_punct(en_vocab, punct_chars, words, sent_starts, n_s
def test_sentencizer_serialize_bytes(en_vocab):
punct_chars = [".", "~", "+"]
sentencizer = Sentencizer(punct_chars=punct_chars)
assert sentencizer.punct_chars == punct_chars
assert sentencizer.punct_chars == set(punct_chars)
bytes_data = sentencizer.to_bytes()
new_sentencizer = Sentencizer().from_bytes(bytes_data)
assert new_sentencizer.punct_chars == punct_chars
assert new_sentencizer.punct_chars == set(punct_chars)

View File

@ -13,26 +13,25 @@ from spacy.lemmatizer import Lemmatizer
from spacy.symbols import ORTH, LEMMA, POS, VERB, VerbForm_part
@pytest.mark.xfail
def test_issue1061():
'''Test special-case works after tokenizing. Was caching problem.'''
text = 'I like _MATH_ even _MATH_ when _MATH_, except when _MATH_ is _MATH_! but not _MATH_.'
"""Test special-case works after tokenizing. Was caching problem."""
text = "I like _MATH_ even _MATH_ when _MATH_, except when _MATH_ is _MATH_! but not _MATH_."
tokenizer = English.Defaults.create_tokenizer()
doc = tokenizer(text)
assert 'MATH' in [w.text for w in doc]
assert '_MATH_' not in [w.text for w in doc]
assert "MATH" in [w.text for w in doc]
assert "_MATH_" not in [w.text for w in doc]
tokenizer.add_special_case('_MATH_', [{ORTH: '_MATH_'}])
tokenizer.add_special_case("_MATH_", [{ORTH: "_MATH_"}])
doc = tokenizer(text)
assert '_MATH_' in [w.text for w in doc]
assert 'MATH' not in [w.text for w in doc]
assert "_MATH_" in [w.text for w in doc]
assert "MATH" not in [w.text for w in doc]
# For sanity, check it works when pipeline is clean.
tokenizer = English.Defaults.create_tokenizer()
tokenizer.add_special_case('_MATH_', [{ORTH: '_MATH_'}])
tokenizer.add_special_case("_MATH_", [{ORTH: "_MATH_"}])
doc = tokenizer(text)
assert '_MATH_' in [w.text for w in doc]
assert 'MATH' not in [w.text for w in doc]
assert "_MATH_" in [w.text for w in doc]
assert "MATH" not in [w.text for w in doc]
@pytest.mark.xfail(

View File

@ -1,7 +1,6 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.matcher import Matcher
from spacy.tokens import Doc

View File

@ -1,7 +1,6 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.matcher import Matcher
from spacy.tokens import Doc

View File

@ -1,7 +1,6 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.matcher import PhraseMatcher
from spacy.tokens import Doc

View File

@ -1,7 +1,6 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.matcher import Matcher
from spacy.tokens import Doc

View File

@ -2,44 +2,37 @@
from __future__ import unicode_literals
from spacy.lang.en import English
import spacy
from spacy.tokenizer import Tokenizer
from spacy import util
from spacy.tests.util import make_tempdir
from ..util import make_tempdir
def test_issue4190():
test_string = "Test c."
# Load default language
nlp_1 = English()
doc_1a = nlp_1(test_string)
result_1a = [token.text for token in doc_1a]
result_1a = [token.text for token in doc_1a] # noqa: F841
# Modify tokenizer
customize_tokenizer(nlp_1)
doc_1b = nlp_1(test_string)
result_1b = [token.text for token in doc_1b]
# Save and Reload
with make_tempdir() as model_dir:
nlp_1.to_disk(model_dir)
nlp_2 = spacy.load(model_dir)
nlp_2 = util.load_model(model_dir)
# This should be the modified tokenizer
doc_2 = nlp_2(test_string)
result_2 = [token.text for token in doc_2]
assert result_1b == result_2
def customize_tokenizer(nlp):
prefix_re = spacy.util.compile_prefix_regex(nlp.Defaults.prefixes)
suffix_re = spacy.util.compile_suffix_regex(nlp.Defaults.suffixes)
infix_re = spacy.util.compile_infix_regex(nlp.Defaults.infixes)
# remove all exceptions where a single letter is followed by a period (e.g. 'h.')
prefix_re = util.compile_prefix_regex(nlp.Defaults.prefixes)
suffix_re = util.compile_suffix_regex(nlp.Defaults.suffixes)
infix_re = util.compile_infix_regex(nlp.Defaults.infixes)
# Remove all exceptions where a single letter is followed by a period (e.g. 'h.')
exceptions = {
k: v
for k, v in dict(nlp.Defaults.tokenizer_exceptions).items()
@ -53,5 +46,4 @@ def customize_tokenizer(nlp):
infix_finditer=infix_re.finditer,
token_match=nlp.tokenizer.token_match,
)
nlp.tokenizer = new_tokenizer

View File

@ -0,0 +1,12 @@
# coding: utf8
from __future__ import unicode_literals
from spacy.lang.el import Greek
def test_issue4272():
"""Test that lookup table can be accessed from Token.lemma if no POS tags
are available."""
nlp = Greek()
doc = nlp("Χθες")
assert doc[0].lemma_

View File

@ -0,0 +1,28 @@
# coding: utf8
from __future__ import unicode_literals
import pytest
from spacy.language import Language
from spacy.pipeline import Pipe
class DummyPipe(Pipe):
def __init__(self):
self.model = "dummy_model"
def predict(self, docs):
return ([1, 2, 3], [4, 5, 6])
def set_annotations(self, docs, scores, tensor=None):
return docs
@pytest.fixture
def nlp():
return Language()
def test_multiple_predictions(nlp):
doc = nlp.make_doc("foo")
dummy_pipe = DummyPipe()
dummy_pipe(doc)

View File

@ -41,8 +41,8 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
parser.model, _ = parser.Model(10)
new_parser = Parser(en_vocab)
new_parser.model, _ = new_parser.Model(10)
new_parser = new_parser.from_bytes(parser.to_bytes())
assert new_parser.to_bytes() == parser.to_bytes()
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
assert new_parser.to_bytes(exclude=["vocab"]) == parser.to_bytes(exclude=["vocab"])
@pytest.mark.parametrize("Parser", test_parsers)
@ -55,8 +55,8 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
parser_d = Parser(en_vocab)
parser_d.model, _ = parser_d.Model(0)
parser_d = parser_d.from_disk(file_path)
parser_bytes = parser.to_bytes(exclude=["model"])
parser_d_bytes = parser_d.to_bytes(exclude=["model"])
parser_bytes = parser.to_bytes(exclude=["model", "vocab"])
parser_d_bytes = parser_d.to_bytes(exclude=["model", "vocab"])
assert parser_bytes == parser_d_bytes
@ -64,7 +64,7 @@ def test_to_from_bytes(parser, blank_parser):
assert parser.model is not True
assert blank_parser.model is True
assert blank_parser.moves.n_moves != parser.moves.n_moves
bytes_data = parser.to_bytes()
bytes_data = parser.to_bytes(exclude=["vocab"])
blank_parser.from_bytes(bytes_data)
assert blank_parser.model is not True
assert blank_parser.moves.n_moves == parser.moves.n_moves
@ -97,9 +97,9 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
def test_serialize_tensorizer_roundtrip_bytes(en_vocab):
tensorizer = Tensorizer(en_vocab)
tensorizer.model = tensorizer.Model()
tensorizer_b = tensorizer.to_bytes()
tensorizer_b = tensorizer.to_bytes(exclude=["vocab"])
new_tensorizer = Tensorizer(en_vocab).from_bytes(tensorizer_b)
assert new_tensorizer.to_bytes() == tensorizer_b
assert new_tensorizer.to_bytes(exclude=["vocab"]) == tensorizer_b
def test_serialize_tensorizer_roundtrip_disk(en_vocab):
@ -109,13 +109,15 @@ def test_serialize_tensorizer_roundtrip_disk(en_vocab):
file_path = d / "tensorizer"
tensorizer.to_disk(file_path)
tensorizer_d = Tensorizer(en_vocab).from_disk(file_path)
assert tensorizer.to_bytes() == tensorizer_d.to_bytes()
assert tensorizer.to_bytes(exclude=["vocab"]) == tensorizer_d.to_bytes(
exclude=["vocab"]
)
def test_serialize_textcat_empty(en_vocab):
# See issue #1105
textcat = TextCategorizer(en_vocab, labels=["ENTITY", "ACTION", "MODIFIER"])
textcat.to_bytes()
textcat.to_bytes(exclude=["vocab"])
@pytest.mark.parametrize("Parser", test_parsers)
@ -128,13 +130,17 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
parser = Parser(en_vocab)
parser.model, _ = parser.Model(0)
parser.cfg["foo"] = "bar"
new_parser = get_new_parser().from_bytes(parser.to_bytes())
new_parser = get_new_parser().from_bytes(parser.to_bytes(exclude=["vocab"]))
assert "foo" in new_parser.cfg
new_parser = get_new_parser().from_bytes(parser.to_bytes(), exclude=["cfg"])
new_parser = get_new_parser().from_bytes(
parser.to_bytes(exclude=["vocab"]), exclude=["cfg"]
)
assert "foo" not in new_parser.cfg
new_parser = get_new_parser().from_bytes(parser.to_bytes(exclude=["cfg"]))
new_parser = get_new_parser().from_bytes(
parser.to_bytes(exclude=["cfg"]), exclude=["vocab"]
)
assert "foo" not in new_parser.cfg
with pytest.raises(ValueError):
parser.to_bytes(cfg=False)
parser.to_bytes(cfg=False, exclude=["vocab"])
with pytest.raises(ValueError):
get_new_parser().from_bytes(parser.to_bytes(), cfg=False)
get_new_parser().from_bytes(parser.to_bytes(exclude=["vocab"]), cfg=False)

View File

@ -12,12 +12,14 @@ test_strings = [([], []), (["rats", "are", "cute"], ["i", "like", "rats"])]
test_strings_attrs = [(["rats", "are", "cute"], "Hello")]
@pytest.mark.xfail
@pytest.mark.parametrize("text", ["rat"])
def test_serialize_vocab(en_vocab, text):
text_hash = en_vocab.strings.add(text)
vocab_bytes = en_vocab.to_bytes()
vocab_bytes = en_vocab.to_bytes(exclude=["lookups"])
new_vocab = Vocab().from_bytes(vocab_bytes)
assert new_vocab.strings[text_hash] == text
assert new_vocab.to_bytes(exclude=["lookups"]) == vocab_bytes
@pytest.mark.parametrize("strings1,strings2", test_strings)

View File

@ -3,6 +3,9 @@ from __future__ import unicode_literals
import pytest
from spacy.lookups import Lookups
from spacy.vocab import Vocab
from ..util import make_tempdir
def test_lookups_api():
@ -10,6 +13,7 @@ def test_lookups_api():
data = {"foo": "bar", "hello": "world"}
lookups = Lookups()
lookups.add_table(table_name, data)
assert len(lookups) == 1
assert table_name in lookups
assert lookups.has_table(table_name)
table = lookups.get_table(table_name)
@ -22,5 +26,91 @@ def test_lookups_api():
assert len(table) == 3
with pytest.raises(KeyError):
lookups.get_table("xyz")
# with pytest.raises(ValueError):
# lookups.add_table(table_name)
with pytest.raises(ValueError):
lookups.add_table(table_name)
table = lookups.remove_table(table_name)
assert table.name == table_name
assert len(lookups) == 0
assert table_name not in lookups
with pytest.raises(KeyError):
lookups.get_table(table_name)
# This fails on Python 3.5
@pytest.mark.xfail
def test_lookups_to_from_bytes():
lookups = Lookups()
lookups.add_table("table1", {"foo": "bar", "hello": "world"})
lookups.add_table("table2", {"a": 1, "b": 2, "c": 3})
lookups_bytes = lookups.to_bytes()
new_lookups = Lookups()
new_lookups.from_bytes(lookups_bytes)
assert len(new_lookups) == 2
assert "table1" in new_lookups
assert "table2" in new_lookups
table1 = new_lookups.get_table("table1")
assert len(table1) == 2
assert table1.get("foo") == "bar"
table2 = new_lookups.get_table("table2")
assert len(table2) == 3
assert table2.get("b") == 2
assert new_lookups.to_bytes() == lookups_bytes
# This fails on Python 3.5
@pytest.mark.xfail
def test_lookups_to_from_disk():
lookups = Lookups()
lookups.add_table("table1", {"foo": "bar", "hello": "world"})
lookups.add_table("table2", {"a": 1, "b": 2, "c": 3})
with make_tempdir() as tmpdir:
lookups.to_disk(tmpdir)
new_lookups = Lookups()
new_lookups.from_disk(tmpdir)
assert len(new_lookups) == 2
assert "table1" in new_lookups
assert "table2" in new_lookups
table1 = new_lookups.get_table("table1")
assert len(table1) == 2
assert table1.get("foo") == "bar"
table2 = new_lookups.get_table("table2")
assert len(table2) == 3
assert table2.get("b") == 2
# This fails on Python 3.5
@pytest.mark.xfail
def test_lookups_to_from_bytes_via_vocab():
table_name = "test"
vocab = Vocab()
vocab.lookups.add_table(table_name, {"foo": "bar", "hello": "world"})
assert len(vocab.lookups) == 1
assert table_name in vocab.lookups
vocab_bytes = vocab.to_bytes()
new_vocab = Vocab()
new_vocab.from_bytes(vocab_bytes)
assert len(new_vocab.lookups) == 1
assert table_name in new_vocab.lookups
table = new_vocab.lookups.get_table(table_name)
assert len(table) == 2
assert table.get("hello") == "world"
assert new_vocab.to_bytes() == vocab_bytes
# This fails on Python 3.5
@pytest.mark.xfail
def test_lookups_to_from_disk_via_vocab():
table_name = "test"
vocab = Vocab()
vocab.lookups.add_table(table_name, {"foo": "bar", "hello": "world"})
assert len(vocab.lookups) == 1
assert table_name in vocab.lookups
with make_tempdir() as tmpdir:
vocab.to_disk(tmpdir)
new_vocab = Vocab()
new_vocab.from_disk(tmpdir)
assert len(new_vocab.lookups) == 1
assert table_name in new_vocab.lookups
table = new_vocab.lookups.get_table(table_name)
assert len(table) == 2
assert table.get("hello") == "world"

View File

@ -16,10 +16,10 @@ cdef class Tokenizer:
cdef PreshMap _specials
cpdef readonly Vocab vocab
cdef public object token_match
cdef public object prefix_search
cdef public object suffix_search
cdef public object infix_finditer
cdef object _token_match
cdef object _prefix_search
cdef object _suffix_search
cdef object _infix_finditer
cdef object _rules
cpdef Doc tokens_from_list(self, list strings)

View File

@ -61,6 +61,38 @@ cdef class Tokenizer:
for chunk, substrings in sorted(rules.items()):
self.add_special_case(chunk, substrings)
property token_match:
def __get__(self):
return self._token_match
def __set__(self, token_match):
self._token_match = token_match
self._flush_cache()
property prefix_search:
def __get__(self):
return self._prefix_search
def __set__(self, prefix_search):
self._prefix_search = prefix_search
self._flush_cache()
property suffix_search:
def __get__(self):
return self._suffix_search
def __set__(self, suffix_search):
self._suffix_search = suffix_search
self._flush_cache()
property infix_finditer:
def __get__(self):
return self._infix_finditer
def __set__(self, infix_finditer):
self._infix_finditer = infix_finditer
self._flush_cache()
def __reduce__(self):
args = (self.vocab,
self._rules,
@ -141,9 +173,23 @@ cdef class Tokenizer:
for text in texts:
yield self(text)
def _flush_cache(self):
self._reset_cache([key for key in self._cache if not key in self._specials])
def _reset_cache(self, keys):
for k in keys:
del self._cache[k]
if not k in self._specials:
cached = <_Cached*>self._cache.get(k)
if cached is not NULL:
self.mem.free(cached)
def _reset_specials(self):
for k in self._specials:
cached = <_Cached*>self._specials.get(k)
del self._specials[k]
if cached is not NULL:
self.mem.free(cached)
cdef int _try_cache(self, hash_t key, Doc tokens) except -1:
cached = <_Cached*>self._cache.get(key)
@ -183,6 +229,9 @@ cdef class Tokenizer:
while string and len(string) != last_size:
if self.token_match and self.token_match(string):
break
if self._specials.get(hash_string(string)) != NULL:
has_special[0] = 1
break
last_size = len(string)
pre_len = self.find_prefix(string)
if pre_len != 0:
@ -360,8 +409,15 @@ cdef class Tokenizer:
cached.is_lex = False
cached.data.tokens = self.vocab.make_fused_token(substrings)
key = hash_string(string)
stale_special = <_Cached*>self._specials.get(key)
stale_cached = <_Cached*>self._cache.get(key)
self._flush_cache()
self._specials.set(key, cached)
self._cache.set(key, cached)
if stale_special is not NULL:
self.mem.free(stale_special)
if stale_special != stale_cached and stale_cached is not NULL:
self.mem.free(stale_cached)
self._rules[string] = substrings
def to_disk(self, path, **kwargs):
@ -444,7 +500,10 @@ cdef class Tokenizer:
if data.get("rules"):
# make sure to hard reset the cache to remove data from the default exceptions
self._rules = {}
self._reset_cache([key for key in self._cache])
self._reset_specials()
self._cache = PreshMap()
self._specials = PreshMap()
for string, substrings in data.get("rules", {}).items():
self.add_special_case(string, substrings)

View File

@ -109,13 +109,8 @@ cdef class Retokenizer:
def __exit__(self, *args):
# Do the actual merging here
if len(self.merges) > 1:
_bulk_merge(self.doc, self.merges)
elif len(self.merges) == 1:
(span, attrs) = self.merges[0]
start = span.start
end = span.end
_merge(self.doc, start, end, attrs)
if len(self.merges) >= 1:
_merge(self.doc, self.merges)
# Iterate in order, to keep things simple.
for start_char, orths, heads, attrs in sorted(self.splits):
# Resolve token index
@ -140,95 +135,7 @@ cdef class Retokenizer:
_split(self.doc, token_index, orths, head_indices, attrs)
def _merge(Doc doc, int start, int end, attributes):
"""Retokenize the document, such that the span at
`doc.text[start_idx : end_idx]` is merged into a single token. If
`start_idx` and `end_idx `do not mark start and end token boundaries,
the document remains unchanged.
start_idx (int): Character index of the start of the slice to merge.
end_idx (int): Character index after the end of the slice to merge.
**attributes: Attributes to assign to the merged token. By default,
attributes are inherited from the syntactic root of the span.
RETURNS (Token): The newly merged token, or `None` if the start and end
indices did not fall at token boundaries.
"""
cdef Span span = doc[start:end]
cdef int start_char = span.start_char
cdef int end_char = span.end_char
# Resize the doc.tensor, if it's set. Let the last row for each token stand
# for the merged region. To do this, we create a boolean array indicating
# whether the row is to be deleted, then use numpy.delete
if doc.tensor is not None and doc.tensor.size != 0:
doc.tensor = _resize_tensor(doc.tensor, [(start, end)])
# Get LexemeC for newly merged token
new_orth = ''.join([t.text_with_ws for t in span])
if span[-1].whitespace_:
new_orth = new_orth[:-len(span[-1].whitespace_)]
cdef const LexemeC* lex = doc.vocab.get(doc.mem, new_orth)
# House the new merged token where it starts
cdef TokenC* token = &doc.c[start]
token.spacy = doc.c[end-1].spacy
for attr_name, attr_value in attributes.items():
if attr_name == "_": # Set extension attributes
for ext_attr_key, ext_attr_value in attr_value.items():
doc[start]._.set(ext_attr_key, ext_attr_value)
elif attr_name == TAG:
doc.vocab.morphology.assign_tag(token, attr_value)
else:
# Set attributes on both token and lexeme to take care of token
# attribute vs. lexical attribute without having to enumerate them.
# If an attribute name is not valid, set_struct_attr will ignore it.
Token.set_struct_attr(token, attr_name, attr_value)
Lexeme.set_struct_attr(<LexemeC*>lex, attr_name, attr_value)
# Make sure ent_iob remains consistent
if doc.c[end].ent_iob == 1 and token.ent_iob in (0, 2):
if token.ent_type == doc.c[end].ent_type:
token.ent_iob = 3
else:
# If they're not the same entity type, let them be two entities
doc.c[end].ent_iob = 3
# Begin by setting all the head indices to absolute token positions
# This is easier to work with for now than the offsets
# Before thinking of something simpler, beware the case where a
# dependency bridges over the entity. Here the alignment of the
# tokens changes.
span_root = span.root.i
token.dep = span.root.dep
# We update token.lex after keeping span root and dep, since
# setting token.lex will change span.start and span.end properties
# as it modifies the character offsets in the doc
token.lex = lex
for i in range(doc.length):
doc.c[i].head += i
# Set the head of the merged token, and its dep relation, from the Span
token.head = doc.c[span_root].head
# Adjust deps before shrinking tokens
# Tokens which point into the merged token should now point to it
# Subtract the offset from all tokens which point to >= end
offset = (end - start) - 1
for i in range(doc.length):
head_idx = doc.c[i].head
if start <= head_idx < end:
doc.c[i].head = start
elif head_idx >= end:
doc.c[i].head -= offset
# Now compress the token array
for i in range(end, doc.length):
doc.c[i - offset] = doc.c[i]
for i in range(doc.length - offset, doc.length):
memset(&doc.c[i], 0, sizeof(TokenC))
doc.c[i].lex = &EMPTY_LEXEME
doc.length -= offset
for i in range(doc.length):
# ...And, set heads back to a relative position
doc.c[i].head -= i
# Set the left/right children, left/right edges
set_children_from_heads(doc.c, doc.length)
# Return the merged Python object
return doc[start]
def _bulk_merge(Doc doc, merges):
def _merge(Doc doc, merges):
"""Retokenize the document, such that the spans described in 'merges'
are merged into a single token. This method assumes that the merges
are in the same order at which they appear in the doc, and that merges
@ -256,6 +163,26 @@ def _bulk_merge(Doc doc, merges):
spans.append(span)
# House the new merged token where it starts
token = &doc.c[start]
# Initially set attributes to attributes of span root
token.tag = doc.c[span.root.i].tag
token.pos = doc.c[span.root.i].pos
token.morph = doc.c[span.root.i].morph
token.ent_iob = doc.c[span.root.i].ent_iob
token.ent_type = doc.c[span.root.i].ent_type
merged_iob = token.ent_iob
# If span root is part of an entity, merged token is B-ENT
if token.ent_iob in (1, 3):
merged_iob = 3
# If start token is I-ENT and previous token is of the same
# type, then I-ENT (could check I-ENT from start to span root)
if doc.c[start].ent_iob == 1 and start > 0 \
and doc.c[start].ent_type == token.ent_type \
and doc.c[start - 1].ent_type == token.ent_type:
merged_iob = 1
token.ent_iob = merged_iob
# Unset attributes that don't match new token
token.lemma = 0
token.norm = 0
tokens[merge_index] = token
# Resize the doc.tensor, if it's set. Let the last row for each token stand
# for the merged region. To do this, we create a boolean array indicating
@ -351,17 +278,7 @@ def _bulk_merge(Doc doc, merges):
# Set the left/right children, left/right edges
set_children_from_heads(doc.c, doc.length)
# Make sure ent_iob remains consistent
for (span, _) in merges:
if(span.end < len(offsets)):
# If it's not the last span
token_after_span_position = offsets[span.end]
if doc.c[token_after_span_position].ent_iob == 1\
and doc.c[token_after_span_position - 1].ent_iob in (0, 2):
if doc.c[token_after_span_position - 1].ent_type == doc.c[token_after_span_position].ent_type:
doc.c[token_after_span_position - 1].ent_iob = 3
else:
# If they're not the same entity type, let them be two entities
doc.c[token_after_span_position].ent_iob = 3
make_iob_consistent(doc.c, doc.length)
# Return the merged Python object
return doc[spans[0].start]
@ -480,3 +397,12 @@ def _validate_extensions(extensions):
raise ValueError(Errors.E118.format(attr=key))
if not is_writable_attr(extension):
raise ValueError(Errors.E119.format(attr=key))
cdef make_iob_consistent(TokenC* tokens, int length):
cdef int i
if tokens[0].ent_iob == 1:
tokens[0].ent_iob = 3
for i in range(1, length):
if tokens[i].ent_iob == 1 and tokens[i - 1].ent_type != tokens[i].ent_type:
tokens[i].ent_iob = 3

View File

@ -200,13 +200,15 @@ cdef class Span:
return Underscore(Underscore.span_extensions, self,
start=self.start_char, end=self.end_char)
def as_doc(self):
def as_doc(self, bint copy_user_data=False):
"""Create a `Doc` object with a copy of the `Span`'s data.
copy_user_data (bool): Whether or not to copy the original doc's user data.
RETURNS (Doc): The `Doc` copy of the span.
DOCS: https://spacy.io/api/span#as_doc
"""
# TODO: make copy_user_data a keyword-only argument (Python 3 only)
words = [t.text for t in self]
spaces = [bool(t.whitespace_) for t in self]
cdef Doc doc = Doc(self.doc.vocab, words=words, spaces=spaces)
@ -235,6 +237,8 @@ cdef class Span:
cat_start, cat_end, cat_label = key
if cat_start == self.start_char and cat_end == self.end_char:
doc.cats[cat_label] = value
if copy_user_data:
doc.user_data = self.doc.user_data
return doc
def _fix_dep_copy(self, attrs, array):

View File

@ -131,8 +131,7 @@ def load_language_data(path):
path = path.with_suffix(path.suffix + ".gz")
if path.exists():
return srsly.read_gzip_json(path)
# TODO: move to spacy.errors
raise ValueError("Can't find language data file: {}".format(path2str(path)))
raise ValueError(Errors.E160.format(path=path2str(path)))
def get_module_path(module):
@ -458,6 +457,14 @@ def expand_exc(excs, search, replace):
def get_lemma_tables(lookups):
"""Load lemmatizer data from lookups table. Mostly used via
Language.Defaults.create_lemmatizer, but available as helper so it can be
reused in language classes that implement custom lemmatizers.
lookups (Lookups): The lookups table.
RETURNS (tuple): A (lemma_rules, lemma_index, lemma_exc, lemma_lookup)
tuple that can be used to initialize a Lemmatizer.
"""
lemma_rules = {}
lemma_index = {}
lemma_exc = {}

View File

@ -43,6 +43,7 @@ cdef class Vocab:
lemmatizer (object): A lemmatizer. Defaults to `None`.
strings (StringStore): StringStore that maps strings to integers, and
vice versa.
lookups (Lookups): Container for large lookup tables and dictionaries.
RETURNS (Vocab): The newly constructed object.
"""
lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {}
@ -433,6 +434,8 @@ cdef class Vocab:
file_.write(self.lexemes_to_bytes())
if "vectors" not in "exclude" and self.vectors is not None:
self.vectors.to_disk(path)
if "lookups" not in "exclude" and self.lookups is not None:
self.lookups.to_disk(path)
def from_disk(self, path, exclude=tuple(), **kwargs):
"""Loads state from a directory. Modifies the object in place and
@ -457,6 +460,8 @@ cdef class Vocab:
self.vectors.from_disk(path, exclude=["strings"])
if self.vectors.name is not None:
link_vectors_to_models(self)
if "lookups" not in exclude:
self.lookups.from_disk(path)
return self
def to_bytes(self, exclude=tuple(), **kwargs):
@ -476,7 +481,8 @@ cdef class Vocab:
getters = OrderedDict((
("strings", lambda: self.strings.to_bytes()),
("lexemes", lambda: self.lexemes_to_bytes()),
("vectors", deserialize_vectors)
("vectors", deserialize_vectors),
("lookups", lambda: self.lookups.to_bytes())
))
exclude = util.get_serialization_exclude(getters, exclude, kwargs)
return util.to_bytes(getters, exclude)
@ -499,7 +505,8 @@ cdef class Vocab:
setters = OrderedDict((
("strings", lambda b: self.strings.from_bytes(b)),
("lexemes", lambda b: self.lexemes_from_bytes(b)),
("vectors", lambda b: serialize_vectors(b))
("vectors", lambda b: serialize_vectors(b)),
("lookups", lambda b: self.lookups.from_bytes(b))
))
exclude = util.get_serialization_exclude(setters, exclude, kwargs)
util.from_bytes(bytes_data, setters, exclude)

View File

@ -292,9 +292,10 @@ Create a new `Doc` object corresponding to the `Span`, with a copy of the data.
> assert doc2.text == u"New York"
> ```
| Name | Type | Description |
| ----------- | ----- | --------------------------------------- |
| **RETURNS** | `Doc` | A `Doc` object of the `Span`'s content. |
| Name | Type | Description |
| ----------------- | ----- | ---------------------------------------------------- |
| `copy_user_data` | bool | Whether or not to copy the original doc's user data. |
| **RETURNS** | `Doc` | A `Doc` object of the `Span`'s content. |
## Span.root {#root tag="property" model="parser"}

View File

@ -107,7 +107,7 @@ process.
<Infobox>
**Usage:** [Models directory](/models) [Benchmarks](#benchmarks)
**Usage:** [Models directory](/models)
</Infobox>

View File

@ -10,10 +10,7 @@
"modelsRepo": "explosion/spacy-models",
"social": {
"twitter": "spacy_io",
"github": "explosion",
"reddit": "spacynlp",
"codepen": "explosion",
"gitter": "explosion/spaCy"
"github": "explosion"
},
"theme": "#09a3d5",
"analytics": "UA-58931649-1",
@ -69,6 +66,7 @@
"items": [
{ "text": "Twitter", "url": "https://twitter.com/spacy_io" },
{ "text": "GitHub", "url": "https://github.com/explosion/spaCy" },
{ "text": "YouTube", "url": "https://youtube.com/c/ExplosionAI" },
{ "text": "Blog", "url": "https://explosion.ai/blog" }
]
}