mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 00:46:28 +03:00
Merge branch 'develop' into feature/scorer-adjustments
This commit is contained in:
commit
9d79916792
|
@ -16,7 +16,7 @@ from bin.ud import conll17_ud_eval
|
|||
from spacy.tokens import Token, Doc
|
||||
from spacy.gold import Example
|
||||
from spacy.util import compounding, minibatch, minibatch_by_words
|
||||
from spacy.syntax.nonproj import projectivize
|
||||
from spacy.pipeline._parser_internals.nonproj import projectivize
|
||||
from spacy.matcher import Matcher
|
||||
from spacy import displacy
|
||||
from collections import defaultdict
|
||||
|
|
|
@ -20,20 +20,20 @@ seed = 0
|
|||
accumulate_gradient = 1
|
||||
use_pytorch_for_gpu_memory = false
|
||||
# Control how scores are printed and checkpoints are evaluated.
|
||||
scores = ["speed", "tags_acc", "uas", "las", "ents_f"]
|
||||
score_weights = {"las": 0.4, "ents_f": 0.4, "tags_acc": 0.2}
|
||||
# These settings are invalid for the transformer models.
|
||||
eval_batch_size = 128
|
||||
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
|
||||
init_tok2vec = null
|
||||
discard_oversize = false
|
||||
omit_extra_lookups = false
|
||||
batch_by = "words"
|
||||
use_gpu = -1
|
||||
raw_text = null
|
||||
tag_map = null
|
||||
vectors = null
|
||||
base_model = null
|
||||
morph_rules = null
|
||||
|
||||
[training.batch_size]
|
||||
@schedules = "compounding.v1"
|
||||
start = 1000
|
||||
start = 100
|
||||
stop = 1000
|
||||
compound = 1.001
|
||||
|
||||
|
@ -46,74 +46,79 @@ L2 = 0.01
|
|||
grad_clip = 1.0
|
||||
use_averages = false
|
||||
eps = 1e-8
|
||||
#learn_rate = 0.001
|
||||
|
||||
[training.optimizer.learn_rate]
|
||||
@schedules = "warmup_linear.v1"
|
||||
warmup_steps = 250
|
||||
total_steps = 20000
|
||||
initial_rate = 0.001
|
||||
learn_rate = 0.001
|
||||
|
||||
[nlp]
|
||||
lang = "en"
|
||||
base_model = null
|
||||
vectors = null
|
||||
load_vocab_data = false
|
||||
pipeline = ["tok2vec", "ner", "tagger", "parser"]
|
||||
|
||||
[nlp.pipeline]
|
||||
[nlp.tokenizer]
|
||||
@tokenizers = "spacy.Tokenizer.v1"
|
||||
|
||||
[nlp.pipeline.tok2vec]
|
||||
[nlp.lemmatizer]
|
||||
@lemmatizers = "spacy.Lemmatizer.v1"
|
||||
|
||||
[components]
|
||||
|
||||
[components.tok2vec]
|
||||
factory = "tok2vec"
|
||||
|
||||
|
||||
[nlp.pipeline.ner]
|
||||
[components.ner]
|
||||
factory = "ner"
|
||||
learn_tokens = false
|
||||
min_action_freq = 1
|
||||
|
||||
[nlp.pipeline.tagger]
|
||||
[components.tagger]
|
||||
factory = "tagger"
|
||||
|
||||
[nlp.pipeline.parser]
|
||||
[components.parser]
|
||||
factory = "parser"
|
||||
learn_tokens = false
|
||||
min_action_freq = 30
|
||||
|
||||
[nlp.pipeline.tagger.model]
|
||||
[components.tagger.model]
|
||||
@architectures = "spacy.Tagger.v1"
|
||||
|
||||
[nlp.pipeline.tagger.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecTensors.v1"
|
||||
width = ${nlp.pipeline.tok2vec.model:width}
|
||||
[components.tagger.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model.encode:width}
|
||||
|
||||
[nlp.pipeline.parser.model]
|
||||
[components.parser.model]
|
||||
@architectures = "spacy.TransitionBasedParser.v1"
|
||||
nr_feature_tokens = 8
|
||||
hidden_width = 128
|
||||
maxout_pieces = 2
|
||||
use_upper = true
|
||||
|
||||
[nlp.pipeline.parser.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecTensors.v1"
|
||||
width = ${nlp.pipeline.tok2vec.model:width}
|
||||
[components.parser.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model.encode:width}
|
||||
|
||||
[nlp.pipeline.ner.model]
|
||||
[components.ner.model]
|
||||
@architectures = "spacy.TransitionBasedParser.v1"
|
||||
nr_feature_tokens = 3
|
||||
hidden_width = 128
|
||||
maxout_pieces = 2
|
||||
use_upper = true
|
||||
|
||||
[nlp.pipeline.ner.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecTensors.v1"
|
||||
width = ${nlp.pipeline.tok2vec.model:width}
|
||||
[components.ner.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model.encode:width}
|
||||
|
||||
[nlp.pipeline.tok2vec.model]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
pretrained_vectors = ${nlp:vectors}
|
||||
width = 128
|
||||
[components.tok2vec.model]
|
||||
@architectures = "spacy.Tok2Vec.v1"
|
||||
|
||||
[components.tok2vec.model.embed]
|
||||
@architectures = "spacy.MultiHashEmbed.v1"
|
||||
width = ${components.tok2vec.model.encode:width}
|
||||
rows = 2000
|
||||
also_embed_subwords = true
|
||||
also_use_static_vectors = false
|
||||
|
||||
[components.tok2vec.model.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||
width = 96
|
||||
depth = 4
|
||||
window_size = 1
|
||||
embed_size = 7000
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = ${training:dropout}
|
||||
|
|
|
@ -9,11 +9,11 @@ max_epochs = 100
|
|||
orth_variant_level = 0.0
|
||||
gold_preproc = true
|
||||
max_length = 0
|
||||
scores = ["tag_acc", "dep_uas", "dep_las"]
|
||||
scores = ["tag_acc", "dep_uas", "dep_las", "speed"]
|
||||
score_weights = {"dep_las": 0.8, "tag_acc": 0.2}
|
||||
limit = 0
|
||||
seed = 0
|
||||
accumulate_gradient = 2
|
||||
accumulate_gradient = 1
|
||||
discard_oversize = false
|
||||
raw_text = null
|
||||
tag_map = null
|
||||
|
@ -22,7 +22,7 @@ base_model = null
|
|||
|
||||
eval_batch_size = 128
|
||||
use_pytorch_for_gpu_memory = false
|
||||
batch_by = "padded"
|
||||
batch_by = "words"
|
||||
|
||||
[training.batch_size]
|
||||
@schedules = "compounding.v1"
|
||||
|
@ -64,8 +64,8 @@ min_action_freq = 1
|
|||
@architectures = "spacy.Tagger.v1"
|
||||
|
||||
[components.tagger.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecTensors.v1"
|
||||
width = ${components.tok2vec.model:width}
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model.encode:width}
|
||||
|
||||
[components.parser.model]
|
||||
@architectures = "spacy.TransitionBasedParser.v1"
|
||||
|
@ -74,16 +74,22 @@ hidden_width = 64
|
|||
maxout_pieces = 3
|
||||
|
||||
[components.parser.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecTensors.v1"
|
||||
width = ${components.tok2vec.model:width}
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model.encode:width}
|
||||
|
||||
[components.tok2vec.model]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
pretrained_vectors = ${training:vectors}
|
||||
@architectures = "spacy.Tok2Vec.v1"
|
||||
|
||||
[components.tok2vec.model.embed]
|
||||
@architectures = "spacy.MultiHashEmbed.v1"
|
||||
width = ${components.tok2vec.model.encode:width}
|
||||
rows = 2000
|
||||
also_embed_subwords = true
|
||||
also_use_static_vectors = false
|
||||
|
||||
[components.tok2vec.model.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||
width = 96
|
||||
depth = 4
|
||||
window_size = 1
|
||||
embed_size = 2000
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
dropout = null
|
||||
|
|
|
@ -13,7 +13,7 @@ import spacy
|
|||
import spacy.util
|
||||
from spacy.tokens import Token, Doc
|
||||
from spacy.gold import Example
|
||||
from spacy.syntax.nonproj import projectivize
|
||||
from spacy.pipeline._parser_internals.nonproj import projectivize
|
||||
from collections import defaultdict
|
||||
from spacy.matcher import Matcher
|
||||
|
||||
|
|
16
setup.py
16
setup.py
|
@ -31,6 +31,7 @@ MOD_NAMES = [
|
|||
"spacy.vocab",
|
||||
"spacy.attrs",
|
||||
"spacy.kb",
|
||||
"spacy.ml.parser_model",
|
||||
"spacy.morphology",
|
||||
"spacy.pipeline.dep_parser",
|
||||
"spacy.pipeline.morphologizer",
|
||||
|
@ -40,14 +41,14 @@ MOD_NAMES = [
|
|||
"spacy.pipeline.sentencizer",
|
||||
"spacy.pipeline.senter",
|
||||
"spacy.pipeline.tagger",
|
||||
"spacy.syntax.stateclass",
|
||||
"spacy.syntax._state",
|
||||
"spacy.pipeline.transition_parser",
|
||||
"spacy.pipeline._parser_internals.arc_eager",
|
||||
"spacy.pipeline._parser_internals.ner",
|
||||
"spacy.pipeline._parser_internals.nonproj",
|
||||
"spacy.pipeline._parser_internals._state",
|
||||
"spacy.pipeline._parser_internals.stateclass",
|
||||
"spacy.pipeline._parser_internals.transition_system",
|
||||
"spacy.tokenizer",
|
||||
"spacy.syntax.nn_parser",
|
||||
"spacy.syntax._parser_model",
|
||||
"spacy.syntax.nonproj",
|
||||
"spacy.syntax.transition_system",
|
||||
"spacy.syntax.arc_eager",
|
||||
"spacy.gold.gold_io",
|
||||
"spacy.tokens.doc",
|
||||
"spacy.tokens.span",
|
||||
|
@ -57,7 +58,6 @@ MOD_NAMES = [
|
|||
"spacy.matcher.matcher",
|
||||
"spacy.matcher.phrasematcher",
|
||||
"spacy.matcher.dependencymatcher",
|
||||
"spacy.syntax.ner",
|
||||
"spacy.symbols",
|
||||
"spacy.vectors",
|
||||
]
|
||||
|
|
|
@ -10,7 +10,7 @@ from thinc.api import Config
|
|||
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
||||
from ._util import import_code, debug_cli
|
||||
from ..gold import Corpus, Example
|
||||
from ..syntax import nonproj
|
||||
from ..pipeline._parser_internals import nonproj
|
||||
from ..language import Language
|
||||
from .. import util
|
||||
|
||||
|
|
|
@ -67,10 +67,7 @@ def evaluate(
|
|||
corpus = Corpus(data_path, data_path)
|
||||
nlp = util.load_model(model)
|
||||
dev_dataset = list(corpus.dev_dataset(nlp, gold_preproc=gold_preproc))
|
||||
begin = timer()
|
||||
scores = nlp.evaluate(dev_dataset, verbose=False)
|
||||
end = timer()
|
||||
nwords = sum(len(ex.predicted) for ex in dev_dataset)
|
||||
metrics = {
|
||||
"TOK": "token_acc",
|
||||
"TAG": "tag_acc",
|
||||
|
@ -82,17 +79,21 @@ def evaluate(
|
|||
"NER P": "ents_p",
|
||||
"NER R": "ents_r",
|
||||
"NER F": "ents_f",
|
||||
"Textcat": "cats_score",
|
||||
"Sent P": "sents_p",
|
||||
"Sent R": "sents_r",
|
||||
"Sent F": "sents_f",
|
||||
"TEXTCAT": "cats_score",
|
||||
"SENT P": "sents_p",
|
||||
"SENT R": "sents_r",
|
||||
"SENT F": "sents_f",
|
||||
"SPEED": "speed",
|
||||
}
|
||||
results = {}
|
||||
for metric, key in metrics.items():
|
||||
if key in scores:
|
||||
if key == "cats_score":
|
||||
metric = metric + " (" + scores.get("cats_score_desc", "unk") + ")"
|
||||
results[metric] = f"{scores[key]*100:.2f}"
|
||||
if key == "speed":
|
||||
results[metric] = f"{scores[key]:.0f}"
|
||||
else:
|
||||
results[metric] = f"{scores[key]*100:.2f}"
|
||||
data = {re.sub(r"[\s/]", "_", k.lower()): v for k, v in results.items()}
|
||||
|
||||
msg.table(results, title="Results")
|
||||
|
|
|
@ -11,7 +11,6 @@ from ...util import ensure_path, working_dir
|
|||
from .._util import project_cli, Arg, PROJECT_FILE, load_project_config, get_checksum
|
||||
|
||||
|
||||
|
||||
# TODO: find a solution for caches
|
||||
# CACHES = [
|
||||
# Path.home() / ".torch",
|
||||
|
|
|
@ -1,5 +1,4 @@
|
|||
from typing import Optional, Dict, Any, Tuple, Union, Callable, List
|
||||
from timeit import default_timer as timer
|
||||
import srsly
|
||||
import tqdm
|
||||
from pathlib import Path
|
||||
|
@ -81,16 +80,20 @@ def train(
|
|||
msg.info("Using CPU")
|
||||
msg.info(f"Loading config and nlp from: {config_path}")
|
||||
config = Config().from_disk(config_path)
|
||||
if config.get("training", {}).get("seed") is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
with show_validation_error():
|
||||
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
|
||||
if config["training"]["base_model"]:
|
||||
base_nlp = util.load_model(config["training"]["base_model"])
|
||||
# TODO: do something to check base_nlp against regular nlp described in config?
|
||||
nlp = base_nlp
|
||||
# If everything matches it will look something like:
|
||||
# base_nlp = util.load_model(config["training"]["base_model"])
|
||||
# nlp = base_nlp
|
||||
raise NotImplementedError("base_model not supported yet.")
|
||||
if config["training"]["vectors"] is not None:
|
||||
util.load_vectors_into_model(nlp, config["training"]["vectors"])
|
||||
verify_config(nlp)
|
||||
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
|
||||
if config["training"]["seed"] is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
if config["training"]["use_pytorch_for_gpu_memory"]:
|
||||
# It feels kind of weird to not have a default for this.
|
||||
use_pytorch_for_gpu_memory()
|
||||
|
@ -243,19 +246,16 @@ def create_evaluation_callback(
|
|||
) -> Callable[[], Tuple[float, Dict[str, float]]]:
|
||||
def evaluate() -> Tuple[float, Dict[str, float]]:
|
||||
dev_examples = corpus.dev_dataset(
|
||||
nlp, gold_preproc=cfg["gold_preproc"], ignore_misaligned=True
|
||||
nlp, gold_preproc=cfg["gold_preproc"]
|
||||
)
|
||||
dev_examples = list(dev_examples)
|
||||
n_words = sum(len(ex.predicted) for ex in dev_examples)
|
||||
batch_size = cfg["eval_batch_size"]
|
||||
start_time = timer()
|
||||
if optimizer.averages:
|
||||
with nlp.use_params(optimizer.averages):
|
||||
scores = nlp.evaluate(dev_examples, batch_size=batch_size)
|
||||
else:
|
||||
scores = nlp.evaluate(dev_examples, batch_size=batch_size)
|
||||
end_time = timer()
|
||||
wps = n_words / (end_time - start_time)
|
||||
# Calculate a weighted sum based on score_weights for the main score
|
||||
weights = cfg["score_weights"]
|
||||
try:
|
||||
|
@ -264,7 +264,6 @@ def create_evaluation_callback(
|
|||
keys = list(scores.keys())
|
||||
err = Errors.E983.format(dict="score_weights", key=str(e), keys=keys)
|
||||
raise KeyError(err)
|
||||
scores["speed"] = wps
|
||||
return weighted_score, scores
|
||||
|
||||
return evaluate
|
||||
|
@ -446,7 +445,7 @@ def update_meta(
|
|||
training: Union[Dict[str, Any], Config], nlp: Language, info: Dict[str, Any]
|
||||
) -> None:
|
||||
nlp.meta["performance"] = {}
|
||||
for metric in training["scores_weights"]:
|
||||
for metric in training["score_weights"]:
|
||||
nlp.meta["performance"][metric] = info["other_scores"][metric]
|
||||
for pipe_name in nlp.pipe_names:
|
||||
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
|
||||
|
|
|
@ -432,12 +432,12 @@ class Errors:
|
|||
"Current DocBin: {current}\nOther DocBin: {other}")
|
||||
E169 = ("Can't find module: {module}")
|
||||
E170 = ("Cannot apply transition {name}: invalid for the current state.")
|
||||
E171 = ("Matcher.add received invalid on_match callback argument: expected "
|
||||
E171 = ("Matcher.add received invalid 'on_match' callback argument: expected "
|
||||
"callable or None, but got: {arg_type}")
|
||||
E175 = ("Can't remove rule for unknown match pattern ID: {key}")
|
||||
E176 = ("Alias '{alias}' is not defined in the Knowledge Base.")
|
||||
E177 = ("Ill-formed IOB input detected: {tag}")
|
||||
E178 = ("Invalid pattern. Expected list of dicts but got: {pat}. Maybe you "
|
||||
E178 = ("Each pattern should be a list of dicts, but got: {pat}. Maybe you "
|
||||
"accidentally passed a single pattern to Matcher.add instead of a "
|
||||
"list of patterns? If you only want to add one pattern, make sure "
|
||||
"to wrap it in a list. For example: matcher.add('{key}', [pattern])")
|
||||
|
@ -483,6 +483,10 @@ class Errors:
|
|||
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
||||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
|
||||
"a string value from {expected} but got: '{arg}'")
|
||||
E948 = ("Matcher.add received invalid 'patterns' argument: expected "
|
||||
"a List, but got: {arg_type}")
|
||||
E952 = ("The section '{name}' is not a valid section in the provided config.")
|
||||
E953 = ("Mismatched IDs received by the Tok2Vec listener: {id1} vs. {id2}")
|
||||
E954 = ("The Tok2Vec listener did not receive a valid input.")
|
||||
|
|
|
@ -1,7 +1,15 @@
|
|||
from typing import Union, List, Iterable, Iterator, TYPE_CHECKING
|
||||
from pathlib import Path
|
||||
import random
|
||||
|
||||
from .. import util
|
||||
from .example import Example
|
||||
from ..tokens import DocBin, Doc
|
||||
from ..vocab import Vocab
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# This lets us add type hints for mypy etc. without causing circular imports
|
||||
from ..language import Language # noqa: F401
|
||||
|
||||
|
||||
class Corpus:
|
||||
|
@ -11,20 +19,23 @@ class Corpus:
|
|||
DOCS: https://spacy.io/api/corpus
|
||||
"""
|
||||
|
||||
def __init__(self, train_loc, dev_loc, limit=0):
|
||||
def __init__(
|
||||
self, train_loc: Union[str, Path], dev_loc: Union[str, Path], limit: int = 0
|
||||
) -> None:
|
||||
"""Create a Corpus.
|
||||
|
||||
train (str / Path): File or directory of training data.
|
||||
dev (str / Path): File or directory of development data.
|
||||
limit (int): Max. number of examples returned
|
||||
RETURNS (Corpus): The newly created object.
|
||||
limit (int): Max. number of examples returned.
|
||||
|
||||
DOCS: https://spacy.io/api/corpus#init
|
||||
"""
|
||||
self.train_loc = train_loc
|
||||
self.dev_loc = dev_loc
|
||||
self.limit = limit
|
||||
|
||||
@staticmethod
|
||||
def walk_corpus(path):
|
||||
def walk_corpus(path: Union[str, Path]) -> List[Path]:
|
||||
path = util.ensure_path(path)
|
||||
if not path.is_dir():
|
||||
return [path]
|
||||
|
@ -43,7 +54,9 @@ class Corpus:
|
|||
locs.append(path)
|
||||
return locs
|
||||
|
||||
def _make_example(self, nlp, reference, gold_preproc):
|
||||
def _make_example(
|
||||
self, nlp: "Language", reference: Doc, gold_preproc: bool
|
||||
) -> Example:
|
||||
if gold_preproc or reference.has_unknown_spaces:
|
||||
return Example(
|
||||
Doc(
|
||||
|
@ -56,7 +69,9 @@ class Corpus:
|
|||
else:
|
||||
return Example(nlp.make_doc(reference.text), reference)
|
||||
|
||||
def make_examples(self, nlp, reference_docs, max_length=0):
|
||||
def make_examples(
|
||||
self, nlp: "Language", reference_docs: Iterable[Doc], max_length: int = 0
|
||||
) -> Iterator[Example]:
|
||||
for reference in reference_docs:
|
||||
if len(reference) == 0:
|
||||
continue
|
||||
|
@ -69,7 +84,9 @@ class Corpus:
|
|||
elif max_length == 0 or len(ref_sent) < max_length:
|
||||
yield self._make_example(nlp, ref_sent.as_doc(), False)
|
||||
|
||||
def make_examples_gold_preproc(self, nlp, reference_docs):
|
||||
def make_examples_gold_preproc(
|
||||
self, nlp: "Language", reference_docs: Iterable[Doc]
|
||||
) -> Iterator[Example]:
|
||||
for reference in reference_docs:
|
||||
if reference.is_sentenced:
|
||||
ref_sents = [sent.as_doc() for sent in reference.sents]
|
||||
|
@ -80,7 +97,9 @@ class Corpus:
|
|||
if len(eg.x):
|
||||
yield eg
|
||||
|
||||
def read_docbin(self, vocab, locs):
|
||||
def read_docbin(
|
||||
self, vocab: Vocab, locs: Iterable[Union[str, Path]]
|
||||
) -> Iterator[Doc]:
|
||||
""" Yield training examples as example dicts """
|
||||
i = 0
|
||||
for loc in locs:
|
||||
|
@ -96,8 +115,14 @@ class Corpus:
|
|||
if self.limit >= 1 and i >= self.limit:
|
||||
break
|
||||
|
||||
def count_train(self, nlp):
|
||||
"""Returns count of words in train examples"""
|
||||
def count_train(self, nlp: "Language") -> int:
|
||||
"""Returns count of words in train examples.
|
||||
|
||||
nlp (Language): The current nlp. object.
|
||||
RETURNS (int): The word count.
|
||||
|
||||
DOCS: https://spacy.io/api/corpus#count_train
|
||||
"""
|
||||
n = 0
|
||||
i = 0
|
||||
for example in self.train_dataset(nlp):
|
||||
|
@ -108,8 +133,25 @@ class Corpus:
|
|||
return n
|
||||
|
||||
def train_dataset(
|
||||
self, nlp, *, shuffle=True, gold_preproc=False, max_length=0, **kwargs
|
||||
):
|
||||
self,
|
||||
nlp: "Language",
|
||||
*,
|
||||
shuffle: bool = True,
|
||||
gold_preproc: bool = False,
|
||||
max_length: int = 0
|
||||
) -> Iterator[Example]:
|
||||
"""Yield examples from the training data.
|
||||
|
||||
nlp (Language): The current nlp object.
|
||||
shuffle (bool): Whether to shuffle the examples.
|
||||
gold_preproc (bool): Whether to train on gold-standard sentences and tokens.
|
||||
max_length (int): Maximum document length. Longer documents will be
|
||||
split into sentences, if sentence boundaries are available. 0 for
|
||||
no limit.
|
||||
YIELDS (Example): The examples.
|
||||
|
||||
DOCS: https://spacy.io/api/corpus#train_dataset
|
||||
"""
|
||||
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.train_loc))
|
||||
if gold_preproc:
|
||||
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
||||
|
@ -120,7 +162,17 @@ class Corpus:
|
|||
random.shuffle(examples)
|
||||
yield from examples
|
||||
|
||||
def dev_dataset(self, nlp, *, gold_preproc=False, **kwargs):
|
||||
def dev_dataset(
|
||||
self, nlp: "Language", *, gold_preproc: bool = False
|
||||
) -> Iterator[Example]:
|
||||
"""Yield examples from the development data.
|
||||
|
||||
nlp (Language): The current nlp object.
|
||||
gold_preproc (bool): Whether to train on gold-standard sentences and tokens.
|
||||
YIELDS (Example): The examples.
|
||||
|
||||
DOCS: https://spacy.io/api/corpus#dev_dataset
|
||||
"""
|
||||
ref_docs = self.read_docbin(nlp.vocab, self.walk_corpus(self.dev_loc))
|
||||
if gold_preproc:
|
||||
examples = self.make_examples_gold_preproc(nlp, ref_docs)
|
||||
|
|
|
@ -10,7 +10,7 @@ from .align import Alignment
|
|||
from .iob_utils import biluo_to_iob, biluo_tags_from_offsets, biluo_tags_from_doc
|
||||
from .iob_utils import spans_from_biluo_tags
|
||||
from ..errors import Errors, Warnings
|
||||
from ..syntax import nonproj
|
||||
from ..pipeline._parser_internals import nonproj
|
||||
|
||||
|
||||
cpdef Doc annotations2doc(vocab, tok_annot, doc_annot):
|
||||
|
|
|
@ -14,13 +14,14 @@ from thinc.api import get_current_ops, Config, require_gpu, Optimizer
|
|||
import srsly
|
||||
import multiprocessing as mp
|
||||
from itertools import chain, cycle
|
||||
from timeit import default_timer as timer
|
||||
|
||||
from .tokens.underscore import Underscore
|
||||
from .vocab import Vocab, create_vocab
|
||||
from .pipe_analysis import analyze_pipes, analyze_all_pipes, validate_attrs
|
||||
from .gold import Example
|
||||
from .scorer import Scorer
|
||||
from .util import link_vectors_to_models, create_default_optimizer, registry
|
||||
from .util import create_default_optimizer, registry
|
||||
from .util import SimpleFrozenDict, combine_score_weights
|
||||
from .lang.tokenizer_exceptions import URL_MATCH, BASE_EXCEPTIONS
|
||||
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
|
||||
|
@ -36,6 +37,7 @@ from . import util
|
|||
from . import about
|
||||
|
||||
|
||||
# TODO: integrate pipeline analyis
|
||||
ENABLE_PIPELINE_ANALYSIS = False
|
||||
# This is the base config will all settings (training etc.)
|
||||
DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg"
|
||||
|
@ -43,6 +45,11 @@ DEFAULT_CONFIG = Config().from_disk(DEFAULT_CONFIG_PATH)
|
|||
|
||||
|
||||
class BaseDefaults:
|
||||
"""Language data defaults, available via Language.Defaults. Can be
|
||||
overwritten by language subclasses by defining their own subclasses of
|
||||
Language.Defaults.
|
||||
"""
|
||||
|
||||
config: Config = Config()
|
||||
tokenizer_exceptions: Dict[str, List[dict]] = BASE_EXCEPTIONS
|
||||
prefixes: Optional[List[Union[str, Pattern]]] = TOKENIZER_PREFIXES
|
||||
|
@ -58,6 +65,10 @@ class BaseDefaults:
|
|||
|
||||
@registry.tokenizers("spacy.Tokenizer.v1")
|
||||
def create_tokenizer() -> Callable[["Language"], Tokenizer]:
|
||||
"""Registered function to create a tokenizer. Returns a factory that takes
|
||||
the nlp object and returns a Tokenizer instance using the language detaults.
|
||||
"""
|
||||
|
||||
def tokenizer_factory(nlp: "Language") -> Tokenizer:
|
||||
prefixes = nlp.Defaults.prefixes
|
||||
suffixes = nlp.Defaults.suffixes
|
||||
|
@ -80,6 +91,11 @@ def create_tokenizer() -> Callable[["Language"], Tokenizer]:
|
|||
|
||||
@registry.lemmatizers("spacy.Lemmatizer.v1")
|
||||
def create_lemmatizer() -> Callable[["Language"], "Lemmatizer"]:
|
||||
"""Registered function to create a lemmatizer. Returns a factory that takes
|
||||
the nlp object and returns a Lemmatizer instance with data loaded in from
|
||||
spacy-lookups-data, if the package is installed.
|
||||
"""
|
||||
# TODO: Will be replaced when the lemmatizer becomes a pipeline component
|
||||
tables = ["lemma_lookup", "lemma_rules", "lemma_exc", "lemma_index"]
|
||||
|
||||
def lemmatizer_factory(nlp: "Language") -> "Lemmatizer":
|
||||
|
@ -116,7 +132,7 @@ class Language:
|
|||
create_tokenizer: Optional[Callable[["Language"], Callable[[str], Doc]]] = None,
|
||||
create_lemmatizer: Optional[Callable[["Language"], Callable]] = None,
|
||||
**kwargs,
|
||||
):
|
||||
) -> None:
|
||||
"""Initialise a Language object.
|
||||
|
||||
vocab (Vocab): A `Vocab` object. If `True`, a vocab is created.
|
||||
|
@ -134,7 +150,8 @@ class Language:
|
|||
returns a tokenizer.
|
||||
create_lemmatizer (Callable): Function that takes the nlp object and
|
||||
returns a lemmatizer.
|
||||
RETURNS (Language): The newly constructed object.
|
||||
|
||||
DOCS: https://spacy.io/api/language#init
|
||||
"""
|
||||
# We're only calling this to import all factories provided via entry
|
||||
# points. The factory decorator applied to these functions takes care
|
||||
|
@ -189,6 +206,13 @@ class Language:
|
|||
|
||||
@property
|
||||
def meta(self) -> Dict[str, Any]:
|
||||
"""Custom meta data of the language class. If a model is loaded, this
|
||||
includes details from the model's meta.json.
|
||||
|
||||
RETURNS (Dict[str, Any]): The meta.
|
||||
|
||||
DOCS: https://spacy.io/api/language#meta
|
||||
"""
|
||||
spacy_version = util.get_model_version_range(about.__version__)
|
||||
if self.vocab.lang:
|
||||
self._meta.setdefault("lang", self.vocab.lang)
|
||||
|
@ -221,6 +245,13 @@ class Language:
|
|||
|
||||
@property
|
||||
def config(self) -> Config:
|
||||
"""Trainable config for the current language instance. Includes the
|
||||
current pipeline components, as well as default training config.
|
||||
|
||||
RETURNS (thinc.api.Config): The config.
|
||||
|
||||
DOCS: https://spacy.io/api/language#config
|
||||
"""
|
||||
self._config.setdefault("nlp", {})
|
||||
self._config.setdefault("training", {})
|
||||
self._config["nlp"]["lang"] = self.lang
|
||||
|
@ -382,6 +413,8 @@ class Language:
|
|||
select the best model. Weights should sum to 1.0 per component and
|
||||
will be combined and normalized for the whole pipeline.
|
||||
func (Optional[Callable]): Factory function if not used as a decorator.
|
||||
|
||||
DOCS: https://spacy.io/api/language#factory
|
||||
"""
|
||||
if not isinstance(name, str):
|
||||
raise ValueError(Errors.E963.format(decorator="factory"))
|
||||
|
@ -460,6 +493,8 @@ class Language:
|
|||
select the best model. Weights should sum to 1.0 per component and
|
||||
will be combined and normalized for the whole pipeline.
|
||||
func (Optional[Callable]): Factory function if not used as a decorator.
|
||||
|
||||
DOCS: https://spacy.io/api/language#component
|
||||
"""
|
||||
if name is not None and not isinstance(name, str):
|
||||
raise ValueError(Errors.E963.format(decorator="component"))
|
||||
|
@ -504,6 +539,7 @@ class Language:
|
|||
self,
|
||||
factory_name: str,
|
||||
name: Optional[str] = None,
|
||||
*,
|
||||
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||
validate: bool = True,
|
||||
|
@ -521,6 +557,8 @@ class Language:
|
|||
validate (bool): Whether to validate the component config against the
|
||||
arguments and types expected by the factory.
|
||||
RETURNS (Callable[[Doc], Doc]): The pipeline component.
|
||||
|
||||
DOCS: https://spacy.io/api/language#create_pipe
|
||||
"""
|
||||
name = name if name is not None else factory_name
|
||||
if not isinstance(config, dict):
|
||||
|
@ -692,6 +730,7 @@ class Language:
|
|||
self,
|
||||
name: str,
|
||||
factory_name: str,
|
||||
*,
|
||||
config: Dict[str, Any] = SimpleFrozenDict(),
|
||||
validate: bool = True,
|
||||
) -> None:
|
||||
|
@ -761,6 +800,7 @@ class Language:
|
|||
def __call__(
|
||||
self,
|
||||
text: str,
|
||||
*,
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> Doc:
|
||||
|
@ -770,8 +810,8 @@ class Language:
|
|||
|
||||
text (str): The text to be processed.
|
||||
disable (list): Names of the pipeline components to disable.
|
||||
component_cfg (dict): An optional dictionary with extra keyword arguments
|
||||
for specific components.
|
||||
component_cfg (Dict[str, dict]): An optional dictionary with extra
|
||||
keyword arguments for specific components.
|
||||
RETURNS (Doc): A container for accessing the annotations.
|
||||
|
||||
DOCS: https://spacy.io/api/language#call
|
||||
|
@ -811,6 +851,7 @@ class Language:
|
|||
|
||||
def select_pipes(
|
||||
self,
|
||||
*,
|
||||
disable: Optional[Union[str, Iterable[str]]] = None,
|
||||
enable: Optional[Union[str, Iterable[str]]] = None,
|
||||
) -> "DisabledPipes":
|
||||
|
@ -853,7 +894,7 @@ class Language:
|
|||
def update(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
dummy: Optional[Any] = None,
|
||||
_: Optional[Any] = None,
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
|
@ -863,7 +904,7 @@ class Language:
|
|||
"""Update the models in the pipeline.
|
||||
|
||||
examples (Iterable[Example]): A batch of examples
|
||||
dummy: Should not be set - serves to catch backwards-incompatible scripts.
|
||||
_: Should not be set - serves to catch backwards-incompatible scripts.
|
||||
drop (float): The dropout rate.
|
||||
sgd (Optimizer): An optimizer.
|
||||
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
|
||||
|
@ -873,7 +914,7 @@ class Language:
|
|||
|
||||
DOCS: https://spacy.io/api/language#update
|
||||
"""
|
||||
if dummy is not None:
|
||||
if _ is not None:
|
||||
raise ValueError(Errors.E989)
|
||||
if losses is None:
|
||||
losses = {}
|
||||
|
@ -890,12 +931,10 @@ class Language:
|
|||
raise TypeError(
|
||||
Errors.E978.format(name="language", method="update", types=wrong_types)
|
||||
)
|
||||
|
||||
if sgd is None:
|
||||
if self._optimizer is None:
|
||||
self._optimizer = create_default_optimizer()
|
||||
sgd = self._optimizer
|
||||
|
||||
if component_cfg is None:
|
||||
component_cfg = {}
|
||||
for i, (name, proc) in enumerate(self.pipeline):
|
||||
|
@ -915,6 +954,7 @@ class Language:
|
|||
def rehearse(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
|
@ -937,8 +977,9 @@ class Language:
|
|||
>>> nlp.update(labelled_batch)
|
||||
>>> raw_batch = [Example.from_dict(nlp.make_doc(text), {}) for text in next(raw_text_batches)]
|
||||
>>> nlp.rehearse(raw_batch)
|
||||
|
||||
DOCS: https://spacy.io/api/language#rehearse
|
||||
"""
|
||||
# TODO: document
|
||||
if len(examples) == 0:
|
||||
return
|
||||
if not isinstance(examples, IterableInstance):
|
||||
|
@ -983,17 +1024,18 @@ class Language:
|
|||
|
||||
def begin_training(
|
||||
self,
|
||||
get_examples: Optional[Callable] = None,
|
||||
get_examples: Optional[Callable[[], Iterable[Example]]] = None,
|
||||
*,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
device: int = -1,
|
||||
) -> Optimizer:
|
||||
"""Allocate models, pre-process training data and acquire a trainer and
|
||||
optimizer. Used as a contextmanager.
|
||||
"""Initialize the pipe for training, using data examples if available.
|
||||
|
||||
get_examples (function): Function returning example training data.
|
||||
TODO: document format change since 3.0.
|
||||
sgd (Optional[Optimizer]): An optimizer.
|
||||
RETURNS: An optimizer.
|
||||
get_examples (Callable[[], Iterable[Example]]): Optional function that
|
||||
returns gold-standard Example objects.
|
||||
sgd (thinc.api.Optimizer): Optional optimizer. Will be created with
|
||||
create_optimizer if it doesn't exist.
|
||||
RETURNS (thinc.api.Optimizer): The optimizer.
|
||||
|
||||
DOCS: https://spacy.io/api/language#begin_training
|
||||
"""
|
||||
|
@ -1009,7 +1051,6 @@ class Language:
|
|||
if self.vocab.vectors.data.shape[1] >= 1:
|
||||
ops = get_current_ops()
|
||||
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = create_default_optimizer()
|
||||
self._optimizer = sgd
|
||||
|
@ -1022,25 +1063,26 @@ class Language:
|
|||
return self._optimizer
|
||||
|
||||
def resume_training(
|
||||
self, sgd: Optional[Optimizer] = None, device: int = -1
|
||||
self, *, sgd: Optional[Optimizer] = None, device: int = -1
|
||||
) -> Optimizer:
|
||||
"""Continue training a pretrained model.
|
||||
|
||||
Create and return an optimizer, and initialize "rehearsal" for any pipeline
|
||||
component that has a .rehearse() method. Rehearsal is used to prevent
|
||||
models from "forgetting" their initialised "knowledge". To perform
|
||||
models from "forgetting" their initialized "knowledge". To perform
|
||||
rehearsal, collect samples of text you want the models to retain performance
|
||||
on, and call nlp.rehearse() with a batch of Example objects.
|
||||
|
||||
sgd (Optional[Optimizer]): An optimizer.
|
||||
RETURNS (Optimizer): The optimizer.
|
||||
|
||||
DOCS: https://spacy.io/api/language#resume_training
|
||||
"""
|
||||
if device >= 0: # TODO: do we need this here?
|
||||
require_gpu(device)
|
||||
ops = get_current_ops()
|
||||
if self.vocab.vectors.data.shape[1] >= 1:
|
||||
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = create_default_optimizer()
|
||||
self._optimizer = sgd
|
||||
|
@ -1052,11 +1094,12 @@ class Language:
|
|||
def evaluate(
|
||||
self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
verbose: bool = False,
|
||||
batch_size: int = 256,
|
||||
scorer: Optional[Scorer] = None,
|
||||
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
) -> Scorer:
|
||||
) -> Dict[str, Union[float, dict]]:
|
||||
"""Evaluate a model's pipeline components.
|
||||
|
||||
examples (Iterable[Example]): `Example` objects.
|
||||
|
@ -1088,7 +1131,14 @@ class Language:
|
|||
kwargs.setdefault("verbose", verbose)
|
||||
kwargs.setdefault("nlp", self)
|
||||
scorer = Scorer(**kwargs)
|
||||
docs = list(eg.predicted for eg in examples)
|
||||
texts = [eg.reference.text for eg in examples]
|
||||
docs = [eg.predicted for eg in examples]
|
||||
start_time = timer()
|
||||
# tokenize the texts only for timing purposes
|
||||
if not hasattr(self.tokenizer, "pipe"):
|
||||
_ = [self.tokenizer(text) for text in texts]
|
||||
else:
|
||||
_ = list(self.tokenizer.pipe(texts))
|
||||
for name, pipe in self.pipeline:
|
||||
kwargs = component_cfg.get(name, {})
|
||||
kwargs.setdefault("batch_size", batch_size)
|
||||
|
@ -1096,11 +1146,18 @@ class Language:
|
|||
docs = _pipe(docs, pipe, kwargs)
|
||||
else:
|
||||
docs = pipe.pipe(docs, **kwargs)
|
||||
# iterate over the final generator
|
||||
if len(self.pipeline):
|
||||
docs = list(docs)
|
||||
end_time = timer()
|
||||
for i, (doc, eg) in enumerate(zip(docs, examples)):
|
||||
if verbose:
|
||||
print(doc)
|
||||
eg.predicted = doc
|
||||
return scorer.score(examples)
|
||||
results = scorer.score(examples)
|
||||
n_words = sum(len(eg.predicted) for eg in examples)
|
||||
results["speed"] = n_words / (end_time - start_time)
|
||||
return results
|
||||
|
||||
@contextmanager
|
||||
def use_params(self, params: dict):
|
||||
|
@ -1112,7 +1169,9 @@ class Language:
|
|||
|
||||
EXAMPLE:
|
||||
>>> with nlp.use_params(optimizer.averages):
|
||||
>>> nlp.to_disk('/tmp/checkpoint')
|
||||
>>> nlp.to_disk("/tmp/checkpoint")
|
||||
|
||||
DOCS: https://spacy.io/api/language#use_params
|
||||
"""
|
||||
contexts = [
|
||||
pipe.use_params(params)
|
||||
|
@ -1136,6 +1195,7 @@ class Language:
|
|||
def pipe(
|
||||
self,
|
||||
texts: Iterable[str],
|
||||
*,
|
||||
as_tuples: bool = False,
|
||||
batch_size: int = 1000,
|
||||
disable: Iterable[str] = tuple(),
|
||||
|
@ -1305,6 +1365,16 @@ class Language:
|
|||
"""Create the nlp object from a loaded config. Will set up the tokenizer
|
||||
and language data, add pipeline components etc. If no config is provided,
|
||||
the default config of the given language is used.
|
||||
|
||||
config (Dict[str, Any] / Config): The loaded config.
|
||||
disable (Iterable[str]): List of pipeline component names to disable.
|
||||
auto_fill (bool): Automatically fill in missing values in config based
|
||||
on defaults and function argument annotations.
|
||||
validate (bool): Validate the component config and arguments against
|
||||
the types expected by the factory.
|
||||
RETURNS (Language): The initialized Language class.
|
||||
|
||||
DOCS: https://spacy.io/api/language#from_config
|
||||
"""
|
||||
if auto_fill:
|
||||
config = util.deep_merge_configs(config, cls.default_config)
|
||||
|
@ -1338,6 +1408,10 @@ class Language:
|
|||
nlp = cls(
|
||||
create_tokenizer=create_tokenizer, create_lemmatizer=create_lemmatizer,
|
||||
)
|
||||
# Note that we don't load vectors here, instead they get loaded explicitly
|
||||
# inside stuff like the spacy train function. If we loaded them here,
|
||||
# then we would load them twice at runtime: once when we make from config,
|
||||
# and then again when we load from disk.
|
||||
pipeline = config.get("components", {})
|
||||
for pipe_name in config["nlp"]["pipeline"]:
|
||||
if pipe_name not in pipeline:
|
||||
|
@ -1362,7 +1436,9 @@ class Language:
|
|||
nlp.resolved = resolved
|
||||
return nlp
|
||||
|
||||
def to_disk(self, path: Union[str, Path], exclude: Iterable[str] = tuple()) -> None:
|
||||
def to_disk(
|
||||
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
|
||||
) -> None:
|
||||
"""Save the current state to a directory. If a model is loaded, this
|
||||
will include the model.
|
||||
|
||||
|
@ -1391,7 +1467,7 @@ class Language:
|
|||
util.to_disk(path, serializers, exclude)
|
||||
|
||||
def from_disk(
|
||||
self, path: Union[str, Path], exclude: Iterable[str] = tuple()
|
||||
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
|
||||
) -> "Language":
|
||||
"""Loads state from a directory. Modifies the object in place and
|
||||
returns it. If the saved `Language` object contains a model, the
|
||||
|
@ -1418,7 +1494,6 @@ class Language:
|
|||
_fix_pretrained_vectors_name(self)
|
||||
|
||||
path = util.ensure_path(path)
|
||||
|
||||
deserializers = {}
|
||||
if Path(path / "config.cfg").exists():
|
||||
deserializers["config.cfg"] = lambda p: self.config.from_disk(p)
|
||||
|
@ -1443,7 +1518,7 @@ class Language:
|
|||
self._link_components()
|
||||
return self
|
||||
|
||||
def to_bytes(self, exclude: Iterable[str] = tuple()) -> bytes:
|
||||
def to_bytes(self, *, exclude: Iterable[str] = tuple()) -> bytes:
|
||||
"""Serialize the current state to a binary string.
|
||||
|
||||
exclude (list): Names of components or serialization fields to exclude.
|
||||
|
@ -1465,7 +1540,7 @@ class Language:
|
|||
return util.to_bytes(serializers, exclude)
|
||||
|
||||
def from_bytes(
|
||||
self, bytes_data: bytes, exclude: Iterable[str] = tuple()
|
||||
self, bytes_data: bytes, *, exclude: Iterable[str] = tuple()
|
||||
) -> "Language":
|
||||
"""Load state from a binary string.
|
||||
|
||||
|
@ -1509,6 +1584,12 @@ class Language:
|
|||
|
||||
@dataclass
|
||||
class FactoryMeta:
|
||||
"""Dataclass containing information about a component and its defaults
|
||||
provided by the @Language.component or @Language.factory decorator. It's
|
||||
created whenever a component is defined and stored on the Language class for
|
||||
each component instance and factory instance.
|
||||
"""
|
||||
|
||||
factory: str
|
||||
default_config: Optional[Dict[str, Any]] = None # noqa: E704
|
||||
assigns: Iterable[str] = tuple()
|
||||
|
@ -1539,8 +1620,6 @@ def _fix_pretrained_vectors_name(nlp: Language) -> None:
|
|||
nlp.vocab.vectors.name = vectors_name
|
||||
else:
|
||||
raise ValueError(Errors.E092)
|
||||
if nlp.vocab.vectors.size != 0:
|
||||
link_vectors_to_models(nlp.vocab)
|
||||
for name, proc in nlp.pipeline:
|
||||
if not hasattr(proc, "cfg"):
|
||||
continue
|
||||
|
@ -1551,7 +1630,7 @@ def _fix_pretrained_vectors_name(nlp: Language) -> None:
|
|||
class DisabledPipes(list):
|
||||
"""Manager for temporary pipeline disabling."""
|
||||
|
||||
def __init__(self, nlp: Language, names: List[str]):
|
||||
def __init__(self, nlp: Language, names: List[str]) -> None:
|
||||
self.nlp = nlp
|
||||
self.names = names
|
||||
# Important! Not deep copy -- we just want the container (but we also
|
||||
|
|
|
@ -21,7 +21,6 @@ class Lemmatizer:
|
|||
|
||||
lookups (Lookups): The lookups object containing the (optional) tables
|
||||
"lemma_rules", "lemma_index", "lemma_exc" and "lemma_lookup".
|
||||
RETURNS (Lemmatizer): The newly constructed object.
|
||||
"""
|
||||
self.lookups = lookups if lookups is not None else Lookups()
|
||||
self.is_base_form = is_base_form
|
||||
|
|
|
@ -52,8 +52,6 @@ class Lookups:
|
|||
def __init__(self) -> None:
|
||||
"""Initialize the Lookups object.
|
||||
|
||||
RETURNS (Lookups): The newly created object.
|
||||
|
||||
DOCS: https://spacy.io/api/lookups#init
|
||||
"""
|
||||
self._tables = {}
|
||||
|
@ -202,7 +200,6 @@ class Table(OrderedDict):
|
|||
|
||||
data (dict): The dictionary.
|
||||
name (str): Optional table name for reference.
|
||||
RETURNS (Table): The newly created object.
|
||||
|
||||
DOCS: https://spacy.io/api/lookups#table.from_dict
|
||||
"""
|
||||
|
@ -215,7 +212,6 @@ class Table(OrderedDict):
|
|||
|
||||
name (str): Optional table name for reference.
|
||||
data (dict): Initial data, used to hint Bloom Filter.
|
||||
RETURNS (Table): The newly created object.
|
||||
|
||||
DOCS: https://spacy.io/api/lookups#table.init
|
||||
"""
|
||||
|
|
|
@ -36,7 +36,6 @@ cdef class DependencyMatcher:
|
|||
|
||||
vocab (Vocab): The vocabulary object, which must be shared with the
|
||||
documents the matcher will operate on.
|
||||
RETURNS (DependencyMatcher): The newly constructed object.
|
||||
"""
|
||||
size = 20
|
||||
# TODO: make matcher work with validation
|
||||
|
|
|
@ -66,6 +66,7 @@ cdef class Matcher:
|
|||
cdef public object validate
|
||||
cdef public object _patterns
|
||||
cdef public object _callbacks
|
||||
cdef public object _filter
|
||||
cdef public object _extensions
|
||||
cdef public object _extra_predicates
|
||||
cdef public object _seen_attrs
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
# cython: infer_types=True, cython: profile=True
|
||||
from typing import List
|
||||
|
||||
from libcpp.vector cimport vector
|
||||
from libc.stdint cimport int32_t
|
||||
from libc.string cimport memset, memcmp
|
||||
from cymem.cymem cimport Pool
|
||||
from murmurhash.mrmr cimport hash64
|
||||
|
||||
|
@ -37,11 +40,11 @@ cdef class Matcher:
|
|||
|
||||
vocab (Vocab): The vocabulary object, which must be shared with the
|
||||
documents the matcher will operate on.
|
||||
RETURNS (Matcher): The newly constructed object.
|
||||
"""
|
||||
self._extra_predicates = []
|
||||
self._patterns = {}
|
||||
self._callbacks = {}
|
||||
self._filter = {}
|
||||
self._extensions = {}
|
||||
self._seen_attrs = set()
|
||||
self.vocab = vocab
|
||||
|
@ -69,7 +72,7 @@ cdef class Matcher:
|
|||
"""
|
||||
return self._normalize_key(key) in self._patterns
|
||||
|
||||
def add(self, key, patterns, *_patterns, on_match=None):
|
||||
def add(self, key, patterns, *, on_match=None, greedy: str=None):
|
||||
"""Add a match-rule to the matcher. A match-rule consists of: an ID
|
||||
key, an on_match callback, and one or more patterns.
|
||||
|
||||
|
@ -87,11 +90,10 @@ cdef class Matcher:
|
|||
'+': Require the pattern to match 1 or more times.
|
||||
'*': Allow the pattern to zero or more times.
|
||||
|
||||
The + and * operators are usually interpretted "greedily", i.e. longer
|
||||
matches are returned where possible. However, if you specify two '+'
|
||||
and '*' patterns in a row and their matches overlap, the first
|
||||
operator will behave non-greedily. This quirk in the semantics makes
|
||||
the matcher more efficient, by avoiding the need for back-tracking.
|
||||
The + and * operators return all possible matches (not just the greedy
|
||||
ones). However, the "greedy" argument can filter the final matches
|
||||
by returning a non-overlapping set per key, either taking preference to
|
||||
the first greedy match ("FIRST"), or the longest ("LONGEST").
|
||||
|
||||
As of spaCy v2.2.2, Matcher.add supports the future API, which makes
|
||||
the patterns the second argument and a list (instead of a variable
|
||||
|
@ -101,16 +103,15 @@ cdef class Matcher:
|
|||
key (str): The match ID.
|
||||
patterns (list): The patterns to add for the given key.
|
||||
on_match (callable): Optional callback executed on match.
|
||||
*_patterns (list): For backwards compatibility: list of patterns to add
|
||||
as variable arguments. Will be ignored if a list of patterns is
|
||||
provided as the second argument.
|
||||
greedy (str): Optional filter: "FIRST" or "LONGEST".
|
||||
"""
|
||||
errors = {}
|
||||
if on_match is not None and not hasattr(on_match, "__call__"):
|
||||
raise ValueError(Errors.E171.format(arg_type=type(on_match)))
|
||||
if patterns is None or hasattr(patterns, "__call__"): # old API
|
||||
on_match = patterns
|
||||
patterns = _patterns
|
||||
if patterns is None or not isinstance(patterns, List): # old API
|
||||
raise ValueError(Errors.E948.format(arg_type=type(patterns)))
|
||||
if greedy is not None and greedy not in ["FIRST", "LONGEST"]:
|
||||
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=greedy))
|
||||
for i, pattern in enumerate(patterns):
|
||||
if len(pattern) == 0:
|
||||
raise ValueError(Errors.E012.format(key=key))
|
||||
|
@ -133,6 +134,7 @@ cdef class Matcher:
|
|||
raise ValueError(Errors.E154.format())
|
||||
self._patterns.setdefault(key, [])
|
||||
self._callbacks[key] = on_match
|
||||
self._filter[key] = greedy
|
||||
self._patterns[key].extend(patterns)
|
||||
|
||||
def remove(self, key):
|
||||
|
@ -218,6 +220,7 @@ cdef class Matcher:
|
|||
length = doclike.end - doclike.start
|
||||
else:
|
||||
raise ValueError(Errors.E195.format(good="Doc or Span", got=type(doclike).__name__))
|
||||
cdef Pool tmp_pool = Pool()
|
||||
if len(set([LEMMA, POS, TAG]) & self._seen_attrs) > 0 \
|
||||
and not doc.is_tagged:
|
||||
raise ValueError(Errors.E155.format())
|
||||
|
@ -225,11 +228,42 @@ cdef class Matcher:
|
|||
raise ValueError(Errors.E156.format())
|
||||
matches = find_matches(&self.patterns[0], self.patterns.size(), doclike, length,
|
||||
extensions=self._extensions, predicates=self._extra_predicates)
|
||||
for i, (key, start, end) in enumerate(matches):
|
||||
final_matches = []
|
||||
pairs_by_id = {}
|
||||
# For each key, either add all matches, or only the filtered, non-overlapping ones
|
||||
for (key, start, end) in matches:
|
||||
span_filter = self._filter.get(key)
|
||||
if span_filter is not None:
|
||||
pairs = pairs_by_id.get(key, [])
|
||||
pairs.append((start,end))
|
||||
pairs_by_id[key] = pairs
|
||||
else:
|
||||
final_matches.append((key, start, end))
|
||||
matched = <char*>tmp_pool.alloc(length, sizeof(char))
|
||||
empty = <char*>tmp_pool.alloc(length, sizeof(char))
|
||||
for key, pairs in pairs_by_id.items():
|
||||
memset(matched, 0, length * sizeof(matched[0]))
|
||||
span_filter = self._filter.get(key)
|
||||
if span_filter == "FIRST":
|
||||
sorted_pairs = sorted(pairs, key=lambda x: (x[0], -x[1]), reverse=False) # sort by start
|
||||
elif span_filter == "LONGEST":
|
||||
sorted_pairs = sorted(pairs, key=lambda x: (x[1]-x[0], -x[0]), reverse=True) # reverse sort by length
|
||||
else:
|
||||
raise ValueError(Errors.E947.format(expected=["FIRST", "LONGEST"], arg=span_filter))
|
||||
for (start, end) in sorted_pairs:
|
||||
assert 0 <= start < end # Defend against segfaults
|
||||
span_len = end-start
|
||||
# If no tokens in the span have matched
|
||||
if memcmp(&matched[start], &empty[start], span_len * sizeof(matched[0])) == 0:
|
||||
final_matches.append((key, start, end))
|
||||
# Mark tokens that have matched
|
||||
memset(&matched[start], 1, span_len * sizeof(matched[0]))
|
||||
# perform the callbacks on the filtered set of results
|
||||
for i, (key, start, end) in enumerate(final_matches):
|
||||
on_match = self._callbacks.get(key, None)
|
||||
if on_match is not None:
|
||||
on_match(self, doc, i, matches)
|
||||
return matches
|
||||
on_match(self, doc, i, final_matches)
|
||||
return final_matches
|
||||
|
||||
def _normalize_key(self, key):
|
||||
if isinstance(key, basestring):
|
||||
|
@ -240,9 +274,9 @@ cdef class Matcher:
|
|||
|
||||
def unpickle_matcher(vocab, patterns, callbacks):
|
||||
matcher = Matcher(vocab)
|
||||
for key, specs in patterns.items():
|
||||
for key, pattern in patterns.items():
|
||||
callback = callbacks.get(key, None)
|
||||
matcher.add(key, callback, *specs)
|
||||
matcher.add(key, pattern, on_match=callback)
|
||||
return matcher
|
||||
|
||||
|
||||
|
|
|
@ -32,7 +32,6 @@ cdef class PhraseMatcher:
|
|||
vocab (Vocab): The shared vocabulary.
|
||||
attr (int / str): Token attribute to match on.
|
||||
validate (bool): Perform additional validation when patterns are added.
|
||||
RETURNS (PhraseMatcher): The newly constructed object.
|
||||
|
||||
DOCS: https://spacy.io/api/phrasematcher#init
|
||||
"""
|
||||
|
|
|
@ -1,16 +1,18 @@
|
|||
from typing import List
|
||||
from thinc.api import Model
|
||||
from thinc.types import Floats2d
|
||||
from ..tokens import Doc
|
||||
|
||||
|
||||
def CharacterEmbed(nM, nC):
|
||||
def CharacterEmbed(nM: int, nC: int) -> Model[List[Doc], List[Floats2d]]:
|
||||
# nM: Number of dimensions per character. nC: Number of characters.
|
||||
nO = nM * nC if (nM is not None and nC is not None) else None
|
||||
return Model(
|
||||
"charembed",
|
||||
forward,
|
||||
init=init,
|
||||
dims={"nM": nM, "nC": nC, "nO": nO, "nV": 256},
|
||||
dims={"nM": nM, "nC": nC, "nO": nM * nC, "nV": 256},
|
||||
params={"E": None},
|
||||
).initialize()
|
||||
)
|
||||
|
||||
|
||||
def init(model, X=None, Y=None):
|
||||
|
|
|
@ -5,11 +5,11 @@ from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_
|
|||
from thinc.api import HashEmbed, with_ragged, with_array, with_cpu, uniqued
|
||||
from thinc.api import Relu, residual, expand_window, FeatureExtractor
|
||||
|
||||
from ..spacy_vectors import SpacyVectors
|
||||
from ... import util
|
||||
from ...attrs import ID, ORTH, PREFIX, SUFFIX, SHAPE, LOWER
|
||||
from ...util import registry
|
||||
from ..extract_ngrams import extract_ngrams
|
||||
from ..staticvectors import StaticVectors
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.TextCatCNN.v1")
|
||||
|
@ -102,13 +102,7 @@ def build_text_classifier(
|
|||
)
|
||||
|
||||
if pretrained_vectors:
|
||||
nlp = util.load_model(pretrained_vectors)
|
||||
vectors = nlp.vocab.vectors
|
||||
vector_dim = vectors.data.shape[1]
|
||||
|
||||
static_vectors = SpacyVectors(vectors) >> with_array(
|
||||
Linear(width, vector_dim)
|
||||
)
|
||||
static_vectors = StaticVectors(width)
|
||||
vector_layer = trained_vectors | static_vectors
|
||||
vectors_width = width * 2
|
||||
else:
|
||||
|
@ -159,16 +153,11 @@ def build_text_classifier(
|
|||
|
||||
@registry.architectures.register("spacy.TextCatLowData.v1")
|
||||
def build_text_classifier_lowdata(width, pretrained_vectors, dropout, nO=None):
|
||||
nlp = util.load_model(pretrained_vectors)
|
||||
vectors = nlp.vocab.vectors
|
||||
vector_dim = vectors.data.shape[1]
|
||||
|
||||
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
|
||||
with Model.define_operators({">>": chain, "**": clone}):
|
||||
model = (
|
||||
SpacyVectors(vectors)
|
||||
StaticVectors(width)
|
||||
>> list2ragged()
|
||||
>> with_ragged(0, Linear(width, vector_dim))
|
||||
>> ParametricAttention(width)
|
||||
>> reduce_sum()
|
||||
>> residual(Relu(width, width)) ** 2
|
||||
|
|
|
@ -1,223 +1,140 @@
|
|||
from thinc.api import chain, clone, concatenate, with_array, uniqued
|
||||
from thinc.api import Model, noop, with_padded, Maxout, expand_window
|
||||
from thinc.api import HashEmbed, StaticVectors, PyTorchLSTM
|
||||
from thinc.api import residual, LayerNorm, FeatureExtractor, Mish
|
||||
from typing import Optional, List
|
||||
from thinc.api import chain, clone, concatenate, with_array, with_padded
|
||||
from thinc.api import Model, noop, list2ragged, ragged2list
|
||||
from thinc.api import FeatureExtractor, HashEmbed
|
||||
from thinc.api import expand_window, residual, Maxout, Mish, PyTorchLSTM
|
||||
from thinc.types import Floats2d
|
||||
|
||||
from ...tokens import Doc
|
||||
from ... import util
|
||||
from ...util import registry
|
||||
from ...ml import _character_embed
|
||||
from ..staticvectors import StaticVectors
|
||||
from ...pipeline.tok2vec import Tok2VecListener
|
||||
from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.Tok2VecTensors.v1")
|
||||
def tok2vec_tensors_v1(width, upstream="*"):
|
||||
@registry.architectures.register("spacy.Tok2VecListener.v1")
|
||||
def tok2vec_listener_v1(width, upstream="*"):
|
||||
tok2vec = Tok2VecListener(upstream_name=upstream, width=width)
|
||||
return tok2vec
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.VocabVectors.v1")
|
||||
def get_vocab_vectors(name):
|
||||
nlp = util.load_model(name)
|
||||
return nlp.vocab.vectors
|
||||
|
||||
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
||||
def build_hash_embed_cnn_tok2vec(
|
||||
*,
|
||||
width: int,
|
||||
depth: int,
|
||||
embed_size: int,
|
||||
window_size: int,
|
||||
maxout_pieces: int,
|
||||
subword_features: bool,
|
||||
dropout: Optional[float],
|
||||
pretrained_vectors: Optional[bool]
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
"""Build spaCy's 'standard' tok2vec layer, which uses hash embedding
|
||||
with subword features and a CNN with layer-normalized maxout."""
|
||||
return build_Tok2Vec_model(
|
||||
embed=MultiHashEmbed(
|
||||
width=width,
|
||||
rows=embed_size,
|
||||
also_embed_subwords=subword_features,
|
||||
also_use_static_vectors=bool(pretrained_vectors),
|
||||
),
|
||||
encode=MaxoutWindowEncoder(
|
||||
width=width,
|
||||
depth=depth,
|
||||
window_size=window_size,
|
||||
maxout_pieces=maxout_pieces
|
||||
)
|
||||
)
|
||||
|
||||
@registry.architectures.register("spacy.Tok2Vec.v1")
|
||||
def Tok2Vec(extract, embed, encode):
|
||||
field_size = 0
|
||||
if encode.attrs.get("receptive_field", None):
|
||||
field_size = encode.attrs["receptive_field"]
|
||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||
tok2vec = extract >> with_array(embed >> encode, pad=field_size)
|
||||
def build_Tok2Vec_model(
|
||||
embed: Model[List[Doc], List[Floats2d]],
|
||||
encode: Model[List[Floats2d], List[Floats2d]],
|
||||
) -> Model[List[Doc], List[Floats2d]]:
|
||||
|
||||
receptive_field = encode.attrs.get("receptive_field", 0)
|
||||
tok2vec = chain(embed, with_array(encode, pad=receptive_field))
|
||||
tok2vec.set_dim("nO", encode.get_dim("nO"))
|
||||
tok2vec.set_ref("embed", embed)
|
||||
tok2vec.set_ref("encode", encode)
|
||||
return tok2vec
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.Doc2Feats.v1")
|
||||
def Doc2Feats(columns):
|
||||
return FeatureExtractor(columns)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.HashEmbedCNN.v1")
|
||||
def hash_embed_cnn(
|
||||
pretrained_vectors,
|
||||
width,
|
||||
depth,
|
||||
embed_size,
|
||||
maxout_pieces,
|
||||
window_size,
|
||||
subword_features,
|
||||
dropout,
|
||||
):
|
||||
# Does not use character embeddings: set to False by default
|
||||
return build_Tok2Vec_model(
|
||||
width=width,
|
||||
embed_size=embed_size,
|
||||
pretrained_vectors=pretrained_vectors,
|
||||
conv_depth=depth,
|
||||
bilstm_depth=0,
|
||||
maxout_pieces=maxout_pieces,
|
||||
window_size=window_size,
|
||||
subword_features=subword_features,
|
||||
char_embed=False,
|
||||
nM=0,
|
||||
nC=0,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.HashCharEmbedCNN.v1")
|
||||
def hash_charembed_cnn(
|
||||
pretrained_vectors,
|
||||
width,
|
||||
depth,
|
||||
embed_size,
|
||||
maxout_pieces,
|
||||
window_size,
|
||||
nM,
|
||||
nC,
|
||||
dropout,
|
||||
):
|
||||
# Allows using character embeddings by setting nC, nM and char_embed=True
|
||||
return build_Tok2Vec_model(
|
||||
width=width,
|
||||
embed_size=embed_size,
|
||||
pretrained_vectors=pretrained_vectors,
|
||||
conv_depth=depth,
|
||||
bilstm_depth=0,
|
||||
maxout_pieces=maxout_pieces,
|
||||
window_size=window_size,
|
||||
subword_features=False,
|
||||
char_embed=True,
|
||||
nM=nM,
|
||||
nC=nC,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.HashEmbedBiLSTM.v1")
|
||||
def hash_embed_bilstm_v1(
|
||||
pretrained_vectors,
|
||||
width,
|
||||
depth,
|
||||
embed_size,
|
||||
subword_features,
|
||||
maxout_pieces,
|
||||
dropout,
|
||||
):
|
||||
# Does not use character embeddings: set to False by default
|
||||
return build_Tok2Vec_model(
|
||||
width=width,
|
||||
embed_size=embed_size,
|
||||
pretrained_vectors=pretrained_vectors,
|
||||
bilstm_depth=depth,
|
||||
conv_depth=0,
|
||||
maxout_pieces=maxout_pieces,
|
||||
window_size=1,
|
||||
subword_features=subword_features,
|
||||
char_embed=False,
|
||||
nM=0,
|
||||
nC=0,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.HashCharEmbedBiLSTM.v1")
|
||||
def hash_char_embed_bilstm_v1(
|
||||
pretrained_vectors, width, depth, embed_size, maxout_pieces, nM, nC, dropout
|
||||
):
|
||||
# Allows using character embeddings by setting nC, nM and char_embed=True
|
||||
return build_Tok2Vec_model(
|
||||
width=width,
|
||||
embed_size=embed_size,
|
||||
pretrained_vectors=pretrained_vectors,
|
||||
bilstm_depth=depth,
|
||||
conv_depth=0,
|
||||
maxout_pieces=maxout_pieces,
|
||||
window_size=1,
|
||||
subword_features=False,
|
||||
char_embed=True,
|
||||
nM=nM,
|
||||
nC=nC,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.LayerNormalizedMaxout.v1")
|
||||
def LayerNormalizedMaxout(width, maxout_pieces):
|
||||
return Maxout(nO=width, nP=maxout_pieces, dropout=0.0, normalize=True)
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.MultiHashEmbed.v1")
|
||||
def MultiHashEmbed(
|
||||
columns, width, rows, use_subwords, pretrained_vectors, mix, dropout
|
||||
width: int, rows: int, also_embed_subwords: bool, also_use_static_vectors: bool
|
||||
):
|
||||
norm = HashEmbed(
|
||||
nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=6
|
||||
)
|
||||
if use_subwords:
|
||||
prefix = HashEmbed(
|
||||
nO=width,
|
||||
nV=rows // 2,
|
||||
column=columns.index("PREFIX"),
|
||||
dropout=dropout,
|
||||
seed=7,
|
||||
)
|
||||
suffix = HashEmbed(
|
||||
nO=width,
|
||||
nV=rows // 2,
|
||||
column=columns.index("SUFFIX"),
|
||||
dropout=dropout,
|
||||
seed=8,
|
||||
)
|
||||
shape = HashEmbed(
|
||||
nO=width,
|
||||
nV=rows // 2,
|
||||
column=columns.index("SHAPE"),
|
||||
dropout=dropout,
|
||||
seed=9,
|
||||
cols = [NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||
|
||||
seed = 7
|
||||
|
||||
def make_hash_embed(feature):
|
||||
nonlocal seed
|
||||
seed += 1
|
||||
return HashEmbed(
|
||||
width,
|
||||
rows if feature == NORM else rows // 2,
|
||||
column=cols.index(feature),
|
||||
seed=seed,
|
||||
dropout=0.0,
|
||||
)
|
||||
|
||||
if pretrained_vectors:
|
||||
glove = StaticVectors(
|
||||
vectors=pretrained_vectors.data,
|
||||
nO=width,
|
||||
column=columns.index(ID),
|
||||
dropout=dropout,
|
||||
if also_embed_subwords:
|
||||
embeddings = [
|
||||
make_hash_embed(NORM),
|
||||
make_hash_embed(PREFIX),
|
||||
make_hash_embed(SUFFIX),
|
||||
make_hash_embed(SHAPE),
|
||||
]
|
||||
else:
|
||||
embeddings = [make_hash_embed(NORM)]
|
||||
concat_size = width * (len(embeddings) + also_use_static_vectors)
|
||||
if also_use_static_vectors:
|
||||
model = chain(
|
||||
concatenate(
|
||||
chain(
|
||||
FeatureExtractor(cols),
|
||||
list2ragged(),
|
||||
with_array(concatenate(*embeddings)),
|
||||
),
|
||||
StaticVectors(width, dropout=0.0),
|
||||
),
|
||||
with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
|
||||
ragged2list(),
|
||||
)
|
||||
|
||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||
if not use_subwords and not pretrained_vectors:
|
||||
embed_layer = norm
|
||||
else:
|
||||
if use_subwords and pretrained_vectors:
|
||||
concat_columns = glove | norm | prefix | suffix | shape
|
||||
elif use_subwords:
|
||||
concat_columns = norm | prefix | suffix | shape
|
||||
else:
|
||||
concat_columns = glove | norm
|
||||
|
||||
embed_layer = uniqued(concat_columns >> mix, column=columns.index("ORTH"))
|
||||
|
||||
return embed_layer
|
||||
else:
|
||||
model = chain(
|
||||
FeatureExtractor(cols),
|
||||
list2ragged(),
|
||||
with_array(concatenate(*embeddings)),
|
||||
with_array(Maxout(width, concat_size, nP=3, dropout=0.0, normalize=True)),
|
||||
ragged2list(),
|
||||
)
|
||||
return model
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.CharacterEmbed.v1")
|
||||
def CharacterEmbed(columns, width, rows, nM, nC, features, dropout):
|
||||
norm = HashEmbed(
|
||||
nO=width, nV=rows, column=columns.index("NORM"), dropout=dropout, seed=5
|
||||
def CharacterEmbed(width: int, rows: int, nM: int, nC: int):
|
||||
model = chain(
|
||||
concatenate(
|
||||
chain(_character_embed.CharacterEmbed(nM=nM, nC=nC), list2ragged()),
|
||||
chain(
|
||||
FeatureExtractor([NORM]),
|
||||
list2ragged(),
|
||||
with_array(HashEmbed(nO=width, nV=rows, column=0, seed=5))
|
||||
)
|
||||
),
|
||||
with_array(Maxout(width, nM * nC + width, nP=3, normalize=True, dropout=0.0)),
|
||||
ragged2list()
|
||||
)
|
||||
chr_embed = _character_embed.CharacterEmbed(nM=nM, nC=nC)
|
||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||
embed_layer = chr_embed | features >> with_array(norm)
|
||||
embed_layer.set_dim("nO", nM * nC + width)
|
||||
return embed_layer
|
||||
return model
|
||||
|
||||
|
||||
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
|
||||
def MaxoutWindowEncoder(width, window_size, maxout_pieces, depth):
|
||||
def MaxoutWindowEncoder(width: int, window_size: int, maxout_pieces: int, depth: int):
|
||||
cnn = chain(
|
||||
expand_window(window_size=window_size),
|
||||
Maxout(
|
||||
|
@ -238,8 +155,12 @@ def MaxoutWindowEncoder(width, window_size, maxout_pieces, depth):
|
|||
def MishWindowEncoder(width, window_size, depth):
|
||||
cnn = chain(
|
||||
expand_window(window_size=window_size),
|
||||
Mish(nO=width, nI=width * ((window_size * 2) + 1)),
|
||||
LayerNorm(width),
|
||||
Mish(
|
||||
nO=width,
|
||||
nI=width * ((window_size * 2) + 1),
|
||||
dropout=0.0,
|
||||
normalize=True
|
||||
),
|
||||
)
|
||||
model = clone(residual(cnn), depth)
|
||||
model.set_dim("nO", width)
|
||||
|
@ -247,133 +168,7 @@ def MishWindowEncoder(width, window_size, depth):
|
|||
|
||||
|
||||
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
|
||||
def TorchBiLSTMEncoder(width, depth):
|
||||
import torch.nn
|
||||
|
||||
# TODO FIX
|
||||
from thinc.api import PyTorchRNNWrapper
|
||||
|
||||
def BiLSTMEncoder(width, depth, dropout):
|
||||
if depth == 0:
|
||||
return noop()
|
||||
return with_padded(
|
||||
PyTorchRNNWrapper(torch.nn.LSTM(width, width // 2, depth, bidirectional=True))
|
||||
)
|
||||
|
||||
|
||||
def build_Tok2Vec_model(
|
||||
width,
|
||||
embed_size,
|
||||
pretrained_vectors,
|
||||
window_size,
|
||||
maxout_pieces,
|
||||
subword_features,
|
||||
char_embed,
|
||||
nM,
|
||||
nC,
|
||||
conv_depth,
|
||||
bilstm_depth,
|
||||
dropout,
|
||||
) -> Model:
|
||||
if char_embed:
|
||||
subword_features = False
|
||||
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
|
||||
norm = HashEmbed(
|
||||
nO=width, nV=embed_size, column=cols.index(NORM), dropout=None, seed=0
|
||||
)
|
||||
if subword_features:
|
||||
prefix = HashEmbed(
|
||||
nO=width,
|
||||
nV=embed_size // 2,
|
||||
column=cols.index(PREFIX),
|
||||
dropout=None,
|
||||
seed=1,
|
||||
)
|
||||
suffix = HashEmbed(
|
||||
nO=width,
|
||||
nV=embed_size // 2,
|
||||
column=cols.index(SUFFIX),
|
||||
dropout=None,
|
||||
seed=2,
|
||||
)
|
||||
shape = HashEmbed(
|
||||
nO=width,
|
||||
nV=embed_size // 2,
|
||||
column=cols.index(SHAPE),
|
||||
dropout=None,
|
||||
seed=3,
|
||||
)
|
||||
else:
|
||||
prefix, suffix, shape = (None, None, None)
|
||||
if pretrained_vectors is not None:
|
||||
glove = StaticVectors(
|
||||
vectors=pretrained_vectors.data,
|
||||
nO=width,
|
||||
column=cols.index(ID),
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
if subword_features:
|
||||
columns = 5
|
||||
embed = uniqued(
|
||||
(glove | norm | prefix | suffix | shape)
|
||||
>> Maxout(
|
||||
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
|
||||
),
|
||||
column=cols.index(ORTH),
|
||||
)
|
||||
else:
|
||||
columns = 2
|
||||
embed = uniqued(
|
||||
(glove | norm)
|
||||
>> Maxout(
|
||||
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
|
||||
),
|
||||
column=cols.index(ORTH),
|
||||
)
|
||||
elif subword_features:
|
||||
columns = 4
|
||||
embed = uniqued(
|
||||
concatenate(norm, prefix, suffix, shape)
|
||||
>> Maxout(
|
||||
nO=width, nI=width * columns, nP=3, dropout=0.0, normalize=True,
|
||||
),
|
||||
column=cols.index(ORTH),
|
||||
)
|
||||
elif char_embed:
|
||||
embed = _character_embed.CharacterEmbed(nM=nM, nC=nC) | FeatureExtractor(
|
||||
cols
|
||||
) >> with_array(norm)
|
||||
reduce_dimensions = Maxout(
|
||||
nO=width, nI=nM * nC + width, nP=3, dropout=0.0, normalize=True,
|
||||
)
|
||||
else:
|
||||
embed = norm
|
||||
|
||||
convolution = residual(
|
||||
expand_window(window_size=window_size)
|
||||
>> Maxout(
|
||||
nO=width,
|
||||
nI=width * ((window_size * 2) + 1),
|
||||
nP=maxout_pieces,
|
||||
dropout=0.0,
|
||||
normalize=True,
|
||||
)
|
||||
)
|
||||
if char_embed:
|
||||
tok2vec = embed >> with_array(
|
||||
reduce_dimensions >> convolution ** conv_depth, pad=conv_depth
|
||||
)
|
||||
else:
|
||||
tok2vec = FeatureExtractor(cols) >> with_array(
|
||||
embed >> convolution ** conv_depth, pad=conv_depth
|
||||
)
|
||||
|
||||
if bilstm_depth >= 1:
|
||||
tok2vec = tok2vec >> PyTorchLSTM(
|
||||
nO=width, nI=width, depth=bilstm_depth, bi=True
|
||||
)
|
||||
if tok2vec.has_dim("nO") is not False:
|
||||
tok2vec.set_dim("nO", width)
|
||||
tok2vec.set_ref("embed", embed)
|
||||
return tok2vec
|
||||
return with_padded(PyTorchLSTM(width, width, bi=True, depth=depth, dropout=dropout))
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
from libc.string cimport memset, memcpy
|
||||
from libc.stdlib cimport calloc, free, realloc
|
||||
from ..typedefs cimport weight_t, class_t, hash_t
|
||||
|
||||
from ._state cimport StateC
|
||||
from ..typedefs cimport weight_t, hash_t
|
||||
from ..pipeline._parser_internals._state cimport StateC
|
||||
|
||||
|
||||
cdef struct SizesC:
|
|
@ -1,29 +1,18 @@
|
|||
# cython: infer_types=True, cdivision=True, boundscheck=False
|
||||
cimport cython.parallel
|
||||
cimport numpy as np
|
||||
from libc.math cimport exp
|
||||
from libcpp.vector cimport vector
|
||||
from libc.string cimport memset, memcpy
|
||||
from libc.stdlib cimport calloc, free, realloc
|
||||
from cymem.cymem cimport Pool
|
||||
from thinc.extra.search cimport Beam
|
||||
from thinc.backends.linalg cimport Vec, VecVec
|
||||
cimport blis.cy
|
||||
|
||||
import numpy
|
||||
import numpy.random
|
||||
from thinc.api import Linear, Model, CupyOps, NumpyOps, use_ops, noop
|
||||
from thinc.api import Model, CupyOps, NumpyOps
|
||||
|
||||
from ..typedefs cimport weight_t, class_t, hash_t
|
||||
from ..tokens.doc cimport Doc
|
||||
from .stateclass cimport StateClass
|
||||
from .transition_system cimport Transition
|
||||
|
||||
from ..compat import copy_array
|
||||
from ..errors import Errors, TempErrors
|
||||
from ..util import link_vectors_to_models, create_default_optimizer
|
||||
from .. import util
|
||||
from . import nonproj
|
||||
from ..typedefs cimport weight_t, class_t, hash_t
|
||||
from ..pipeline._parser_internals.stateclass cimport StateClass
|
||||
|
||||
|
||||
cdef WeightsC get_c_weights(model) except *:
|
|
@ -1,27 +0,0 @@
|
|||
import numpy
|
||||
from thinc.api import Model, Unserializable
|
||||
|
||||
|
||||
def SpacyVectors(vectors) -> Model:
|
||||
attrs = {"vectors": Unserializable(vectors)}
|
||||
model = Model("spacy_vectors", forward, attrs=attrs)
|
||||
return model
|
||||
|
||||
|
||||
def forward(model, docs, is_train: bool):
|
||||
batch = []
|
||||
vectors = model.attrs["vectors"].obj
|
||||
for doc in docs:
|
||||
indices = numpy.zeros((len(doc),), dtype="i")
|
||||
for i, word in enumerate(doc):
|
||||
if word.orth in vectors.key2row:
|
||||
indices[i] = vectors.key2row[word.orth]
|
||||
else:
|
||||
indices[i] = 0
|
||||
batch_vectors = vectors.data[indices]
|
||||
batch.append(batch_vectors)
|
||||
|
||||
def backprop(dY):
|
||||
return None
|
||||
|
||||
return batch, backprop
|
100
spacy/ml/staticvectors.py
Normal file
100
spacy/ml/staticvectors.py
Normal file
|
@ -0,0 +1,100 @@
|
|||
from typing import List, Tuple, Callable, Optional, cast
|
||||
|
||||
from thinc.initializers import glorot_uniform_init
|
||||
from thinc.util import partial
|
||||
from thinc.types import Ragged, Floats2d, Floats1d
|
||||
from thinc.api import Model, Ops, registry
|
||||
|
||||
from ..tokens import Doc
|
||||
|
||||
|
||||
@registry.layers("spacy.StaticVectors.v1")
|
||||
def StaticVectors(
|
||||
nO: Optional[int] = None,
|
||||
nM: Optional[int] = None,
|
||||
*,
|
||||
dropout: Optional[float] = None,
|
||||
init_W: Callable = glorot_uniform_init,
|
||||
key_attr: str = "ORTH"
|
||||
) -> Model[List[Doc], Ragged]:
|
||||
"""Embed Doc objects with their vocab's vectors table, applying a learned
|
||||
linear projection to control the dimensionality. If a dropout rate is
|
||||
specified, the dropout is applied per dimension over the whole batch.
|
||||
"""
|
||||
return Model(
|
||||
"static_vectors",
|
||||
forward,
|
||||
init=partial(init, init_W),
|
||||
params={"W": None},
|
||||
attrs={"key_attr": key_attr, "dropout_rate": dropout},
|
||||
dims={"nO": nO, "nM": nM},
|
||||
)
|
||||
|
||||
|
||||
def forward(
|
||||
model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
|
||||
) -> Tuple[Ragged, Callable]:
|
||||
if not len(docs):
|
||||
return _handle_empty(model.ops, model.get_dim("nO"))
|
||||
key_attr = model.attrs["key_attr"]
|
||||
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
||||
V = cast(Floats2d, docs[0].vocab.vectors.data)
|
||||
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
|
||||
rows = model.ops.flatten(
|
||||
[doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs]
|
||||
)
|
||||
output = Ragged(
|
||||
model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True),
|
||||
model.ops.asarray([len(doc) for doc in docs], dtype="i"),
|
||||
)
|
||||
if mask is not None:
|
||||
output.data *= mask
|
||||
|
||||
def backprop(d_output: Ragged) -> List[Doc]:
|
||||
if mask is not None:
|
||||
d_output.data *= mask
|
||||
model.inc_grad(
|
||||
"W",
|
||||
model.ops.gemm(d_output.data, model.ops.as_contig(V[rows]), trans1=True),
|
||||
)
|
||||
return []
|
||||
|
||||
return output, backprop
|
||||
|
||||
|
||||
def init(
|
||||
init_W: Callable,
|
||||
model: Model[List[Doc], Ragged],
|
||||
X: Optional[List[Doc]] = None,
|
||||
Y: Optional[Ragged] = None,
|
||||
) -> Model[List[Doc], Ragged]:
|
||||
nM = model.get_dim("nM") if model.has_dim("nM") else None
|
||||
nO = model.get_dim("nO") if model.has_dim("nO") else None
|
||||
if X is not None and len(X):
|
||||
nM = X[0].vocab.vectors.data.shape[1]
|
||||
if Y is not None:
|
||||
nO = Y.data.shape[1]
|
||||
|
||||
if nM is None:
|
||||
raise ValueError(
|
||||
"Cannot initialize StaticVectors layer: nM dimension unset. "
|
||||
"This dimension refers to the width of the vectors table."
|
||||
)
|
||||
if nO is None:
|
||||
raise ValueError(
|
||||
"Cannot initialize StaticVectors layer: nO dimension unset. "
|
||||
"This dimension refers to the output width, after the linear "
|
||||
"projection has been applied."
|
||||
)
|
||||
model.set_dim("nM", nM)
|
||||
model.set_dim("nO", nO)
|
||||
model.set_param("W", init_W(model.ops, (nO, nM)))
|
||||
return model
|
||||
|
||||
|
||||
def _handle_empty(ops: Ops, nO: int):
|
||||
return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: []
|
||||
|
||||
|
||||
def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]:
|
||||
return ops.get_dropout_mask((nO,), rate) if rate is not None else None
|
|
@ -1,5 +1,5 @@
|
|||
from thinc.api import Model, noop, use_ops, Linear
|
||||
from ..syntax._parser_model import ParserStepModel
|
||||
from .parser_model import ParserStepModel
|
||||
|
||||
|
||||
def TransitionModel(tok2vec, lower, upper, dropout=0.2, unseen_classes=set()):
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
from libc.string cimport memcpy, memset, memmove
|
||||
from libc.stdlib cimport malloc, calloc, free
|
||||
from libc.string cimport memcpy, memset
|
||||
from libc.stdlib cimport calloc, free
|
||||
from libc.stdint cimport uint32_t, uint64_t
|
||||
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
|
||||
from murmurhash.mrmr cimport hash64
|
||||
|
||||
from ..vocab cimport EMPTY_LEXEME
|
||||
from ..structs cimport TokenC, SpanC
|
||||
from ..lexeme cimport Lexeme
|
||||
from ..symbols cimport punct
|
||||
from ..attrs cimport IS_SPACE
|
||||
from ..typedefs cimport attr_t
|
||||
from ...vocab cimport EMPTY_LEXEME
|
||||
from ...structs cimport TokenC, SpanC
|
||||
from ...lexeme cimport Lexeme
|
||||
from ...attrs cimport IS_SPACE
|
||||
from ...typedefs cimport attr_t
|
||||
|
||||
|
||||
cdef inline bint is_space_token(const TokenC* token) nogil:
|
|
@ -1,8 +1,6 @@
|
|||
from cymem.cymem cimport Pool
|
||||
|
||||
from .stateclass cimport StateClass
|
||||
from ..typedefs cimport weight_t, attr_t
|
||||
from .transition_system cimport TransitionSystem, Transition
|
||||
from ...typedefs cimport weight_t, attr_t
|
||||
from .transition_system cimport Transition, TransitionSystem
|
||||
|
||||
|
||||
cdef class ArcEager(TransitionSystem):
|
|
@ -1,24 +1,17 @@
|
|||
# cython: profile=True, cdivision=True, infer_types=True
|
||||
from cpython.ref cimport Py_INCREF
|
||||
from cymem.cymem cimport Pool, Address
|
||||
from libc.stdint cimport int32_t
|
||||
|
||||
from collections import defaultdict, Counter
|
||||
import json
|
||||
|
||||
from ..typedefs cimport hash_t, attr_t
|
||||
from ..strings cimport hash_string
|
||||
from ..structs cimport TokenC
|
||||
from ..tokens.doc cimport Doc, set_children_from_heads
|
||||
from ...typedefs cimport hash_t, attr_t
|
||||
from ...strings cimport hash_string
|
||||
from ...structs cimport TokenC
|
||||
from ...tokens.doc cimport Doc, set_children_from_heads
|
||||
from ...gold.example cimport Example
|
||||
from ...errors import Errors
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
from .transition_system cimport move_cost_func_t, label_cost_func_t
|
||||
from ..gold.example cimport Example
|
||||
|
||||
from ..errors import Errors
|
||||
from .nonproj import is_nonproj_tree
|
||||
from . import nonproj
|
||||
|
||||
|
||||
# Calculate cost as gold/not gold. We don't use scalar value anyway.
|
||||
cdef int BINARY_COSTS = 1
|
|
@ -1,6 +1,4 @@
|
|||
from .transition_system cimport TransitionSystem
|
||||
from .transition_system cimport Transition
|
||||
from ..typedefs cimport attr_t
|
||||
|
||||
|
||||
cdef class BiluoPushDown(TransitionSystem):
|
|
@ -2,17 +2,14 @@ from collections import Counter
|
|||
from libc.stdint cimport int32_t
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
from ..typedefs cimport weight_t
|
||||
from ...typedefs cimport weight_t, attr_t
|
||||
from ...lexeme cimport Lexeme
|
||||
from ...attrs cimport IS_SPACE
|
||||
from ...gold.example cimport Example
|
||||
from ...errors import Errors
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
from .transition_system cimport Transition
|
||||
from .transition_system cimport do_func_t
|
||||
from ..lexeme cimport Lexeme
|
||||
from ..attrs cimport IS_SPACE
|
||||
from ..gold.iob_utils import biluo_tags_from_offsets
|
||||
from ..gold.example cimport Example
|
||||
|
||||
from ..errors import Errors
|
||||
from .transition_system cimport Transition, do_func_t
|
||||
|
||||
|
||||
cdef enum:
|
|
@ -5,9 +5,9 @@ scheme.
|
|||
"""
|
||||
from copy import copy
|
||||
|
||||
from ..tokens.doc cimport Doc, set_children_from_heads
|
||||
from ...tokens.doc cimport Doc, set_children_from_heads
|
||||
|
||||
from ..errors import Errors
|
||||
from ...errors import Errors
|
||||
|
||||
|
||||
DELIMITER = '||'
|
|
@ -1,12 +1,8 @@
|
|||
from libc.string cimport memcpy, memset
|
||||
|
||||
from cymem.cymem cimport Pool
|
||||
cimport cython
|
||||
|
||||
from ..structs cimport TokenC, SpanC
|
||||
from ..typedefs cimport attr_t
|
||||
from ...structs cimport TokenC, SpanC
|
||||
from ...typedefs cimport attr_t
|
||||
|
||||
from ..vocab cimport EMPTY_LEXEME
|
||||
from ._state cimport StateC
|
||||
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
# cython: infer_types=True
|
||||
import numpy
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
from ...tokens.doc cimport Doc
|
||||
|
||||
|
||||
cdef class StateClass:
|
|
@ -1,11 +1,11 @@
|
|||
from cymem.cymem cimport Pool
|
||||
|
||||
from ..typedefs cimport attr_t, weight_t
|
||||
from ..structs cimport TokenC
|
||||
from ..strings cimport StringStore
|
||||
from ...typedefs cimport attr_t, weight_t
|
||||
from ...structs cimport TokenC
|
||||
from ...strings cimport StringStore
|
||||
from ...gold.example cimport Example
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
from ..gold.example cimport Example
|
||||
|
||||
|
||||
cdef struct Transition:
|
|
@ -1,19 +1,17 @@
|
|||
# cython: infer_types=True
|
||||
from __future__ import print_function
|
||||
from cpython.ref cimport Py_INCREF
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
from collections import Counter
|
||||
import srsly
|
||||
|
||||
from ..typedefs cimport weight_t
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..structs cimport TokenC
|
||||
from ...typedefs cimport weight_t, attr_t
|
||||
from ...tokens.doc cimport Doc
|
||||
from ...structs cimport TokenC
|
||||
from .stateclass cimport StateClass
|
||||
from ..typedefs cimport attr_t
|
||||
|
||||
from ..errors import Errors
|
||||
from .. import util
|
||||
from ...errors import Errors
|
||||
from ... import util
|
||||
|
||||
|
||||
cdef weight_t MIN_SCORE = -90000
|
|
@ -1,13 +1,13 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
from typing import Optional, Iterable
|
||||
from thinc.api import CosineDistance, to_categorical, get_array_module, Model, Config
|
||||
from thinc.api import Model, Config
|
||||
|
||||
from ..syntax.nn_parser cimport Parser
|
||||
from ..syntax.arc_eager cimport ArcEager
|
||||
from .transition_parser cimport Parser
|
||||
from ._parser_internals.arc_eager cimport ArcEager
|
||||
|
||||
from .functions import merge_subtokens
|
||||
from ..language import Language
|
||||
from ..syntax import nonproj
|
||||
from ._parser_internals import nonproj
|
||||
from ..scorer import Scorer
|
||||
|
||||
|
||||
|
|
|
@ -222,9 +222,9 @@ class EntityLinker(Pipe):
|
|||
set_dropout_rate(self.model, drop)
|
||||
if not sentence_docs:
|
||||
warnings.warn(Warnings.W093.format(name="Entity Linker"))
|
||||
return 0.0
|
||||
return losses
|
||||
sentence_encodings, bp_context = self.model.begin_update(sentence_docs)
|
||||
loss, d_scores = self.get_similarity_loss(
|
||||
loss, d_scores = self.get_loss(
|
||||
sentence_encodings=sentence_encodings, examples=examples
|
||||
)
|
||||
bp_context(d_scores)
|
||||
|
@ -235,7 +235,7 @@ class EntityLinker(Pipe):
|
|||
self.set_annotations(docs, predictions)
|
||||
return losses
|
||||
|
||||
def get_similarity_loss(self, examples: Iterable[Example], sentence_encodings):
|
||||
def get_loss(self, examples: Iterable[Example], sentence_encodings):
|
||||
entity_encodings = []
|
||||
for eg in examples:
|
||||
kb_ids = eg.get_aligned("ENT_KB_ID", as_string=True)
|
||||
|
@ -247,7 +247,7 @@ class EntityLinker(Pipe):
|
|||
entity_encodings = self.model.ops.asarray(entity_encodings, dtype="float32")
|
||||
if sentence_encodings.shape != entity_encodings.shape:
|
||||
err = Errors.E147.format(
|
||||
method="get_similarity_loss", msg="gold entities do not match up"
|
||||
method="get_loss", msg="gold entities do not match up"
|
||||
)
|
||||
raise RuntimeError(err)
|
||||
gradients = self.distance.get_grad(sentence_encodings, entity_encodings)
|
||||
|
@ -337,13 +337,13 @@ class EntityLinker(Pipe):
|
|||
final_kb_ids.append(candidates[0].entity_)
|
||||
else:
|
||||
random.shuffle(candidates)
|
||||
# this will set all prior probabilities to 0 if they should be excluded from the model
|
||||
# set all prior probabilities to 0 if incl_prior=False
|
||||
prior_probs = xp.asarray(
|
||||
[c.prior_prob for c in candidates]
|
||||
)
|
||||
if not self.cfg.get("incl_prior"):
|
||||
prior_probs = xp.asarray(
|
||||
[0.0 for c in candidates]
|
||||
[0.0 for _ in candidates]
|
||||
)
|
||||
scores = prior_probs
|
||||
# add in similarity from the context
|
||||
|
@ -387,7 +387,7 @@ class EntityLinker(Pipe):
|
|||
docs (Iterable[Doc]): The documents to modify.
|
||||
kb_ids (List[str]): The IDs to set, produced by EntityLinker.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/entitylinker#predict
|
||||
DOCS: https://spacy.io/api/entitylinker#set_annotations
|
||||
"""
|
||||
count_ents = len([ent for doc in docs for ent in doc.ents])
|
||||
if count_ents != len(kb_ids):
|
||||
|
@ -400,7 +400,9 @@ class EntityLinker(Pipe):
|
|||
for token in ent:
|
||||
token.ent_kb_id_ = kb_id
|
||||
|
||||
def to_disk(self, path: Union[str, Path], exclude: Iterable[str] = tuple()) -> None:
|
||||
def to_disk(
|
||||
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
|
||||
) -> None:
|
||||
"""Serialize the pipe to disk.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
|
@ -417,7 +419,7 @@ class EntityLinker(Pipe):
|
|||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(
|
||||
self, path: Union[str, Path], exclude: Iterable[str] = tuple()
|
||||
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
|
||||
) -> "EntityLinker":
|
||||
"""Load the pipe from disk. Modifies the object in place and returns it.
|
||||
|
||||
|
|
|
@ -86,7 +86,6 @@ class EntityRuler:
|
|||
overwrite_ents (bool): If existing entities are present, e.g. entities
|
||||
added by the model, overwrite them by matches if necessary.
|
||||
ent_id_sep (str): Separator used internally for entity IDs.
|
||||
RETURNS (EntityRuler): The newly constructed object.
|
||||
|
||||
DOCS: https://spacy.io/api/entityruler#init
|
||||
"""
|
||||
|
@ -316,7 +315,7 @@ class EntityRuler:
|
|||
return Scorer.score_spans(examples, "ents", **kwargs)
|
||||
|
||||
def from_bytes(
|
||||
self, patterns_bytes: bytes, exclude: Iterable[str] = tuple()
|
||||
self, patterns_bytes: bytes, *, exclude: Iterable[str] = tuple()
|
||||
) -> "EntityRuler":
|
||||
"""Load the entity ruler from a bytestring.
|
||||
|
||||
|
@ -340,7 +339,7 @@ class EntityRuler:
|
|||
self.add_patterns(cfg)
|
||||
return self
|
||||
|
||||
def to_bytes(self, exclude: Iterable[str] = tuple()) -> bytes:
|
||||
def to_bytes(self, *, exclude: Iterable[str] = tuple()) -> bytes:
|
||||
"""Serialize the entity ruler patterns to a bytestring.
|
||||
|
||||
RETURNS (bytes): The serialized patterns.
|
||||
|
@ -356,7 +355,7 @@ class EntityRuler:
|
|||
return srsly.msgpack_dumps(serial)
|
||||
|
||||
def from_disk(
|
||||
self, path: Union[str, Path], exclude: Iterable[str] = tuple()
|
||||
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
|
||||
) -> "EntityRuler":
|
||||
"""Load the entity ruler from a file. Expects a file containing
|
||||
newline-delimited JSON (JSONL) with one entry per line.
|
||||
|
@ -392,7 +391,9 @@ class EntityRuler:
|
|||
from_disk(path, deserializers_patterns, {})
|
||||
return self
|
||||
|
||||
def to_disk(self, path: Union[str, Path], exclude: Iterable[str] = tuple()) -> None:
|
||||
def to_disk(
|
||||
self, path: Union[str, Path], *, exclude: Iterable[str] = tuple()
|
||||
) -> None:
|
||||
"""Save the entity ruler patterns to a directory. The patterns will be
|
||||
saved as newline-delimited JSON (JSONL).
|
||||
|
||||
|
|
|
@ -58,7 +58,7 @@ def merge_subtokens(doc: Doc, label: str = "subtok") -> Doc:
|
|||
"""
|
||||
# TODO: make stateful component with "label" config
|
||||
merger = Matcher(doc.vocab)
|
||||
merger.add("SUBTOK", None, [{"DEP": label, "op": "+"}])
|
||||
merger.add("SUBTOK", [[{"DEP": label, "op": "+"}]])
|
||||
matches = merger(doc)
|
||||
spans = filter_spans([doc[start : end + 1] for _, start, end in matches])
|
||||
with doc.retokenize() as retokenizer:
|
||||
|
|
|
@ -22,17 +22,23 @@ default_model_config = """
|
|||
@architectures = "spacy.Tagger.v1"
|
||||
|
||||
[model.tok2vec]
|
||||
@architectures = "spacy.HashCharEmbedCNN.v1"
|
||||
pretrained_vectors = null
|
||||
@architectures = "spacy.Tok2Vec.v1"
|
||||
|
||||
[model.tok2vec.embed]
|
||||
@architectures = "spacy.CharacterEmbed.v1"
|
||||
width = 128
|
||||
depth = 4
|
||||
embed_size = 7000
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
rows = 7000
|
||||
nM = 64
|
||||
nC = 8
|
||||
dropout = null
|
||||
|
||||
[model.tok2vec.encode]
|
||||
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||
width = 128
|
||||
depth = 4
|
||||
window_size = 1
|
||||
maxout_pieces = 3
|
||||
"""
|
||||
|
||||
DEFAULT_MORPH_MODEL = Config().from_str(default_model_config)["model"]
|
||||
|
||||
|
||||
|
@ -149,7 +155,6 @@ class Morphologizer(Tagger):
|
|||
self.cfg["labels_pos"][norm_label] = POS_IDS[pos]
|
||||
self.set_output(len(self.labels))
|
||||
self.model.initialize()
|
||||
util.link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
@ -160,7 +165,7 @@ class Morphologizer(Tagger):
|
|||
docs (Iterable[Doc]): The documents to modify.
|
||||
batch_tag_ids: The IDs to set, produced by Morphologizer.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/morphologizer#predict
|
||||
DOCS: https://spacy.io/api/morphologizer#set_annotations
|
||||
"""
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
|
@ -230,7 +235,7 @@ class Morphologizer(Tagger):
|
|||
"morph", **kwargs))
|
||||
return results
|
||||
|
||||
def to_bytes(self, exclude=tuple()):
|
||||
def to_bytes(self, *, exclude=tuple()):
|
||||
"""Serialize the pipe to a bytestring.
|
||||
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
|
@ -244,7 +249,7 @@ class Morphologizer(Tagger):
|
|||
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple()):
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
"""Load the pipe from a bytestring.
|
||||
|
||||
bytes_data (bytes): The serialized pipe.
|
||||
|
@ -267,7 +272,7 @@ class Morphologizer(Tagger):
|
|||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, exclude=tuple()):
|
||||
def to_disk(self, path, *, exclude=tuple()):
|
||||
"""Serialize the pipe to disk.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
|
@ -282,7 +287,7 @@ class Morphologizer(Tagger):
|
|||
}
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, exclude=tuple()):
|
||||
def from_disk(self, path, *, exclude=tuple()):
|
||||
"""Load the pipe from disk. Modifies the object in place and returns it.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
from typing import Optional
|
||||
import numpy
|
||||
from thinc.api import CosineDistance, to_categorical, to_categorical, Model, Config
|
||||
from thinc.api import CosineDistance, to_categorical, Model, Config
|
||||
from thinc.api import set_dropout_rate
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
|
@ -9,9 +9,8 @@ from ..tokens.doc cimport Doc
|
|||
from .pipe import Pipe
|
||||
from .tagger import Tagger
|
||||
from ..language import Language
|
||||
from ..syntax import nonproj
|
||||
from ._parser_internals import nonproj
|
||||
from ..attrs import POS, ID
|
||||
from ..util import link_vectors_to_models
|
||||
from ..errors import Errors
|
||||
|
||||
|
||||
|
@ -91,7 +90,6 @@ class MultitaskObjective(Tagger):
|
|||
if label is not None and label not in self.labels:
|
||||
self.labels[label] = len(self.labels)
|
||||
self.model.initialize()
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
@ -179,7 +177,6 @@ class ClozeMultitask(Pipe):
|
|||
pass
|
||||
|
||||
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None):
|
||||
link_vectors_to_models(self.vocab)
|
||||
self.model.initialize()
|
||||
X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
|
||||
self.model.output_layer.begin_training(X)
|
||||
|
@ -222,3 +219,6 @@ class ClozeMultitask(Pipe):
|
|||
|
||||
if losses is not None:
|
||||
losses[self.name] += loss
|
||||
|
||||
def add_label(self, label):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
# cython: infer_types=True, profile=True, binding=True
|
||||
from typing import Optional, Iterable
|
||||
from thinc.api import CosineDistance, to_categorical, get_array_module, Model, Config
|
||||
from thinc.api import Model, Config
|
||||
|
||||
from ..syntax.nn_parser cimport Parser
|
||||
from ..syntax.ner cimport BiluoPushDown
|
||||
from .transition_parser cimport Parser
|
||||
from ._parser_internals.ner cimport BiluoPushDown
|
||||
|
||||
from ..language import Language
|
||||
from ..scorer import Scorer
|
||||
|
|
2
spacy/pipeline/pipe.pxd
Normal file
2
spacy/pipeline/pipe.pxd
Normal file
|
@ -0,0 +1,2 @@
|
|||
cdef class Pipe:
|
||||
cdef public str name
|
|
@ -3,12 +3,12 @@ import srsly
|
|||
|
||||
from ..tokens.doc cimport Doc
|
||||
|
||||
from ..util import link_vectors_to_models, create_default_optimizer
|
||||
from ..util import create_default_optimizer
|
||||
from ..errors import Errors
|
||||
from .. import util
|
||||
|
||||
|
||||
class Pipe:
|
||||
cdef class Pipe:
|
||||
"""This class is a base class and not instantiated directly. Trainable
|
||||
pipeline components like the EntityRecognizer or TextCategorizer inherit
|
||||
from it and it defines the interface that components should follow to
|
||||
|
@ -17,8 +17,6 @@ class Pipe:
|
|||
DOCS: https://spacy.io/api/pipe
|
||||
"""
|
||||
|
||||
name = None
|
||||
|
||||
def __init__(self, vocab, model, name, **cfg):
|
||||
"""Initialize a pipeline component.
|
||||
|
||||
|
@ -32,7 +30,9 @@ class Pipe:
|
|||
raise NotImplementedError
|
||||
|
||||
def __call__(self, Doc doc):
|
||||
"""Add context-sensitive embeddings to the Doc.tensor attribute.
|
||||
"""Apply the pipe to one document. The document is modified in place,
|
||||
and returned. This usually happens under the hood when the nlp object
|
||||
is called on a text and all components are applied to the Doc.
|
||||
|
||||
docs (Doc): The Doc to preocess.
|
||||
RETURNS (Doc): The processed Doc.
|
||||
|
@ -74,9 +74,9 @@ class Pipe:
|
|||
"""Modify a batch of documents, using pre-computed scores.
|
||||
|
||||
docs (Iterable[Doc]): The documents to modify.
|
||||
tokvecses: The tensors to set, produced by Pipe.predict.
|
||||
scores: The scores to assign.
|
||||
|
||||
DOCS: https://spacy.io/api/pipe#predict
|
||||
DOCS: https://spacy.io/api/pipe#set_annotations
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
@ -145,8 +145,6 @@ class Pipe:
|
|||
DOCS: https://spacy.io/api/pipe#begin_training
|
||||
"""
|
||||
self.model.initialize()
|
||||
if hasattr(self, "vocab"):
|
||||
link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
@ -178,7 +176,7 @@ class Pipe:
|
|||
"""
|
||||
return {}
|
||||
|
||||
def to_bytes(self, exclude=tuple()):
|
||||
def to_bytes(self, *, exclude=tuple()):
|
||||
"""Serialize the pipe to a bytestring.
|
||||
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
|
@ -193,7 +191,7 @@ class Pipe:
|
|||
serialize["vocab"] = self.vocab.to_bytes
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple()):
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
"""Load the pipe from a bytestring.
|
||||
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
|
@ -216,7 +214,7 @@ class Pipe:
|
|||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, exclude=tuple()):
|
||||
def to_disk(self, path, *, exclude=tuple()):
|
||||
"""Serialize the pipe to disk.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
|
@ -230,7 +228,7 @@ class Pipe:
|
|||
serialize["model"] = lambda p: self.model.to_disk(p)
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, exclude=tuple()):
|
||||
def from_disk(self, path, *, exclude=tuple()):
|
||||
"""Load the pipe from disk.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
|
|
|
@ -162,7 +162,7 @@ class Sentencizer(Pipe):
|
|||
del results["sents_per_type"]
|
||||
return results
|
||||
|
||||
def to_bytes(self, exclude=tuple()):
|
||||
def to_bytes(self, *, exclude=tuple()):
|
||||
"""Serialize the sentencizer to a bytestring.
|
||||
|
||||
RETURNS (bytes): The serialized object.
|
||||
|
@ -171,7 +171,7 @@ class Sentencizer(Pipe):
|
|||
"""
|
||||
return srsly.msgpack_dumps({"punct_chars": list(self.punct_chars)})
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple()):
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
"""Load the sentencizer from a bytestring.
|
||||
|
||||
bytes_data (bytes): The data to load.
|
||||
|
@ -183,7 +183,7 @@ class Sentencizer(Pipe):
|
|||
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
|
||||
return self
|
||||
|
||||
def to_disk(self, path, exclude=tuple()):
|
||||
def to_disk(self, path, *, exclude=tuple()):
|
||||
"""Serialize the sentencizer to disk.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencizer#to_disk
|
||||
|
@ -193,7 +193,7 @@ class Sentencizer(Pipe):
|
|||
srsly.write_json(path, {"punct_chars": list(self.punct_chars)})
|
||||
|
||||
|
||||
def from_disk(self, path, exclude=tuple()):
|
||||
def from_disk(self, path, *, exclude=tuple()):
|
||||
"""Load the sentencizer from disk.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencizer#from_disk
|
||||
|
@ -203,3 +203,9 @@ class Sentencizer(Pipe):
|
|||
cfg = srsly.read_json(path)
|
||||
self.punct_chars = set(cfg.get("punct_chars", self.default_punct_chars))
|
||||
return self
|
||||
|
||||
def get_loss(self, examples, scores):
|
||||
raise NotImplementedError
|
||||
|
||||
def add_label(self, label):
|
||||
raise NotImplementedError
|
||||
|
|
|
@ -76,7 +76,7 @@ class SentenceRecognizer(Tagger):
|
|||
docs (Iterable[Doc]): The documents to modify.
|
||||
batch_tag_ids: The IDs to set, produced by SentenceRecognizer.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/sentencerecognizer#predict
|
||||
DOCS: https://spacy.io/api/sentencerecognizer#set_annotations
|
||||
"""
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
|
@ -109,7 +109,7 @@ class SentenceRecognizer(Tagger):
|
|||
for eg in examples:
|
||||
eg_truth = []
|
||||
for x in eg.get_aligned("sent_start"):
|
||||
if x == None:
|
||||
if x is None:
|
||||
eg_truth.append(None)
|
||||
elif x == 1:
|
||||
eg_truth.append(labels[1])
|
||||
|
@ -138,7 +138,6 @@ class SentenceRecognizer(Tagger):
|
|||
"""
|
||||
self.set_output(len(self.labels))
|
||||
self.model.initialize()
|
||||
util.link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
@ -157,7 +156,7 @@ class SentenceRecognizer(Tagger):
|
|||
del results["sents_per_type"]
|
||||
return results
|
||||
|
||||
def to_bytes(self, exclude=tuple()):
|
||||
def to_bytes(self, *, exclude=tuple()):
|
||||
"""Serialize the pipe to a bytestring.
|
||||
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
|
@ -171,7 +170,7 @@ class SentenceRecognizer(Tagger):
|
|||
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple()):
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
"""Load the pipe from a bytestring.
|
||||
|
||||
bytes_data (bytes): The serialized pipe.
|
||||
|
@ -194,7 +193,7 @@ class SentenceRecognizer(Tagger):
|
|||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, exclude=tuple()):
|
||||
def to_disk(self, path, *, exclude=tuple()):
|
||||
"""Serialize the pipe to disk.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
|
@ -209,7 +208,7 @@ class SentenceRecognizer(Tagger):
|
|||
}
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, exclude=tuple()):
|
||||
def from_disk(self, path, *, exclude=tuple()):
|
||||
"""Load the pipe from disk. Modifies the object in place and returns it.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
|
|
|
@ -131,8 +131,6 @@ class SimpleNER(Pipe):
|
|||
return losses
|
||||
|
||||
def get_loss(self, examples: List[Example], scores) -> Tuple[List[Floats2d], float]:
|
||||
loss = 0
|
||||
d_scores = []
|
||||
truths = []
|
||||
for eg in examples:
|
||||
tags = eg.get_aligned("TAG", as_string=True)
|
||||
|
@ -159,7 +157,6 @@ class SimpleNER(Pipe):
|
|||
if not hasattr(get_examples, "__call__"):
|
||||
gold_tuples = get_examples
|
||||
get_examples = lambda: gold_tuples
|
||||
labels = _get_labels(get_examples())
|
||||
for label in _get_labels(get_examples()):
|
||||
self.add_label(label)
|
||||
labels = self.labels
|
||||
|
@ -168,7 +165,6 @@ class SimpleNER(Pipe):
|
|||
self.model.initialize()
|
||||
if pipeline is not None:
|
||||
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
|
||||
util.link_vectors_to_models(self.vocab)
|
||||
self.loss_func = SequenceCategoricalCrossentropy(
|
||||
names=self.get_tag_names(), normalize=True, missing_value=None
|
||||
)
|
||||
|
|
|
@ -145,7 +145,7 @@ class Tagger(Pipe):
|
|||
docs (Iterable[Doc]): The documents to modify.
|
||||
batch_tag_ids: The IDs to set, produced by Tagger.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/tagger#predict
|
||||
DOCS: https://spacy.io/api/tagger#set_annotations
|
||||
"""
|
||||
if isinstance(docs, Doc):
|
||||
docs = [docs]
|
||||
|
@ -318,7 +318,6 @@ class Tagger(Pipe):
|
|||
self.model.initialize(X=doc_sample)
|
||||
# Get batch of example docs, example outputs to call begin_training().
|
||||
# This lets the model infer shapes.
|
||||
util.link_vectors_to_models(self.vocab)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
return sgd
|
||||
|
@ -370,7 +369,7 @@ class Tagger(Pipe):
|
|||
scores.update(Scorer.score_token_attr(examples, "lemma", **kwargs))
|
||||
return scores
|
||||
|
||||
def to_bytes(self, exclude=tuple()):
|
||||
def to_bytes(self, *, exclude=tuple()):
|
||||
"""Serialize the pipe to a bytestring.
|
||||
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
|
@ -388,7 +387,7 @@ class Tagger(Pipe):
|
|||
serialize["morph_rules"] = lambda: srsly.msgpack_dumps(morph_rules)
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple()):
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
"""Load the pipe from a bytestring.
|
||||
|
||||
bytes_data (bytes): The serialized pipe.
|
||||
|
@ -424,7 +423,7 @@ class Tagger(Pipe):
|
|||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
return self
|
||||
|
||||
def to_disk(self, path, exclude=tuple()):
|
||||
def to_disk(self, path, *, exclude=tuple()):
|
||||
"""Serialize the pipe to disk.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
|
@ -443,7 +442,7 @@ class Tagger(Pipe):
|
|||
}
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
def from_disk(self, path, exclude=tuple()):
|
||||
def from_disk(self, path, *, exclude=tuple()):
|
||||
"""Load the pipe from disk. Modifies the object in place and returns it.
|
||||
|
||||
path (str / Path): Path to a directory.
|
||||
|
|
|
@ -163,7 +163,7 @@ class TextCategorizer(Pipe):
|
|||
docs (Iterable[Doc]): The documents to modify.
|
||||
scores: The scores to set, produced by TextCategorizer.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/textcategorizer#predict
|
||||
DOCS: https://spacy.io/api/textcategorizer#set_annotations
|
||||
"""
|
||||
for i, doc in enumerate(docs):
|
||||
for j, label in enumerate(self.labels):
|
||||
|
@ -238,8 +238,11 @@ class TextCategorizer(Pipe):
|
|||
|
||||
DOCS: https://spacy.io/api/textcategorizer#rehearse
|
||||
"""
|
||||
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
if self._rehearsal_model is None:
|
||||
return
|
||||
return losses
|
||||
try:
|
||||
docs = [eg.predicted for eg in examples]
|
||||
except AttributeError:
|
||||
|
@ -250,7 +253,7 @@ class TextCategorizer(Pipe):
|
|||
raise TypeError(err)
|
||||
if not any(len(doc) for doc in docs):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return
|
||||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
scores, bp_scores = self.model.begin_update(docs)
|
||||
target = self._rehearsal_model(examples)
|
||||
|
@ -259,7 +262,6 @@ class TextCategorizer(Pipe):
|
|||
if sgd is not None:
|
||||
self.model.finish_update(sgd)
|
||||
if losses is not None:
|
||||
losses.setdefault(self.name, 0.0)
|
||||
losses[self.name] += (gradient ** 2).sum()
|
||||
return losses
|
||||
|
||||
|
@ -356,7 +358,6 @@ class TextCategorizer(Pipe):
|
|||
docs = [Doc(Vocab(), words=["hello"])]
|
||||
truths, _ = self._examples_to_truth(examples)
|
||||
self.set_output(len(self.labels))
|
||||
util.link_vectors_to_models(self.vocab)
|
||||
self.model.initialize(X=docs, Y=truths)
|
||||
if sgd is None:
|
||||
sgd = self.create_optimizer()
|
||||
|
|
|
@ -7,7 +7,7 @@ from ..tokens import Doc
|
|||
from ..vocab import Vocab
|
||||
from ..language import Language
|
||||
from ..errors import Errors
|
||||
from ..util import link_vectors_to_models, minibatch
|
||||
from ..util import minibatch
|
||||
|
||||
|
||||
default_model_config = """
|
||||
|
@ -109,7 +109,7 @@ class Tok2Vec(Pipe):
|
|||
docs (Iterable[Doc]): The documents to modify.
|
||||
tokvecses: The tensors to set, produced by Tok2Vec.predict.
|
||||
|
||||
DOCS: https://spacy.io/api/tok2vec#predict
|
||||
DOCS: https://spacy.io/api/tok2vec#set_annotations
|
||||
"""
|
||||
for doc, tokvecs in zip(docs, tokvecses):
|
||||
assert tokvecs.shape[0] == len(doc)
|
||||
|
@ -196,9 +196,11 @@ class Tok2Vec(Pipe):
|
|||
|
||||
DOCS: https://spacy.io/api/tok2vec#begin_training
|
||||
"""
|
||||
docs = [Doc(Vocab(), words=["hello"])]
|
||||
docs = [Doc(self.vocab, words=["hello"])]
|
||||
self.model.initialize(X=docs)
|
||||
link_vectors_to_models(self.vocab)
|
||||
|
||||
def add_label(self, label):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class Tok2VecListener(Model):
|
||||
|
|
|
@ -1,16 +1,15 @@
|
|||
from .stateclass cimport StateClass
|
||||
from .arc_eager cimport TransitionSystem
|
||||
from cymem.cymem cimport Pool
|
||||
|
||||
from ..vocab cimport Vocab
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..structs cimport TokenC
|
||||
from ._state cimport StateC
|
||||
from ._parser_model cimport WeightsC, ActivationsC, SizesC
|
||||
from .pipe cimport Pipe
|
||||
from ._parser_internals.transition_system cimport Transition, TransitionSystem
|
||||
from ._parser_internals._state cimport StateC
|
||||
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC
|
||||
|
||||
|
||||
cdef class Parser:
|
||||
cdef class Parser(Pipe):
|
||||
cdef readonly Vocab vocab
|
||||
cdef public object model
|
||||
cdef public str name
|
||||
cdef public object _rehearsal_model
|
||||
cdef readonly TransitionSystem moves
|
||||
cdef readonly object cfg
|
|
@ -1,42 +1,32 @@
|
|||
# cython: infer_types=True, cdivision=True, boundscheck=False
|
||||
cimport cython.parallel
|
||||
from __future__ import print_function
|
||||
from cymem.cymem cimport Pool
|
||||
cimport numpy as np
|
||||
from itertools import islice
|
||||
from cpython.ref cimport PyObject, Py_XDECREF
|
||||
from cpython.exc cimport PyErr_CheckSignals, PyErr_SetFromErrno
|
||||
from libc.math cimport exp
|
||||
from libcpp.vector cimport vector
|
||||
from libc.string cimport memset, memcpy
|
||||
from libc.string cimport memset
|
||||
from libc.stdlib cimport calloc, free
|
||||
from cymem.cymem cimport Pool
|
||||
from thinc.backends.linalg cimport Vec, VecVec
|
||||
|
||||
from thinc.api import chain, clone, Linear, list2array, NumpyOps, CupyOps, use_ops
|
||||
from thinc.api import get_array_module, zero_init, set_dropout_rate
|
||||
from itertools import islice
|
||||
import srsly
|
||||
|
||||
from ._parser_internals.stateclass cimport StateClass
|
||||
from ..ml.parser_model cimport alloc_activations, free_activations
|
||||
from ..ml.parser_model cimport predict_states, arg_max_if_valid
|
||||
from ..ml.parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
|
||||
from ..ml.parser_model cimport get_c_weights, get_c_sizes
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..errors import Errors, Warnings
|
||||
from .. import util
|
||||
from ..util import create_default_optimizer
|
||||
|
||||
from thinc.api import set_dropout_rate
|
||||
import numpy.random
|
||||
import numpy
|
||||
import warnings
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..typedefs cimport weight_t, class_t, hash_t
|
||||
from ._parser_model cimport alloc_activations, free_activations
|
||||
from ._parser_model cimport predict_states, arg_max_if_valid
|
||||
from ._parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
|
||||
from ._parser_model cimport get_c_weights, get_c_sizes
|
||||
from .stateclass cimport StateClass
|
||||
from ._state cimport StateC
|
||||
from .transition_system cimport Transition
|
||||
|
||||
from ..util import link_vectors_to_models, create_default_optimizer, registry
|
||||
from ..compat import copy_array
|
||||
from ..errors import Errors, Warnings
|
||||
from .. import util
|
||||
from . import nonproj
|
||||
|
||||
|
||||
cdef class Parser:
|
||||
cdef class Parser(Pipe):
|
||||
"""
|
||||
Base class of the DependencyParser and EntityRecognizer.
|
||||
"""
|
||||
|
@ -107,7 +97,7 @@ cdef class Parser:
|
|||
|
||||
@property
|
||||
def tok2vec(self):
|
||||
'''Return the embedding and convolutional layer of the model.'''
|
||||
"""Return the embedding and convolutional layer of the model."""
|
||||
return self.model.get_ref("tok2vec")
|
||||
|
||||
@property
|
||||
|
@ -138,13 +128,13 @@ cdef class Parser:
|
|||
raise NotImplementedError
|
||||
|
||||
def init_multitask_objectives(self, get_examples, pipeline, **cfg):
|
||||
'''Setup models for secondary objectives, to benefit from multi-task
|
||||
"""Setup models for secondary objectives, to benefit from multi-task
|
||||
learning. This method is intended to be overridden by subclasses.
|
||||
|
||||
For instance, the dependency parser can benefit from sharing
|
||||
an input representation with a label prediction model. These auxiliary
|
||||
models are discarded after training.
|
||||
'''
|
||||
"""
|
||||
pass
|
||||
|
||||
def use_params(self, params):
|
||||
|
@ -456,7 +446,6 @@ cdef class Parser:
|
|||
self.model.initialize()
|
||||
if pipeline is not None:
|
||||
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
|
||||
link_vectors_to_models(self.vocab)
|
||||
return sgd
|
||||
|
||||
def to_disk(self, path, exclude=tuple()):
|
|
@ -171,17 +171,6 @@ class ModelMetaSchema(BaseModel):
|
|||
# fmt: on
|
||||
|
||||
|
||||
# JSON training format
|
||||
|
||||
|
||||
class TrainingSchema(BaseModel):
|
||||
# TODO: write
|
||||
|
||||
class Config:
|
||||
title = "Schema for training data in spaCy's JSON format"
|
||||
extra = "forbid"
|
||||
|
||||
|
||||
# Config schema
|
||||
# We're not setting any defaults here (which is too messy) and are making all
|
||||
# fields required, so we can raise validation errors for missing values. To
|
||||
|
|
|
@ -84,7 +84,6 @@ class Scorer:
|
|||
**cfg,
|
||||
) -> None:
|
||||
"""Initialize the Scorer.
|
||||
RETURNS (Scorer): The newly created object.
|
||||
|
||||
DOCS: https://spacy.io/api/scorer#init
|
||||
"""
|
||||
|
|
|
@ -97,7 +97,6 @@ cdef class StringStore:
|
|||
"""Create the StringStore.
|
||||
|
||||
strings (iterable): A sequence of unicode strings to add to the store.
|
||||
RETURNS (StringStore): The newly constructed object.
|
||||
"""
|
||||
self.mem = Pool()
|
||||
self._map = PreshMap()
|
||||
|
|
|
@ -63,18 +63,11 @@ def test_matcher_len_contains(matcher):
|
|||
assert "TEST2" not in matcher
|
||||
|
||||
|
||||
def test_matcher_add_new_old_api(en_vocab):
|
||||
def test_matcher_add_new_api(en_vocab):
|
||||
doc = Doc(en_vocab, words=["a", "b"])
|
||||
patterns = [[{"TEXT": "a"}], [{"TEXT": "a"}, {"TEXT": "b"}]]
|
||||
matcher = Matcher(en_vocab)
|
||||
matcher.add("OLD_API", None, *patterns)
|
||||
assert len(matcher(doc)) == 2
|
||||
matcher = Matcher(en_vocab)
|
||||
on_match = Mock()
|
||||
matcher.add("OLD_API_CALLBACK", on_match, *patterns)
|
||||
assert len(matcher(doc)) == 2
|
||||
assert on_match.call_count == 2
|
||||
# New API: add(key: str, patterns: List[List[dict]], on_match: Callable)
|
||||
matcher = Matcher(en_vocab)
|
||||
matcher.add("NEW_API", patterns)
|
||||
assert len(matcher(doc)) == 2
|
||||
|
@ -176,7 +169,7 @@ def test_matcher_match_zero_plus(matcher):
|
|||
|
||||
def test_matcher_match_one_plus(matcher):
|
||||
control = Matcher(matcher.vocab)
|
||||
control.add("BasicPhilippe", None, [{"ORTH": "Philippe"}])
|
||||
control.add("BasicPhilippe", [[{"ORTH": "Philippe"}]])
|
||||
doc = Doc(control.vocab, words=["Philippe", "Philippe"])
|
||||
m = control(doc)
|
||||
assert len(m) == 2
|
||||
|
|
|
@ -7,18 +7,10 @@ from spacy.tokens import Doc, Span
|
|||
|
||||
|
||||
pattern1 = [{"ORTH": "A"}, {"ORTH": "A", "OP": "*"}]
|
||||
pattern2 = [{"ORTH": "A"}, {"ORTH": "A"}]
|
||||
pattern2 = [{"ORTH": "A", "OP": "*"}, {"ORTH": "A"}]
|
||||
pattern3 = [{"ORTH": "A"}, {"ORTH": "A"}]
|
||||
pattern4 = [
|
||||
{"ORTH": "B"},
|
||||
{"ORTH": "A", "OP": "*"},
|
||||
{"ORTH": "B"},
|
||||
]
|
||||
pattern5 = [
|
||||
{"ORTH": "B", "OP": "*"},
|
||||
{"ORTH": "A", "OP": "*"},
|
||||
{"ORTH": "B"},
|
||||
]
|
||||
pattern4 = [{"ORTH": "B"}, {"ORTH": "A", "OP": "*"}, {"ORTH": "B"}]
|
||||
pattern5 = [{"ORTH": "B", "OP": "*"}, {"ORTH": "A", "OP": "*"}, {"ORTH": "B"}]
|
||||
|
||||
re_pattern1 = "AA*"
|
||||
re_pattern2 = "A*A"
|
||||
|
@ -26,10 +18,16 @@ re_pattern3 = "AA"
|
|||
re_pattern4 = "BA*B"
|
||||
re_pattern5 = "B*A*B"
|
||||
|
||||
longest1 = "A A A A A"
|
||||
longest2 = "A A A A A"
|
||||
longest3 = "A A"
|
||||
longest4 = "B A A A A A B" # "FIRST" would be "B B"
|
||||
longest5 = "B B A A A A A B"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def text():
|
||||
return "(ABBAAAAAB)."
|
||||
return "(BBAAAAAB)."
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -41,25 +39,63 @@ def doc(en_tokenizer, text):
|
|||
@pytest.mark.parametrize(
|
||||
"pattern,re_pattern",
|
||||
[
|
||||
pytest.param(pattern1, re_pattern1, marks=pytest.mark.xfail()),
|
||||
pytest.param(pattern2, re_pattern2, marks=pytest.mark.xfail()),
|
||||
pytest.param(pattern3, re_pattern3, marks=pytest.mark.xfail()),
|
||||
(pattern1, re_pattern1),
|
||||
(pattern2, re_pattern2),
|
||||
(pattern3, re_pattern3),
|
||||
(pattern4, re_pattern4),
|
||||
pytest.param(pattern5, re_pattern5, marks=pytest.mark.xfail()),
|
||||
(pattern5, re_pattern5),
|
||||
],
|
||||
)
|
||||
def test_greedy_matching(doc, text, pattern, re_pattern):
|
||||
"""Test that the greedy matching behavior of the * op is consistant with
|
||||
def test_greedy_matching_first(doc, text, pattern, re_pattern):
|
||||
"""Test that the greedy matching behavior "FIRST" is consistent with
|
||||
other re implementations."""
|
||||
matcher = Matcher(doc.vocab)
|
||||
matcher.add(re_pattern, [pattern])
|
||||
matcher.add(re_pattern, [pattern], greedy="FIRST")
|
||||
matches = matcher(doc)
|
||||
re_matches = [m.span() for m in re.finditer(re_pattern, text)]
|
||||
for match, re_match in zip(matches, re_matches):
|
||||
assert match[1:] == re_match
|
||||
for (key, m_s, m_e), (re_s, re_e) in zip(matches, re_matches):
|
||||
# matching the string, not the exact position
|
||||
assert doc[m_s:m_e].text == doc[re_s:re_e].text
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"pattern,longest",
|
||||
[
|
||||
(pattern1, longest1),
|
||||
(pattern2, longest2),
|
||||
(pattern3, longest3),
|
||||
(pattern4, longest4),
|
||||
(pattern5, longest5),
|
||||
],
|
||||
)
|
||||
def test_greedy_matching_longest(doc, text, pattern, longest):
|
||||
"""Test the "LONGEST" greedy matching behavior"""
|
||||
matcher = Matcher(doc.vocab)
|
||||
matcher.add("RULE", [pattern], greedy="LONGEST")
|
||||
matches = matcher(doc)
|
||||
for (key, s, e) in matches:
|
||||
assert doc[s:e].text == longest
|
||||
|
||||
|
||||
def test_greedy_matching_longest_first(en_tokenizer):
|
||||
"""Test that "LONGEST" matching prefers the first of two equally long matches"""
|
||||
doc = en_tokenizer(" ".join("CCC"))
|
||||
matcher = Matcher(doc.vocab)
|
||||
pattern = [{"ORTH": "C"}, {"ORTH": "C"}]
|
||||
matcher.add("RULE", [pattern], greedy="LONGEST")
|
||||
matches = matcher(doc)
|
||||
# out of 0-2 and 1-3, the first should be picked
|
||||
assert len(matches) == 1
|
||||
assert matches[0][1] == 0
|
||||
assert matches[0][2] == 2
|
||||
|
||||
|
||||
def test_invalid_greediness(doc, text):
|
||||
matcher = Matcher(doc.vocab)
|
||||
with pytest.raises(ValueError):
|
||||
matcher.add("RULE", [pattern1], greedy="GREEDY")
|
||||
|
||||
|
||||
@pytest.mark.xfail
|
||||
@pytest.mark.parametrize(
|
||||
"pattern,re_pattern",
|
||||
[
|
||||
|
@ -74,7 +110,7 @@ def test_match_consuming(doc, text, pattern, re_pattern):
|
|||
"""Test that matcher.__call__ consumes tokens on a match similar to
|
||||
re.findall."""
|
||||
matcher = Matcher(doc.vocab)
|
||||
matcher.add(re_pattern, [pattern])
|
||||
matcher.add(re_pattern, [pattern], greedy="FIRST")
|
||||
matches = matcher(doc)
|
||||
re_matches = [m.span() for m in re.finditer(re_pattern, text)]
|
||||
assert len(matches) == len(re_matches)
|
||||
|
|
|
@ -4,8 +4,8 @@ from spacy import registry
|
|||
from spacy.gold import Example
|
||||
from spacy.pipeline import DependencyParser
|
||||
from spacy.tokens import Doc
|
||||
from spacy.syntax.nonproj import projectivize
|
||||
from spacy.syntax.arc_eager import ArcEager
|
||||
from spacy.pipeline._parser_internals.nonproj import projectivize
|
||||
from spacy.pipeline._parser_internals.arc_eager import ArcEager
|
||||
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
||||
|
||||
|
||||
|
|
|
@ -5,7 +5,7 @@ from spacy.lang.en import English
|
|||
|
||||
from spacy.language import Language
|
||||
from spacy.lookups import Lookups
|
||||
from spacy.syntax.ner import BiluoPushDown
|
||||
from spacy.pipeline._parser_internals.ner import BiluoPushDown
|
||||
from spacy.gold import Example
|
||||
from spacy.tokens import Doc
|
||||
from spacy.vocab import Vocab
|
||||
|
|
|
@ -3,8 +3,8 @@ import pytest
|
|||
from spacy import registry
|
||||
from spacy.gold import Example
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.syntax.arc_eager import ArcEager
|
||||
from spacy.syntax.nn_parser import Parser
|
||||
from spacy.pipeline._parser_internals.arc_eager import ArcEager
|
||||
from spacy.pipeline.transition_parser import Parser
|
||||
from spacy.tokens.doc import Doc
|
||||
from thinc.api import Model
|
||||
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
import pytest
|
||||
from spacy.syntax.nonproj import ancestors, contains_cycle, is_nonproj_arc
|
||||
from spacy.syntax.nonproj import is_nonproj_tree
|
||||
from spacy.syntax import nonproj
|
||||
from spacy.pipeline._parser_internals.nonproj import ancestors, contains_cycle, is_nonproj_arc
|
||||
from spacy.pipeline._parser_internals.nonproj import is_nonproj_tree
|
||||
from spacy.pipeline._parser_internals import nonproj
|
||||
|
||||
from ..util import get_doc
|
||||
|
||||
|
|
|
@ -9,7 +9,6 @@ from spacy.matcher import Matcher
|
|||
from spacy.tokens import Doc, Span
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.compat import pickle
|
||||
from spacy.util import link_vectors_to_models
|
||||
import numpy
|
||||
import random
|
||||
|
||||
|
@ -190,7 +189,6 @@ def test_issue2871():
|
|||
_ = vocab[word] # noqa: F841
|
||||
vocab.set_vector(word, vector_data[0])
|
||||
vocab.vectors.name = "dummy_vectors"
|
||||
link_vectors_to_models(vocab)
|
||||
assert vocab["dog"].rank == 0
|
||||
assert vocab["cat"].rank == 1
|
||||
assert vocab["SUFFIX"].rank == 2
|
||||
|
|
|
@ -5,6 +5,7 @@ from spacy.lang.en import English
|
|||
from spacy.language import Language
|
||||
from spacy.util import registry, deep_merge_configs, load_model_from_config
|
||||
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
|
||||
from spacy.ml.models import MultiHashEmbed, MaxoutWindowEncoder
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
@ -40,7 +41,7 @@ factory = "tagger"
|
|||
@architectures = "spacy.Tagger.v1"
|
||||
|
||||
[components.tagger.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecTensors.v1"
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model:width}
|
||||
"""
|
||||
|
||||
|
@ -68,18 +69,18 @@ dropout = null
|
|||
@registry.architectures.register("my_test_parser")
|
||||
def my_parser():
|
||||
tok2vec = build_Tok2Vec_model(
|
||||
width=321,
|
||||
embed_size=5432,
|
||||
pretrained_vectors=None,
|
||||
window_size=3,
|
||||
maxout_pieces=4,
|
||||
subword_features=True,
|
||||
char_embed=True,
|
||||
nM=64,
|
||||
nC=8,
|
||||
conv_depth=2,
|
||||
bilstm_depth=0,
|
||||
dropout=None,
|
||||
MultiHashEmbed(
|
||||
width=321,
|
||||
rows=5432,
|
||||
also_embed_subwords=True,
|
||||
also_use_static_vectors=False
|
||||
),
|
||||
MaxoutWindowEncoder(
|
||||
width=321,
|
||||
window_size=3,
|
||||
maxout_pieces=4,
|
||||
depth=2
|
||||
)
|
||||
)
|
||||
parser = build_tb_parser_model(
|
||||
tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5
|
||||
|
|
|
@ -5,12 +5,32 @@ from thinc.api import fix_random_seed, Adam, set_dropout_rate
|
|||
from numpy.testing import assert_array_equal
|
||||
import numpy
|
||||
|
||||
from spacy.ml.models import build_Tok2Vec_model
|
||||
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
|
||||
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
|
||||
from spacy.lang.en import English
|
||||
from spacy.lang.en.examples import sentences as EN_SENTENCES
|
||||
|
||||
|
||||
def get_textcat_kwargs():
|
||||
return {
|
||||
"width": 64,
|
||||
"embed_size": 2000,
|
||||
"pretrained_vectors": None,
|
||||
"exclusive_classes": False,
|
||||
"ngram_size": 1,
|
||||
"window_size": 1,
|
||||
"conv_depth": 2,
|
||||
"dropout": None,
|
||||
"nO": 7,
|
||||
}
|
||||
|
||||
def get_textcat_cnn_kwargs():
|
||||
return {
|
||||
"tok2vec": test_tok2vec(),
|
||||
"exclusive_classes": False,
|
||||
"nO": 13,
|
||||
}
|
||||
|
||||
def get_all_params(model):
|
||||
params = []
|
||||
for node in model.walk():
|
||||
|
@ -35,50 +55,34 @@ def get_gradient(model, Y):
|
|||
raise ValueError(f"Could not get gradient for type {type(Y)}")
|
||||
|
||||
|
||||
def get_tok2vec_kwargs():
|
||||
# This actually creates models, so seems best to put it in a function.
|
||||
return {
|
||||
"embed": MultiHashEmbed(
|
||||
width=32,
|
||||
rows=500,
|
||||
also_embed_subwords=True,
|
||||
also_use_static_vectors=False
|
||||
),
|
||||
"encode": MaxoutWindowEncoder(
|
||||
width=32,
|
||||
depth=2,
|
||||
maxout_pieces=2,
|
||||
window_size=1,
|
||||
)
|
||||
}
|
||||
|
||||
|
||||
def test_tok2vec():
|
||||
return build_Tok2Vec_model(**TOK2VEC_KWARGS)
|
||||
|
||||
|
||||
TOK2VEC_KWARGS = {
|
||||
"width": 96,
|
||||
"embed_size": 2000,
|
||||
"subword_features": True,
|
||||
"char_embed": False,
|
||||
"conv_depth": 4,
|
||||
"bilstm_depth": 0,
|
||||
"maxout_pieces": 4,
|
||||
"window_size": 1,
|
||||
"dropout": 0.1,
|
||||
"nM": 0,
|
||||
"nC": 0,
|
||||
"pretrained_vectors": None,
|
||||
}
|
||||
|
||||
TEXTCAT_KWARGS = {
|
||||
"width": 64,
|
||||
"embed_size": 2000,
|
||||
"pretrained_vectors": None,
|
||||
"exclusive_classes": False,
|
||||
"ngram_size": 1,
|
||||
"window_size": 1,
|
||||
"conv_depth": 2,
|
||||
"dropout": None,
|
||||
"nO": 7,
|
||||
}
|
||||
|
||||
TEXTCAT_CNN_KWARGS = {
|
||||
"tok2vec": test_tok2vec(),
|
||||
"exclusive_classes": False,
|
||||
"nO": 13,
|
||||
}
|
||||
return build_Tok2Vec_model(**get_tok2vec_kwargs())
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"seed,model_func,kwargs",
|
||||
[
|
||||
(0, build_Tok2Vec_model, TOK2VEC_KWARGS),
|
||||
(0, build_text_classifier, TEXTCAT_KWARGS),
|
||||
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS),
|
||||
(0, build_Tok2Vec_model, get_tok2vec_kwargs()),
|
||||
(0, build_text_classifier, get_textcat_kwargs()),
|
||||
(0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs()),
|
||||
],
|
||||
)
|
||||
def test_models_initialize_consistently(seed, model_func, kwargs):
|
||||
|
@ -96,9 +100,9 @@ def test_models_initialize_consistently(seed, model_func, kwargs):
|
|||
@pytest.mark.parametrize(
|
||||
"seed,model_func,kwargs,get_X",
|
||||
[
|
||||
(0, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs),
|
||||
(0, build_text_classifier, TEXTCAT_KWARGS, get_docs),
|
||||
(0, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs),
|
||||
(0, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs),
|
||||
(0, build_text_classifier, get_textcat_kwargs(), get_docs),
|
||||
(0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs),
|
||||
],
|
||||
)
|
||||
def test_models_predict_consistently(seed, model_func, kwargs, get_X):
|
||||
|
@ -131,9 +135,9 @@ def test_models_predict_consistently(seed, model_func, kwargs, get_X):
|
|||
@pytest.mark.parametrize(
|
||||
"seed,dropout,model_func,kwargs,get_X",
|
||||
[
|
||||
(0, 0.2, build_Tok2Vec_model, TOK2VEC_KWARGS, get_docs),
|
||||
(0, 0.2, build_text_classifier, TEXTCAT_KWARGS, get_docs),
|
||||
(0, 0.2, build_simple_cnn_text_classifier, TEXTCAT_CNN_KWARGS, get_docs),
|
||||
(0, 0.2, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs),
|
||||
(0, 0.2, build_text_classifier, get_textcat_kwargs(), get_docs),
|
||||
(0, 0.2, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs),
|
||||
],
|
||||
)
|
||||
def test_models_update_consistently(seed, dropout, model_func, kwargs, get_X):
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
import pytest
|
||||
|
||||
from spacy.ml.models.tok2vec import build_Tok2Vec_model
|
||||
from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed
|
||||
from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.tokens import Doc
|
||||
|
||||
|
@ -13,18 +15,18 @@ def test_empty_doc():
|
|||
vocab = Vocab()
|
||||
doc = Doc(vocab, words=[])
|
||||
tok2vec = build_Tok2Vec_model(
|
||||
width,
|
||||
embed_size,
|
||||
pretrained_vectors=None,
|
||||
conv_depth=4,
|
||||
bilstm_depth=0,
|
||||
window_size=1,
|
||||
maxout_pieces=3,
|
||||
subword_features=True,
|
||||
char_embed=False,
|
||||
nM=64,
|
||||
nC=8,
|
||||
dropout=None,
|
||||
MultiHashEmbed(
|
||||
width=width,
|
||||
rows=embed_size,
|
||||
also_use_static_vectors=False,
|
||||
also_embed_subwords=True
|
||||
),
|
||||
MaxoutWindowEncoder(
|
||||
width=width,
|
||||
depth=4,
|
||||
window_size=1,
|
||||
maxout_pieces=3
|
||||
)
|
||||
)
|
||||
tok2vec.initialize()
|
||||
vectors, backprop = tok2vec.begin_update([doc])
|
||||
|
@ -38,18 +40,18 @@ def test_empty_doc():
|
|||
def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
||||
batch = get_batch(batch_size)
|
||||
tok2vec = build_Tok2Vec_model(
|
||||
width,
|
||||
embed_size,
|
||||
pretrained_vectors=None,
|
||||
conv_depth=4,
|
||||
bilstm_depth=0,
|
||||
window_size=1,
|
||||
maxout_pieces=3,
|
||||
subword_features=True,
|
||||
char_embed=False,
|
||||
nM=64,
|
||||
nC=8,
|
||||
dropout=None,
|
||||
MultiHashEmbed(
|
||||
width=width,
|
||||
rows=embed_size,
|
||||
also_use_static_vectors=False,
|
||||
also_embed_subwords=True
|
||||
),
|
||||
MaxoutWindowEncoder(
|
||||
width=width,
|
||||
depth=4,
|
||||
window_size=1,
|
||||
maxout_pieces=3,
|
||||
)
|
||||
)
|
||||
tok2vec.initialize()
|
||||
vectors, backprop = tok2vec.begin_update(batch)
|
||||
|
@ -60,24 +62,25 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
|||
|
||||
# fmt: off
|
||||
@pytest.mark.parametrize(
|
||||
"tok2vec_config",
|
||||
"width,embed_arch,embed_config,encode_arch,encode_config",
|
||||
[
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 9, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False, "dropout": None},
|
||||
(8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
|
||||
(8, MultiHashEmbed, {"rows": 100, "also_embed_subwords": True, "also_use_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
|
||||
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
|
||||
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
|
||||
],
|
||||
)
|
||||
# fmt: on
|
||||
def test_tok2vec_configs(tok2vec_config):
|
||||
def test_tok2vec_configs(width, embed_arch, embed_config, encode_arch, encode_config):
|
||||
embed_config["width"] = width
|
||||
encode_config["width"] = width
|
||||
docs = get_batch(3)
|
||||
tok2vec = build_Tok2Vec_model(**tok2vec_config)
|
||||
tok2vec = build_Tok2Vec_model(
|
||||
embed_arch(**embed_config),
|
||||
encode_arch(**encode_config)
|
||||
)
|
||||
tok2vec.initialize(docs)
|
||||
vectors, backprop = tok2vec.begin_update(docs)
|
||||
assert len(vectors) == len(docs)
|
||||
assert vectors[0].shape == (len(docs[0]), tok2vec_config["width"])
|
||||
assert vectors[0].shape == (len(docs[0]), width)
|
||||
backprop(vectors)
|
||||
|
|
|
@ -50,7 +50,6 @@ cdef class Tokenizer:
|
|||
recognised as tokens.
|
||||
url_match (callable): A boolean function matching strings to be
|
||||
recognised as tokens after considering prefixes and suffixes.
|
||||
RETURNS (Tokenizer): The newly constructed object.
|
||||
|
||||
EXAMPLE:
|
||||
>>> tokenizer = Tokenizer(nlp.vocab)
|
||||
|
@ -729,7 +728,7 @@ cdef class Tokenizer:
|
|||
with path.open("wb") as file_:
|
||||
file_.write(self.to_bytes(**kwargs))
|
||||
|
||||
def from_disk(self, path, **kwargs):
|
||||
def from_disk(self, path, *, exclude=tuple()):
|
||||
"""Loads state from a directory. Modifies the object in place and
|
||||
returns it.
|
||||
|
||||
|
@ -742,10 +741,10 @@ cdef class Tokenizer:
|
|||
path = util.ensure_path(path)
|
||||
with path.open("rb") as file_:
|
||||
bytes_data = file_.read()
|
||||
self.from_bytes(bytes_data, **kwargs)
|
||||
self.from_bytes(bytes_data, exclude=exclude)
|
||||
return self
|
||||
|
||||
def to_bytes(self, exclude=tuple()):
|
||||
def to_bytes(self, *, exclude=tuple()):
|
||||
"""Serialize the current state to a binary string.
|
||||
|
||||
exclude (list): String names of serialization fields to exclude.
|
||||
|
@ -764,7 +763,7 @@ cdef class Tokenizer:
|
|||
}
|
||||
return util.to_bytes(serializers, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple()):
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
"""Load state from a binary string.
|
||||
|
||||
bytes_data (bytes): The data to load from.
|
||||
|
|
|
@ -312,6 +312,7 @@ def _split(Doc doc, int token_index, orths, heads, attrs):
|
|||
"""Retokenize the document, such that the token at
|
||||
`doc[token_index]` is split into tokens with the orth 'orths'
|
||||
token_index(int): token index of the token to split.
|
||||
|
||||
orths: IDs of the verbatim text content of the tokens to create
|
||||
**attributes: Attributes to assign to each of the newly created tokens. By default,
|
||||
attributes are inherited from the original token.
|
||||
|
|
|
@ -1,10 +1,12 @@
|
|||
from typing import Iterable, Iterator
|
||||
import numpy
|
||||
import zlib
|
||||
import srsly
|
||||
from thinc.api import NumpyOps
|
||||
|
||||
from .doc import Doc
|
||||
from ..vocab import Vocab
|
||||
from ..compat import copy_reg
|
||||
from ..tokens import Doc
|
||||
from ..attrs import SPACY, ORTH, intify_attr
|
||||
from ..errors import Errors
|
||||
|
||||
|
@ -44,13 +46,18 @@ class DocBin:
|
|||
document from the DocBin.
|
||||
"""
|
||||
|
||||
def __init__(self, attrs=ALL_ATTRS, store_user_data=False, docs=[]):
|
||||
def __init__(
|
||||
self,
|
||||
attrs: Iterable[str] = ALL_ATTRS,
|
||||
store_user_data: bool = False,
|
||||
docs: Iterable[Doc] = tuple(),
|
||||
) -> None:
|
||||
"""Create a DocBin object to hold serialized annotations.
|
||||
|
||||
attrs (list): List of attributes to serialize. 'orth' and 'spacy' are
|
||||
always serialized, so they're not required. Defaults to None.
|
||||
attrs (Iterable[str]): List of attributes to serialize. 'orth' and
|
||||
'spacy' are always serialized, so they're not required.
|
||||
store_user_data (bool): Whether to include the `Doc.user_data`.
|
||||
RETURNS (DocBin): The newly constructed object.
|
||||
docs (Iterable[Doc]): Docs to add.
|
||||
|
||||
DOCS: https://spacy.io/api/docbin#init
|
||||
"""
|
||||
|
@ -68,11 +75,11 @@ class DocBin:
|
|||
for doc in docs:
|
||||
self.add(doc)
|
||||
|
||||
def __len__(self):
|
||||
def __len__(self) -> int:
|
||||
"""RETURNS: The number of Doc objects added to the DocBin."""
|
||||
return len(self.tokens)
|
||||
|
||||
def add(self, doc):
|
||||
def add(self, doc: Doc) -> None:
|
||||
"""Add a Doc's annotations to the DocBin for serialization.
|
||||
|
||||
doc (Doc): The Doc object to add.
|
||||
|
@ -100,7 +107,7 @@ class DocBin:
|
|||
if self.store_user_data:
|
||||
self.user_data.append(srsly.msgpack_dumps(doc.user_data))
|
||||
|
||||
def get_docs(self, vocab):
|
||||
def get_docs(self, vocab: Vocab) -> Iterator[Doc]:
|
||||
"""Recover Doc objects from the annotations, using the given vocab.
|
||||
|
||||
vocab (Vocab): The shared vocab.
|
||||
|
@ -125,7 +132,7 @@ class DocBin:
|
|||
doc.user_data.update(user_data)
|
||||
yield doc
|
||||
|
||||
def merge(self, other):
|
||||
def merge(self, other: "DocBin") -> None:
|
||||
"""Extend the annotations of this DocBin with the annotations from
|
||||
another. Will raise an error if the pre-defined attrs of the two
|
||||
DocBins don't match.
|
||||
|
@ -144,7 +151,7 @@ class DocBin:
|
|||
if self.store_user_data:
|
||||
self.user_data.extend(other.user_data)
|
||||
|
||||
def to_bytes(self):
|
||||
def to_bytes(self) -> bytes:
|
||||
"""Serialize the DocBin's annotations to a bytestring.
|
||||
|
||||
RETURNS (bytes): The serialized DocBin.
|
||||
|
@ -156,7 +163,6 @@ class DocBin:
|
|||
lengths = [len(tokens) for tokens in self.tokens]
|
||||
tokens = numpy.vstack(self.tokens) if self.tokens else numpy.asarray([])
|
||||
spaces = numpy.vstack(self.spaces) if self.spaces else numpy.asarray([])
|
||||
|
||||
msg = {
|
||||
"version": self.version,
|
||||
"attrs": self.attrs,
|
||||
|
@ -171,7 +177,7 @@ class DocBin:
|
|||
msg["user_data"] = self.user_data
|
||||
return zlib.compress(srsly.msgpack_dumps(msg))
|
||||
|
||||
def from_bytes(self, bytes_data):
|
||||
def from_bytes(self, bytes_data: bytes) -> "DocBin":
|
||||
"""Deserialize the DocBin's annotations from a bytestring.
|
||||
|
||||
bytes_data (bytes): The data to load from.
|
||||
|
|
|
@ -173,7 +173,6 @@ cdef class Doc:
|
|||
words. True means that the word is followed by a space, False means
|
||||
it is not. If `None`, defaults to `[True]*len(words)`
|
||||
user_data (dict or None): Optional extra data to attach to the Doc.
|
||||
RETURNS (Doc): The newly constructed object.
|
||||
|
||||
DOCS: https://spacy.io/api/doc#init
|
||||
"""
|
||||
|
@ -988,20 +987,20 @@ cdef class Doc:
|
|||
other.c = &tokens[PADDING]
|
||||
return other
|
||||
|
||||
def to_disk(self, path, **kwargs):
|
||||
def to_disk(self, path, *, exclude=tuple()):
|
||||
"""Save the current state to a directory.
|
||||
|
||||
path (str / Path): A path to a directory, which will be created if
|
||||
it doesn't exist. Paths may be either strings or Path-like objects.
|
||||
exclude (list): String names of serialization fields to exclude.
|
||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||
|
||||
DOCS: https://spacy.io/api/doc#to_disk
|
||||
"""
|
||||
path = util.ensure_path(path)
|
||||
with path.open("wb") as file_:
|
||||
file_.write(self.to_bytes(**kwargs))
|
||||
file_.write(self.to_bytes(exclude=exclude))
|
||||
|
||||
def from_disk(self, path, **kwargs):
|
||||
def from_disk(self, path, *, exclude=tuple()):
|
||||
"""Loads state from a directory. Modifies the object in place and
|
||||
returns it.
|
||||
|
||||
|
@ -1015,9 +1014,9 @@ cdef class Doc:
|
|||
path = util.ensure_path(path)
|
||||
with path.open("rb") as file_:
|
||||
bytes_data = file_.read()
|
||||
return self.from_bytes(bytes_data, **kwargs)
|
||||
return self.from_bytes(bytes_data, exclude=exclude)
|
||||
|
||||
def to_bytes(self, exclude=tuple(), **kwargs):
|
||||
def to_bytes(self, *, exclude=tuple()):
|
||||
"""Serialize, i.e. export the document contents to a binary string.
|
||||
|
||||
exclude (list): String names of serialization fields to exclude.
|
||||
|
@ -1026,9 +1025,9 @@ cdef class Doc:
|
|||
|
||||
DOCS: https://spacy.io/api/doc#to_bytes
|
||||
"""
|
||||
return srsly.msgpack_dumps(self.to_dict(exclude=exclude, **kwargs))
|
||||
return srsly.msgpack_dumps(self.to_dict(exclude=exclude))
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
"""Deserialize, i.e. import the document contents from a binary string.
|
||||
|
||||
data (bytes): The string to load from.
|
||||
|
@ -1037,13 +1036,9 @@ cdef class Doc:
|
|||
|
||||
DOCS: https://spacy.io/api/doc#from_bytes
|
||||
"""
|
||||
return self.from_dict(
|
||||
srsly.msgpack_loads(bytes_data),
|
||||
exclude=exclude,
|
||||
**kwargs
|
||||
)
|
||||
return self.from_dict(srsly.msgpack_loads(bytes_data), exclude=exclude)
|
||||
|
||||
def to_dict(self, exclude=tuple(), **kwargs):
|
||||
def to_dict(self, *, exclude=tuple()):
|
||||
"""Export the document contents to a dictionary for serialization.
|
||||
|
||||
exclude (list): String names of serialization fields to exclude.
|
||||
|
@ -1091,14 +1086,14 @@ cdef class Doc:
|
|||
serializers["user_data_values"] = lambda: srsly.msgpack_dumps(user_data_values)
|
||||
return util.to_dict(serializers, exclude)
|
||||
|
||||
def from_dict(self, msg, exclude=tuple(), **kwargs):
|
||||
def from_dict(self, msg, *, exclude=tuple()):
|
||||
"""Deserialize, i.e. import the document contents from a binary string.
|
||||
|
||||
data (bytes): The string to load from.
|
||||
exclude (list): String names of serialization fields to exclude.
|
||||
RETURNS (Doc): Itself.
|
||||
|
||||
DOCS: https://spacy.io/api/doc#from_bytes
|
||||
DOCS: https://spacy.io/api/doc#from_dict
|
||||
"""
|
||||
if self.length != 0:
|
||||
raise ValueError(Errors.E033.format(length=self.length))
|
||||
|
|
|
@ -94,7 +94,6 @@ cdef class Span:
|
|||
kb_id (uint64): An identifier from a Knowledge Base to capture the meaning of a named entity.
|
||||
vector (ndarray[ndim=1, dtype='float32']): A meaning representation
|
||||
of the span.
|
||||
RETURNS (Span): The newly constructed object.
|
||||
|
||||
DOCS: https://spacy.io/api/span#init
|
||||
"""
|
||||
|
|
|
@ -7,7 +7,7 @@ import importlib.util
|
|||
import re
|
||||
from pathlib import Path
|
||||
import thinc
|
||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
|
||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer, Model
|
||||
import functools
|
||||
import itertools
|
||||
import numpy.random
|
||||
|
@ -24,6 +24,8 @@ import tempfile
|
|||
import shutil
|
||||
import shlex
|
||||
import inspect
|
||||
from thinc.types import Unserializable
|
||||
|
||||
|
||||
try:
|
||||
import cupy.random
|
||||
|
@ -187,6 +189,20 @@ def get_module_path(module: ModuleType) -> Path:
|
|||
return Path(sys.modules[module.__module__].__file__).parent
|
||||
|
||||
|
||||
def load_vectors_into_model(
|
||||
nlp: "Language", name: Union[str, Path], *, add_strings=True
|
||||
) -> None:
|
||||
"""Load word vectors from an installed model or path into a model instance."""
|
||||
vectors_nlp = load_model(name)
|
||||
nlp.vocab.vectors = vectors_nlp.vocab.vectors
|
||||
if add_strings:
|
||||
# I guess we should add the strings from the vectors_nlp model?
|
||||
# E.g. if someone does a similarity query, they might expect the strings.
|
||||
for key in nlp.vocab.vectors.key2row:
|
||||
if key in vectors_nlp.vocab.strings:
|
||||
nlp.vocab.strings.add(vectors_nlp.vocab.strings[key])
|
||||
|
||||
|
||||
def load_model(
|
||||
name: Union[str, Path],
|
||||
disable: Iterable[str] = tuple(),
|
||||
|
@ -1184,22 +1200,6 @@ class DummyTokenizer:
|
|||
return self
|
||||
|
||||
|
||||
def link_vectors_to_models(vocab: "Vocab") -> None:
|
||||
vectors = vocab.vectors
|
||||
if vectors.name is None:
|
||||
vectors.name = VECTORS_KEY
|
||||
if vectors.data.size != 0:
|
||||
warnings.warn(Warnings.W020.format(shape=vectors.data.shape))
|
||||
for word in vocab:
|
||||
if word.orth in vectors.key2row:
|
||||
word.rank = vectors.key2row[word.orth]
|
||||
else:
|
||||
word.rank = 0
|
||||
|
||||
|
||||
VECTORS_KEY = "spacy_pretrained_vectors"
|
||||
|
||||
|
||||
def create_default_optimizer() -> Optimizer:
|
||||
# TODO: Do we still want to allow env_opt?
|
||||
learn_rate = env_opt("learn_rate", 0.001)
|
||||
|
|
|
@ -58,7 +58,6 @@ cdef class Vectors:
|
|||
data (numpy.ndarray): The vector data.
|
||||
keys (iterable): A sequence of keys, aligned with the data.
|
||||
name (str): A name to identify the vectors table.
|
||||
RETURNS (Vectors): The newly created object.
|
||||
|
||||
DOCS: https://spacy.io/api/vectors#init
|
||||
"""
|
||||
|
|
|
@ -16,7 +16,7 @@ from .errors import Errors
|
|||
from .lemmatizer import Lemmatizer
|
||||
from .attrs import intify_attrs, NORM, IS_STOP
|
||||
from .vectors import Vectors
|
||||
from .util import link_vectors_to_models, registry
|
||||
from .util import registry
|
||||
from .lookups import Lookups, load_lookups
|
||||
from . import util
|
||||
from .lang.norm_exceptions import BASE_NORMS
|
||||
|
@ -74,7 +74,6 @@ cdef class Vocab:
|
|||
lookups (Lookups): Container for large lookup tables and dictionaries.
|
||||
oov_prob (float): Default OOV probability.
|
||||
vectors_name (unicode): Optional name to identify the vectors table.
|
||||
RETURNS (Vocab): The newly constructed object.
|
||||
"""
|
||||
lex_attr_getters = lex_attr_getters if lex_attr_getters is not None else {}
|
||||
if lookups in (None, True, False):
|
||||
|
@ -345,7 +344,6 @@ cdef class Vocab:
|
|||
synonym = self.strings[syn_keys[i][0]]
|
||||
score = scores[i][0]
|
||||
remap[word] = (synonym, score)
|
||||
link_vectors_to_models(self)
|
||||
return remap
|
||||
|
||||
def get_vector(self, orth, minn=None, maxn=None):
|
||||
|
@ -440,7 +438,7 @@ cdef class Vocab:
|
|||
orth = self.strings.add(orth)
|
||||
return orth in self.vectors
|
||||
|
||||
def to_disk(self, path, exclude=tuple()):
|
||||
def to_disk(self, path, *, exclude=tuple()):
|
||||
"""Save the current state to a directory.
|
||||
|
||||
path (unicode or Path): A path to a directory, which will be created if
|
||||
|
@ -460,7 +458,7 @@ cdef class Vocab:
|
|||
if "lookups" not in "exclude" and self.lookups is not None:
|
||||
self.lookups.to_disk(path)
|
||||
|
||||
def from_disk(self, path, exclude=tuple()):
|
||||
def from_disk(self, path, *, exclude=tuple()):
|
||||
"""Loads state from a directory. Modifies the object in place and
|
||||
returns it.
|
||||
|
||||
|
@ -477,8 +475,6 @@ cdef class Vocab:
|
|||
if "vectors" not in exclude:
|
||||
if self.vectors is not None:
|
||||
self.vectors.from_disk(path, exclude=["strings"])
|
||||
if self.vectors.name is not None:
|
||||
link_vectors_to_models(self)
|
||||
if "lookups" not in exclude:
|
||||
self.lookups.from_disk(path)
|
||||
if "lexeme_norm" in self.lookups:
|
||||
|
@ -489,7 +485,7 @@ cdef class Vocab:
|
|||
self._by_orth = PreshMap()
|
||||
return self
|
||||
|
||||
def to_bytes(self, exclude=tuple()):
|
||||
def to_bytes(self, *, exclude=tuple()):
|
||||
"""Serialize the current state to a binary string.
|
||||
|
||||
exclude (list): String names of serialization fields to exclude.
|
||||
|
@ -510,7 +506,7 @@ cdef class Vocab:
|
|||
}
|
||||
return util.to_bytes(getters, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, exclude=tuple()):
|
||||
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||
"""Load state from a binary string.
|
||||
|
||||
bytes_data (bytes): The data to load from.
|
||||
|
@ -538,8 +534,6 @@ cdef class Vocab:
|
|||
)
|
||||
self.length = 0
|
||||
self._by_orth = PreshMap()
|
||||
if self.vectors.name is not None:
|
||||
link_vectors_to_models(self)
|
||||
return self
|
||||
|
||||
def _reset_cache(self, keys, strings):
|
||||
|
|
|
@ -4,6 +4,7 @@ teaser: Pre-defined model architectures included with the core library
|
|||
source: spacy/ml/models
|
||||
menu:
|
||||
- ['Tok2Vec', 'tok2vec']
|
||||
- ['Transformers', 'transformers']
|
||||
- ['Parser & NER', 'parser']
|
||||
- ['Text Classification', 'textcat']
|
||||
- ['Entity Linking', 'entitylinker']
|
||||
|
@ -13,7 +14,7 @@ TODO: intro and how architectures work, link to
|
|||
[`registry`](/api/top-level#registry),
|
||||
[custom models](/usage/training#custom-models) usage etc.
|
||||
|
||||
## Tok2Vec architectures {#tok2vec source="spacy/ml/models/tok2vec.py"}}
|
||||
## Tok2Vec architectures {#tok2vec source="spacy/ml/models/tok2vec.py"}
|
||||
|
||||
### spacy.HashEmbedCNN.v1 {#HashEmbedCNN}
|
||||
|
||||
|
@ -21,12 +22,61 @@ TODO: intro and how architectures work, link to
|
|||
|
||||
### spacy.HashCharEmbedBiLSTM.v1 {#HashCharEmbedBiLSTM}
|
||||
|
||||
## Transformer architectures {#transformers source="github.com/explosion/spacy-transformers/blob/master/spacy_transformers/architectures.py"}
|
||||
|
||||
The following architectures are provided by the package
|
||||
[`spacy-transformers`](https://github.com/explosion/spacy-transformers). See the
|
||||
[usage documentation](/usage/transformers) for how to integrate the
|
||||
architectures into your training config.
|
||||
|
||||
### spacy-transformers.TransformerModel.v1 {#TransformerModel}
|
||||
|
||||
<!-- TODO: description -->
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy-transformers.TransformerModel.v1"
|
||||
> name = "roberta-base"
|
||||
> tokenizer_config = {"use_fast": true}
|
||||
>
|
||||
> [model.get_spans]
|
||||
> @span_getters = "strided_spans.v1"
|
||||
> window = 128
|
||||
> stride = 96
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------------ | ---------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `name` | str | Any model name that can be loaded by [`transformers.AutoModel`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoModel). |
|
||||
| `get_spans` | `Callable` | Function that takes a batch of [`Doc`](/api/doc) object and returns lists of [`Span`](/api) objects to process by the transformer. [See here](/api/transformer#span_getters) for built-in options and examples. |
|
||||
| `tokenizer_config` | `Dict[str, Any]` | Tokenizer settings passed to [`transformers.AutoTokenizer`](https://huggingface.co/transformers/model_doc/auto.html#transformers.AutoTokenizer). |
|
||||
|
||||
### spacy-transformers.Tok2VecListener.v1 {#Tok2VecListener}
|
||||
|
||||
<!-- TODO: description -->
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
> [model]
|
||||
> @architectures = "spacy-transformers.Tok2VecListener.v1"
|
||||
> grad_factor = 1.0
|
||||
>
|
||||
> [model.pooling]
|
||||
> @layers = "reduce_mean.v1"
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------- | ------------------------- | ---------------------------------------------------------------------------------------------- |
|
||||
| `grad_factor` | float | Factor for weighting the gradient if multiple components listen to the same transformer model. |
|
||||
| `pooling` | `Model[Ragged, Floats2d]` | Pooling layer to determine how the vector for each spaCy token will be computed. |
|
||||
|
||||
## Parser & NER architectures {#parser source="spacy/ml/models/parser.py"}
|
||||
|
||||
### spacy.TransitionBasedParser.v1 {#TransitionBasedParser}
|
||||
|
||||
<!-- TODO: intro -->
|
||||
|
||||
> #### Example Config
|
||||
>
|
||||
> ```ini
|
||||
|
|
|
@ -13,25 +13,84 @@ datasets in the [DocBin](/api/docbin) (`.spacy`) format.
|
|||
|
||||
Create a `Corpus`. The input data can be a file or a directory of files.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------ | ---------------------------------------------------------------- |
|
||||
| `train` | str / `Path` | Training data (`.spacy` file or directory of `.spacy` files). |
|
||||
| `dev` | str / `Path` | Development data (`.spacy` file or directory of `.spacy` files). |
|
||||
| `limit` | int | Maximum number of examples returned. |
|
||||
| **RETURNS** | `Corpus` | The newly constructed object. |
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.gold import Corpus
|
||||
>
|
||||
> corpus = Corpus("./train.spacy", "./dev.spacy")
|
||||
> ```
|
||||
|
||||
<!-- TODO: document remaining methods / decide which to document -->
|
||||
|
||||
## Corpus.walk_corpus {#walk_corpus tag="staticmethod"}
|
||||
|
||||
## Corpus.make_examples {#make_examples tag="method"}
|
||||
|
||||
## Corpus.make_examples_gold_preproc {#make_examples_gold_preproc tag="method"}
|
||||
|
||||
## Corpus.read_docbin {#read_docbin tag="method"}
|
||||
|
||||
## Corpus.count_train {#count_train tag="method"}
|
||||
| Name | Type | Description |
|
||||
| ------- | ------------ | ---------------------------------------------------------------- |
|
||||
| `train` | str / `Path` | Training data (`.spacy` file or directory of `.spacy` files). |
|
||||
| `dev` | str / `Path` | Development data (`.spacy` file or directory of `.spacy` files). |
|
||||
| `limit` | int | Maximum number of examples returned. `0` for no limit (default). |
|
||||
|
||||
## Corpus.train_dataset {#train_dataset tag="method"}
|
||||
|
||||
Yield examples from the training data.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.gold import Corpus
|
||||
> import spacy
|
||||
>
|
||||
> corpus = Corpus("./train.spacy", "./dev.spacy")
|
||||
> nlp = spacy.blank("en")
|
||||
> train_data = corpus.train_dataset(nlp)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | ---------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `nlp` | `Language` | The current `nlp` object. |
|
||||
| _keyword-only_ | | |
|
||||
| `shuffle` | bool | Whether to shuffle the examples. Defaults to `True`. |
|
||||
| `gold_preproc` | bool | Whether to train on gold-standard sentences and tokens. Defaults to `False`. |
|
||||
| `max_length` | int | Maximum document length. Longer documents will be split into sentences, if sentence boundaries are available. `0` for no limit (default). |
|
||||
| **YIELDS** | `Example` | The examples. |
|
||||
|
||||
## Corpus.dev_dataset {#dev_dataset tag="method"}
|
||||
|
||||
Yield examples from the development data.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.gold import Corpus
|
||||
> import spacy
|
||||
>
|
||||
> corpus = Corpus("./train.spacy", "./dev.spacy")
|
||||
> nlp = spacy.blank("en")
|
||||
> dev_data = corpus.dev_dataset(nlp)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | ---------- | ---------------------------------------------------------------------------- |
|
||||
| `nlp` | `Language` | The current `nlp` object. |
|
||||
| _keyword-only_ | | |
|
||||
| `gold_preproc` | bool | Whether to train on gold-standard sentences and tokens. Defaults to `False`. |
|
||||
| **YIELDS** | `Example` | The examples. |
|
||||
|
||||
## Corpus.count_train {#count_train tag="method"}
|
||||
|
||||
Get the word count of all training examples.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.gold import Corpus
|
||||
> import spacy
|
||||
>
|
||||
> corpus = Corpus("./train.spacy", "./dev.spacy")
|
||||
> nlp = spacy.blank("en")
|
||||
> word_count = corpus.count_train(nlp)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ---------- | ------------------------- |
|
||||
| `nlp` | `Language` | The current `nlp` object. |
|
||||
| **RETURNS** | int | The word count. |
|
||||
|
||||
<!-- TODO: document remaining methods? / decide which to document -->
|
||||
|
|
|
@ -87,13 +87,12 @@ Create a `Token` object from a `TokenC*` pointer.
|
|||
> token = Token.cinit(&doc.c[3], doc, 3)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------- | ------------------------------------------------------------ |
|
||||
| `vocab` | `Vocab` | A reference to the shared `Vocab`. |
|
||||
| `c` | `TokenC*` | A pointer to a [`TokenC`](/api/cython-structs#tokenc)struct. |
|
||||
| `offset` | `int` | The offset of the token within the document. |
|
||||
| `doc` | `Doc` | The parent document. |
|
||||
| **RETURNS** | `Token` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| -------- | --------- | ------------------------------------------------------------ |
|
||||
| `vocab` | `Vocab` | A reference to the shared `Vocab`. |
|
||||
| `c` | `TokenC*` | A pointer to a [`TokenC`](/api/cython-structs#tokenc)struct. |
|
||||
| `offset` | `int` | The offset of the token within the document. |
|
||||
| `doc` | `Doc` | The parent document. |
|
||||
|
||||
## Span {#span tag="cdef class" source="spacy/tokens/span.pxd"}
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ applied to the `Doc` in order. Both [`__call__`](/api/dependencyparser#call) and
|
|||
|
||||
## DependencyParser.begin_training {#begin_training tag="method"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
Initialize the pipe for training, using data examples if available. Returns an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
|
||||
> #### Example
|
||||
|
@ -290,10 +290,11 @@ Serialize the pipe to disk.
|
|||
> parser.to_disk("/path/to/parser")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## DependencyParser.from_disk {#from_disk tag="method"}
|
||||
|
||||
|
@ -306,11 +307,12 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> parser.from_disk("/path/to/parser")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------ | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `DependencyParser` | The modified `DependencyParser` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | ------------------ | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `DependencyParser` | The modified `DependencyParser` object. |
|
||||
|
||||
## DependencyParser.to_bytes {#to_bytes tag="method"}
|
||||
|
||||
|
@ -323,10 +325,11 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
|
||||
Serialize the pipe to a bytestring.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `DependencyParser` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `DependencyParser` object. |
|
||||
|
||||
## DependencyParser.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
|
@ -340,11 +343,12 @@ Load the pipe from a bytestring. Modifies the object in place and returns it.
|
|||
> parser.from_bytes(parser_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | ------------------ | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `DependencyParser` | The `DependencyParser` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | ------------------ | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `DependencyParser` | The `DependencyParser` object. |
|
||||
|
||||
## DependencyParser.labels {#labels tag="property"}
|
||||
|
||||
|
|
|
@ -30,12 +30,11 @@ Construct a `Doc` object. The most common way to get a `Doc` object is via the
|
|||
> doc = Doc(nlp.vocab, words=words, spaces=spaces)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | A storage container for lexical types. |
|
||||
| `words` | iterable | A list of strings to add to the container. |
|
||||
| `spaces` | iterable | A list of boolean values indicating whether each word has a subsequent space. Must have the same length as `words`, if specified. Defaults to a sequence of `True`. |
|
||||
| **RETURNS** | `Doc` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| -------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | A storage container for lexical types. |
|
||||
| `words` | iterable | A list of strings to add to the container. |
|
||||
| `spaces` | iterable | A list of boolean values indicating whether each word has a subsequent space. Must have the same length as `words`, if specified. Defaults to a sequence of `True`. |
|
||||
|
||||
## Doc.\_\_getitem\_\_ {#getitem tag="method"}
|
||||
|
||||
|
@ -386,10 +385,11 @@ Save the current state to a directory.
|
|||
> doc.to_disk("/path/to/doc")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | ------------ | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## Doc.from_disk {#from_disk tag="method" new="2"}
|
||||
|
||||
|
@ -403,11 +403,12 @@ Loads state from a directory. Modifies the object in place and returns it.
|
|||
> doc = Doc(Vocab()).from_disk("/path/to/doc")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------ | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Doc` | The modified `Doc` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Doc` | The modified `Doc` object. |
|
||||
|
||||
## Doc.to_bytes {#to_bytes tag="method"}
|
||||
|
||||
|
@ -420,10 +421,11 @@ Serialize, i.e. export the document contents to a binary string.
|
|||
> doc_bytes = doc.to_bytes()
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | A losslessly serialized copy of the `Doc`, including all annotations. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | A losslessly serialized copy of the `Doc`, including all annotations. |
|
||||
|
||||
## Doc.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
|
@ -439,11 +441,12 @@ Deserialize, i.e. import the document contents from a binary string.
|
|||
> assert doc.text == doc2.text
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----- | ------------------------------------------------------------------------- |
|
||||
| `data` | bytes | The string to load from. |
|
||||
| `exclude` | list | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Doc` | The `Doc` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| `data` | bytes | The string to load from. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Doc` | The `Doc` object. |
|
||||
|
||||
## Doc.retokenize {#retokenize tag="contextmanager" new="2.1"}
|
||||
|
||||
|
|
|
@ -44,11 +44,11 @@ Create a `DocBin` object to hold serialized annotations.
|
|||
> doc_bin = DocBin(attrs=["ENT_IOB", "ENT_TYPE"])
|
||||
> ```
|
||||
|
||||
| Argument | Type | Description |
|
||||
| ----------------- | -------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| `attrs` | list | List of attributes to serialize. `ORTH` (hash of token text) and `SPACY` (whether the token is followed by whitespace) are always serialized, so they're not required. Defaults to `("ORTH", "TAG", "HEAD", "DEP", "ENT_IOB", "ENT_TYPE", "ENT_KB_ID", "LEMMA", "MORPH", "POS")`. |
|
||||
| `store_user_data` | bool | Whether to include the `Doc.user_data` and the values of custom extension attributes. Defaults to `False`. |
|
||||
| **RETURNS** | `DocBin` | The newly constructed object. |
|
||||
| Argument | Type | Description |
|
||||
| ----------------- | --------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `attrs` | `Iterable[str]` | List of attributes to serialize. `ORTH` (hash of token text) and `SPACY` (whether the token is followed by whitespace) are always serialized, so they're not required. Defaults to `("ORTH", "TAG", "HEAD", "DEP", "ENT_IOB", "ENT_TYPE", "ENT_KB_ID", "LEMMA", "MORPH", "POS")`. |
|
||||
| `store_user_data` | bool | Whether to include the `Doc.user_data` and the values of custom extension attributes. Defaults to `False`. |
|
||||
| `docs` | `Iterable[Doc]` | `Doc` objects to add on initialization. |
|
||||
|
||||
## DocBin.\_\len\_\_ {#len tag="method"}
|
||||
|
||||
|
|
|
@ -125,7 +125,7 @@ applied to the `Doc` in order. Both [`__call__`](/api/entitylinker#call) and
|
|||
|
||||
## EntityLinker.begin_training {#begin_training tag="method"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
Initialize the pipe for training, using data examples if available. Returns an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object. Before calling this
|
||||
method, a knowledge base should have been defined with
|
||||
[`set_kb`](/api/entitylinker#set_kb).
|
||||
|
@ -265,10 +265,11 @@ Serialize the pipe to disk.
|
|||
> entity_linker.to_disk("/path/to/entity_linker")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## EntityLinker.from_disk {#from_disk tag="method"}
|
||||
|
||||
|
@ -281,11 +282,12 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> entity_linker.from_disk("/path/to/entity_linker")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `EntityLinker` | The modified `EntityLinker` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `EntityLinker` | The modified `EntityLinker` object. |
|
||||
|
||||
## Serialization fields {#serialization-fields}
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ applied to the `Doc` in order. Both [`__call__`](/api/entityrecognizer#call) and
|
|||
|
||||
## EntityRecognizer.begin_training {#begin_training tag="method"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
Initialize the pipe for training, using data examples if available. Returns an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
|
||||
> #### Example
|
||||
|
@ -289,10 +289,11 @@ Serialize the pipe to disk.
|
|||
> ner.to_disk("/path/to/ner")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## EntityRecognizer.from_disk {#from_disk tag="method"}
|
||||
|
||||
|
@ -305,11 +306,12 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> ner.from_disk("/path/to/ner")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------ | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `EntityRecognizer` | The modified `EntityRecognizer` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | ------------------ | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `EntityRecognizer` | The modified `EntityRecognizer` object. |
|
||||
|
||||
## EntityRecognizer.to_bytes {#to_bytes tag="method"}
|
||||
|
||||
|
@ -322,10 +324,11 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
|
||||
Serialize the pipe to a bytestring.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `EntityRecognizer` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `EntityRecognizer` object. |
|
||||
|
||||
## EntityRecognizer.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
|
@ -339,11 +342,12 @@ Load the pipe from a bytestring. Modifies the object in place and returns it.
|
|||
> ner.from_bytes(ner_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | ------------------ | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `EntityRecognizer` | The `EntityRecognizer` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | ------------------ | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `EntityRecognizer` | The `EntityRecognizer` object. |
|
||||
|
||||
## EntityRecognizer.labels {#labels tag="property"}
|
||||
|
||||
|
|
|
@ -37,7 +37,6 @@ both documents.
|
|||
| `reference` | `Doc` | The document containing gold-standard annotations. Can not be `None`. |
|
||||
| _keyword-only_ | | |
|
||||
| `alignment` | `Alignment` | An object holding the alignment between the tokens of the `predicted` and `reference` documents. |
|
||||
| **RETURNS** | `Example` | The newly constructed object. |
|
||||
|
||||
## Example.from_dict {#from_dict tag="classmethod"}
|
||||
|
||||
|
|
|
@ -27,11 +27,10 @@ Create the knowledge base.
|
|||
> kb = KnowledgeBase(vocab=vocab, entity_vector_length=64)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ---------------------- | --------------- | ---------------------------------------- |
|
||||
| `vocab` | `Vocab` | A `Vocab` object. |
|
||||
| `entity_vector_length` | int | Length of the fixed-size entity vectors. |
|
||||
| **RETURNS** | `KnowledgeBase` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| ---------------------- | ------- | ---------------------------------------- |
|
||||
| `vocab` | `Vocab` | A `Vocab` object. |
|
||||
| `entity_vector_length` | int | Length of the fixed-size entity vectors. |
|
||||
|
||||
## KnowledgeBase.entity_vector_length {#entity_vector_length tag="property"}
|
||||
|
||||
|
@ -255,7 +254,6 @@ but instead these objects are returned by the
|
|||
| `entity_freq` | float | The entity frequency as recorded in the KB. |
|
||||
| `alias_hash` | int | The hash of the textual mention or alias. |
|
||||
| `prior_prob` | float | The prior probability of the `alias` referring to the `entity` |
|
||||
| **RETURNS** | `Candidate` | The newly constructed object. |
|
||||
|
||||
## Candidate attributes {#candidate_attributes}
|
||||
|
||||
|
|
|
@ -15,6 +15,58 @@ the tagger or parser that are called on a document in order. You can also add
|
|||
your own processing pipeline components that take a `Doc` object, modify it and
|
||||
return it.
|
||||
|
||||
## Language.\_\_init\_\_ {#init tag="method"}
|
||||
|
||||
Initialize a `Language` object.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> # Construction from subclass
|
||||
> from spacy.lang.en import English
|
||||
> nlp = English()
|
||||
>
|
||||
> # Construction from scratch
|
||||
> from spacy.vocab import Vocab
|
||||
> from spacy.language import Language
|
||||
> nlp = Language(Vocab())
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------------ | ----------- | ------------------------------------------------------------------------------------------ |
|
||||
| `vocab` | `Vocab` | A `Vocab` object. If `True`, a vocab is created using the default language data settings. |
|
||||
| _keyword-only_ | | |
|
||||
| `max_length` | int | Maximum number of characters allowed in a single text. Defaults to `10 ** 6`. |
|
||||
| `meta` | dict | Custom meta data for the `Language` class. Is written to by models to add model meta data. |
|
||||
| `create_tokenizer` | `Callable` | Optional function that receives the `nlp` object and returns a tokenizer. |
|
||||
|
||||
## Language.from_config {#from_config tag="classmethod"}
|
||||
|
||||
Create a `Language` object from a loaded config. Will set up the tokenizer and
|
||||
language data, add pipeline components based on the pipeline and components
|
||||
define in the config and validate the results. If no config is provided, the
|
||||
default config of the given language is used. This is also how spaCy loads a
|
||||
model under the hood based on its [`config.cfg`](/api/data-formats#config).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from thinc.api import Config
|
||||
> from spacy.language import Language
|
||||
>
|
||||
> config = Config().from_disk("./config.cfg")
|
||||
> nlp = Language.from_config(config)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | ---------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `config` | `Dict[str, Any]` / [`Config`](https://thinc.ai/docs/api-config#config) | The loaded config. |
|
||||
| _keyword-only_ | |
|
||||
| `disable` | `Iterable[str]` | List of pipeline component names to disable. |
|
||||
| `auto_fill` | bool | Whether to automatically fill in missing values in the config, based on defaults and function argument annotations. Defaults to `True`. |
|
||||
| `validate` | bool | Whether to validate the component config and arguments against the types expected by the factory. Defaults to `True`. |
|
||||
| **RETURNS** | `Language` | The initialized object. |
|
||||
|
||||
## Language.component {#component tag="classmethod" new="3"}
|
||||
|
||||
Register a custom pipeline component under a given name. This allows
|
||||
|
@ -101,57 +153,6 @@ examples, see the
|
|||
| `default_score_weights` | `Dict[str, float]` | The scores to report during training, and their default weight towards the final score used to select the best model. Weights should sum to `1.0` per component and will be combined and normalized for the whole pipeline. |
|
||||
| `func` | `Optional[Callable]` | Optional function if not used a a decorator. |
|
||||
|
||||
## Language.\_\_init\_\_ {#init tag="method"}
|
||||
|
||||
Initialize a `Language` object.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from spacy.vocab import Vocab
|
||||
> from spacy.language import Language
|
||||
> nlp = Language(Vocab())
|
||||
>
|
||||
> from spacy.lang.en import English
|
||||
> nlp = English()
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------------ | ----------- | ------------------------------------------------------------------------------------------ |
|
||||
| `vocab` | `Vocab` | A `Vocab` object. If `True`, a vocab is created using the default language data settings. |
|
||||
| _keyword-only_ | | |
|
||||
| `max_length` | int | Maximum number of characters allowed in a single text. Defaults to `10 ** 6`. |
|
||||
| `meta` | dict | Custom meta data for the `Language` class. Is written to by models to add model meta data. |
|
||||
| `create_tokenizer` | `Callable` | Optional function that receives the `nlp` object and returns a tokenizer. |
|
||||
| **RETURNS** | `Language` | The newly constructed object. |
|
||||
|
||||
## Language.from_config {#from_config tag="classmethod"}
|
||||
|
||||
Create a `Language` object from a loaded config. Will set up the tokenizer and
|
||||
language data, add pipeline components based on the pipeline and components
|
||||
define in the config and validate the results. If no config is provided, the
|
||||
default config of the given language is used. This is also how spaCy loads a
|
||||
model under the hood based on its [`config.cfg`](/api/data-formats#config).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> from thinc.api import Config
|
||||
> from spacy.language import Language
|
||||
>
|
||||
> config = Config().from_disk("./config.cfg")
|
||||
> nlp = Language.from_config(config)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | ---------------------------------------------------------------------- | --------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `config` | `Dict[str, Any]` / [`Config`](https://thinc.ai/docs/api-config#config) | The loaded config. |
|
||||
| _keyword-only_ | |
|
||||
| `disable` | `Iterable[str]` | List of pipeline component names to disable. |
|
||||
| `auto_fill` | bool | Whether to automatically fill in missing values in the config, based on defaults and function argument annotations. Defaults to `True`. |
|
||||
| `validate` | bool | Whether to validate the component config and arguments against the types expected by the factory. Defaults to `True`. |
|
||||
| **RETURNS** | `Language` | The initialized object. |
|
||||
|
||||
## Language.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
Apply the pipeline to some text. The text can span multiple sentences, and can
|
||||
|
@ -164,11 +165,13 @@ contain arbitrary whitespace. Alignment into the original string is preserved.
|
|||
> assert (doc[0].text, doc[0].head.tag_) == ("An", "NN")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----------- | --------------------------------------------------------------------------------- |
|
||||
| `text` | str | The text to be processed. |
|
||||
| `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
|
||||
| **RETURNS** | `Doc` | A container for accessing the annotations. |
|
||||
| Name | Type | Description |
|
||||
| --------------- | ----------------- | ------------------------------------------------------------------------------------------------------ |
|
||||
| `text` | str | The text to be processed. |
|
||||
| _keyword-only_ | | |
|
||||
| `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
|
||||
| `component_cfg` | `Dict[str, dict]` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. |
|
||||
| **RETURNS** | [`Doc`](/api/doc) | A container for accessing the annotations. |
|
||||
|
||||
## Language.pipe {#pipe tag="method"}
|
||||
|
||||
|
@ -183,15 +186,57 @@ more efficient than processing texts one-by-one.
|
|||
> assert doc.is_parsed
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------------------------------------- | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `texts` | `Iterable[str]` | A sequence of strings. |
|
||||
| `as_tuples` | bool | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. |
|
||||
| `batch_size` | int | The number of texts to buffer. |
|
||||
| `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
|
||||
| `component_cfg` <Tag variant="new">2.1</Tag> | `Dict[str, Dict]` | Config parameters for specific pipeline components, keyed by component name. |
|
||||
| `n_process` <Tag variant="new">2.2.2</Tag> | int | Number of processors to use, only supported in Python 3. Defaults to `1`. |
|
||||
| **YIELDS** | `Doc` | Documents in the order of the original text. |
|
||||
| Name | Type | Description |
|
||||
| ------------------------------------------ | ----------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `texts` | `Iterable[str]` | A sequence of strings. |
|
||||
| _keyword-only_ | | |
|
||||
| `as_tuples` | bool | If set to `True`, inputs should be a sequence of `(text, context)` tuples. Output will then be a sequence of `(doc, context)` tuples. Defaults to `False`. |
|
||||
| `batch_size` | int | The number of texts to buffer. |
|
||||
| `disable` | `List[str]` | Names of pipeline components to [disable](/usage/processing-pipelines#disabling). |
|
||||
| `cleanup` | bool | If `True`, unneeded strings are freed to control memory use. Experimental. |
|
||||
| `component_cfg` | `Dict[str, dict]` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. |
|
||||
| `n_process` <Tag variant="new">2.2.2</Tag> | int | Number of processors to use, only supported in Python 3. Defaults to `1`. |
|
||||
| **YIELDS** | `Doc` | Documents in the order of the original text. |
|
||||
|
||||
## Language.begin_training {#begin_training tag="method"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Returns an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> optimizer = nlp.begin_training(get_examples)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------------------------------------------- | ----------------------------------------------------------------------------------------------------------- |
|
||||
| `get_examples` | `Callable[[], Iterable[Example]]` | Optional function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | An optional optimizer. Will be created via [`create_optimizer`](/api/language#create_optimizer) if not set. |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## Language.resume_training {#resume_training tag="method,experimental" new="3"}
|
||||
|
||||
Continue training a pretrained model. Create and return an optimizer, and
|
||||
initialize "rehearsal" for any pipeline component that has a `rehearse` method.
|
||||
Rehearsal is used to prevent models from "forgetting" their initialized
|
||||
"knowledge". To perform rehearsal, collect samples of text you want the models
|
||||
to retain performance on, and call [`nlp.rehearse`](/api/language#rehearse) with
|
||||
a batch of [Example](/api/example) objects.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> optimizer = nlp.resume_training()
|
||||
> nlp.rehearse(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------------------------------------------- | ----------------------------------------------------------------------------------------------------------- |
|
||||
| _keyword-only_ | | |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | An optional optimizer. Will be created via [`create_optimizer`](/api/language#create_optimizer) if not set. |
|
||||
| **RETURNS** | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
|
||||
## Language.update {#update tag="method"}
|
||||
|
||||
|
@ -206,15 +251,37 @@ Update the models in the pipeline.
|
|||
> nlp.update([example], sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------------------------------------- | ------------------- | ---------------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | A batch of `Example` objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `sgd` | `Optimizer` | An [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. |
|
||||
| `losses` | `Dict[str, float]` | Dictionary to update with the loss, keyed by pipeline component. |
|
||||
| `component_cfg` <Tag variant="new">2.1</Tag> | `Dict[str, Dict]` | Config parameters for specific pipeline components, keyed by component name. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
| Name | Type | Description |
|
||||
| --------------- | --------------------------------------------------- | ------------------------------------------------------------------------------------------------------ |
|
||||
| `examples` | `Iterable[Example]` | A batch of `Example` objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
| `losses` | `Dict[str, float]` | Dictionary to update with the loss, keyed by pipeline component. |
|
||||
| `component_cfg` | `Dict[str, dict]` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
|
||||
## Language.rehearse {#rehearse tag="method,experimental"}
|
||||
|
||||
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
||||
current model to make predictions similar to an initial model, to try to address
|
||||
the "catastrophic forgetting" problem. This feature is experimental.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> optimizer = nlp.resume_training()
|
||||
> losses = nlp.rehearse(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------------------------------------------- | ----------------------------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `drop` | float | The dropout rate. |
|
||||
| `sgd` | [`Optimizer`](https://thinc.ai/docs/api-optimizers) | The optimizer. |
|
||||
| `losses` | `Dict[str, float]` | Optional record of the loss during training. Updated using the component name as the key. |
|
||||
| **RETURNS** | `Dict[str, float]` | The updated `losses` dictionary. |
|
||||
|
||||
## Language.evaluate {#evaluate tag="method"}
|
||||
|
||||
|
@ -227,33 +294,15 @@ Evaluate a model's pipeline components.
|
|||
> print(scores)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------------------------------------- | ------------------------------- | ------------------------------------------------------------------------------------- |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| `verbose` | bool | Print debugging information. |
|
||||
| `batch_size` | int | The batch size to use. |
|
||||
| `scorer` | `Scorer` | Optional [`Scorer`](/api/scorer) to use. If not passed in, a new one will be created. |
|
||||
| `component_cfg` <Tag variant="new">2.1</Tag> | `Dict[str, Dict]` | Config parameters for specific pipeline components, keyed by component name. |
|
||||
| **RETURNS** | `Dict[str, Union[float, Dict]]` | A dictionary of evaluation scores. |
|
||||
|
||||
## Language.begin_training {#begin_training tag="method"}
|
||||
|
||||
Allocate models, pre-process training data and acquire an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers).
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> optimizer = nlp.begin_training(get_examples)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| -------------------------------------------- | ------------------- | ------------------------------------------------------------------------------------------------------------------ |
|
||||
| `get_examples` | `Iterable[Example]` | Optional gold-standard annotations in the form of [`Example`](/api/example) objects. |
|
||||
| `sgd` | `Optimizer` | An optional [`Optimizer`](https://thinc.ai/docs/api-optimizers) object. If not set, a default one will be created. |
|
||||
| `component_cfg` <Tag variant="new">2.1</Tag> | `Dict[str, Dict]` | Config parameters for specific pipeline components, keyed by component name. |
|
||||
| `**cfg` | - | Config parameters (sent to all components). |
|
||||
| **RETURNS** | `Optimizer` | An optimizer. |
|
||||
| Name | Type | Description |
|
||||
| --------------- | ------------------------------- | ------------------------------------------------------------------------------------------------------ |
|
||||
| `examples` | `Iterable[Example]` | A batch of [`Example`](/api/example) objects to learn from. |
|
||||
| _keyword-only_ | | |
|
||||
| `verbose` | bool | Print debugging information. |
|
||||
| `batch_size` | int | The batch size to use. |
|
||||
| `scorer` | `Scorer` | Optional [`Scorer`](/api/scorer) to use. If not passed in, a new one will be created. |
|
||||
| `component_cfg` | `Dict[str, dict]` | Optional dictionary of keyword arguments for components, keyed by component names. Defaults to `None`. |
|
||||
| **RETURNS** | `Dict[str, Union[float, dict]]` | A dictionary of evaluation scores. |
|
||||
|
||||
## Language.use_params {#use_params tag="contextmanager, method"}
|
||||
|
||||
|
@ -296,6 +345,7 @@ To create a component and add it to the pipeline, you should always use
|
|||
| ------------------------------------- | ---------------- | --------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `factory_name` | str | Name of the registered component factory. |
|
||||
| `name` | str | Optional unique name of pipeline component instance. If not set, the factory name is used. An error is raised if the name already exists in the pipeline. |
|
||||
| _keyword-only_ | | |
|
||||
| `config` <Tag variant="new">3</Tag> | `Dict[str, Any]` | Optional config parameters to use for this component. Will be merged with the `default_config` specified by the component factory. |
|
||||
| `validate` <Tag variant="new">3</Tag> | bool | Whether to validate the component config and arguments against the types expected by the factory. Defaults to `True`. |
|
||||
| **RETURNS** | callable | The pipeline component. |
|
||||
|
@ -418,10 +468,13 @@ Replace a component in the pipeline.
|
|||
> nlp.replace_pipe("parser", my_custom_parser)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | -------- | --------------------------------- |
|
||||
| `name` | str | Name of the component to replace. |
|
||||
| `component` | callable | The pipeline component to insert. |
|
||||
| Name | Type | Description |
|
||||
| ------------------------------------- | ---------------- | ------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `name` | str | Name of the component to replace. |
|
||||
| `component` | callable | The pipeline component to insert. |
|
||||
| _keyword-only_ | | |
|
||||
| `config` <Tag variant="new">3</Tag> | `Dict[str, Any]` | Optional config parameters to use for the new component. Will be merged with the `default_config` specified by the component factory. |
|
||||
| `validate` <Tag variant="new">3</Tag> | bool | Whether to validate the component config and arguments against the types expected by the factory. Defaults to `True`. |
|
||||
|
||||
## Language.rename_pipe {#rename_pipe tag="method" new="2"}
|
||||
|
||||
|
@ -492,11 +545,12 @@ As of spaCy v3.0, the `disable_pipes` method has been renamed to `select_pipes`:
|
|||
|
||||
</Infobox>
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------------------------------ |
|
||||
| `disable` | str / list | Name(s) of pipeline components to disable. |
|
||||
| `enable` | str / list | Names(s) of pipeline components that will not be disabled. |
|
||||
| **RETURNS** | `DisabledPipes` | The disabled pipes that can be restored by calling the object's `.restore()` method. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------------------------------------ |
|
||||
| _keyword-only_ | | |
|
||||
| `disable` | str / list | Name(s) of pipeline components to disable. |
|
||||
| `enable` | str / list | Names(s) of pipeline components that will not be disabled. |
|
||||
| **RETURNS** | `DisabledPipes` | The disabled pipes that can be restored by calling the object's `.restore()` method. |
|
||||
|
||||
## Language.get_factory_meta {#get_factory_meta tag="classmethod" new="3"}
|
||||
|
||||
|
@ -591,10 +645,11 @@ the model**.
|
|||
> nlp.to_disk("/path/to/models")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | ------------ | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | Names of pipeline components or [serialization fields](#serialization-fields) to exclude. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | Names of pipeline components or [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## Language.from_disk {#from_disk tag="method" new="2"}
|
||||
|
||||
|
@ -616,11 +671,12 @@ loaded object.
|
|||
> nlp = English().from_disk("/path/to/en_model")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------ | ----------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | list | Names of pipeline components or [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Language` | The modified `Language` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ----------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | Names of pipeline components or [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Language` | The modified `Language` object. |
|
||||
|
||||
## Language.to_bytes {#to_bytes tag="method"}
|
||||
|
||||
|
@ -632,10 +688,11 @@ Serialize the current state to a binary string.
|
|||
> nlp_bytes = nlp.to_bytes()
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----- | ----------------------------------------------------------------------------------------- |
|
||||
| `exclude` | list | Names of pipeline components or [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `Language` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ----------------------------------------------------------------------------------------- |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | Names of pipeline components or [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `Language` object. |
|
||||
|
||||
## Language.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
|
@ -653,11 +710,12 @@ available to the loaded object.
|
|||
> nlp2.from_bytes(nlp_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | ---------- | ----------------------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| `exclude` | list | Names of pipeline components or [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Language` | The `Language` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ----------------------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | Names of pipeline components or [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Language` | The `Language` object. |
|
||||
|
||||
## Attributes {#attributes}
|
||||
|
||||
|
@ -767,8 +825,8 @@ serialization by passing in the string names via the `exclude` argument.
|
|||
The `FactoryMeta` contains the information about the component and its default
|
||||
provided by the [`@Language.component`](/api/language#component) or
|
||||
[`@Language.factory`](/api/language#factory) decorator. It's created whenever a
|
||||
component is added to the pipeline and stored on the `Language` class for each
|
||||
component instance and factory instance.
|
||||
component is defined and stored on the `Language` class for each component
|
||||
instance and factory instance.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------------------- | ------------------ | --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
|
|
|
@ -31,7 +31,6 @@ when a `Language` subclass and its `Vocab` is initialized.
|
|||
| Name | Type | Description |
|
||||
| -------------------------------------- | ------------------------- | ------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `lookups` <Tag variant="new">2.2</Tag> | [`Lookups`](/api/lookups) | The lookups object containing the (optional) tables `"lemma_rules"`, `"lemma_index"`, `"lemma_exc"` and `"lemma_lookup"`. |
|
||||
| **RETURNS** | `Lemmatizer` | The newly created object. |
|
||||
|
||||
## Lemmatizer.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
|
|
|
@ -13,11 +13,10 @@ lemmatization depends on the part-of-speech tag).
|
|||
|
||||
Create a `Lexeme` object.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | -------- | ----------------------------- |
|
||||
| `vocab` | `Vocab` | The parent vocabulary. |
|
||||
| `orth` | int | The orth id of the lexeme. |
|
||||
| **RETURNS** | `Lexeme` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| ------- | ------- | -------------------------- |
|
||||
| `vocab` | `Vocab` | The parent vocabulary. |
|
||||
| `orth` | int | The orth id of the lexeme. |
|
||||
|
||||
## Lexeme.set_flag {#set_flag tag="method"}
|
||||
|
||||
|
|
|
@ -236,10 +236,9 @@ Initialize a new table.
|
|||
> assert table["foo"] == "bar"
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------- | ---------------------------------- |
|
||||
| `name` | str | Optional table name for reference. |
|
||||
| **RETURNS** | `Table` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| ------ | ---- | ---------------------------------- |
|
||||
| `name` | str | Optional table name for reference. |
|
||||
|
||||
### Table.from_dict {#table.from_dict tag="classmethod"}
|
||||
|
||||
|
|
|
@ -19,11 +19,10 @@ string where an integer is expected) or unexpected property names.
|
|||
> matcher = Matcher(nlp.vocab)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------------------------------------- | --------- | ------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The vocabulary object, which must be shared with the documents the matcher will operate on. |
|
||||
| `validate` <Tag variant="new">2.1</Tag> | bool | Validate all patterns added to this matcher. |
|
||||
| **RETURNS** | `Matcher` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| --------------------------------------- | ------- | ------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The vocabulary object, which must be shared with the documents the matcher will operate on. |
|
||||
| `validate` <Tag variant="new">2.1</Tag> | bool | Validate all patterns added to this matcher. |
|
||||
|
||||
## Matcher.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
|
|
|
@ -6,7 +6,6 @@ source: spacy/tokens/morphanalysis.pyx
|
|||
|
||||
Stores a single morphological analysis.
|
||||
|
||||
|
||||
## MorphAnalysis.\_\_init\_\_ {#init tag="method"}
|
||||
|
||||
Initialize a MorphAnalysis object from a UD FEATS string or a dictionary of
|
||||
|
@ -16,17 +15,15 @@ morphological features.
|
|||
>
|
||||
> ```python
|
||||
> from spacy.tokens import MorphAnalysis
|
||||
>
|
||||
>
|
||||
> feats = "Feat1=Val1|Feat2=Val2"
|
||||
> m = MorphAnalysis(nlp.vocab, feats)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------ | ----------------------------- |
|
||||
| `vocab` | `Vocab` | The vocab. |
|
||||
| `features` | `Union[Dict, str]` | The morphological features. |
|
||||
| **RETURNS** | `MorphAnalysis` | The newly constructed object. |
|
||||
|
||||
| Name | Type | Description |
|
||||
| ---------- | ------------------ | --------------------------- |
|
||||
| `vocab` | `Vocab` | The vocab. |
|
||||
| `features` | `Union[Dict, str]` | The morphological features. |
|
||||
|
||||
## MorphAnalysis.\_\_contains\_\_ {#contains tag="method"}
|
||||
|
||||
|
@ -44,7 +41,6 @@ Whether a feature/value pair is in the analysis.
|
|||
| ----------- | ----- | ------------------------------------- |
|
||||
| **RETURNS** | `str` | A feature/value pair in the analysis. |
|
||||
|
||||
|
||||
## MorphAnalysis.\_\_iter\_\_ {#iter tag="method"}
|
||||
|
||||
Iterate over the feature/value pairs in the analysis.
|
||||
|
@ -61,7 +57,6 @@ Iterate over the feature/value pairs in the analysis.
|
|||
| ---------- | ----- | ------------------------------------- |
|
||||
| **YIELDS** | `str` | A feature/value pair in the analysis. |
|
||||
|
||||
|
||||
## MorphAnalysis.\_\_len\_\_ {#len tag="method"}
|
||||
|
||||
Returns the number of features in the analysis.
|
||||
|
@ -78,7 +73,6 @@ Returns the number of features in the analysis.
|
|||
| ----------- | ----- | --------------------------------------- |
|
||||
| **RETURNS** | `int` | The number of features in the analysis. |
|
||||
|
||||
|
||||
## MorphAnalysis.\_\_str\_\_ {#str tag="method"}
|
||||
|
||||
Returns the morphological analysis in the UD FEATS string format.
|
||||
|
@ -92,10 +86,9 @@ Returns the morphological analysis in the UD FEATS string format.
|
|||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ----- | ---------------------------------|
|
||||
| ----------- | ----- | -------------------------------- |
|
||||
| **RETURNS** | `str` | The analysis in UD FEATS format. |
|
||||
|
||||
|
||||
## MorphAnalysis.get {#get tag="method"}
|
||||
|
||||
Retrieve values for a feature by field.
|
||||
|
@ -108,11 +101,10 @@ Retrieve values for a feature by field.
|
|||
> assert morph.get("Feat1") == ["Val1", "Val2"]
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------ | ----------------------------------- |
|
||||
| `field` | `str` | The field to retrieve. |
|
||||
| **RETURNS** | `list` | A list of the individual features. |
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------ | ---------------------------------- |
|
||||
| `field` | `str` | The field to retrieve. |
|
||||
| **RETURNS** | `list` | A list of the individual features. |
|
||||
|
||||
## MorphAnalysis.to_dict {#to_dict tag="method"}
|
||||
|
||||
|
@ -128,10 +120,9 @@ map.
|
|||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------ | -----------------------------------------|
|
||||
| ----------- | ------ | ---------------------------------------- |
|
||||
| **RETURNS** | `dict` | The dict representation of the analysis. |
|
||||
|
||||
|
||||
## MorphAnalysis.from_id {#from_id tag="classmethod"}
|
||||
|
||||
Create a morphological analysis from a given hash ID.
|
||||
|
@ -149,5 +140,3 @@ Create a morphological analysis from a given hash ID.
|
|||
| ------- | ------- | -------------------------------- |
|
||||
| `vocab` | `Vocab` | The vocab. |
|
||||
| `key` | `int` | The hash of the features string. |
|
||||
|
||||
|
||||
|
|
|
@ -121,7 +121,7 @@ applied to the `Doc` in order. Both [`__call__`](/api/morphologizer#call) and
|
|||
|
||||
## Morphologizer.begin_training {#begin_training tag="method"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
Initialize the pipe for training, using data examples if available. Returns an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
|
||||
> #### Example
|
||||
|
@ -276,10 +276,11 @@ Serialize the pipe to disk.
|
|||
> morphologizer.to_disk("/path/to/morphologizer")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## Morphologizer.from_disk {#from_disk tag="method"}
|
||||
|
||||
|
@ -292,11 +293,12 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> morphologizer.from_disk("/path/to/morphologizer")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Morphologizer` | The modified `Morphologizer` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Morphologizer` | The modified `Morphologizer` object. |
|
||||
|
||||
## Morphologizer.to_bytes {#to_bytes tag="method"}
|
||||
|
||||
|
@ -309,10 +311,11 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
|
||||
Serialize the pipe to a bytestring.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `Morphologizer` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the `Morphologizer` object. |
|
||||
|
||||
## Morphologizer.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
|
@ -326,11 +329,12 @@ Load the pipe from a bytestring. Modifies the object in place and returns it.
|
|||
> morphologizer.from_bytes(morphologizer_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | --------------- | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Morphologizer` | The `Morphologizer` object. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Morphologizer` | The `Morphologizer` object. |
|
||||
|
||||
## Morphologizer.labels {#labels tag="property"}
|
||||
|
||||
|
|
|
@ -4,12 +4,11 @@ tag: class
|
|||
source: spacy/morphology.pyx
|
||||
---
|
||||
|
||||
Store the possible morphological analyses for a language, and index them
|
||||
by hash. To save space on each token, tokens only know the hash of their
|
||||
Store the possible morphological analyses for a language, and index them by
|
||||
hash. To save space on each token, tokens only know the hash of their
|
||||
morphological analysis, so queries of morphological attributes are delegated to
|
||||
this class.
|
||||
|
||||
|
||||
## Morphology.\_\_init\_\_ {#init tag="method"}
|
||||
|
||||
Create a Morphology object using the tag map, lemmatizer and exceptions.
|
||||
|
@ -22,21 +21,18 @@ Create a Morphology object using the tag map, lemmatizer and exceptions.
|
|||
> morphology = Morphology(strings, tag_map, lemmatizer)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ---------------------------------------- | --------------------------------------------------------------------------------------------------------- |
|
||||
| `strings` | `StringStore` | The string store. |
|
||||
| `tag_map` | `Dict[str, Dict]` | The tag map. |
|
||||
| `lemmatizer`| `Lemmatizer` | The lemmatizer. |
|
||||
| `exc` | `Dict[str, Dict]` | A dictionary of exceptions in the format `{tag: {orth: {"POS": "X", "Feat1": "Val1, "Feat2": "Val2", ...}` |
|
||||
| **RETURNS** | `Morphology` | The newly constructed object. |
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | ----------------- | ---------------------------------------------------------------------------------------------------------- |
|
||||
| `strings` | `StringStore` | The string store. |
|
||||
| `tag_map` | `Dict[str, Dict]` | The tag map. |
|
||||
| `lemmatizer` | `Lemmatizer` | The lemmatizer. |
|
||||
| `exc` | `Dict[str, Dict]` | A dictionary of exceptions in the format `{tag: {orth: {"POS": "X", "Feat1": "Val1, "Feat2": "Val2", ...}` |
|
||||
|
||||
## Morphology.add {#add tag="method"}
|
||||
|
||||
Insert a morphological analysis in the morphology table, if not already
|
||||
present. The morphological analysis may be provided in the UD FEATS format as a
|
||||
string or in the tag map dictionary format. Returns the hash of the new
|
||||
analysis.
|
||||
Insert a morphological analysis in the morphology table, if not already present.
|
||||
The morphological analysis may be provided in the UD FEATS format as a string or
|
||||
in the tag map dictionary format. Returns the hash of the new analysis.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -46,10 +42,9 @@ analysis.
|
|||
> assert hash == nlp.vocab.strings[feats]
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------- | --------------------------- |
|
||||
| `features` | `Union[Dict, str]` | The morphological features. |
|
||||
|
||||
| Name | Type | Description |
|
||||
| ---------- | ------------------ | --------------------------- |
|
||||
| `features` | `Union[Dict, str]` | The morphological features. |
|
||||
|
||||
## Morphology.get {#get tag="method"}
|
||||
|
||||
|
@ -63,33 +58,30 @@ analysis.
|
|||
|
||||
Get the FEATS string for the hash of the morphological analysis.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------ | --------------------------------------- |
|
||||
| `morph` | int | The hash of the morphological analysis. |
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------- | ---- | --------------------------------------- |
|
||||
| `morph` | int | The hash of the morphological analysis. |
|
||||
|
||||
## Morphology.load_tag_map {#load_tag_map tag="method"}
|
||||
|
||||
Replace the current tag map with the provided tag map.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ------------------ | ------------ |
|
||||
| `tag_map` | `Dict[str, Dict]` | The tag map. |
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | ----------------- | ------------ |
|
||||
| `tag_map` | `Dict[str, Dict]` | The tag map. |
|
||||
|
||||
## Morphology.load_morph_exceptions {#load_morph_exceptions tag="method"}
|
||||
|
||||
Replace the current morphological exceptions with the provided exceptions.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------- | ------------------ | ----------------------------- |
|
||||
| `morph_rules` | `Dict[str, Dict]` | The morphological exceptions. |
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------- | ----------------- | ----------------------------- |
|
||||
| `morph_rules` | `Dict[str, Dict]` | The morphological exceptions. |
|
||||
|
||||
## Morphology.add_special_case {#add_special_case tag="method"}
|
||||
|
||||
Add a special-case rule to the morphological analyzer. Tokens whose tag and
|
||||
orth match the rule will receive the specified properties.
|
||||
Add a special-case rule to the morphological analyzer. Tokens whose tag and orth
|
||||
match the rule will receive the specified properties.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -98,27 +90,24 @@ orth match the rule will receive the specified properties.
|
|||
> morphology.add_special_case("DT", "the", attrs)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ---- | ---------------------------------------------- |
|
||||
| `tag_str` | str | The fine-grained tag. |
|
||||
| `orth_str` | str | The token text. |
|
||||
| `attrs` | dict | The features to assign for this token and tag. |
|
||||
|
||||
| Name | Type | Description |
|
||||
| ---------- | ---- | ---------------------------------------------- |
|
||||
| `tag_str` | str | The fine-grained tag. |
|
||||
| `orth_str` | str | The token text. |
|
||||
| `attrs` | dict | The features to assign for this token and tag. |
|
||||
|
||||
## Morphology.exc {#exc tag="property"}
|
||||
|
||||
The current morphological exceptions.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ---------- | ----- | --------------------------------------------------- |
|
||||
| **YIELDS** | dict | The current dictionary of morphological exceptions. |
|
||||
|
||||
| Name | Type | Description |
|
||||
| ---------- | ---- | --------------------------------------------------- |
|
||||
| **YIELDS** | dict | The current dictionary of morphological exceptions. |
|
||||
|
||||
## Morphology.lemmatize {#lemmatize tag="method"}
|
||||
|
||||
TODO
|
||||
|
||||
|
||||
## Morphology.feats_to_dict {#feats_to_dict tag="staticmethod"}
|
||||
|
||||
Convert a string FEATS representation to a dictionary of features and values in
|
||||
|
@ -132,11 +121,10 @@ the same format as the tag map.
|
|||
> assert d == {"Feat1": "Val1", "Feat2": "Val2"}
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | ---- | ------------------------------------------------------------- |
|
||||
| Name | Type | Description |
|
||||
| ----------- | ---- | ------------------------------------------------------------------ |
|
||||
| `feats` | str | The morphological features in Universal Dependencies FEATS format. |
|
||||
| **RETURNS** | dict | The morphological features as a dictionary. |
|
||||
|
||||
| **RETURNS** | dict | The morphological features as a dictionary. |
|
||||
|
||||
## Morphology.dict_to_feats {#dict_to_feats tag="staticmethod"}
|
||||
|
||||
|
@ -150,12 +138,11 @@ Convert a dictionary of features and values to a string FEATS representation.
|
|||
> assert f == "Feat1=Val1|Feat2=Val2"
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| Name | Type | Description |
|
||||
| ------------ | ----------------- | --------------------------------------------------------------------- |
|
||||
| `feats_dict` | `Dict[str, Dict]` | The morphological features as a dictionary. |
|
||||
| **RETURNS** | str | The morphological features as in Universal Dependencies FEATS format. |
|
||||
|
||||
|
||||
## Attributes {#attributes}
|
||||
|
||||
| Name | Type | Description |
|
||||
|
|
|
@ -35,12 +35,11 @@ be shown.
|
|||
> matcher = PhraseMatcher(nlp.vocab)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------------------------------------- | --------------- | ------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The vocabulary object, which must be shared with the documents the matcher will operate on. |
|
||||
| `attr` <Tag variant="new">2.1</Tag> | int / str | The token attribute to match on. Defaults to `ORTH`, i.e. the verbatim token text. |
|
||||
| `validate` <Tag variant="new">2.1</Tag> | bool | Validate patterns added to the matcher. |
|
||||
| **RETURNS** | `PhraseMatcher` | The newly constructed object. |
|
||||
| Name | Type | Description |
|
||||
| --------------------------------------- | --------- | ------------------------------------------------------------------------------------------- |
|
||||
| `vocab` | `Vocab` | The vocabulary object, which must be shared with the documents the matcher will operate on. |
|
||||
| `attr` <Tag variant="new">2.1</Tag> | int / str | The token attribute to match on. Defaults to `ORTH`, i.e. the verbatim token text. |
|
||||
| `validate` <Tag variant="new">2.1</Tag> | bool | Validate patterns added to the matcher. |
|
||||
|
||||
## PhraseMatcher.\_\_call\_\_ {#call tag="method"}
|
||||
|
||||
|
|
|
@ -95,7 +95,7 @@ applied to the `Doc` in order. Both [`__call__`](/api/pipe#call) and
|
|||
|
||||
## Pipe.begin_training {#begin_training tag="method"}
|
||||
|
||||
Initialize the pipe for training, using data examples if available. Return an
|
||||
Initialize the pipe for training, using data examples if available. Returns an
|
||||
[`Optimizer`](https://thinc.ai/docs/api-optimizers) object.
|
||||
|
||||
> #### Example
|
||||
|
@ -198,7 +198,7 @@ the "catastrophic forgetting" problem. This feature is experimental.
|
|||
>
|
||||
> ```python
|
||||
> pipe = nlp.add_pipe("your_custom_pipe")
|
||||
> optimizer = nlp.begin_training()
|
||||
> optimizer = nlp.resume_training()
|
||||
> losses = pipe.rehearse(examples, sgd=optimizer)
|
||||
> ```
|
||||
|
||||
|
@ -306,10 +306,11 @@ Serialize the pipe to disk.
|
|||
> pipe.to_disk("/path/to/pipe")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| --------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | --------------------------------------------------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
|
||||
## Pipe.from_disk {#from_disk tag="method"}
|
||||
|
||||
|
@ -322,11 +323,12 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
> pipe.from_disk("/path/to/pipe")
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Pipe` | The modified pipe. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | -------------------------------------------------------------------------- |
|
||||
| `path` | str / `Path` | A path to a directory. Paths may be either strings or `Path`-like objects. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Pipe` | The modified pipe. |
|
||||
|
||||
## Pipe.to_bytes {#to_bytes tag="method"}
|
||||
|
||||
|
@ -339,10 +341,11 @@ Load the pipe from disk. Modifies the object in place and returns it.
|
|||
|
||||
Serialize the pipe to a bytestring.
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the pipe. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | bytes | The serialized form of the pipe. |
|
||||
|
||||
## Pipe.from_bytes {#from_bytes tag="method"}
|
||||
|
||||
|
@ -356,11 +359,12 @@ Load the pipe from a bytestring. Modifies the object in place and returns it.
|
|||
> pipe.from_bytes(pipe_bytes)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ------------ | --------------- | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Pipe` | The pipe. |
|
||||
| Name | Type | Description |
|
||||
| -------------- | --------------- | ------------------------------------------------------------------------- |
|
||||
| `bytes_data` | bytes | The data to load from. |
|
||||
| _keyword-only_ | | |
|
||||
| `exclude` | `Iterable[str]` | String names of [serialization fields](#serialization-fields) to exclude. |
|
||||
| **RETURNS** | `Pipe` | The pipe. |
|
||||
|
||||
## Serialization fields {#serialization-fields}
|
||||
|
||||
|
|
|
@ -27,10 +27,9 @@ Create a new `Scorer`.
|
|||
> scorer = Scorer(nlp)
|
||||
> ```
|
||||
|
||||
| Name | Type | Description |
|
||||
| ----------- | -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `nlp` | Language | The pipeline to use for scoring, where each pipeline component may provide a scoring method. If none is provided, then a default pipeline for the multi-language code `xx` is constructed containing: `senter`, `tagger`, `morphologizer`, `parser`, `ner`, `textcat`. |
|
||||
| **RETURNS** | `Scorer` | The newly created object. |
|
||||
| Name | Type | Description |
|
||||
| ----- | -------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| `nlp` | Language | The pipeline to use for scoring, where each pipeline component may provide a scoring method. If none is provided, then a default pipeline for the multi-language code `xx` is constructed containing: `senter`, `tagger`, `morphologizer`, `parser`, `ner`, `textcat`. |
|
||||
|
||||
## Scorer.score {#score tag="method"}
|
||||
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user