Default settings to configurations (#4995)

* fix grad_clip naming

* cleaning up pretrained_vectors out of cfg

* further refactoring Model init's

* move Model building out of pipes

* further refactor to require a model config when creating a pipe

* small fixes

* making cfg in nn_parser more consistent

* fixing nr_class for parser

* fixing nn_parser's nO

* fix printing of loss

* architectures in own file per type, consistent naming

* convenience methods default_tagger_config and default_tok2vec_config

* let create_pipe access default config if available for that component

* default_parser_config

* move defaults to separate folder

* allow reading nlp from package or dir with argument 'name'

* architecture spacy.VocabVectors.v1 to read static vectors from file

* cleanup

* default configs for nel, textcat, morphologizer, tensorizer

* fix imports

* fixing unit tests

* fixes and clean up

* fixing defaults, nO, fix unit tests

* restore parser IO

* fix IO

* 'fix' serialization test

* add *.cfg to manifest

* fix example configs with additional arguments

* replace Morpohologizer with Tagger

* add IO bit when testing overfitting of tagger (currently failing)

* fix IO - don't initialize when reading from disk

* expand overfitting tests to also check IO goes OK

* remove dropout from HashEmbed to fix Tagger performance

* add defaults for sentrec

* update thinc

* always pass a Model instance to a Pipe

* fix piped_added statement

* remove obsolete W029

* remove obsolete errors

* restore byte checking tests (work again)

* clean up test

* further test cleanup

* convert from config to Model in create_pipe

* bring back error when component is not initialized

* cleanup

* remove calls for nlp2.begin_training

* use thinc.api in imports

* allow setting charembed's nM and nC

* fix for hardcoded nM/nC + unit test

* formatting fixes

* trigger build
This commit is contained in:
Sofie Van Landeghem 2020-02-27 18:42:27 +01:00 committed by GitHub
parent f39ddda193
commit 06f0a8daa0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
64 changed files with 1511 additions and 1213 deletions

View File

@ -1,5 +1,5 @@
recursive-include include *.h
recursive-include spacy *.txt *.pyx *.pxd
recursive-include spacy *.pyx *.pxd *.txt *.cfg
include LICENSE
include README.md
include bin/spacy

View File

@ -386,8 +386,8 @@ def _load_pretrained_tok2vec(nlp, loc):
weights_data = file_.read()
loaded = []
for name, component in nlp.pipeline:
if hasattr(component, "model") and hasattr(component.model, "tok2vec"):
component.tok2vec.from_bytes(weights_data)
if hasattr(component, "model") and component.model.has_ref("tok2vec"):
component.get_ref("tok2vec").from_bytes(weights_data)
loaded.append(name)
return loaded

View File

@ -1,13 +1,9 @@
# coding: utf-8
from random import shuffle
import logging
import numpy as np
from thinc.model import Model
from thinc.api import chain
from thinc.loss import CosineDistance
from thinc.layers import Linear
from thinc.api import Model, chain, CosineDistance, Linear
from spacy.util import create_default_optimizer

View File

@ -39,25 +39,27 @@ factory = "tagger"
factory = "parser"
[nlp.pipeline.tagger.model]
@architectures = "tagger_model.v1"
@architectures = "spacy.Tagger.v1"
[nlp.pipeline.tagger.model.tok2vec]
@architectures = "tok2vec_tensors.v1"
@architectures = "spacy.Tok2VecTensors.v1"
width = ${nlp.pipeline.tok2vec.model:width}
[nlp.pipeline.parser.model]
@architectures = "transition_based_parser.v1"
@architectures = "spacy.TransitionBasedParser.v1"
nr_feature_tokens = 8
hidden_width = 64
maxout_pieces = 3
[nlp.pipeline.parser.model.tok2vec]
@architectures = "tok2vec_tensors.v1"
@architectures = "spacy.Tok2VecTensors.v1"
width = ${nlp.pipeline.tok2vec.model:width}
[nlp.pipeline.tok2vec.model]
@architectures = "hash_embed_bilstm.v1"
@architectures = "spacy.HashEmbedBiLSTM.v1"
pretrained_vectors = ${nlp:vectors}
width = 96
depth = 4
embed_size = 2000
subword_features = true
char_embed = false

View File

@ -39,27 +39,28 @@ factory = "tagger"
factory = "parser"
[nlp.pipeline.tagger.model]
@architectures = "tagger_model.v1"
@architectures = "spacy.Tagger.v1"
[nlp.pipeline.tagger.model.tok2vec]
@architectures = "tok2vec_tensors.v1"
@architectures = "spacy.Tok2VecTensors.v1"
width = ${nlp.pipeline.tok2vec.model:width}
[nlp.pipeline.parser.model]
@architectures = "transition_based_parser.v1"
@architectures = "spacy.TransitionBasedParser.v1"
nr_feature_tokens = 8
hidden_width = 64
maxout_pieces = 3
[nlp.pipeline.parser.model.tok2vec]
@architectures = "tok2vec_tensors.v1"
@architectures = "spacy.Tok2VecTensors.v1"
width = ${nlp.pipeline.tok2vec.model:width}
[nlp.pipeline.tok2vec.model]
@architectures = "hash_embed_cnn.v1"
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = ${nlp:vectors}
width = 96
depth = 4
window_size = 1
embed_size = 2000
maxout_pieces = 3
subword_features = true

View File

@ -20,9 +20,9 @@ import random
import ml_datasets
import spacy
from spacy.util import minibatch, use_gpu, compounding
from spacy.util import minibatch
from spacy.pipeline import TextCategorizer
from spacy.ml.tok2vec import Tok2Vec
from spacy.ml.models.tok2vec import build_Tok2Vec_model
import numpy
@ -65,9 +65,7 @@ def prefer_gpu():
def build_textcat_model(tok2vec, nr_class, width):
from thinc.model import Model
from thinc.layers import Softmax, chain, reduce_mean
from thinc.layers import list2ragged
from thinc.api import Model, Softmax, chain, reduce_mean, list2ragged
with Model.define_operators({">>": chain}):
model = (
@ -76,7 +74,7 @@ def build_textcat_model(tok2vec, nr_class, width):
>> reduce_mean()
>> Softmax(nr_class, width)
)
model.tok2vec = tok2vec
model.set_ref("tok2vec", tok2vec)
return model
@ -97,8 +95,9 @@ def create_pipeline(width, embed_size, vectors_model):
textcat = TextCategorizer(
nlp.vocab,
labels=["POSITIVE", "NEGATIVE"],
# TODO: replace with config version
model=build_textcat_model(
Tok2Vec(width=width, embed_size=embed_size), 2, width
build_Tok2Vec_model(width=width, embed_size=embed_size), 2, width
),
)
@ -121,7 +120,7 @@ def train_tensorizer(nlp, texts, dropout, n_iter):
def train_textcat(nlp, n_texts, n_iter=10):
textcat = nlp.get_pipe("textcat")
tok2vec_weights = textcat.model.tok2vec.to_bytes()
tok2vec_weights = textcat.model.get_ref("tok2vec").to_bytes()
(train_texts, train_cats), (dev_texts, dev_cats) = load_textcat_data(limit=n_texts)
print(
"Using {} examples ({} training, {} evaluation)".format(
@ -135,7 +134,7 @@ def train_textcat(nlp, n_texts, n_iter=10):
other_pipes = [pipe for pipe in nlp.pipe_names if pipe not in pipe_exceptions]
with nlp.disable_pipes(*other_pipes): # only train textcat
optimizer = nlp.begin_training()
textcat.model.tok2vec.from_bytes(tok2vec_weights)
textcat.model.get_ref("tok2vec").from_bytes(tok2vec_weights)
print("Training the model...")
print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F"))
for i in range(n_iter):

View File

@ -74,7 +74,7 @@ def main(model=None, output_dir=None, n_iter=20, n_texts=2000, init_tok2vec=None
optimizer = nlp.begin_training()
if init_tok2vec is not None:
with init_tok2vec.open("rb") as file_:
textcat.model.tok2vec.from_bytes(file_.read())
textcat.model.get_ref("tok2vec").from_bytes(file_.read())
print("Training the model...")
print("{:^5}\t{:^5}\t{:^5}\t{:^5}".format("LOSS", "P", "R", "F"))
batch_sizes = compounding(4.0, 32.0, 1.001)

View File

@ -6,7 +6,7 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
"thinc==8.0.0a0",
"thinc==8.0.0a1",
"blis>=0.4.0,<0.5.0"
]
build-backend = "setuptools.build_meta"

View File

@ -1,7 +1,7 @@
# Our libraries
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc==8.0.0a0
thinc==8.0.0a1
blis>=0.4.0,<0.5.0
ml_datasets>=0.1.1
murmurhash>=0.28.0,<1.1.0

View File

@ -36,13 +36,13 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0
thinc==8.0.0a0
thinc==8.0.0a1
install_requires =
# Our libraries
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc==8.0.0a0
thinc==8.0.0a1
blis>=0.4.0,<0.5.0
wasabi>=0.4.0,<1.1.0
srsly>=2.0.0,<3.0.0

View File

@ -11,10 +11,10 @@ import srsly
from ..gold import Example
from ..errors import Errors
from ..ml.models.multi_task import build_masked_language_model
from ..tokens import Doc
from ..attrs import ID, HEAD
from ..ml.component_models import Tok2Vec
from ..ml.component_models import masked_language_model
from ..ml.models.tok2vec import build_Tok2Vec_model
from .. import util
from ..util import create_default_optimizer
from .train import _load_pretrained_tok2vec
@ -108,14 +108,19 @@ def pretrain(
pretrained_vectors = None if not use_vectors else nlp.vocab.vectors
model = create_pretraining_model(
nlp,
Tok2Vec(
# TODO: replace with config
build_Tok2Vec_model(
width,
embed_rows,
conv_depth=conv_depth,
pretrained_vectors=pretrained_vectors,
bilstm_depth=bilstm_depth, # Requires PyTorch. Experimental.
subword_features=not use_chars, # Set to False for Chinese etc
cnn_maxout_pieces=cnn_pieces, # If set to 1, use Mish activation.
maxout_pieces=cnn_pieces, # If set to 1, use Mish activation.
window_size=1,
char_embed=False,
nM=64,
nC=8
),
)
# Load in pretrained weights
@ -152,7 +157,7 @@ def pretrain(
is_temp_str = ".temp" if is_temp else ""
with model.use_params(optimizer.averages):
with (output_dir / f"model{epoch}{is_temp_str}.bin").open("wb") as file_:
file_.write(model.tok2vec.to_bytes())
file_.write(model.get_ref("tok2vec").to_bytes())
log = {
"nr_word": tracker.nr_word,
"loss": tracker.loss,
@ -284,7 +289,7 @@ def create_pretraining_model(nlp, tok2vec):
# "tok2vec" has to be the same set of processes as what the components do.
tok2vec = chain(tok2vec, list2array())
model = chain(tok2vec, output_layer)
model = masked_language_model(nlp.vocab, model)
model = build_masked_language_model(nlp.vocab, model)
model.set_ref("tok2vec", tok2vec)
model.set_ref("output_layer", output_layer)
model.initialize(X=[nlp.make_doc("Give it a doc to infer shapes")])

View File

@ -9,7 +9,7 @@ from wasabi import msg
import contextlib
import random
from ..util import create_default_optimizer
from ..util import create_default_optimizer, registry
from ..util import use_gpu as set_gpu
from ..attrs import PROB, IS_OOV, CLUSTER, LANG
from ..gold import GoldCorpus
@ -111,6 +111,8 @@ def train(
eval_beam_widths.sort()
has_beam_widths = eval_beam_widths != [1]
default_dir = Path(__file__).parent.parent / "ml" / "models" / "defaults"
# Set up the base model and pipeline. If a base model is specified, load
# the model and make sure the pipeline matches the pipeline setting. If
# training starts from a blank model, intitalize the language class.
@ -118,7 +120,6 @@ def train(
msg.text(f"Training pipeline: {pipeline}")
disabled_pipes = None
pipes_added = False
msg.text(f"Training pipeline: {pipeline}")
if use_gpu >= 0:
activated_gpu = None
try:
@ -140,16 +141,36 @@ def train(
f"specified as `lang` argument ('{lang}') ",
exits=1,
)
if vectors:
msg.text(f"Loading vectors from model '{vectors}'")
nlp.disable_pipes([p for p in nlp.pipe_names if p not in pipeline])
for pipe in pipeline:
pipe_cfg = {}
# first, create the model.
# Bit of a hack after the refactor to get the vectors into a default config
# use train-from-config instead :-)
if pipe == "parser":
pipe_cfg = {"learn_tokens": learn_tokens}
config_loc = default_dir / "parser_defaults.cfg"
elif pipe == "tagger":
config_loc = default_dir / "tagger_defaults.cfg"
elif pipe == "ner":
config_loc = default_dir / "ner_defaults.cfg"
elif pipe == "textcat":
pipe_cfg = {
"exclusive_classes": not textcat_multilabel,
"architecture": textcat_arch,
"positive_label": textcat_positive_label,
}
config_loc = default_dir / "textcat_defaults.cfg"
else:
raise ValueError(f"Component {pipe} currently not supported.")
pipe_cfg = util.load_config(config_loc, create_objects=False)
if vectors:
pretrained_config = {'@architectures': 'spacy.VocabVectors.v1', 'name': vectors}
pipe_cfg["model"]["tok2vec"]["pretrained_vectors"] = pretrained_config
if pipe == "parser":
pipe_cfg["learn_tokens"] = learn_tokens
elif pipe == "textcat":
pipe_cfg["exclusive_classes"] = not textcat_multilabel
pipe_cfg["architecture"] = textcat_arch
pipe_cfg["positive_label"] = textcat_positive_label
if pipe not in nlp.pipe_names:
msg.text(f"Adding component to base model '{pipe}'")
nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
@ -181,26 +202,42 @@ def train(
msg.text(f"Starting with blank model '{lang}'")
lang_cls = util.get_lang_class(lang)
nlp = lang_cls()
if vectors:
msg.text(f"Loading vectors from model '{vectors}'")
for pipe in pipeline:
# first, create the model.
# Bit of a hack after the refactor to get the vectors into a default config
# use train-from-config instead :-)
if pipe == "parser":
pipe_cfg = {"learn_tokens": learn_tokens}
config_loc = default_dir / "parser_defaults.cfg"
elif pipe == "tagger":
config_loc = default_dir / "tagger_defaults.cfg"
elif pipe == "ner":
config_loc = default_dir / "ner_defaults.cfg"
elif pipe == "textcat":
pipe_cfg = {
"exclusive_classes": not textcat_multilabel,
"architecture": textcat_arch,
"positive_label": textcat_positive_label,
}
config_loc = default_dir / "textcat_defaults.cfg"
else:
pipe_cfg = {}
nlp.add_pipe(nlp.create_pipe(pipe, config=pipe_cfg))
raise ValueError(f"Component {pipe} currently not supported.")
pipe_cfg = util.load_config(config_loc, create_objects=False)
if vectors:
pretrained_config = {'@architectures': 'spacy.VocabVectors.v1', 'name': vectors}
pipe_cfg["model"]["tok2vec"]["pretrained_vectors"] = pretrained_config
if pipe == "parser":
pipe_cfg["learn_tokens"] = learn_tokens
elif pipe == "textcat":
pipe_cfg["exclusive_classes"] = not textcat_multilabel
pipe_cfg["architecture"] = textcat_arch
pipe_cfg["positive_label"] = textcat_positive_label
pipe = nlp.create_pipe(pipe, config=pipe_cfg)
nlp.add_pipe(pipe)
# Update tag map with provided mapping
nlp.vocab.morphology.tag_map.update(tag_map)
if vectors:
msg.text(f"Loading vector from model '{vectors}'")
_load_vectors(nlp, vectors)
# Multitask objectives
multitask_options = [("parser", parser_multitasks), ("ner", entity_multitasks)]
for pipe_name, multitasks in multitask_options:
@ -228,7 +265,7 @@ def train(
optimizer = nlp.begin_training(lambda: corpus.train_examples, **cfg)
nlp._optimizer = None
# Load in pretrained weights
# Load in pretrained weights (TODO: this may be broken in the config rewrite)
if init_tok2vec is not None:
components = _load_pretrained_tok2vec(nlp, init_tok2vec)
msg.text(f"Loaded pretrained tok2vec for: {components}")
@ -531,7 +568,7 @@ def _create_progress_bar(total):
def _load_vectors(nlp, vectors):
util.load_model(vectors, vocab=nlp.vocab)
loaded_model = util.load_model(vectors, vocab=nlp.vocab)
for lex in nlp.vocab:
values = {}
for attr, func in nlp.vocab.lex_attr_getters.items():
@ -541,6 +578,7 @@ def _load_vectors(nlp, vectors):
values[lex.vocab.strings[attr]] = func(lex.orth_)
lex.set_attrs(**values)
lex.is_oov = False
return loaded_model
def _load_pretrained_tok2vec(nlp, loc):
@ -551,8 +589,8 @@ def _load_pretrained_tok2vec(nlp, loc):
weights_data = file_.read()
loaded = []
for name, component in nlp.pipeline:
if hasattr(component, "model") and hasattr(component.model, "tok2vec"):
component.tok2vec.from_bytes(weights_data)
if hasattr(component, "model") and component.model.has_ref("tok2vec"):
component.get_ref("tok2vec").from_bytes(weights_data)
loaded.append(name)
return loaded

View File

@ -1,19 +1,17 @@
from typing import Optional, Dict, List, Union, Sequence
from pydantic import BaseModel, FilePath, StrictInt
import plac
from wasabi import msg
import tqdm
from pathlib import Path
from wasabi import msg
import thinc
import thinc.schedules
from thinc.api import Model
from pydantic import BaseModel, FilePath, StrictInt
import tqdm
# TODO: relative imports?
import spacy
from spacy.gold import GoldCorpus
from spacy.pipeline.tok2vec import Tok2VecListener
from spacy.ml import component_models
from spacy import util
from ..gold import GoldCorpus
from .. import util
registry = util.registry
@ -57,23 +55,24 @@ factory = "tok2vec"
factory = "ner"
[nlp.pipeline.ner.model]
@architectures = "transition_based_ner.v1"
@architectures = "spacy.TransitionBasedParser.v1"
nr_feature_tokens = 3
hidden_width = 64
maxout_pieces = 3
[nlp.pipeline.ner.model.tok2vec]
@architectures = "tok2vec_tensors.v1"
@architectures = "spacy.Tok2VecTensors.v1"
width = ${nlp.pipeline.tok2vec.model:width}
[nlp.pipeline.tok2vec.model]
@architectures = "hash_embed_cnn.v1"
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = ${nlp:vectors}
width = 128
depth = 4
window_size = 1
embed_size = 10000
maxout_pieces = 3
subword_features = true
"""
@ -113,65 +112,6 @@ class ConfigSchema(BaseModel):
extra = "allow"
# Of course, these would normally decorate the functions where they're defined.
# But for now...
@registry.architectures.register("hash_embed_cnn.v1")
def hash_embed_cnn(
pretrained_vectors, width, depth, embed_size, maxout_pieces, window_size
):
return component_models.Tok2Vec(
width=width,
embed_size=embed_size,
pretrained_vectors=pretrained_vectors,
conv_depth=depth,
cnn_maxout_pieces=maxout_pieces,
bilstm_depth=0,
window_size=window_size,
)
@registry.architectures.register("hash_embed_bilstm.v1")
def hash_embed_bilstm_v1(pretrained_vectors, width, depth, embed_size):
return component_models.Tok2Vec(
width=width,
embed_size=embed_size,
pretrained_vectors=pretrained_vectors,
bilstm_depth=depth,
conv_depth=0,
cnn_maxout_pieces=0,
)
@registry.architectures.register("tagger_model.v1")
def build_tagger_model_v1(tok2vec):
return component_models.build_tagger_model(nr_class=None, tok2vec=tok2vec)
@registry.architectures.register("transition_based_parser.v1")
def create_tb_parser_model(
tok2vec: Model,
nr_feature_tokens: StrictInt = 3,
hidden_width: StrictInt = 64,
maxout_pieces: StrictInt = 3,
):
from thinc.api import Linear, chain, list2array, use_ops, zero_init
from spacy.ml._layers import PrecomputableAffine
from spacy.syntax._parser_model import ParserModel
token_vector_width = tok2vec.get_dim("nO")
tok2vec = chain(tok2vec, list2array())
tok2vec.set_dim("nO", token_vector_width)
lower = PrecomputableAffine(
hidden_width, nF=nr_feature_tokens, nI=tok2vec.get_dim("nO"), nP=maxout_pieces
)
lower.set_dim("nP", maxout_pieces)
with use_ops("numpy"):
# Initialize weights at zero, as it's a classification layer.
upper = Linear(init_W=zero_init)
return ParserModel(tok2vec, lower, upper)
@plac.annotations(
# fmt: off
train_path=("Location of JSON-formatted training data", "positional", None, Path),
@ -224,23 +164,25 @@ def train_from_config(
config_path, data_paths, raw_text=None, meta_path=None, output_path=None,
):
msg.info(f"Loading config from: {config_path}")
config = util.load_from_config(config_path, create_objects=True)
config = util.load_config(config_path, create_objects=True)
use_gpu = config["training"]["use_gpu"]
if use_gpu >= 0:
msg.info("Using GPU")
else:
msg.info("Using CPU")
msg.info("Creating nlp from config")
nlp = create_nlp_from_config(**config["nlp"])
nlp_config = util.load_config(config_path, create_objects=False)["nlp"]
nlp = util.load_model_from_config(nlp_config)
optimizer = config["optimizer"]
limit = config["training"]["limit"]
training = config["training"]
limit = training["limit"]
msg.info("Loading training corpus")
corpus = GoldCorpus(data_paths["train"], data_paths["dev"], limit=limit)
msg.info("Initializing the nlp pipeline")
nlp.begin_training(lambda: corpus.train_examples, device=use_gpu)
train_batches = create_train_batches(nlp, corpus, config["training"])
evaluate = create_evaluation_callback(nlp, optimizer, corpus, config["training"])
train_batches = create_train_batches(nlp, corpus, training)
evaluate = create_evaluation_callback(nlp, optimizer, corpus, training)
# Create iterator, which yields out info after each optimization step.
msg.info("Start training")
@ -249,16 +191,16 @@ def train_from_config(
optimizer,
train_batches,
evaluate,
config["training"]["dropout"],
config["training"]["patience"],
config["training"]["eval_frequency"],
training["dropout"],
training["patience"],
training["eval_frequency"],
)
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
print_row = setup_printer(config)
print_row = setup_printer(training, nlp)
try:
progress = tqdm.tqdm(total=config["training"]["eval_frequency"], leave=False)
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
for batch, info, is_best_checkpoint in training_step_iterator:
progress.update(1)
if is_best_checkpoint is not None:
@ -266,9 +208,7 @@ def train_from_config(
print_row(info)
if is_best_checkpoint and output_path is not None:
nlp.to_disk(output_path)
progress = tqdm.tqdm(
total=config["training"]["eval_frequency"], leave=False
)
progress = tqdm.tqdm(total=training["eval_frequency"], leave=False)
finally:
if output_path is not None:
with nlp.use_params(optimizer.averages):
@ -280,18 +220,6 @@ def train_from_config(
# msg.good("Created best model", best_model_path)
def create_nlp_from_config(lang, vectors, pipeline):
lang_class = spacy.util.get_lang_class(lang)
nlp = lang_class()
if vectors is not None:
spacy.cli.train._load_vectors(nlp, vectors)
for name, component_cfg in pipeline.items():
factory = component_cfg.pop("factory")
component = nlp.create_pipe(factory, config=component_cfg)
nlp.add_pipe(component, name=name)
return nlp
def create_train_batches(nlp, corpus, cfg):
while True:
train_examples = corpus.train_dataset(
@ -405,10 +333,10 @@ def subdivide_batch(batch):
return [batch]
def setup_printer(config):
score_cols = config["training"]["scores"]
def setup_printer(training, nlp):
score_cols = training["scores"]
score_widths = [max(len(col), 6) for col in score_cols]
loss_cols = [f"Loss {pipe}" for pipe in config["nlp"]["pipeline"]]
loss_cols = [f"Loss {pipe}" for pipe in nlp.pipe_names]
loss_widths = [max(len(col), 8) for col in loss_cols]
table_header = ["#"] + loss_cols + score_cols + ["Score"]
table_header = [col.upper() for col in table_header]
@ -420,20 +348,13 @@ def setup_printer(config):
def print_row(info):
losses = [
"{0:.2f}".format(info["losses"].get(col, 0.0))
for col in config["nlp"]["pipeline"]
"{0:.2f}".format(info["losses"].get(pipe_name, 0.0))
for pipe_name in nlp.pipe_names
]
scores = [
"{0:.2f}".format(info["other_scores"].get(col, 0.0))
for col in config["training"]["scores"]
"{0:.2f}".format(info["other_scores"].get(col, 0.0)) for col in score_cols
]
data = [info["step"]] + losses + scores + ["{0:.2f}".format(info["score"])]
msg.row(data, widths=table_widths, aligns=table_aligns)
return print_row
@registry.architectures.register("tok2vec_tensors.v1")
def tok2vec_tensors_v1(width):
tok2vec = Tok2VecListener("tok2vec", width=width)
return tok2vec

View File

@ -106,6 +106,12 @@ class Warnings(object):
"Provide features as a dict {{\"Field1\": \"Value1,Value2\"}} or "
"string \"Field1=Value1,Value2|Field2=Value3\".")
# TODO: fix numbering after merging develop into master
W098 = ("No Model config was provided to create the '{name}' component, "
"so a default configuration was used.")
W099 = ("Expected 'dict' type for the 'model' argument of pipe '{pipe}', "
"but got '{type}' instead, so ignoring it.")
@add_codes
class Errors(object):
@ -227,7 +233,7 @@ class Errors(object):
E050 = ("Can't find model '{name}'. It doesn't seem to be a Python "
"package or a valid path to a data directory.")
E052 = ("Can't find model directory: {path}")
E053 = ("Could not read meta.json from {path}")
E053 = ("Could not read {name} from {path}")
E054 = ("No valid '{setting}' setting found in model meta.json.")
E055 = ("Invalid ORTH value in exception:\nKey: {key}\nOrths: {orths}")
E056 = ("Invalid tokenizer exception: ORTH values combined don't match "
@ -345,8 +351,8 @@ class Errors(object):
E108 = ("As of spaCy v2.1, the pipe name `sbd` has been deprecated "
"in favor of the pipe name `sentencizer`, which does the same "
"thing. For example, use `nlp.create_pipeline('sentencizer')`")
E109 = ("Model for component '{name}' not initialized. Did you forget to "
"load a model, or forget to call begin_training()?")
E109 = ("Component '{name}' could not be run. Did you forget to "
"call begin_training()?")
E110 = ("Invalid displaCy render wrapper. Expected callable, got: {obj}")
E111 = ("Pickling a token is not supported, because tokens are only views "
"of the parent Doc and can't exist on their own. A pickled token "
@ -532,6 +538,9 @@ class Errors(object):
"make sure the gold EL data refers to valid results of the "
"named entity recognizer in the `nlp` pipeline.")
# TODO: fix numbering after merging develop into master
E993 = ("The config for 'nlp' should include either a key 'name' to "
"refer to an existing model by name or path, or a key 'lang' "
"to create a new blank model.")
E996 = ("Could not parse {file}: {msg}")
E997 = ("Tokenizer special cases are not allowed to modify the text. "
"This would map '{chunk}' to '{orth}' given token attributes "

View File

@ -4,7 +4,9 @@ import weakref
import functools
from contextlib import contextmanager
from copy import copy, deepcopy
from thinc.api import get_current_ops
from pathlib import Path
from thinc.api import get_current_ops, Config
import srsly
import multiprocessing as mp
from itertools import chain, cycle
@ -16,7 +18,7 @@ from .lookups import Lookups
from .analysis import analyze_pipes, analyze_all_pipes, validate_attrs
from .gold import Example
from .scorer import Scorer
from .util import link_vectors_to_models, create_default_optimizer
from .util import link_vectors_to_models, create_default_optimizer, registry
from .attrs import IS_STOP, LANG
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES
from .lang.punctuation import TOKENIZER_INFIXES
@ -24,7 +26,7 @@ from .lang.tokenizer_exceptions import TOKEN_MATCH
from .lang.tag_map import TAG_MAP
from .tokens import Doc
from .lang.lex_attrs import LEX_ATTRS, is_stop
from .errors import Errors, Warnings, deprecation_warning
from .errors import Errors, Warnings, deprecation_warning, user_warning
from . import util
from . import about
@ -128,7 +130,7 @@ class Language(object):
factories = {"tokenizer": lambda nlp: nlp.Defaults.create_tokenizer(nlp)}
def __init__(
self, vocab=True, make_doc=True, max_length=10 ** 6, meta={}, **kwargs
self, vocab=True, make_doc=True, max_length=10 ** 6, meta={}, config=None, **kwargs
):
"""Initialise a Language object.
@ -138,6 +140,7 @@ class Language(object):
object. Usually a `Tokenizer`.
meta (dict): Custom meta data for the Language class. Is written to by
models to add model meta data.
config (Config): Configuration data for creating the pipeline components.
max_length (int) :
Maximum number of characters in a single text. The current v2 models
may run out memory on extremely long texts, due to large internal
@ -152,6 +155,9 @@ class Language(object):
user_factories = util.registry.factories.get_all()
self.factories.update(user_factories)
self._meta = dict(meta)
self._config = config
if not self._config:
self._config = Config()
self._path = None
if vocab is True:
factory = self.Defaults.create_vocab
@ -170,6 +176,21 @@ class Language(object):
self.max_length = max_length
self._optimizer = None
from .ml.models.defaults import default_tagger_config, default_parser_config, default_ner_config, \
default_textcat_config, default_nel_config, default_morphologizer_config, default_sentrec_config, \
default_tensorizer_config, default_tok2vec_config
self.defaults = {"tagger": default_tagger_config(),
"parser": default_parser_config(),
"ner": default_ner_config(),
"textcat": default_textcat_config(),
"entity_linker": default_nel_config(),
"morphologizer": default_morphologizer_config(),
"sentrec": default_sentrec_config(),
"tensorizer": default_tensorizer_config(),
"tok2vec": default_tok2vec_config(),
}
@property
def path(self):
return self._path
@ -203,6 +224,10 @@ class Language(object):
def meta(self, value):
self._meta = value
@property
def config(self):
return self._config
# Conveniences to access pipeline components
# Shouldn't be used anymore!
@property
@ -293,7 +318,24 @@ class Language(object):
else:
raise KeyError(Errors.E002.format(name=name))
factory = self.factories[name]
return factory(self, **config)
default_config = self.defaults.get(name, None)
# transform the model's config to an actual Model
model_cfg = None
if "model" in config:
model_cfg = config["model"]
if not isinstance(model_cfg, dict):
user_warning(Warnings.W099.format(type=type(model_cfg), pipe=name))
model_cfg = None
del config["model"]
if model_cfg is None and default_config is not None:
user_warning(Warnings.W098)
model_cfg = default_config["model"]
model = None
if model_cfg is not None:
self.config[name] = {"model": model_cfg}
model = registry.make_from_config({"model": model_cfg}, validate=True)["model"]
return factory(self, model, **config)
def add_pipe(
self, component, name=None, before=None, after=None, first=None, last=None
@ -430,7 +472,10 @@ class Language(object):
continue
if not hasattr(proc, "__call__"):
raise ValueError(Errors.E003.format(component=type(proc), name=name))
doc = proc(doc, **component_cfg.get(name, {}))
try:
doc = proc(doc, **component_cfg.get(name, {}))
except KeyError:
raise ValueError(Errors.E109.format(name=name))
if doc is None:
raise ValueError(Errors.E005.format(name=name))
return doc
@ -578,9 +623,6 @@ class Language(object):
ops = get_current_ops()
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
link_vectors_to_models(self.vocab)
if self.vocab.vectors.data.shape[1]:
cfg["pretrained_vectors"] = self.vocab.vectors.name
cfg["pretrained_dims"] = self.vocab.vectors.data.shape[1]
if sgd is None:
sgd = create_default_optimizer()
self._optimizer = sgd
@ -611,8 +653,6 @@ class Language(object):
if self.vocab.vectors.data.shape[1] >= 1:
self.vocab.vectors.data = ops.asarray(self.vocab.vectors.data)
link_vectors_to_models(self.vocab)
if self.vocab.vectors.data.shape[1]:
cfg["pretrained_vectors"] = self.vocab.vectors
if sgd is None:
sgd = create_default_optimizer()
self._optimizer = sgd
@ -868,6 +908,7 @@ class Language(object):
serializers["meta.json"] = lambda p: p.open("w").write(
srsly.json_dumps(self.meta)
)
serializers["config.cfg"] = lambda p: self.config.to_disk(p)
for name, proc in self.pipeline:
if not hasattr(proc, "name"):
continue
@ -895,6 +936,8 @@ class Language(object):
exclude = disable
path = util.ensure_path(path)
deserializers = {}
if Path(path / "config.cfg").exists():
deserializers["config.cfg"] = lambda p: self.config.from_disk(p)
deserializers["meta.json"] = lambda p: self.meta.update(srsly.read_json(p))
deserializers["vocab"] = lambda p: self.vocab.from_disk(
p
@ -933,6 +976,7 @@ class Language(object):
serializers["vocab"] = lambda: self.vocab.to_bytes()
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
serializers["config.cfg"] = lambda: self.config.to_bytes()
for name, proc in self.pipeline:
if name in exclude:
continue
@ -955,6 +999,7 @@ class Language(object):
deprecation_warning(Warnings.W014)
exclude = disable
deserializers = {}
deserializers["config.cfg"] = lambda b: self.config.from_bytes(b)
deserializers["meta.json"] = lambda b: self.meta.update(srsly.json_loads(b))
deserializers["vocab"] = lambda b: self.vocab.from_bytes(
b
@ -981,8 +1026,8 @@ class component(object):
and class components and will automatically register components in the
Language.factories. If the component is a class and needs access to the
nlp object or config parameters, it can expose a from_nlp classmethod
that takes the nlp object and **cfg arguments and returns the initialized
component.
that takes the nlp & model objects and **cfg arguments, and returns the
initialized component.
"""
# NB: This decorator needs to live here, because it needs to write to
@ -1011,9 +1056,9 @@ class component(object):
obj.requires = self.requires
obj.retokenizes = self.retokenizes
def factory(nlp, **cfg):
def factory(nlp, model, **cfg):
if hasattr(obj, "from_nlp"):
return obj.from_nlp(nlp, **cfg)
return obj.from_nlp(nlp, model, **cfg)
elif isinstance(obj, type):
return obj()
return obj

View File

@ -1,227 +0,0 @@
from spacy import util
from spacy.ml.extract_ngrams import extract_ngrams
from ..attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE
from ..errors import Errors
from ._character_embed import CharacterEmbed
from thinc.api import Model, Maxout, Linear, residual, reduce_mean, list2ragged
from thinc.api import PyTorchLSTM, add, MultiSoftmax, HashEmbed, StaticVectors
from thinc.api import expand_window, FeatureExtractor, SparseLinear, chain
from thinc.api import clone, concatenate, with_array, Softmax, Logistic, uniqued
from thinc.api import zero_init
def build_text_classifier(arch, config):
if arch == "cnn":
return build_simple_cnn_text_classifier(**config)
elif arch == "bow":
return build_bow_text_classifier(**config)
else:
raise ValueError("Unexpected textcat arch")
def build_simple_cnn_text_classifier(tok2vec, nr_class, exclusive_classes, **cfg):
"""
Build a simple CNN text classifier, given a token-to-vector model as inputs.
If exclusive_classes=True, a softmax non-linearity is applied, so that the
outputs sum to 1. If exclusive_classes=False, a logistic non-linearity
is applied instead, so that outputs are in the range [0, 1].
"""
with Model.define_operators({">>": chain}):
if exclusive_classes:
output_layer = Softmax(nO=nr_class, nI=tok2vec.get_dim("nO"))
else:
# TODO: experiment with init_w=zero_init
output_layer = Linear(nO=nr_class, nI=tok2vec.get_dim("nO")) >> Logistic()
model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer
model.set_ref("tok2vec", tok2vec)
model.set_dim("nO", nr_class)
return model
def build_bow_text_classifier(
nr_class, exclusive_classes, ngram_size=1, no_output_layer=False, **cfg
):
with Model.define_operators({">>": chain}):
model = extract_ngrams(ngram_size, attr=ORTH) >> SparseLinear(nr_class)
model.to_cpu()
if not no_output_layer:
output_layer = (
Softmax(nO=nr_class) if exclusive_classes else Logistic(nO=nr_class)
)
output_layer.to_cpu()
model = model >> output_layer
model.set_dim("nO", nr_class)
return model
def build_nel_encoder(embed_width, hidden_width, ner_types, **cfg):
if "entity_width" not in cfg:
raise ValueError(Errors.E144.format(param="entity_width"))
conv_depth = cfg.get("conv_depth", 2)
cnn_maxout_pieces = cfg.get("cnn_maxout_pieces", 3)
pretrained_vectors = cfg.get("pretrained_vectors", None)
context_width = cfg.get("entity_width")
with Model.define_operators({">>": chain, "**": clone}):
nel_tok2vec = Tok2Vec(
width=hidden_width,
embed_size=embed_width,
pretrained_vectors=pretrained_vectors,
cnn_maxout_pieces=cnn_maxout_pieces,
subword_features=True,
conv_depth=conv_depth,
bilstm_depth=0,
)
model = (
nel_tok2vec
>> list2ragged()
>> reduce_mean()
>> residual(Maxout(nO=hidden_width, nI=hidden_width, nP=2, dropout=0.0))
>> Linear(nO=context_width, nI=hidden_width)
)
model.initialize()
model.set_ref("tok2vec", nel_tok2vec)
model.set_dim("nO", context_width)
return model
def masked_language_model(*args, **kwargs):
raise NotImplementedError
def build_tagger_model(nr_class, tok2vec):
token_vector_width = tok2vec.get_dim("nO")
# TODO: glorot_uniform_init seems to work a bit better than zero_init here?!
softmax = with_array(Softmax(nO=nr_class, nI=token_vector_width, init_W=zero_init))
model = chain(tok2vec, softmax)
model.set_ref("tok2vec", tok2vec)
model.set_ref("softmax", softmax)
return model
def build_morphologizer_model(class_nums, **cfg):
embed_size = util.env_opt("embed_size", 7000)
if "token_vector_width" in cfg:
token_vector_width = cfg["token_vector_width"]
else:
token_vector_width = util.env_opt("token_vector_width", 128)
pretrained_vectors = cfg.get("pretrained_vectors")
char_embed = cfg.get("char_embed", True)
with Model.define_operators({">>": chain, "+": add, "**": clone}):
if "tok2vec" in cfg:
tok2vec = cfg["tok2vec"]
else:
tok2vec = Tok2Vec(
token_vector_width,
embed_size,
char_embed=char_embed,
pretrained_vectors=pretrained_vectors,
)
softmax = with_array(MultiSoftmax(nOs=class_nums, nI=token_vector_width))
model = tok2vec >> softmax
model.set_ref("tok2vec", tok2vec)
model.set_ref("softmax", softmax)
return model
def Tok2Vec(
width,
embed_size,
pretrained_vectors=None,
window_size=1,
cnn_maxout_pieces=3,
subword_features=True,
char_embed=False,
conv_depth=4,
bilstm_depth=0,
):
if char_embed:
subword_features = False
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
norm = HashEmbed(nO=width, nV=embed_size, column=cols.index(NORM), dropout=0.0)
if subword_features:
prefix = HashEmbed(
nO=width, nV=embed_size // 2, column=cols.index(PREFIX), dropout=0.0
)
suffix = HashEmbed(
nO=width, nV=embed_size // 2, column=cols.index(SUFFIX), dropout=0.0
)
shape = HashEmbed(
nO=width, nV=embed_size // 2, column=cols.index(SHAPE), dropout=0.0
)
else:
prefix, suffix, shape = (None, None, None)
if pretrained_vectors is not None:
glove = StaticVectors(
vectors=pretrained_vectors, nO=width, column=cols.index(ID), dropout=0.0
)
if subword_features:
embed = uniqued(
(glove | norm | prefix | suffix | shape)
>> Maxout(
nO=width, nI=width * 5, nP=3, dropout=0.0, normalize=True
),
column=cols.index(ORTH),
)
else:
embed = uniqued(
(glove | norm)
>> Maxout(
nO=width, nI=width * 2, nP=3, dropout=0.0, normalize=True
),
column=cols.index(ORTH),
)
elif subword_features:
embed = uniqued(
concatenate(norm, prefix, suffix, shape)
>> Maxout(nO=width, nI=width * 4, nP=3, dropout=0.0, normalize=True),
column=cols.index(ORTH),
)
elif char_embed:
embed = CharacterEmbed(nM=64, nC=8) | FeatureExtractor(cols) >> with_array(
norm
)
reduce_dimensions = Maxout(
nO=width,
nI=64 * 8 + width,
nP=cnn_maxout_pieces,
dropout=0.0,
normalize=True,
)
else:
embed = norm
convolution = residual(
expand_window(window_size=window_size)
>> Maxout(
nO=width,
nI=width * 3,
nP=cnn_maxout_pieces,
dropout=0.0,
normalize=True,
)
)
if char_embed:
tok2vec = embed >> with_array(
reduce_dimensions >> convolution ** conv_depth, pad=conv_depth
)
else:
tok2vec = FeatureExtractor(cols) >> with_array(
embed >> convolution ** conv_depth, pad=conv_depth
)
if bilstm_depth >= 1:
tok2vec = tok2vec >> PyTorchLSTM(
nO=width, nI=width, depth=bilstm_depth, bi=True
)
# Work around thinc API limitations :(. TODO: Revise in Thinc 7
tok2vec.set_dim("nO", width)
tok2vec.set_ref("embed", embed)
return tok2vec

View File

@ -0,0 +1,6 @@
from .entity_linker import *
from .parser import *
from .tagger import *
from .tensorizer import *
from .textcat import *
from .tok2vec import *

View File

@ -0,0 +1,93 @@
from pathlib import Path
from .... import util
def default_nel_config():
loc = Path(__file__).parent / "entity_linker_defaults.cfg"
return util.load_config(loc, create_objects=False)
def default_nel():
loc = Path(__file__).parent / "entity_linker_defaults.cfg"
return util.load_config(loc, create_objects=True)["model"]
def default_morphologizer_config():
loc = Path(__file__).parent / "morphologizer_defaults.cfg"
return util.load_config(loc, create_objects=False)
def default_morphologizer():
loc = Path(__file__).parent / "morphologizer_defaults.cfg"
return util.load_config(loc, create_objects=True)["model"]
def default_parser_config():
loc = Path(__file__).parent / "parser_defaults.cfg"
return util.load_config(loc, create_objects=False)
def default_parser():
loc = Path(__file__).parent / "parser_defaults.cfg"
return util.load_config(loc, create_objects=True)["model"]
def default_ner_config():
loc = Path(__file__).parent / "ner_defaults.cfg"
return util.load_config(loc, create_objects=False)
def default_ner():
loc = Path(__file__).parent / "ner_defaults.cfg"
return util.load_config(loc, create_objects=True)["model"]
def default_sentrec_config():
loc = Path(__file__).parent / "sentrec_defaults.cfg"
return util.load_config(loc, create_objects=False)
def default_sentrec():
loc = Path(__file__).parent / "sentrec_defaults.cfg"
return util.load_config(loc, create_objects=True)["model"]
def default_tagger_config():
loc = Path(__file__).parent / "tagger_defaults.cfg"
return util.load_config(loc, create_objects=False)
def default_tagger():
loc = Path(__file__).parent / "tagger_defaults.cfg"
return util.load_config(loc, create_objects=True)["model"]
def default_tensorizer_config():
loc = Path(__file__).parent / "tensorizer_defaults.cfg"
return util.load_config(loc, create_objects=False)
def default_tensorizer():
loc = Path(__file__).parent / "tensorizer_defaults.cfg"
return util.load_config(loc, create_objects=True)["model"]
def default_textcat_config():
loc = Path(__file__).parent / "textcat_defaults.cfg"
return util.load_config(loc, create_objects=False)
def default_textcat():
loc = Path(__file__).parent / "textcat_defaults.cfg"
return util.load_config(loc, create_objects=True)["model"]
def default_tok2vec_config():
loc = Path(__file__).parent / "tok2vec_defaults.cfg"
return util.load_config(loc, create_objects=False)
def default_tok2vec():
loc = Path(__file__).parent / "tok2vec_defaults.cfg"
return util.load_config(loc, create_objects=True)["model"]

View File

@ -0,0 +1,12 @@
[model]
@architectures = "spacy.EntityLinker.v1"
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 96
depth = 2
embed_size = 300
window_size = 1
maxout_pieces = 3
subword_features = true

View File

@ -0,0 +1,14 @@
[model]
@architectures = "spacy.Tagger.v1"
[model.tok2vec]
@architectures = "spacy.HashCharEmbedCNN.v1"
pretrained_vectors = null
width = 128
depth = 4
embed_size = 7000
window_size = 1
maxout_pieces = 3
subword_features = true
nM = 64
nC = 8

View File

@ -0,0 +1,15 @@
[model]
@architectures = "spacy.TransitionBasedParser.v1"
nr_feature_tokens = 6
hidden_width = 64
maxout_pieces = 2
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true

View File

@ -0,0 +1,15 @@
[model]
@architectures = "spacy.TransitionBasedParser.v1"
nr_feature_tokens = 8
hidden_width = 64
maxout_pieces = 2
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true

View File

@ -0,0 +1,14 @@
[model]
@architectures = "spacy.Tagger.v1"
[model.tok2vec]
@architectures = "spacy.HashCharEmbedCNN.v1"
pretrained_vectors = null
width = 12
depth = 1
embed_size = 2000
window_size = 1
maxout_pieces = 2
subword_features = true
nM = 64
nC = 8

View File

@ -0,0 +1,12 @@
[model]
@architectures = "spacy.Tagger.v1"
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true

View File

@ -0,0 +1,4 @@
[model]
@architectures = "spacy.Tensorizer.v1"
input_size=96
output_size=300

View File

@ -0,0 +1,13 @@
[model]
@architectures = "spacy.TextCatCNN.v1"
exclusive_classes = false
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true

View File

@ -0,0 +1,9 @@
[model]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true

View File

@ -0,0 +1,23 @@
from pathlib import Path
from thinc.api import chain, clone, list2ragged, reduce_mean, residual
from thinc.api import Model, Maxout, Linear
from spacy.util import registry
@registry.architectures.register("spacy.EntityLinker.v1")
def build_nel_encoder(tok2vec, nO=None):
with Model.define_operators({">>": chain, "**": clone}):
token_width = tok2vec.get_dim("nO")
output_layer = Linear(nO=nO, nI=token_width)
model = (
tok2vec
>> list2ragged()
>> reduce_mean()
>> residual(Maxout(nO=token_width, nI=token_width, nP=2, dropout=0.0))
>> output_layer
)
model.set_ref("output_layer", output_layer)
model.set_ref("tok2vec", tok2vec)
return model

View File

@ -0,0 +1,29 @@
from thinc.api import chain, Maxout, LayerNorm, Softmax, Linear, zero_init
def build_multi_task_model(n_tags, tok2vec=None, token_vector_width=96):
model = chain(
tok2vec,
Maxout(nO=token_vector_width * 2, nI=token_vector_width, nP=3, dropout=0.0),
LayerNorm(token_vector_width * 2),
Softmax(nO=n_tags, nI=token_vector_width * 2),
)
return model
def build_cloze_multi_task_model(vocab, tok2vec):
output_size = vocab.vectors.data.shape[1]
output_layer = chain(
Maxout(
nO=output_size, nI=tok2vec.get_dim("nO"), nP=3, normalize=True, dropout=0.0
),
Linear(nO=output_size, nI=output_size, init_W=zero_init),
)
model = chain(tok2vec, output_layer)
model = build_masked_language_model(vocab, model)
return model
def build_masked_language_model(*args, **kwargs):
# TODO cf https://github.com/explosion/spaCy/blob/2c107f02a4d60bda2440db0aad1a88cbbf4fb52d/spacy/_ml.py#L828
raise NotImplementedError

33
spacy/ml/models/parser.py Normal file
View File

@ -0,0 +1,33 @@
from pydantic import StrictInt
from spacy.util import registry
from spacy.ml._layers import PrecomputableAffine
from spacy.syntax._parser_model import ParserModel
from thinc.api import Model, chain, list2array, Linear, zero_init, use_ops
@registry.architectures.register("spacy.TransitionBasedParser.v1")
def build_tb_parser_model(
tok2vec: Model,
nr_feature_tokens: StrictInt,
hidden_width: StrictInt,
maxout_pieces: StrictInt,
nO=None,
):
token_vector_width = tok2vec.get_dim("nO")
tok2vec = chain(tok2vec, list2array())
tok2vec.set_dim("nO", token_vector_width)
lower = PrecomputableAffine(
nO=hidden_width,
nF=nr_feature_tokens,
nI=tok2vec.get_dim("nO"),
nP=maxout_pieces,
)
lower.set_dim("nP", maxout_pieces)
with use_ops("numpy"):
# Initialize weights at zero, as it's a classification layer.
upper = Linear(nO=nO, init_W=zero_init)
model = ParserModel(tok2vec, lower, upper)
return model

16
spacy/ml/models/tagger.py Normal file
View File

@ -0,0 +1,16 @@
from thinc.api import zero_init, with_array, Softmax, chain, Model
from spacy.util import registry
@registry.architectures.register("spacy.Tagger.v1")
def build_tagger_model(tok2vec, nO=None) -> Model:
token_vector_width = tok2vec.get_dim("nO")
# TODO: glorot_uniform_init seems to work a bit better than zero_init here?!
output_layer = Softmax(nO, nI=token_vector_width, init_W=zero_init)
softmax = with_array(output_layer)
model = chain(tok2vec, softmax)
model.set_ref("tok2vec", tok2vec)
model.set_ref("softmax", softmax)
model.set_ref("output_layer", output_layer)
return model

View File

@ -0,0 +1,10 @@
from thinc.api import Linear, zero_init
from ... import util
from ...util import registry
@registry.architectures.register("spacy.Tensorizer.v1")
def build_tensorizer(input_size, output_size):
input_size = util.env_opt("token_vector_width", input_size)
return Linear(output_size, input_size, init_W=zero_init)

View File

@ -0,0 +1,42 @@
from spacy.attrs import ORTH
from spacy.util import registry
from spacy.ml.extract_ngrams import extract_ngrams
from thinc.api import Model, chain, reduce_mean, Linear, list2ragged, Logistic, SparseLinear, Softmax
@registry.architectures.register("spacy.TextCatCNN.v1")
def build_simple_cnn_text_classifier(tok2vec, exclusive_classes, nO=None):
"""
Build a simple CNN text classifier, given a token-to-vector model as inputs.
If exclusive_classes=True, a softmax non-linearity is applied, so that the
outputs sum to 1. If exclusive_classes=False, a logistic non-linearity
is applied instead, so that outputs are in the range [0, 1].
"""
with Model.define_operators({">>": chain}):
if exclusive_classes:
output_layer = Softmax(nO=nO, nI=tok2vec.get_dim("nO"))
model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer
model.set_ref("output_layer", output_layer)
else:
# TODO: experiment with init_w=zero_init
linear_layer = Linear(nO=nO, nI=tok2vec.get_dim("nO"))
model = tok2vec >> list2ragged() >> reduce_mean() >> linear_layer >> Logistic()
model.set_ref("output_layer", linear_layer)
model.set_ref("tok2vec", tok2vec)
model.set_dim("nO", nO)
return model
@registry.architectures.register("spacy.TextCatBOW.v1")
def build_bow_text_classifier(exclusive_classes, ngram_size, no_output_layer, nO=None):
# Note: original defaults were ngram_size=1 and no_output_layer=False
with Model.define_operators({">>": chain}):
model = extract_ngrams(ngram_size, attr=ORTH) >> SparseLinear(nO)
model.to_cpu()
if not no_output_layer:
output_layer = Softmax(nO) if exclusive_classes else Logistic(nO)
output_layer.to_cpu()
model = model >> output_layer
model.set_ref("output_layer", output_layer)
return model

390
spacy/ml/models/tok2vec.py Normal file
View File

@ -0,0 +1,390 @@
from thinc.api import chain, clone, concatenate, with_array, uniqued
from thinc.api import Model, noop, with_padded, Maxout, expand_window
from thinc.api import HashEmbed, StaticVectors, PyTorchLSTM
from thinc.api import residual, LayerNorm, FeatureExtractor, Mish
from ... import util
from ...util import registry, make_layer
from ...ml import _character_embed
from ...pipeline.tok2vec import Tok2VecListener
from ...attrs import ID, ORTH, NORM, PREFIX, SUFFIX, SHAPE
@registry.architectures.register("spacy.Tok2VecTensors.v1")
def tok2vec_tensors_v1(width):
tok2vec = Tok2VecListener("tok2vec", width=width)
return tok2vec
@registry.architectures.register("spacy.VocabVectors.v1")
def get_vocab_vectors(name):
nlp = util.load_model(name)
return nlp.vocab.vectors
@registry.architectures.register("spacy.Tok2Vec.v1")
def Tok2Vec(config):
doc2feats = make_layer(config["@doc2feats"])
embed = make_layer(config["@embed"])
encode = make_layer(config["@encode"])
field_size = 0
if encode.has_attr("receptive_field"):
field_size = encode.attrs["receptive_field"]
tok2vec = chain(doc2feats, with_array(chain(embed, encode), pad=field_size))
tok2vec.attrs["cfg"] = config
tok2vec.set_dim("nO", encode.get_dim("nO"))
tok2vec.set_ref("embed", embed)
tok2vec.set_ref("encode", encode)
return tok2vec
@registry.architectures.register("spacy.Doc2Feats.v1")
def Doc2Feats(config):
columns = config["columns"]
return FeatureExtractor(columns)
@registry.architectures.register("spacy.HashEmbedCNN.v1")
def hash_embed_cnn(
pretrained_vectors,
width,
depth,
embed_size,
maxout_pieces,
window_size,
subword_features,
):
# Does not use character embeddings: set to False by default
return build_Tok2Vec_model(
width=width,
embed_size=embed_size,
pretrained_vectors=pretrained_vectors,
conv_depth=depth,
bilstm_depth=0,
maxout_pieces=maxout_pieces,
window_size=window_size,
subword_features=subword_features,
char_embed=False,
nM=0,
nC=0,
)
@registry.architectures.register("spacy.HashCharEmbedCNN.v1")
def hash_charembed_cnn(
pretrained_vectors,
width,
depth,
embed_size,
maxout_pieces,
window_size,
subword_features,
nM=0,
nC=0,
):
# Allows using character embeddings by setting nC, nM and char_embed=True
return build_Tok2Vec_model(
width=width,
embed_size=embed_size,
pretrained_vectors=pretrained_vectors,
conv_depth=depth,
bilstm_depth=0,
maxout_pieces=maxout_pieces,
window_size=window_size,
subword_features=subword_features,
char_embed=True,
nM=nM,
nC=nC,
)
@registry.architectures.register("spacy.HashEmbedBiLSTM.v1")
def hash_embed_bilstm_v1(
pretrained_vectors, width, depth, embed_size, subword_features
):
# Does not use character embeddings: set to False by default
return build_Tok2Vec_model(
width=width,
embed_size=embed_size,
pretrained_vectors=pretrained_vectors,
bilstm_depth=depth,
conv_depth=0,
maxout_pieces=0,
window_size=1,
subword_features=subword_features,
char_embed=False,
nM=0,
nC=0,
)
@registry.architectures.register("spacy.HashCharEmbedBiLSTM.v1")
def hash_embed_bilstm_v1(
pretrained_vectors, width, depth, embed_size, subword_features, nM=0, nC=0
):
# Allows using character embeddings by setting nC, nM and char_embed=True
return build_Tok2Vec_model(
width=width,
embed_size=embed_size,
pretrained_vectors=pretrained_vectors,
bilstm_depth=depth,
conv_depth=0,
maxout_pieces=0,
window_size=1,
subword_features=subword_features,
char_embed=True,
nM=nM,
nC=nC,
)
@registry.architectures.register("spacy.MultiHashEmbed.v1")
def MultiHashEmbed(config):
# For backwards compatibility with models before the architecture registry,
# we have to be careful to get exactly the same model structure. One subtle
# trick is that when we define concatenation with the operator, the operator
# is actually binary associative. So when we write (a | b | c), we're actually
# getting concatenate(concatenate(a, b), c). That's why the implementation
# is a bit ugly here.
cols = config["columns"]
width = config["width"]
rows = config["rows"]
norm = HashEmbed(width, rows, column=cols.index("NORM"))
if config["use_subwords"]:
prefix = HashEmbed(width, rows // 2, column=cols.index("PREFIX"))
suffix = HashEmbed(width, rows // 2, column=cols.index("SUFFIX"))
shape = HashEmbed(width, rows // 2, column=cols.index("SHAPE"))
if config.get("@pretrained_vectors"):
glove = make_layer(config["@pretrained_vectors"])
mix = make_layer(config["@mix"])
with Model.define_operators({">>": chain, "|": concatenate}):
if config["use_subwords"] and config["@pretrained_vectors"]:
mix._layers[0].set_dim("nI", width * 5)
layer = uniqued(
(glove | norm | prefix | suffix | shape) >> mix,
column=cols.index("ORTH"),
)
elif config["use_subwords"]:
mix._layers[0].set_dim("nI", width * 4)
layer = uniqued(
(norm | prefix | suffix | shape) >> mix, column=cols.index("ORTH")
)
elif config["@pretrained_vectors"]:
mix._layers[0].set_dim("nI", width * 2)
layer = uniqued((glove | norm) >> mix, column=cols.index("ORTH"))
else:
layer = norm
layer.attrs["cfg"] = config
return layer
@registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(config):
width = config["width"]
chars = config["chars"]
chr_embed = _character_embed.CharacterEmbed(nM=width, nC=chars)
other_tables = make_layer(config["@embed_features"])
mix = make_layer(config["@mix"])
model = chain(concatenate(chr_embed, other_tables), mix)
model.attrs["cfg"] = config
return model
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
def MaxoutWindowEncoder(config):
nO = config["width"]
nW = config["window_size"]
nP = config["pieces"]
depth = config["depth"]
cnn = (
expand_window(window_size=nW),
Maxout(nO=nO, nI=nO * ((nW * 2) + 1), nP=nP, dropout=0.0, normalize=True),
)
model = clone(residual(cnn), depth)
model.set_dim("nO", nO)
model.attrs["receptive_field"] = nW * depth
return model
@registry.architectures.register("spacy.MishWindowEncoder.v1")
def MishWindowEncoder(config):
nO = config["width"]
nW = config["window_size"]
depth = config["depth"]
cnn = chain(
expand_window(window_size=nW),
Mish(nO=nO, nI=nO * ((nW * 2) + 1)),
LayerNorm(nO),
)
model = clone(residual(cnn), depth)
model.set_dim("nO", nO)
return model
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
def TorchBiLSTMEncoder(config):
import torch.nn
# TODO FIX
from thinc.api import PyTorchRNNWrapper
width = config["width"]
depth = config["depth"]
if depth == 0:
return noop()
return with_padded(
PyTorchRNNWrapper(torch.nn.LSTM(width, width // 2, depth, bidirectional=True))
)
# TODO: update
_EXAMPLE_CONFIG = {
"@doc2feats": {
"arch": "Doc2Feats",
"config": {"columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]},
},
"@embed": {
"arch": "spacy.MultiHashEmbed.v1",
"config": {
"width": 96,
"rows": 2000,
"columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"],
"use_subwords": True,
"@pretrained_vectors": {
"arch": "TransformedStaticVectors",
"config": {
"vectors_name": "en_vectors_web_lg.vectors",
"width": 96,
"column": 0,
},
},
"@mix": {
"arch": "LayerNormalizedMaxout",
"config": {"width": 96, "pieces": 3},
},
},
},
"@encode": {
"arch": "MaxoutWindowEncode",
"config": {"width": 96, "window_size": 1, "depth": 4, "pieces": 3},
},
}
def build_Tok2Vec_model(
width,
embed_size,
pretrained_vectors,
window_size,
maxout_pieces,
subword_features,
char_embed,
nM,
nC,
conv_depth,
bilstm_depth,
) -> Model:
if char_embed:
subword_features = False
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
with Model.define_operators({">>": chain, "|": concatenate, "**": clone}):
norm = HashEmbed(nO=width, nV=embed_size, column=cols.index(NORM))
if subword_features:
prefix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(PREFIX))
suffix = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SUFFIX))
shape = HashEmbed(nO=width, nV=embed_size // 2, column=cols.index(SHAPE))
else:
prefix, suffix, shape = (None, None, None)
if pretrained_vectors is not None:
glove = StaticVectors(
vectors=pretrained_vectors.data,
nO=width,
column=cols.index(ID),
dropout=0.0,
)
if subword_features:
columns = 5
embed = uniqued(
(glove | norm | prefix | suffix | shape)
>> Maxout(
nO=width,
nI=width * columns,
nP=maxout_pieces,
dropout=0.0,
normalize=True,
),
column=cols.index(ORTH),
)
else:
columns = 2
embed = uniqued(
(glove | norm)
>> Maxout(
nO=width,
nI=width * columns,
nP=maxout_pieces,
dropout=0.0,
normalize=True,
),
column=cols.index(ORTH),
)
elif subword_features:
columns = 4
embed = uniqued(
concatenate(norm, prefix, suffix, shape)
>> Maxout(
nO=width,
nI=width * columns,
nP=maxout_pieces,
dropout=0.0,
normalize=True,
),
column=cols.index(ORTH),
)
elif char_embed:
embed = _character_embed.CharacterEmbed(nM=nM, nC=nC) | FeatureExtractor(
cols
) >> with_array(norm)
reduce_dimensions = Maxout(
nO=width,
nI=nM * nC + width,
nP=maxout_pieces,
dropout=0.0,
normalize=True,
)
else:
embed = norm
convolution = residual(
expand_window(window_size=window_size)
>> Maxout(
nO=width,
nI=width * ((window_size * 2) + 1),
nP=maxout_pieces,
dropout=0.0,
normalize=True,
)
)
if char_embed:
tok2vec = embed >> with_array(
reduce_dimensions >> convolution ** conv_depth, pad=conv_depth
)
else:
tok2vec = FeatureExtractor(cols) >> with_array(
embed >> convolution ** conv_depth, pad=conv_depth
)
if bilstm_depth >= 1:
tok2vec = tok2vec >> PyTorchLSTM(
nO=width, nI=width, depth=bilstm_depth, bi=True
)
tok2vec.set_dim("nO", width)
tok2vec.set_ref("embed", embed)
return tok2vec

View File

@ -1,178 +0,0 @@
from thinc.api import Model, chain, clone, concatenate, with_array, uniqued, noop
from thinc.api import with_padded, Maxout, expand_window, HashEmbed, StaticVectors
from thinc.api import residual, LayerNorm, FeatureExtractor
from ..ml import _character_embed
from ..util import make_layer, registry
@registry.architectures.register("spacy.Tok2Vec.v1")
def Tok2Vec(config):
doc2feats = make_layer(config["@doc2feats"])
embed = make_layer(config["@embed"])
encode = make_layer(config["@encode"])
field_size = 0
if encode.has_attr("receptive_field"):
field_size = encode.attrs["receptive_field"]
tok2vec = chain(doc2feats, with_array(chain(embed, encode), pad=field_size))
tok2vec.attrs["cfg"] = config
tok2vec.set_dim("nO", encode.get_dim("nO"))
tok2vec.set_ref("embed", embed)
tok2vec.set_ref("encode", encode)
return tok2vec
@registry.architectures.register("spacy.Doc2Feats.v1")
def Doc2Feats(config):
columns = config["columns"]
return FeatureExtractor(columns)
@registry.architectures.register("spacy.MultiHashEmbed.v1")
def MultiHashEmbed(config):
# For backwards compatibility with models before the architecture registry,
# we have to be careful to get exactly the same model structure. One subtle
# trick is that when we define concatenation with the operator, the operator
# is actually binary associative. So when we write (a | b | c), we're actually
# getting concatenate(concatenate(a, b), c). That's why the implementation
# is a bit ugly here.
cols = config["columns"]
width = config["width"]
rows = config["rows"]
norm = HashEmbed(width, rows, column=cols.index("NORM"), dropout=0.0)
if config["use_subwords"]:
prefix = HashEmbed(width, rows // 2, column=cols.index("PREFIX"), dropout=0.0)
suffix = HashEmbed(width, rows // 2, column=cols.index("SUFFIX"), dropout=0.0)
shape = HashEmbed(width, rows // 2, column=cols.index("SHAPE"), dropout=0.0)
if config.get("@pretrained_vectors"):
glove = make_layer(config["@pretrained_vectors"])
mix = make_layer(config["@mix"])
with Model.define_operators({">>": chain, "|": concatenate}):
if config["use_subwords"] and config["@pretrained_vectors"]:
mix._layers[0].set_dim("nI", width * 5)
layer = uniqued(
(glove | norm | prefix | suffix | shape) >> mix,
column=cols.index("ORTH"),
)
elif config["use_subwords"]:
mix._layers[0].set_dim("nI", width * 4)
layer = uniqued(
(norm | prefix | suffix | shape) >> mix, column=cols.index("ORTH")
)
elif config["@pretrained_vectors"]:
mix._layers[0].set_dim("nI", width * 2)
layer = uniqued((glove | norm) >> mix, column=cols.index("ORTH"),)
else:
layer = norm
layer.attrs["cfg"] = config
return layer
@registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(config):
width = config["width"]
chars = config["chars"]
chr_embed = _character_embed.CharacterEmbed(nM=width, nC=chars)
other_tables = make_layer(config["@embed_features"])
mix = make_layer(config["@mix"])
model = chain(concatenate(chr_embed, other_tables), mix)
model.attrs["cfg"] = config
return model
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
def MaxoutWindowEncoder(config):
nO = config["width"]
nW = config["window_size"]
nP = config["pieces"]
depth = config["depth"]
cnn = (
expand_window(window_size=nW),
Maxout(nO=nO, nI=nO * ((nW * 2) + 1), nP=nP, dropout=0.0, normalize=True),
)
model = clone(residual(cnn), depth)
model.set_dim("nO", nO)
model.attrs["receptive_field"] = nW * depth
return model
@registry.architectures.register("spacy.MishWindowEncoder.v1")
def MishWindowEncoder(config):
from thinc.api import Mish
nO = config["width"]
nW = config["window_size"]
depth = config["depth"]
cnn = chain(
expand_window(window_size=nW),
Mish(nO=nO, nI=nO * ((nW * 2) + 1)),
LayerNorm(nO),
)
model = clone(residual(cnn), depth)
model.set_dim("nO", nO)
return model
@registry.architectures.register("spacy.PretrainedVectors.v1")
def PretrainedVectors(config):
# TODO: actual vectors instead of name
return StaticVectors(
vectors=config["vectors_name"],
nO=config["width"],
column=config["column"],
dropout=0.0,
)
@registry.architectures.register("spacy.TorchBiLSTMEncoder.v1")
def TorchBiLSTMEncoder(config):
import torch.nn
# TODO: FIX
from thinc.api import PyTorchRNNWrapper
width = config["width"]
depth = config["depth"]
if depth == 0:
return noop()
return with_padded(
PyTorchRNNWrapper(torch.nn.LSTM(width, width // 2, depth, bidirectional=True))
)
# TODO: update
_EXAMPLE_CONFIG = {
"@doc2feats": {
"arch": "Doc2Feats",
"config": {"columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"]},
},
"@embed": {
"arch": "spacy.MultiHashEmbed.v1",
"config": {
"width": 96,
"rows": 2000,
"columns": ["ID", "NORM", "PREFIX", "SUFFIX", "SHAPE", "ORTH"],
"use_subwords": True,
"@pretrained_vectors": {
"arch": "TransformedStaticVectors",
"config": {
"vectors_name": "en_vectors_web_lg.vectors",
"width": 96,
"column": 0,
},
},
"@mix": {
"arch": "LayerNormalizedMaxout",
"config": {"width": 96, "pieces": 3},
},
},
},
"@encode": {
"arch": "MaxoutWindowEncode",
"config": {"width": 96, "window_size": 1, "depth": 4, "pieces": 3},
},
}

View File

@ -66,7 +66,7 @@ class EntityRuler(object):
self.add_patterns(patterns)
@classmethod
def from_nlp(cls, nlp, **cfg):
def from_nlp(cls, nlp, model=None, **cfg):
return cls(nlp, **cfg)
def __len__(self):

View File

@ -76,11 +76,9 @@ class SimilarityHook(Pipe):
yield self(doc)
def predict(self, doc1, doc2):
self.require_model()
return self.model.predict([(doc1, doc2)])
def update(self, doc1_doc2, golds, sgd=None, drop=0.0):
self.require_model()
sims, bp_sims = self.model.begin_update(doc1_doc2)
def begin_training(self, _=tuple(), pipeline=None, sgd=None, **kwargs):

View File

@ -15,25 +15,15 @@ from ..tokens.doc cimport Doc
from ..vocab cimport Vocab
from ..morphology cimport Morphology
from ..ml.component_models import build_morphologizer_model
@component("morphologizer", assigns=["token.morph", "token.pos"])
class Morphologizer(Pipe):
@classmethod
def Model(cls, **cfg):
if cfg.get('pretrained_dims') and not cfg.get('pretrained_vectors'):
raise ValueError(TempErrors.T008)
class_map = Morphology.create_class_map()
return build_morphologizer_model(class_map.field_sizes, **cfg)
def __init__(self, vocab, model=True, **cfg):
def __init__(self, vocab, model, **cfg):
self.vocab = vocab
self.model = model
self.cfg = dict(sorted(cfg.items()))
self.cfg.setdefault('cnn_maxout_pieces', 2)
self._class_map = self.vocab.morphology.create_class_map()
self._class_map = self.vocab.morphology.create_class_map() # Morphology.create_class_map() ?
@property
def labels(self):
@ -58,6 +48,14 @@ class Morphologizer(Pipe):
self.set_annotations(docs, features, tensors=tokvecs)
yield from docs
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None,
**kwargs):
self.set_output(len(self.labels))
self.model.initialize()
if sgd is None:
sgd = self.create_optimizer()
return sgd
def predict(self, docs):
if not any(len(doc) for doc in docs):
# Handle case where there are no tokens in any docs.
@ -65,8 +63,8 @@ class Morphologizer(Pipe):
guesses = [self.model.ops.alloc((0, n_labels)) for doc in docs]
tokvecs = self.model.ops.alloc((0, self.model.get_ref("tok2vec").get_dim("nO")))
return guesses, tokvecs
tokvecs = self.model.tok2vec(docs)
scores = self.model.softmax(tokvecs)
tokvecs = self.model.get_ref("tok2vec")(docs)
scores = self.model.get_ref("softmax")(tokvecs)
return scores, tokvecs
def set_annotations(self, docs, batch_scores, tensors=None):

View File

@ -3,8 +3,7 @@
import numpy
import srsly
import random
from thinc.api import chain, Linear, Maxout, Softmax, LayerNorm, list2array
from thinc.api import zero_init, CosineDistance, to_categorical, get_array_module
from thinc.api import CosineDistance, to_categorical, get_array_module
from thinc.api import set_dropout_rate
from ..tokens.doc cimport Doc
@ -22,11 +21,6 @@ from ..attrs import POS, ID
from ..util import link_vectors_to_models, create_default_optimizer
from ..parts_of_speech import X
from ..kb import KnowledgeBase
from ..ml.component_models import Tok2Vec, build_tagger_model
from ..ml.component_models import build_text_classifier
from ..ml.component_models import build_simple_cnn_text_classifier
from ..ml.component_models import build_bow_text_classifier, build_nel_encoder
from ..ml.component_models import masked_language_model
from ..errors import Errors, TempErrors, user_warning, Warnings
from .. import util
@ -47,13 +41,8 @@ class Pipe(object):
name = None
@classmethod
def Model(cls, *shape, **kwargs):
"""Initialize a model for the pipe."""
raise NotImplementedError
@classmethod
def from_nlp(cls, nlp, **cfg):
return cls(nlp.vocab, **cfg)
def from_nlp(cls, nlp, model, **cfg):
return cls(nlp.vocab, model, **cfg)
def _get_doc(self, example):
""" Use this method if the `example` can be both a Doc or an Example """
@ -61,7 +50,7 @@ class Pipe(object):
return example
return example.doc
def __init__(self, vocab, model=True, **cfg):
def __init__(self, vocab, model, **cfg):
"""Create a new pipe instance."""
raise NotImplementedError
@ -72,7 +61,6 @@ class Pipe(object):
Both __call__ and pipe should delegate to the `predict()`
and `set_annotations()` methods.
"""
self.require_model()
doc = self._get_doc(example)
predictions = self.predict([doc])
if isinstance(predictions, tuple) and len(predictions) == 2:
@ -85,11 +73,6 @@ class Pipe(object):
return example
return doc
def require_model(self):
"""Raise an error if the component's model is not initialized."""
if getattr(self, "model", None) in (None, True, False):
raise ValueError(Errors.E109.format(name=self.name))
def pipe(self, stream, batch_size=128, n_threads=-1, as_example=False):
"""Apply the pipe to a stream of documents.
@ -116,7 +99,6 @@ class Pipe(object):
"""Apply the pipeline's model to a batch of docs, without
modifying them.
"""
self.require_model()
raise NotImplementedError
def set_annotations(self, docs, scores, tensors=None):
@ -158,22 +140,23 @@ class Pipe(object):
):
"""Initialize the pipe for training, using data exampes if available.
If no model has been initialized yet, the model is added."""
if self.model is True:
self.model = self.Model(**self.cfg)
self.model.initialize()
if hasattr(self, "vocab"):
link_vectors_to_models(self.vocab)
self.model.initialize()
if sgd is None:
sgd = self.create_optimizer()
return sgd
def set_output(self, nO):
self.model.set_dim("nO", nO)
if self.model.has_ref("output_layer"):
self.model.get_ref("output_layer").set_dim("nO", nO)
def get_gradients(self):
"""Get non-zero gradients of the model's parameters, as a dictionary
keyed by the parameter ID. The values are (weights, gradients) tuples.
"""
gradients = {}
if self.model in (None, True, False):
return gradients
queue = [self.model]
seen = set()
for node in queue:
@ -199,8 +182,7 @@ class Pipe(object):
"""
serialize = {}
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
if self.model not in (True, False, None):
serialize["model"] = self.model.to_bytes
serialize["model"] = self.model.to_bytes
if hasattr(self, "vocab"):
serialize["vocab"] = self.vocab.to_bytes
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
@ -210,20 +192,15 @@ class Pipe(object):
"""Load the pipe from a bytestring."""
def load_model(b):
# TODO: Remove this once we don't have to handle previous models
if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg:
self.cfg["pretrained_vectors"] = self.vocab.vectors
if self.model is True:
self.model = self.Model(**self.cfg)
try:
self.model.from_bytes(b)
except AttributeError:
raise ValueError(Errors.E149)
deserialize = {}
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
if hasattr(self, "vocab"):
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
deserialize["model"] = load_model
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
util.from_bytes(bytes_data, deserialize, exclude)
@ -234,8 +211,7 @@ class Pipe(object):
serialize = {}
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
if self.model not in (None, True, False):
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes())
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes())
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
util.to_disk(path, serialize, exclude)
@ -243,19 +219,14 @@ class Pipe(object):
"""Load the pipe from disk."""
def load_model(p):
# TODO: Remove this once we don't have to handle previous models
if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg:
self.cfg["pretrained_vectors"] = self.vocab.vectors
if self.model is True:
self.model = self.Model(**self.cfg)
try:
self.model.from_bytes(p.open("rb").read())
except AttributeError:
raise ValueError(Errors.E149)
deserialize = {}
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
deserialize["model"] = load_model
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
util.from_disk(path, deserialize, exclude)
@ -266,31 +237,13 @@ class Pipe(object):
class Tensorizer(Pipe):
"""Pre-train position-sensitive vectors for tokens."""
@classmethod
def Model(cls, output_size=300, **cfg):
"""Create a new statistical model for the class.
width (int): Output size of the model.
embed_size (int): Number of vectors in the embedding table.
**cfg: Config parameters.
RETURNS (Model): A `thinc.model.Model` or similar instance.
"""
input_size = util.env_opt("token_vector_width", cfg.get("input_size", 96))
return Linear(output_size, input_size, init_W=zero_init)
def __init__(self, vocab, model=True, **cfg):
def __init__(self, vocab, model, **cfg):
"""Construct a new statistical model. Weights are not allocated on
initialisation.
vocab (Vocab): A `Vocab` instance. The model must share the same
`Vocab` instance with the `Doc` objects it will process.
model (Model): A `Model` instance or `True` to allocate one later.
**cfg: Config parameters.
EXAMPLE:
>>> from spacy.pipeline import TokenVectorEncoder
>>> tok2vec = TokenVectorEncoder(nlp.vocab)
>>> tok2vec.model = tok2vec.Model(128, 5000)
"""
self.vocab = vocab
self.model = model
@ -337,7 +290,6 @@ class Tensorizer(Pipe):
docs (iterable): A sequence of `Doc` objects.
RETURNS (object): Vector representations for each token in the docs.
"""
self.require_model()
inputs = self.model.ops.flatten([doc.tensor for doc in docs])
outputs = self.model(inputs)
return self.model.ops.unflatten(outputs, [len(d) for d in docs])
@ -362,7 +314,6 @@ class Tensorizer(Pipe):
sgd (callable): An optimizer.
RETURNS (dict): Results from the update.
"""
self.require_model()
examples = Example.to_example_objects(examples)
inputs = []
bp_inputs = []
@ -405,10 +356,8 @@ class Tensorizer(Pipe):
"""
if pipeline is not None:
for name, model in pipeline:
if getattr(model, "tok2vec", None):
self.input_models.append(model.tok2vec)
if self.model is True:
self.model = self.Model(**self.cfg)
if model.has_ref("tok2vec"):
self.input_models.append(model.get_ref("tok2vec"))
self.model.initialize()
link_vectors_to_models(self.vocab)
if sgd is None:
@ -423,7 +372,7 @@ class Tagger(Pipe):
DOCS: https://spacy.io/api/tagger
"""
def __init__(self, vocab, model=True, **cfg):
def __init__(self, vocab, model, **cfg):
self.vocab = vocab
self.model = model
self._rehearsal_model = None
@ -433,13 +382,6 @@ class Tagger(Pipe):
def labels(self):
return tuple(self.vocab.morphology.tag_names)
@property
def tok2vec(self):
if self.model in (None, True, False):
return None
else:
return chain(self.model.get_ref("tok2vec"), list2array())
def __call__(self, example):
doc = self._get_doc(example)
tags = self.predict([doc])
@ -465,7 +407,6 @@ class Tagger(Pipe):
yield from docs
def predict(self, docs):
self.require_model()
if not any(len(doc) for doc in docs):
# Handle cases where there are no tokens in any docs.
n_labels = len(self.labels)
@ -513,7 +454,6 @@ class Tagger(Pipe):
doc.is_tagged = True
def update(self, examples, drop=0., sgd=None, losses=None, set_annotations=False):
self.require_model()
examples = Example.to_example_objects(examples)
if losses is not None and self.name not in losses:
losses[self.name] = 0.
@ -600,52 +540,21 @@ class Tagger(Pipe):
vocab.morphology = Morphology(vocab.strings, new_tag_map,
vocab.morphology.lemmatizer,
exc=vocab.morphology.exc)
self.cfg["pretrained_vectors"] = kwargs.get("pretrained_vectors")
if self.model is True:
for hp in ["token_vector_width", "conv_depth"]:
if hp in kwargs:
self.cfg[hp] = kwargs[hp]
self.model = self.Model(self.vocab.morphology.n_tags, **self.cfg)
self.set_output(len(self.labels))
self.model.initialize()
# Get batch of example docs, example outputs to call begin_training().
# This lets the model infer shapes.
n_tags = self.vocab.morphology.n_tags
for node in self.model.walk():
# TODO: softmax hack ?
if node.name == "softmax" and node.has_dim("nO") is None:
node.set_dim("nO", n_tags)
link_vectors_to_models(self.vocab)
self.model.initialize()
if sgd is None:
sgd = self.create_optimizer()
return sgd
@classmethod
def Model(cls, n_tags=None, **cfg):
if cfg.get("pretrained_dims") and not cfg.get("pretrained_vectors"):
raise ValueError(TempErrors.T008)
if "tok2vec" in cfg:
tok2vec = cfg["tok2vec"]
else:
config = {
"width": cfg.get("token_vector_width", 96),
"embed_size": cfg.get("embed_size", 2000),
"pretrained_vectors": cfg.get("pretrained_vectors", None),
"window_size": cfg.get("window_size", 1),
"cnn_maxout_pieces": cfg.get("cnn_maxout_pieces", 3),
"subword_features": cfg.get("subword_features", True),
"char_embed": cfg.get("char_embed", False),
"conv_depth": cfg.get("conv_depth", 4),
"bilstm_depth": cfg.get("bilstm_depth", 0),
}
tok2vec = Tok2Vec(**config)
return build_tagger_model(n_tags, tok2vec)
def add_label(self, label, values=None):
if not isinstance(label, str):
raise ValueError(Errors.E187)
if label in self.labels:
return 0
if self.model not in (True, False, None):
if self.model.has_dim("nO"):
# Here's how the model resizing will work, once the
# neuron-to-tag mapping is no longer controlled by
# the Morphology class, which sorts the tag names.
@ -672,8 +581,7 @@ class Tagger(Pipe):
def to_bytes(self, exclude=tuple(), **kwargs):
serialize = {}
if self.model not in (None, True, False):
serialize["model"] = self.model.to_bytes
serialize["model"] = self.model.to_bytes
serialize["vocab"] = self.vocab.to_bytes
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
tag_map = dict(sorted(self.vocab.morphology.tag_map.items()))
@ -683,14 +591,6 @@ class Tagger(Pipe):
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
def load_model(b):
# TODO: Remove this once we don't have to handle previous models
if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg:
self.cfg["pretrained_vectors"] = self.vocab.vectors
if self.model is True:
token_vector_width = util.env_opt(
"token_vector_width",
self.cfg.get("token_vector_width", 96))
self.model = self.Model(**self.cfg)
try:
self.model.from_bytes(b)
except AttributeError:
@ -719,18 +619,13 @@ class Tagger(Pipe):
"vocab": lambda p: self.vocab.to_disk(p),
"tag_map": lambda p: srsly.write_msgpack(p, tag_map),
"model": lambda p: p.open("wb").write(self.model.to_bytes()),
"cfg": lambda p: srsly.write_json(p, self.cfg)
"cfg": lambda p: srsly.write_json(p, self.cfg),
}
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
util.to_disk(path, serialize, exclude)
def from_disk(self, path, exclude=tuple(), **kwargs):
def load_model(p):
# TODO: Remove this once we don't have to handle previous models
if self.cfg.get("pretrained_dims") and "pretrained_vectors" not in self.cfg:
self.cfg["pretrained_vectors"] = self.vocab.vectors
if self.model is True:
self.model = self.Model(**self.cfg)
with p.open("rb") as file_:
try:
self.model.from_bytes(file_.read())
@ -745,8 +640,8 @@ class Tagger(Pipe):
exc=self.vocab.morphology.exc)
deserialize = {
"cfg": lambda p: self.cfg.update(_load_cfg(p)),
"vocab": lambda p: self.vocab.from_disk(p),
"cfg": lambda p: self.cfg.update(_load_cfg(p)),
"tag_map": load_tag_map,
"model": load_model,
}
@ -762,16 +657,11 @@ class SentenceRecognizer(Tagger):
DOCS: https://spacy.io/api/sentencerecognizer
"""
def __init__(self, vocab, model=True, **cfg):
def __init__(self, vocab, model, **cfg):
self.vocab = vocab
self.model = model
self._rehearsal_model = None
self.cfg = dict(sorted(cfg.items()))
self.cfg.setdefault("cnn_maxout_pieces", 2)
self.cfg.setdefault("subword_features", True)
self.cfg.setdefault("token_vector_width", 12)
self.cfg.setdefault("conv_depth", 1)
self.cfg.setdefault("pretrained_vectors", None)
@property
def labels(self):
@ -797,7 +687,6 @@ class SentenceRecognizer(Tagger):
doc.c[j].sent_start = -1
def update(self, examples, drop=0., sgd=None, losses=None):
self.require_model()
examples = Example.to_example_objects(examples)
if losses is not None and self.name not in losses:
losses[self.name] = 0.
@ -844,20 +733,12 @@ class SentenceRecognizer(Tagger):
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None,
**kwargs):
cdef Vocab vocab = self.vocab
if self.model is True:
for hp in ["token_vector_width", "conv_depth"]:
if hp in kwargs:
self.cfg[hp] = kwargs[hp]
self.model = self.Model(len(self.labels), **self.cfg)
self.set_output(len(self.labels))
self.model.initialize()
if sgd is None:
sgd = self.create_optimizer()
self.model.initialize()
return sgd
@classmethod
def Model(cls, n_tags, **cfg):
return build_tagger_model(n_tags, **cfg)
def add_label(self, label, values=None):
raise NotImplementedError
@ -867,8 +748,7 @@ class SentenceRecognizer(Tagger):
def to_bytes(self, exclude=tuple(), **kwargs):
serialize = {}
if self.model not in (None, True, False):
serialize["model"] = self.model.to_bytes
serialize["model"] = self.model.to_bytes
serialize["vocab"] = self.vocab.to_bytes
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
@ -876,8 +756,6 @@ class SentenceRecognizer(Tagger):
def from_bytes(self, bytes_data, exclude=tuple(), **kwargs):
def load_model(b):
if self.model is True:
self.model = self.Model(len(self.labels), **self.cfg)
try:
self.model.from_bytes(b)
except AttributeError:
@ -896,15 +774,13 @@ class SentenceRecognizer(Tagger):
serialize = {
"vocab": lambda p: self.vocab.to_disk(p),
"model": lambda p: p.open("wb").write(self.model.to_bytes()),
"cfg": lambda p: srsly.write_json(p, self.cfg)
"cfg": lambda p: srsly.write_json(p, self.cfg),
}
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
util.to_disk(path, serialize, exclude)
def from_disk(self, path, exclude=tuple(), **kwargs):
def load_model(p):
if self.model is True:
self.model = self.Model(len(self.labels), **self.cfg)
with p.open("rb") as file_:
try:
self.model.from_bytes(file_.read())
@ -912,8 +788,8 @@ class SentenceRecognizer(Tagger):
raise ValueError(Errors.E149)
deserialize = {
"cfg": lambda p: self.cfg.update(_load_cfg(p)),
"vocab": lambda p: self.vocab.from_disk(p),
"cfg": lambda p: self.cfg.update(_load_cfg(p)),
"model": load_model,
}
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
@ -927,7 +803,7 @@ class MultitaskObjective(Tagger):
side-objective.
"""
def __init__(self, vocab, model=True, target='dep_tag_offset', **cfg):
def __init__(self, vocab, model, target='dep_tag_offset', **cfg):
self.vocab = vocab
self.model = model
if target == "dep":
@ -947,7 +823,8 @@ class MultitaskObjective(Tagger):
else:
raise ValueError(Errors.E016)
self.cfg = dict(cfg)
self.cfg.setdefault("cnn_maxout_pieces", 2)
# TODO: remove - put in config
self.cfg.setdefault("maxout_pieces", 2)
@property
def labels(self):
@ -969,30 +846,15 @@ class MultitaskObjective(Tagger):
label = self.make_label(i, example.token_annotation)
if label is not None and label not in self.labels:
self.labels[label] = len(self.labels)
if self.model is True:
token_vector_width = util.env_opt("token_vector_width")
self.model = self.Model(len(self.labels), tok2vec=tok2vec)
link_vectors_to_models(self.vocab)
self.model.initialize()
link_vectors_to_models(self.vocab)
if sgd is None:
sgd = self.create_optimizer()
return sgd
@classmethod
def Model(cls, n_tags, tok2vec=None, **cfg):
token_vector_width = util.env_opt("token_vector_width", 96)
model = chain(
tok2vec,
Maxout(nO=token_vector_width*2, nI=token_vector_width, nP=3, dropout=0.0),
LayerNorm(token_vector_width*2),
Softmax(nO=n_tags, nI=token_vector_width*2)
)
return model
def predict(self, docs):
self.require_model()
tokvecs = self.model.tok2vec(docs)
scores = self.model.softmax(tokvecs)
tokvecs = self.model.get_ref("tok2vec")(docs)
scores = self.model.get_ref("softmax")(tokvecs)
return tokvecs, scores
def get_loss(self, examples, scores):
@ -1097,18 +959,7 @@ class MultitaskObjective(Tagger):
class ClozeMultitask(Pipe):
@classmethod
def Model(cls, vocab, tok2vec, **cfg):
output_size = vocab.vectors.data.shape[1]
output_layer = chain(
Maxout(nO=output_size, nI=tok2vec.get_dim("nO"), nP=3, normalize=True, dropout=0.0),
Linear(nO=output_size, nI=output_size, init_W=zero_init)
)
model = chain(tok2vec, output_layer)
model = masked_language_model(vocab, model)
return model
def __init__(self, vocab, model=True, **cfg):
def __init__(self, vocab, model, **cfg):
self.vocab = vocab
self.model = model
self.cfg = cfg
@ -1120,19 +971,16 @@ class ClozeMultitask(Pipe):
def begin_training(self, get_examples=lambda: [], pipeline=None,
tok2vec=None, sgd=None, **kwargs):
link_vectors_to_models(self.vocab)
if self.model is True:
self.model = self.Model(self.vocab, tok2vec)
X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
self.model.initialize()
X = self.model.ops.alloc((5, self.model.get_ref("tok2vec").get_dim("nO")))
self.model.output_layer.begin_training(X)
if sgd is None:
sgd = self.create_optimizer()
return sgd
def predict(self, docs):
self.require_model()
tokvecs = self.model.tok2vec(docs)
vectors = self.model.output_layer(tokvecs)
tokvecs = self.model.get_ref("tok2vec")(docs)
vectors = self.model.get_ref("output_layer")(tokvecs)
return tokvecs, vectors
def get_loss(self, examples, vectors, prediction):
@ -1150,7 +998,6 @@ class ClozeMultitask(Pipe):
pass
def rehearse(self, examples, drop=0., sgd=None, losses=None):
self.require_model()
examples = Example.to_example_objects(examples)
if losses is not None and self.name not in losses:
losses[self.name] = 0.
@ -1171,62 +1018,11 @@ class TextCategorizer(Pipe):
DOCS: https://spacy.io/api/textcategorizer
"""
@classmethod
def Model(cls, nr_class=1, exclusive_classes=None, **cfg):
if nr_class == 1:
exclusive_classes = False
if exclusive_classes is None:
raise ValueError(
"TextCategorizer Model must specify 'exclusive_classes'. "
"This setting determines whether the model will output "
"scores that sum to 1 for each example. If only one class "
"is true for each example, you should set exclusive_classes=True. "
"For 'multi_label' classification, set exclusive_classes=False."
)
if "embed_size" not in cfg:
cfg["embed_size"] = util.env_opt("embed_size", 2000)
if "token_vector_width" not in cfg:
cfg["token_vector_width"] = util.env_opt("token_vector_width", 96)
if cfg.get("architecture") == "bow":
return build_bow_text_classifier(nr_class, exclusive_classes, **cfg)
else:
if "tok2vec" in cfg:
tok2vec = cfg["tok2vec"]
else:
config = {
"width": cfg.get("token_vector_width", 96),
"embed_size": cfg.get("embed_size", 2000),
"pretrained_vectors": cfg.get("pretrained_vectors", None),
"window_size": cfg.get("window_size", 1),
"cnn_maxout_pieces": cfg.get("cnn_maxout_pieces", 3),
"subword_features": cfg.get("subword_features", True),
"char_embed": cfg.get("char_embed", False),
"conv_depth": cfg.get("conv_depth", 4),
"bilstm_depth": cfg.get("bilstm_depth", 0),
}
tok2vec = Tok2Vec(**config)
return build_simple_cnn_text_classifier(
tok2vec,
nr_class,
exclusive_classes,
**cfg
)
@property
def tok2vec(self):
if self.model in (None, True, False):
return None
else:
return self.model.tok2vec
def __init__(self, vocab, model=True, **cfg):
def __init__(self, vocab, model, **cfg):
self.vocab = vocab
self.model = model
self._rehearsal_model = None
self.cfg = dict(cfg)
if "exclusive_classes" not in cfg:
self.cfg["exclusive_classes"] = True
@property
def labels(self):
@ -1255,7 +1051,6 @@ class TextCategorizer(Pipe):
yield from docs
def predict(self, docs):
self.require_model()
tensors = [doc.tensor for doc in docs]
if not any(len(doc) for doc in docs):
@ -1274,7 +1069,6 @@ class TextCategorizer(Pipe):
doc.cats[label] = float(scores[i, j])
def update(self, examples, state=None, drop=0., set_annotations=False, sgd=None, losses=None):
self.require_model()
examples = Example.to_example_objects(examples)
if not any(len(ex.doc) if ex.doc else 0 for ex in examples):
# Handle cases where there are no tokens in any docs.
@ -1311,7 +1105,7 @@ class TextCategorizer(Pipe):
losses.setdefault(self.name, 0.0)
losses[self.name] += (gradient**2).sum()
def get_loss(self, examples, scores):
def _examples_to_truth(self, examples):
golds = [ex.gold for ex in examples]
truths = numpy.zeros((len(golds), len(self.labels)), dtype="f")
not_missing = numpy.ones((len(golds), len(self.labels)), dtype="f")
@ -1322,6 +1116,10 @@ class TextCategorizer(Pipe):
else:
not_missing[i, j] = 0.
truths = self.model.ops.asarray(truths)
return truths, not_missing
def get_loss(self, examples, scores):
truths, not_missing = self._examples_to_truth(examples)
not_missing = self.model.ops.asarray(not_missing)
d_scores = (scores-truths) / scores.shape[0]
d_scores *= not_missing
@ -1333,7 +1131,7 @@ class TextCategorizer(Pipe):
raise ValueError(Errors.E187)
if label in self.labels:
return 0
if self.model not in (None, True, False):
if self.model.has_dim("nO"):
# This functionality was available previously, but was broken.
# The problem is that we resize the last layer, but the last layer
# is actually just an ensemble. We're not resizing the child layers
@ -1348,19 +1146,18 @@ class TextCategorizer(Pipe):
return 1
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs):
for example in get_examples():
# TODO: begin_training is not guaranteed to see all data / labels ?
examples = list(get_examples())
for example in examples:
for cat in example.doc_annotation.cats:
self.add_label(cat)
if self.model is True:
self.cfg.update(kwargs)
self.require_labels()
self.model = self.Model(len(self.labels), **self.cfg)
link_vectors_to_models(self.vocab)
self.require_labels()
docs = [Doc(Vocab(), words=["hello"])]
truths, _ = self._examples_to_truth(examples)
self.set_output(len(self.labels))
self.model.initialize(X=docs, Y=truths)
if sgd is None:
sgd = self.create_optimizer()
# TODO: use get_examples instead
docs = [Doc(Vocab(), words=["hello"])]
self.model.initialize(X=docs)
return sgd
@ -1393,7 +1190,7 @@ cdef class DependencyParser(Parser):
def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg):
for labeller in self._multitasks:
tok2vec = self.model.tok2vec
tok2vec = self.model.get_ref("tok2vec")
labeller.begin_training(get_examples, pipeline=pipeline,
tok2vec=tok2vec, sgd=sgd)
@ -1423,7 +1220,6 @@ cdef class EntityRecognizer(Parser):
assigns = ["doc.ents", "token.ent_iob", "token.ent_type"]
requires = []
TransitionSystem = BiluoPushDown
nr_feature = 6
def add_multitask_objective(self, target):
if target == "cloze":
@ -1435,7 +1231,7 @@ cdef class EntityRecognizer(Parser):
def init_multitask_objectives(self, get_examples, pipeline, sgd=None, **cfg):
for labeller in self._multitasks:
tok2vec = self.model.tok2vec
tok2vec = self.model.get_ref("tok2vec")
labeller.begin_training(get_examples, pipeline=pipeline,
tok2vec=tok2vec)
@ -1464,18 +1260,9 @@ class EntityLinker(Pipe):
"""
NIL = "NIL" # string used to refer to a non-existing link
@classmethod
def Model(cls, **cfg):
embed_width = cfg.get("embed_width", 300)
hidden_width = cfg.get("hidden_width", 128)
type_to_int = cfg.get("type_to_int", dict())
model = build_nel_encoder(embed_width=embed_width, hidden_width=hidden_width, ner_types=len(type_to_int), **cfg)
return model
def __init__(self, vocab, **cfg):
def __init__(self, vocab, model, **cfg):
self.vocab = vocab
self.model = True
self.model = model
self.kb = None
self.cfg = dict(cfg)
self.distance = CosineDistance(normalize=False)
@ -1483,11 +1270,6 @@ class EntityLinker(Pipe):
def set_kb(self, kb):
self.kb = kb
def require_model(self):
# Raise an error if the component's model is not initialized.
if getattr(self, "model", None) in (None, True, False):
raise ValueError(Errors.E109.format(name=self.name))
def require_kb(self):
# Raise an error if the knowledge base is not initialized.
if getattr(self, "kb", None) in (None, True, False):
@ -1495,16 +1277,14 @@ class EntityLinker(Pipe):
def begin_training(self, get_examples=lambda: [], pipeline=None, sgd=None, **kwargs):
self.require_kb()
self.cfg["entity_width"] = self.kb.entity_vector_length
if self.model is True:
self.model = self.Model(**self.cfg)
nO = self.kb.entity_vector_length
self.set_output(nO)
self.model.initialize()
if sgd is None:
sgd = self.create_optimizer()
return sgd
def update(self, examples, state=None, set_annotations=False, drop=0.0, sgd=None, losses=None):
self.require_model()
self.require_kb()
if losses is not None:
losses.setdefault(self.name, 0.0)
@ -1614,7 +1394,6 @@ class EntityLinker(Pipe):
def predict(self, docs):
""" Return the KB IDs for each entity in each doc, including NIL if there is no prediction """
self.require_model()
self.require_kb()
entity_count = 0
@ -1714,15 +1493,12 @@ class EntityLinker(Pipe):
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["kb"] = lambda p: self.kb.dump(p)
if self.model not in (None, True, False):
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes())
serialize["model"] = lambda p: p.open("wb").write(self.model.to_bytes())
exclude = util.get_serialization_exclude(serialize, exclude, kwargs)
util.to_disk(path, serialize, exclude)
def from_disk(self, path, exclude=tuple(), **kwargs):
def load_model(p):
if self.model is True:
self.model = self.Model(**self.cfg)
try:
self.model.from_bytes(p.open("rb").read())
except AttributeError:
@ -1734,8 +1510,8 @@ class EntityLinker(Pipe):
self.set_kb(kb)
deserialize = {}
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["cfg"] = lambda p: self.cfg.update(_load_cfg(p))
deserialize["kb"] = load_kb
deserialize["model"] = load_model
exclude = util.get_serialization_exclude(deserialize, exclude, kwargs)
@ -1782,7 +1558,7 @@ class Sentencizer(Pipe):
self.punct_chars = set(self.default_punct_chars)
@classmethod
def from_nlp(cls, nlp, **cfg):
def from_nlp(cls, nlp, model=None, **cfg):
return cls(**cfg)
def __call__(self, example):
@ -1915,8 +1691,8 @@ class Sentencizer(Pipe):
# Cython classes can't be decorated, so we need to add the factories here
Language.factories["parser"] = lambda nlp, **cfg: DependencyParser.from_nlp(nlp, **cfg)
Language.factories["ner"] = lambda nlp, **cfg: EntityRecognizer.from_nlp(nlp, **cfg)
Language.factories["parser"] = lambda nlp, model, **cfg: DependencyParser.from_nlp(nlp, model, **cfg)
Language.factories["ner"] = lambda nlp, model, **cfg: EntityRecognizer.from_nlp(nlp, model, **cfg)
__all__ = ["Tagger", "DependencyParser", "EntityRecognizer", "Tensorizer", "TextCategorizer", "EntityLinker", "Sentencizer", "SentenceRecognizer"]

View File

@ -5,32 +5,21 @@ from ..gold import Example
from ..tokens import Doc
from ..vocab import Vocab
from ..language import component
from ..util import link_vectors_to_models, minibatch, registry, eg2doc
from ..util import link_vectors_to_models, minibatch, eg2doc
@component("tok2vec", assigns=["doc.tensor"])
class Tok2Vec(Pipe):
@classmethod
def from_nlp(cls, nlp, **cfg):
return cls(nlp.vocab, **cfg)
@classmethod
def Model(cls, architecture, **cfg):
"""Create a new statistical model for the class.
def from_nlp(cls, nlp, model, **cfg):
return cls(nlp.vocab, model, **cfg)
architecture (str): The registered model architecture to use.
**cfg: Config parameters.
RETURNS (Model): A `thinc.model.Model` or similar instance.
"""
model = registry.architectures.get(architecture)
return model(**cfg)
def __init__(self, vocab, model=True, **cfg):
def __init__(self, vocab, model, **cfg):
"""Construct a new statistical model. Weights are not allocated on
initialisation.
vocab (Vocab): A `Vocab` instance. The model must share the same `Vocab`
instance with the `Doc` objects it will process.
model (Model): A `Model` instance or `True` to allocate one later.
**cfg: Config parameters.
"""
self.vocab = vocab
@ -143,8 +132,6 @@ class Tok2Vec(Pipe):
get_examples (function): Function returning example training data.
pipeline (list): The pipeline the model is part of.
"""
if self.model is True:
self.model = self.Model(**self.cfg)
# TODO: use examples instead ?
docs = [Doc(Vocab(), words=["hello"])]
self.model.initialize(X=docs)

View File

@ -221,7 +221,10 @@ cdef int arg_max_if_valid(const weight_t* scores, const int* is_valid, int n) no
class ParserModel(Model):
def __init__(self, tok2vec, lower_model, upper_model, unseen_classes=None):
Model.__init__(self, name="parser_model", forward=forward)
# don't define nO for this object, because we can't dynamically change it
Model.__init__(self, name="parser_model", forward=forward, dims={"nI": None})
if tok2vec.has_dim("nI"):
self.set_dim("nI", tok2vec.get_dim("nI"))
self._layers = [tok2vec, lower_model]
if upper_model is not None:
self._layers.append(upper_model)
@ -229,6 +232,7 @@ class ParserModel(Model):
if unseen_classes:
for class_ in unseen_classes:
self.unseen_classes.add(class_)
self.set_ref("tok2vec", tok2vec)
def predict(self, docs):
step_model = ParserStepModel(docs, self._layers,
@ -238,25 +242,32 @@ class ParserModel(Model):
def resize_output(self, new_nO):
if len(self._layers) == 2:
return
if new_nO == self.upper.get_dim("nO"):
if self.upper.has_dim("nO") and (new_nO == self.upper.get_dim("nO")):
return
smaller = self.upper
nI = smaller.get_dim("nI")
nI = None
if smaller.has_dim("nI"):
nI = smaller.get_dim("nI")
with use_ops('numpy'):
larger = Linear(new_nO, nI)
larger_W = larger.ops.alloc2f(new_nO, nI)
larger_b = larger.ops.alloc1f(new_nO)
smaller_W = smaller.get_param("W")
smaller_b = smaller.get_param("b")
# Weights are stored in (nr_out, nr_in) format, so we're basically
# just adding rows here.
larger_W[:smaller.get_dim("nO")] = smaller_W
larger_b[:smaller.get_dim("nO")] = smaller_b
larger.set_param("W", larger_W)
larger.set_param("b", larger_b)
larger = Linear(nO=new_nO, nI=nI)
larger._init = smaller._init
# it could be that the model is not initialized yet, then skip this bit
if nI:
larger_W = larger.ops.alloc2f(new_nO, nI)
larger_b = larger.ops.alloc1f(new_nO)
smaller_W = smaller.get_param("W")
smaller_b = smaller.get_param("b")
# Weights are stored in (nr_out, nr_in) format, so we're basically
# just adding rows here.
if smaller.has_dim("nO"):
larger_W[:smaller.get_dim("nO")] = smaller_W
larger_b[:smaller.get_dim("nO")] = smaller_b
for i in range(smaller.get_dim("nO"), new_nO):
self.unseen_classes.add(i)
larger.set_param("W", larger_W)
larger.set_param("b", larger_b)
self._layers[-1] = larger
for i in range(smaller.get_dim("nO"), new_nO):
self.unseen_classes.add(i)
def initialize(self, X=None, Y=None):
self.tok2vec.initialize()
@ -412,7 +423,7 @@ cdef class precompute_hiddens:
we can do all our hard maths up front, packed into large multiplications,
and do the hard-to-program parsing on the CPU.
"""
cdef readonly int nF, nO, nP # TODO: make these more like the dimensions in thinc
cdef readonly int nF, nO, nP
cdef bint _is_synchronized
cdef public object ops
cdef np.ndarray _features
@ -458,6 +469,16 @@ cdef class precompute_hiddens:
self._is_synchronized = True
return <float*>self._cached.data
def has_dim(self, name):
if name == "nF":
return self.nF if self.nF is not None else True
elif name == "nP":
return self.nP if self.nP is not None else True
elif name == "nO":
return self.nO if self.nO is not None else True
else:
return False
def get_dim(self, name):
if name == "nF":
return self.nF
@ -468,6 +489,16 @@ cdef class precompute_hiddens:
else:
raise ValueError(f"Dimension {name} invalid -- only nO, nF, nP")
def set_dim(self, name, value):
if name == "nF":
self.nF = value
elif name == "nP":
self.nP = value
elif name == "nO":
self.nO = value
else:
raise ValueError(f"Dimension {name} invalid -- only nO, nF, nP")
def __call__(self, X, bint is_train):
if is_train:
return self.begin_update(X)

View File

@ -27,11 +27,11 @@ from ._parser_model cimport predict_states, arg_max_if_valid
from ._parser_model cimport WeightsC, ActivationsC, SizesC, cpu_log_loss
from ._parser_model cimport get_c_weights, get_c_sizes
from ._parser_model import ParserModel
from ..util import link_vectors_to_models, create_default_optimizer
from ..util import link_vectors_to_models, create_default_optimizer, registry
from ..compat import copy_array
from ..tokens.doc cimport Doc
from ..gold cimport GoldParse
from ..errors import Errors, TempErrors
from ..errors import Errors, user_warning, Warnings
from .. import util
from .stateclass cimport StateClass
from ._state cimport StateC
@ -41,114 +41,42 @@ from . import _beam_utils
from . import nonproj
from ..ml._layers import PrecomputableAffine
from ..ml.component_models import Tok2Vec
cdef class Parser:
"""
Base class of the DependencyParser and EntityRecognizer.
"""
@classmethod
def Model(cls, nr_class, **cfg):
depth = util.env_opt('parser_hidden_depth', cfg.get('hidden_depth', 1))
subword_features = util.env_opt('subword_features',
cfg.get('subword_features', True))
conv_depth = util.env_opt('conv_depth', cfg.get('conv_depth', 4))
conv_window = util.env_opt('conv_window', cfg.get('conv_window', 1))
t2v_pieces = util.env_opt('cnn_maxout_pieces', cfg.get('cnn_maxout_pieces', 3))
bilstm_depth = util.env_opt('bilstm_depth', cfg.get('bilstm_depth', 0))
self_attn_depth = util.env_opt('self_attn_depth', cfg.get('self_attn_depth', 0))
nr_feature_tokens = cfg.get("nr_feature_tokens", cls.nr_feature)
if depth not in (0, 1):
raise ValueError(TempErrors.T004.format(value=depth))
parser_maxout_pieces = util.env_opt('parser_maxout_pieces',
cfg.get('maxout_pieces', 2))
token_vector_width = util.env_opt('token_vector_width',
cfg.get('token_vector_width', 96))
hidden_width = util.env_opt('hidden_width', cfg.get('hidden_width', 64))
if depth == 0:
hidden_width = nr_class
parser_maxout_pieces = 1
embed_size = util.env_opt('embed_size', cfg.get('embed_size', 2000))
pretrained_vectors = cfg.get('pretrained_vectors', None)
tok2vec = Tok2Vec(width=token_vector_width,
embed_size=embed_size,
conv_depth=conv_depth,
window_size=conv_window,
cnn_maxout_pieces=t2v_pieces,
subword_features=subword_features,
pretrained_vectors=pretrained_vectors,
bilstm_depth=bilstm_depth)
tok2vec = chain(tok2vec, list2array())
tok2vec.set_dim("nO", token_vector_width)
lower = PrecomputableAffine(hidden_width,
nF=nr_feature_tokens, nI=token_vector_width,
nP=parser_maxout_pieces)
lower.set_dim("nP", parser_maxout_pieces)
if depth == 1:
with use_ops('numpy'):
upper = Linear(nr_class, hidden_width, init_W=zero_init)
else:
upper = None
cfg = {
'nr_class': nr_class,
'nr_feature_tokens': nr_feature_tokens,
'hidden_depth': depth,
'token_vector_width': token_vector_width,
'hidden_width': hidden_width,
'maxout_pieces': parser_maxout_pieces,
'pretrained_vectors': pretrained_vectors,
'bilstm_depth': bilstm_depth,
'self_attn_depth': self_attn_depth,
'conv_depth': conv_depth,
'window_size': conv_window,
'embed_size': embed_size,
'cnn_maxout_pieces': t2v_pieces
}
model = ParserModel(tok2vec, lower, upper)
model.initialize()
return model, cfg
name = 'base_parser'
def __init__(self, Vocab vocab, moves=True, model=True, **cfg):
def __init__(self, Vocab vocab, model, **cfg):
"""Create a Parser.
vocab (Vocab): The vocabulary object. Must be shared with documents
to be processed. The value is set to the `.vocab` attribute.
moves (TransitionSystem): Defines how the parse-state is created,
updated and evaluated. The value is set to the .moves attribute
unless True (default), in which case a new instance is created with
`Parser.Moves()`.
model (object): Defines how the parse-state is created, updated and
evaluated. The value is set to the .model attribute. If set to True
(default), a new instance will be created with `Parser.Model()`
in parser.begin_training(), parser.from_disk() or parser.from_bytes().
**cfg: Arbitrary configuration parameters. Set to the `.cfg` attribute
**cfg: Configuration parameters. Set to the `.cfg` attribute.
If it doesn't include a value for 'moves', a new instance is
created with `self.TransitionSystem()`. This defines how the
parse-state is created, updated and evaluated.
"""
self.vocab = vocab
if moves is True:
self.moves = self.TransitionSystem(self.vocab.strings)
else:
self.moves = moves
if 'beam_width' not in cfg:
cfg['beam_width'] = util.env_opt('beam_width', 1)
if 'beam_density' not in cfg:
cfg['beam_density'] = util.env_opt('beam_density', 0.0)
if 'beam_update_prob' not in cfg:
cfg['beam_update_prob'] = util.env_opt('beam_update_prob', 1.0)
cfg.setdefault('cnn_maxout_pieces', 3)
cfg.setdefault("nr_feature_tokens", self.nr_feature)
self.cfg = cfg
moves = cfg.get("moves", None)
if moves is None:
# defined by EntityRecognizer as a BiluoPushDown
moves = self.TransitionSystem(self.vocab.strings)
self.moves = moves
cfg.setdefault('min_action_freq', 30)
cfg.setdefault('learn_tokens', False)
cfg.setdefault('beam_width', 1)
cfg.setdefault('beam_update_prob', 1.0) # or 0.5 (both defaults were previously used)
self.model = model
self.set_output(self.moves.n_moves)
self.cfg = cfg
self._multitasks = []
self._rehearsal_model = None
@classmethod
def from_nlp(cls, nlp, **cfg):
return cls(nlp.vocab, **cfg)
def from_nlp(cls, nlp, model, **cfg):
return cls(nlp.vocab, model, **cfg)
def __reduce__(self):
return (Parser, (self.vocab, self.moves, self.model), None, None)
@ -163,8 +91,6 @@ cdef class Parser:
names.append(name)
return names
nr_feature = 8
@property
def labels(self):
class_names = [self.moves.get_class_name(i) for i in range(self.moves.n_moves)]
@ -173,7 +99,7 @@ cdef class Parser:
@property
def tok2vec(self):
'''Return the embedding and convolutional layer of the model.'''
return None if self.model in (None, True, False) else self.model.tok2vec
return self.model.tok2vec
@property
def postprocesses(self):
@ -190,10 +116,7 @@ cdef class Parser:
self._resize()
def _resize(self):
if "nr_class" in self.cfg:
self.cfg["nr_class"] = self.moves.n_moves
if self.model not in (True, False, None):
self.model.resize_output(self.moves.n_moves)
self.model.resize_output(self.moves.n_moves)
if self._rehearsal_model not in (True, False, None):
self._rehearsal_model.resize_output(self.moves.n_moves)
@ -227,7 +150,7 @@ cdef class Parser:
doc (Doc): The document to be processed.
"""
if beam_width is None:
beam_width = self.cfg.get('beam_width', 1)
beam_width = self.cfg['beam_width']
beam_density = self.cfg.get('beam_density', 0.)
states = self.predict([doc], beam_width=beam_width,
beam_density=beam_density)
@ -243,7 +166,7 @@ cdef class Parser:
YIELDS (Doc): Documents, in order.
"""
if beam_width is None:
beam_width = self.cfg.get('beam_width', 1)
beam_width = self.cfg['beam_width']
beam_density = self.cfg.get('beam_density', 0.)
cdef Doc doc
for batch in util.minibatch(docs, size=batch_size):
@ -264,13 +187,7 @@ cdef class Parser:
else:
yield from batch_in_order
def require_model(self):
"""Raise an error if the component's model is not initialized."""
if getattr(self, 'model', None) in (None, True, False):
raise ValueError(Errors.E109.format(name=self.name))
def predict(self, docs, beam_width=1, beam_density=0.0, drop=0.):
self.require_model()
if isinstance(docs, Doc):
docs = [docs]
if not any(len(doc) for doc in docs):
@ -313,11 +230,11 @@ cdef class Parser:
# if labels are missing. We therefore have to check whether we need to
# expand our model output.
self._resize()
cdef int nr_feature = self.model.lower.get_dim("nF")
model = self.model.predict(docs)
token_ids = numpy.zeros((len(docs) * beam_width, self.nr_feature),
token_ids = numpy.zeros((len(docs) * beam_width, nr_feature),
dtype='i', order='C')
cdef int* c_ids
cdef int nr_feature = self.cfg["nr_feature_tokens"]
cdef int n_states
model = self.model.predict(docs)
todo = [beam for beam in beams if not beam.is_done]
@ -430,7 +347,6 @@ cdef class Parser:
return [b for b in beams if not b.is_done]
def update(self, examples, drop=0., set_annotations=False, sgd=None, losses=None):
self.require_model()
examples = Example.to_example_objects(examples)
if losses is None:
@ -440,9 +356,9 @@ cdef class Parser:
multitask.update(examples, drop=drop, sgd=sgd)
# The probability we use beam update, instead of falling back to
# a greedy update
beam_update_prob = self.cfg.get('beam_update_prob', 0.5)
if self.cfg.get('beam_width', 1) >= 2 and numpy.random.random() < beam_update_prob:
return self.update_beam(examples, self.cfg.get('beam_width', 1),
beam_update_prob = self.cfg['beam_update_prob']
if self.cfg['beam_width'] >= 2 and numpy.random.random() < beam_update_prob:
return self.update_beam(examples, self.cfg['beam_width'],
drop=drop, sgd=sgd, losses=losses, set_annotations=set_annotations,
beam_density=self.cfg.get('beam_density', 0.001))
@ -533,7 +449,7 @@ cdef class Parser:
set_dropout_rate(self.model, drop)
model, backprop_tok2vec = self.model.begin_update(docs)
states_d_scores, backprops, beams = _beam_utils.update_beam(
self.moves, self.cfg["nr_feature_tokens"], 10000, states, golds,
self.moves, self.model.lower.get_dim("nF"), 10000, states, golds,
model.state2vec, model.vec2scores, width, losses=losses,
beam_density=beam_density)
for i, d_scores in enumerate(states_d_scores):
@ -562,8 +478,6 @@ cdef class Parser:
keyed by the parameter ID. The values are (weights, gradients) tuples.
"""
gradients = {}
if self.model in (None, True, False):
return gradients
queue = [self.model]
seen = set()
for node in queue:
@ -647,45 +561,40 @@ cdef class Parser:
def create_optimizer(self):
return create_default_optimizer()
def begin_training(self, get_examples, pipeline=None, sgd=None, **cfg):
if 'model' in cfg:
self.model = cfg['model']
def set_output(self, nO):
if self.model.upper.has_dim("nO") is None:
self.model.upper.set_dim("nO", nO)
def begin_training(self, get_examples, pipeline=None, sgd=None, **kwargs):
self.cfg.update(kwargs)
if not hasattr(get_examples, '__call__'):
gold_tuples = get_examples
get_examples = lambda: gold_tuples
cfg.setdefault('min_action_freq', 30)
actions = self.moves.get_actions(gold_parses=get_examples(),
min_freq=cfg.get('min_action_freq', 30),
learn_tokens=self.cfg.get("learn_tokens", False))
min_freq=self.cfg['min_action_freq'],
learn_tokens=self.cfg["learn_tokens"])
for action, labels in self.moves.labels.items():
actions.setdefault(action, {})
for label, freq in labels.items():
if label not in actions[action]:
actions[action][label] = freq
self.moves.initialize_actions(actions)
cfg.setdefault('token_vector_width', 96)
if self.model is True:
self.model, cfg = self.Model(self.moves.n_moves, **cfg)
if sgd is None:
sgd = self.create_optimizer()
doc_sample = []
gold_sample = []
for example in islice(get_examples(), 1000):
parses = example.get_gold_parses(merge=False, vocab=self.vocab)
for doc, gold in parses:
doc_sample.append(doc)
gold_sample.append(gold)
self.model.initialize(doc_sample, gold_sample)
if pipeline is not None:
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **cfg)
link_vectors_to_models(self.vocab)
else:
if sgd is None:
sgd = self.create_optimizer()
if self.model.upper.has_dim("nO") is None:
self.model.upper.set_dim("nO", self.moves.n_moves)
self.model.initialize()
self.cfg.update(cfg)
# make sure we resize so we have an appropriate upper layer
self._resize()
if sgd is None:
sgd = self.create_optimizer()
doc_sample = []
gold_sample = []
for example in islice(get_examples(), 1000):
parses = example.get_gold_parses(merge=False, vocab=self.vocab)
for doc, gold in parses:
doc_sample.append(doc)
gold_sample.append(gold)
self.model.initialize(doc_sample, gold_sample)
if pipeline is not None:
self.init_multitask_objectives(get_examples, pipeline, sgd=sgd, **self.cfg)
link_vectors_to_models(self.vocab)
return sgd
def _get_doc(self, example):
@ -709,28 +618,24 @@ cdef class Parser:
'vocab': lambda p: self.vocab.from_disk(p),
'moves': lambda p: self.moves.from_disk(p, exclude=["strings"]),
'cfg': lambda p: self.cfg.update(srsly.read_json(p)),
'model': lambda p: None
'model': lambda p: None,
}
exclude = util.get_serialization_exclude(deserializers, exclude, kwargs)
util.from_disk(path, deserializers, exclude)
if 'model' not in exclude:
path = util.ensure_path(path)
if self.model is True:
self.model, cfg = self.Model(**self.cfg)
else:
cfg = {}
with (path / 'model').open('rb') as file_:
bytes_data = file_.read()
try:
self._resize()
self.model.from_bytes(bytes_data)
except AttributeError:
raise ValueError(Errors.E149)
self.cfg.update(cfg)
return self
def to_bytes(self, exclude=tuple(), **kwargs):
serializers = {
"model": lambda: (self.model.to_bytes() if self.model is not True else True),
"model": lambda: (self.model.to_bytes()),
"vocab": lambda: self.vocab.to_bytes(),
"moves": lambda: self.moves.to_bytes(exclude=["strings"]),
"cfg": lambda: srsly.json_dumps(self.cfg, indent=2, sort_keys=True)
@ -743,22 +648,14 @@ cdef class Parser:
"vocab": lambda b: self.vocab.from_bytes(b),
"moves": lambda b: self.moves.from_bytes(b, exclude=["strings"]),
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
"model": lambda b: None
"model": lambda b: None,
}
exclude = util.get_serialization_exclude(deserializers, exclude, kwargs)
msg = util.from_bytes(bytes_data, deserializers, exclude)
if 'model' not in exclude:
# TODO: Remove this once we don't have to handle previous models
if self.cfg.get('pretrained_dims') and 'pretrained_vectors' not in self.cfg:
self.cfg['pretrained_vectors'] = self.vocab.vectors
if self.model is True:
self.model, cfg = self.Model(**self.cfg)
else:
cfg = {}
if 'model' in msg:
try:
self.model.from_bytes(msg['model'])
except AttributeError:
raise ValueError(Errors.E149)
self.cfg.update(cfg)
return self

View File

@ -3,12 +3,13 @@ from spacy.tokens import Span
import pytest
from ..util import get_doc
from ...ml.models.defaults import default_ner
def test_doc_add_entities_set_ents_iob(en_vocab):
text = ["This", "is", "a", "lion"]
doc = get_doc(en_vocab, text)
ner = EntityRecognizer(en_vocab)
ner = EntityRecognizer(en_vocab, default_ner())
ner.begin_training([])
ner(doc)
assert len(list(doc.ents)) == 0
@ -24,7 +25,7 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
def test_ents_reset(en_vocab):
text = ["This", "is", "a", "lion"]
doc = get_doc(en_vocab, text)
ner = EntityRecognizer(en_vocab)
ner = EntityRecognizer(en_vocab, default_ner())
ner.begin_training([])
ner(doc)
assert [t.ent_iob_ for t in doc] == (["O"] * len(doc))

View File

@ -3,6 +3,8 @@ from thinc.api import Adam, NumpyOps
from spacy.attrs import NORM
from spacy.gold import GoldParse
from spacy.vocab import Vocab
from spacy.ml.models.defaults import default_parser, default_ner
from spacy.tokens import Doc
from spacy.pipeline import DependencyParser, EntityRecognizer
from spacy.util import fix_random_seed
@ -15,7 +17,7 @@ def vocab():
@pytest.fixture
def parser(vocab):
parser = DependencyParser(vocab)
parser = DependencyParser(vocab, default_parser())
return parser
@ -55,27 +57,31 @@ def test_add_label(parser):
def test_add_label_deserializes_correctly():
ner1 = EntityRecognizer(Vocab())
ner1 = EntityRecognizer(Vocab(), default_ner())
ner1.add_label("C")
ner1.add_label("B")
ner1.add_label("A")
ner1.begin_training([])
ner2 = EntityRecognizer(Vocab()).from_bytes(ner1.to_bytes())
ner2 = EntityRecognizer(Vocab(), default_ner())
# the second model needs to be resized before we can call from_bytes
ner2.model.resize_output(ner1.moves.n_moves)
ner2.from_bytes(ner1.to_bytes())
assert ner1.moves.n_moves == ner2.moves.n_moves
for i in range(ner1.moves.n_moves):
assert ner1.moves.get_class_name(i) == ner2.moves.get_class_name(i)
@pytest.mark.parametrize(
"pipe_cls,n_moves", [(DependencyParser, 5), (EntityRecognizer, 4)]
"pipe_cls,n_moves,model", [(DependencyParser, 5, default_parser()), (EntityRecognizer, 4, default_ner())]
)
def test_add_label_get_label(pipe_cls, n_moves):
def test_add_label_get_label(pipe_cls, n_moves, model):
"""Test that added labels are returned correctly. This test was added to
test for a bug in DependencyParser.labels that'd cause it to fail when
splitting the move names.
"""
labels = ["A", "B", "C"]
pipe = pipe_cls(Vocab())
pipe = pipe_cls(Vocab(), model)
for label in labels:
pipe.add_label(label)
assert len(pipe.move_names) == len(labels) * n_moves

View File

@ -1,5 +1,7 @@
import pytest
from spacy.vocab import Vocab
from spacy.ml.models.defaults import default_parser
from spacy.pipeline import DependencyParser
from spacy.tokens import Doc
from spacy.gold import GoldParse
@ -136,7 +138,7 @@ def test_get_oracle_actions():
deps.append(dep)
ents.append(ent)
doc = Doc(Vocab(), words=[t[1] for t in annot_tuples])
parser = DependencyParser(doc.vocab)
parser = DependencyParser(doc.vocab, default_parser())
parser.moves.add_action(0, "")
parser.moves.add_action(1, "")
parser.moves.add_action(1, "")

View File

@ -1,10 +1,15 @@
import pytest
from spacy import util
from spacy.lang.en import English
from spacy.ml.models.defaults import default_ner
from spacy.pipeline import EntityRecognizer, EntityRuler
from spacy.vocab import Vocab
from spacy.syntax.ner import BiluoPushDown
from spacy.gold import GoldParse
from spacy.tests.util import make_tempdir
from spacy.tokens import Doc
TRAIN_DATA = [
@ -134,7 +139,7 @@ def test_accept_blocked_token():
# 1. test normal behaviour
nlp1 = English()
doc1 = nlp1("I live in New York")
ner1 = EntityRecognizer(doc1.vocab)
ner1 = EntityRecognizer(doc1.vocab, default_ner())
assert [token.ent_iob_ for token in doc1] == ["", "", "", "", ""]
assert [token.ent_type_ for token in doc1] == ["", "", "", "", ""]
@ -152,7 +157,7 @@ def test_accept_blocked_token():
# 2. test blocking behaviour
nlp2 = English()
doc2 = nlp2("I live in New York")
ner2 = EntityRecognizer(doc2.vocab)
ner2 = EntityRecognizer(doc2.vocab, default_ner())
# set "New York" to a blocked entity
doc2.ents = [(0, 3, 5)]
@ -188,7 +193,7 @@ def test_overwrite_token():
assert [token.ent_type_ for token in doc] == ["", "", "", "", ""]
# Check that a new ner can overwrite O
ner2 = EntityRecognizer(doc.vocab)
ner2 = EntityRecognizer(doc.vocab, default_ner())
ner2.moves.add_action(5, "")
ner2.add_label("GPE")
state = ner2.moves.init_batch([doc])[0]
@ -199,6 +204,17 @@ def test_overwrite_token():
assert ner2.moves.is_valid(state, "L-GPE")
def test_empty_ner():
nlp = English()
ner = nlp.create_pipe("ner")
ner.add_label("MY_LABEL")
nlp.add_pipe(ner)
nlp.begin_training()
doc = nlp("John is watching the news about Croatia's elections")
# if this goes wrong, the initialization of the parser's upper layer is probably broken
assert [token.ent_iob_ for token in doc] == ['O', 'O', 'O', 'O', 'O', 'O', 'O', 'O', 'O']
def test_ruler_before_ner():
""" Test that an NER works after an entity_ruler: the second can add annotations """
nlp = English()
@ -214,7 +230,6 @@ def test_ruler_before_ner():
untrained_ner.add_label("MY_LABEL")
nlp.add_pipe(untrained_ner)
nlp.begin_training()
doc = nlp("This is Antti Korhonen speaking in Finland")
expected_iobs = ["B", "O", "O", "O", "O", "O", "O"]
expected_types = ["THING", "", "", "", "", "", ""]
@ -261,28 +276,7 @@ def test_block_ner():
assert [token.ent_type_ for token in doc] == expected_types
def test_change_number_features():
# Test the default number features
nlp = English()
ner = nlp.create_pipe("ner")
nlp.add_pipe(ner)
ner.add_label("PERSON")
nlp.begin_training()
assert ner.model.lower.get_dim("nF") == ner.nr_feature
# Test we can change it
nlp = English()
ner = nlp.create_pipe("ner")
nlp.add_pipe(ner)
ner.add_label("PERSON")
nlp.begin_training(
component_cfg={"ner": {"nr_feature_tokens": 3, "token_vector_width": 128}}
)
assert ner.model.lower.get_dim("nF") == 3
# Test the model runs
nlp("hello world")
def test_overfitting():
def test_overfitting_IO():
# Simple test to try and quickly overfit the NER component - ensuring the ML models work correctly
nlp = English()
ner = nlp.create_pipe("ner")
@ -301,11 +295,20 @@ def test_overfitting():
test_text = "I like London."
doc = nlp(test_text)
ents = doc.ents
assert len(ents) == 1
assert ents[0].text == "London"
assert ents[0].label_ == "LOC"
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
ents2 = doc2.ents
assert len(ents2) == 1
assert ents2[0].text == "London"
assert ents2[0].label_ == "LOC"
class BlockerComponent1(object):
name = "my_blocker"

View File

@ -1,8 +1,9 @@
import pytest
from spacy.ml.component_models import Tok2Vec
from spacy.ml.models.defaults import default_parser, default_tok2vec
from spacy.vocab import Vocab
from spacy.syntax.arc_eager import ArcEager
from spacy.syntax.nn_parser import Parser
from spacy.syntax._parser_model import ParserModel
from spacy.tokens.doc import Doc
from spacy.gold import GoldParse
@ -20,19 +21,22 @@ def arc_eager(vocab):
@pytest.fixture
def tok2vec():
tok2vec = Tok2Vec(8, 100)
tok2vec = default_tok2vec()
tok2vec.initialize()
return tok2vec
@pytest.fixture
def parser(vocab, arc_eager):
return Parser(vocab, moves=arc_eager, model=None)
return Parser(vocab, model=default_parser(), moves=arc_eager)
@pytest.fixture
def model(arc_eager, tok2vec):
return Parser.Model(arc_eager.n_moves, token_vector_width=tok2vec.get_dim("nO"))[0]
def model(arc_eager, tok2vec, vocab):
model = default_parser()
model.resize_output(arc_eager.n_moves)
model.initialize()
return model
@pytest.fixture
@ -46,11 +50,11 @@ def gold(doc):
def test_can_init_nn_parser(parser):
assert parser.model is None
assert isinstance(parser.model, ParserModel)
def test_build_model(parser):
parser.model = Parser.Model(parser.moves.n_moves, hist_size=0)[0]
def test_build_model(parser, vocab):
parser.model = Parser(vocab, model=default_parser(), moves=parser.moves).model
assert parser.model is not None

View File

@ -2,6 +2,7 @@ import pytest
import numpy
from spacy.vocab import Vocab
from spacy.language import Language
from spacy.ml.models.defaults import default_parser
from spacy.pipeline import DependencyParser
from spacy.syntax.arc_eager import ArcEager
from spacy.tokens import Doc
@ -93,7 +94,7 @@ def test_beam_advance_too_few_scores(beam, scores):
def test_beam_parse():
nlp = Language()
nlp.add_pipe(DependencyParser(nlp.vocab), name="parser")
nlp.add_pipe(DependencyParser(nlp.vocab, default_parser()), name="parser")
nlp.parser.add_label("nsubj")
nlp.parser.begin_training([], token_vector_width=8, hidden_width=8)
doc = nlp.make_doc("Australia is a country")

View File

@ -1,7 +1,8 @@
import pytest
from spacy.lang.en import English
from ..util import get_doc, apply_transition_sequence
from ..util import get_doc, apply_transition_sequence, make_tempdir
from ... import util
TRAIN_DATA = [
(
@ -182,7 +183,7 @@ def test_parser_set_sent_starts(en_vocab):
assert token.head in sent
def test_overfitting():
def test_overfitting_IO():
# Simple test to try and quickly overfit the dependency parser - ensuring the ML models work correctly
nlp = English()
parser = nlp.create_pipe("parser")
@ -200,7 +201,15 @@ def test_overfitting():
# test the trained model
test_text = "I like securities."
doc = nlp(test_text)
assert doc[0].dep_ is "nsubj"
assert doc[2].dep_ is "dobj"
assert doc[3].dep_ is "punct"
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
assert doc2[0].dep_ is "nsubj"
assert doc2[2].dep_ is "dobj"
assert doc2[3].dep_ is "punct"

View File

@ -3,6 +3,8 @@ from thinc.api import Adam
from spacy.attrs import NORM
from spacy.gold import GoldParse
from spacy.vocab import Vocab
from spacy.ml.models.defaults import default_parser
from spacy.tokens import Doc
from spacy.pipeline import DependencyParser
@ -14,7 +16,7 @@ def vocab():
@pytest.fixture
def parser(vocab):
parser = DependencyParser(vocab)
parser = DependencyParser(vocab, default_parser())
parser.cfg["token_vector_width"] = 4
parser.cfg["hidden_width"] = 32
# parser.add_label('right')

View File

@ -111,7 +111,8 @@ def test_component_factories_from_nlp():
nlp.add_pipe(pipe)
assert nlp("hello world")
# The first argument here is the class itself, so we're accepting any here
mock.assert_called_once_with(ANY, nlp, foo="bar")
# The model will be initialized to None by the factory
mock.assert_called_once_with(ANY, nlp, None, foo="bar")
def test_analysis_validate_attrs_valid():

View File

@ -1,5 +1,9 @@
import pytest
from spacy import util
from spacy.lang.en import English
from spacy.language import Language
from spacy.tests.util import make_tempdir
def test_label_types():
@ -18,9 +22,9 @@ TRAIN_DATA = [
]
def test_overfitting():
def test_overfitting_IO():
# Simple test to try and quickly overfit the tagger - ensuring the ML models work correctly
nlp = Language()
nlp = English()
tagger = nlp.create_pipe("tagger")
for tag, values in TAG_MAP.items():
tagger.add_label(tag, values)
@ -35,8 +39,17 @@ def test_overfitting():
# test the trained model
test_text = "I like blue eggs"
doc = nlp(test_text)
assert doc[0].tag_ is "N"
assert doc[1].tag_ is "V"
assert doc[2].tag_ is "J"
assert doc[3].tag_ is "N"
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
assert doc2[0].tag_ is "N"
assert doc2[1].tag_ is "V"
assert doc2[2].tag_ is "J"
assert doc2[3].tag_ is "N"

View File

@ -1,8 +1,12 @@
import pytest
import random
import numpy.random
from spacy import util
from spacy.lang.en import English
from spacy.language import Language
from spacy.pipeline import TextCategorizer
from spacy.tests.util import make_tempdir
from spacy.tokens import Doc
from spacy.gold import GoldParse
@ -74,9 +78,9 @@ def test_label_types():
nlp.get_pipe("textcat").add_label(9)
def test_overfitting():
def test_overfitting_IO():
# Simple test to try and quickly overfit the textcat component - ensuring the ML models work correctly
nlp = Language()
nlp = English()
textcat = nlp.create_pipe("textcat")
for _, annotations in TRAIN_DATA:
for label, value in annotations.get("cats").items():
@ -87,11 +91,21 @@ def test_overfitting():
for i in range(50):
losses = {}
nlp.update(TRAIN_DATA, sgd=optimizer, losses=losses)
assert losses["textcat"] < 0.00001
assert losses["textcat"] < 0.01
# test the trained model
test_text = "I am happy."
doc = nlp(test_text)
cats = doc.cats
# note that by default, exclusive_classes = false so we need a bigger error margin
assert cats["POSITIVE"] > 0.9
assert cats["POSITIVE"] + cats["NEGATIVE"] == pytest.approx(1.0, 0.001)
assert cats["POSITIVE"] + cats["NEGATIVE"] == pytest.approx(1.0, 0.1)
# Also test the results are still the same after IO
with make_tempdir() as tmp_dir:
nlp.to_disk(tmp_dir)
nlp2 = util.load_model_from_path(tmp_dir)
doc2 = nlp2(test_text)
cats2 = doc2.cats
assert cats2["POSITIVE"] > 0.9
assert cats2["POSITIVE"] + cats2["NEGATIVE"] == pytest.approx(1.0, 0.1)

View File

@ -10,6 +10,7 @@ from spacy.lang.lex_attrs import is_stop
from spacy.vectors import Vectors
from spacy.vocab import Vocab
from spacy.language import Language
from spacy.ml.models.defaults import default_ner, default_tagger
from spacy.tokens import Doc, Span, Token
from spacy.pipeline import Tagger, EntityRecognizer
from spacy.attrs import HEAD, DEP
@ -123,7 +124,7 @@ def test_issue1727():
correctly after vectors are added."""
data = numpy.ones((3, 300), dtype="f")
vectors = Vectors(data=data, keys=["I", "am", "Matt"])
tagger = Tagger(Vocab())
tagger = Tagger(Vocab(), default_tagger())
tagger.add_label("PRP")
with pytest.warns(UserWarning):
tagger.begin_training()
@ -131,7 +132,7 @@ def test_issue1727():
tagger.vocab.vectors = vectors
with make_tempdir() as path:
tagger.to_disk(path)
tagger = Tagger(Vocab()).from_disk(path)
tagger = Tagger(Vocab(), default_tagger()).from_disk(path)
assert tagger.cfg.get("pretrained_dims", 0) == 0
@ -236,6 +237,7 @@ def test_issue1889(word):
assert is_stop(word, STOP_WORDS) == is_stop(word.upper(), STOP_WORDS)
@pytest.mark.skip(reason="This test has become obsolete with the config refactor of v.3")
def test_issue1915():
cfg = {"hidden_depth": 2} # should error out
nlp = Language()
@ -268,7 +270,7 @@ def test_issue1963(en_tokenizer):
@pytest.mark.parametrize("label", ["U-JOB-NAME"])
def test_issue1967(label):
ner = EntityRecognizer(Vocab())
ner = EntityRecognizer(Vocab(), default_ner())
example = Example(doc=None)
example.set_token_annotation(
ids=[0], words=["word"], tags=["tag"], heads=[0], deps=["dep"], entities=[label]

View File

@ -32,6 +32,9 @@ def test_issue2179():
nlp.begin_training()
nlp2 = Italian()
nlp2.add_pipe(nlp2.create_pipe("ner"))
assert len(nlp2.get_pipe("ner").labels) == 0
nlp2.get_pipe("ner").model.resize_output(nlp.get_pipe("ner").moves.n_moves)
nlp2.from_bytes(nlp.to_bytes())
assert "extra_labels" not in nlp2.get_pipe("ner").cfg
assert nlp2.get_pipe("ner").labels == ("CITIZENSHIP",)

View File

@ -1,6 +1,7 @@
import pytest
from spacy.lang.en import English
from spacy.lang.de import German
from spacy.ml.models.defaults import default_ner
from spacy.pipeline import EntityRuler, EntityRecognizer
from spacy.matcher import Matcher, PhraseMatcher
from spacy.tokens import Doc
@ -103,6 +104,7 @@ def test_issue3209():
assert ner.move_names == move_names
nlp2 = English()
nlp2.add_pipe(nlp2.create_pipe("ner"))
nlp2.get_pipe("ner").model.resize_output(ner.moves.n_moves)
nlp2.from_bytes(nlp.to_bytes())
assert nlp2.get_pipe("ner").move_names == move_names
@ -193,7 +195,7 @@ def test_issue3345():
doc = Doc(nlp.vocab, words=["I", "live", "in", "New", "York"])
doc[4].is_sent_start = True
ruler = EntityRuler(nlp, patterns=[{"label": "GPE", "pattern": "New York"}])
ner = EntityRecognizer(doc.vocab)
ner = EntityRecognizer(doc.vocab, default_ner())
# Add the OUT action. I wouldn't have thought this would be necessary...
ner.moves.add_action(5, "")
ner.add_label("GPE")

View File

@ -1,10 +1,12 @@
from spacy.pipeline.pipes import DependencyParser
from spacy.vocab import Vocab
from spacy.ml.models.defaults import default_parser
def test_issue3830_no_subtok():
"""Test that the parser doesn't have subtok label if not learn_tokens"""
parser = DependencyParser(Vocab())
parser = DependencyParser(Vocab(), default_parser())
parser.add_label("nsubj")
assert "subtok" not in parser.labels
parser.begin_training(lambda: [])
@ -13,7 +15,7 @@ def test_issue3830_no_subtok():
def test_issue3830_with_subtok():
"""Test that the parser does have subtok label if learn_tokens=True."""
parser = DependencyParser(Vocab(), learn_tokens=True)
parser = DependencyParser(Vocab(), default_parser(), learn_tokens=True)
parser.add_label("nsubj")
assert "subtok" not in parser.labels
parser.begin_training(lambda: [])

View File

@ -3,6 +3,7 @@ from spacy.pipeline import EntityRecognizer, EntityRuler
from spacy.lang.en import English
from spacy.tokens import Span
from spacy.util import ensure_path
from spacy.ml.models.defaults import default_ner
from ..util import make_tempdir
@ -73,6 +74,6 @@ def test_issue4042_bug2():
output_dir.mkdir()
ner1.to_disk(output_dir)
ner2 = EntityRecognizer(vocab)
ner2 = EntityRecognizer(vocab, default_ner())
ner2.from_disk(output_dir)
assert len(ner2.labels) == 2

View File

@ -1,5 +1,6 @@
from collections import defaultdict
from spacy.ml.models.defaults import default_ner
from spacy.pipeline import EntityRecognizer
from spacy.lang.en import English
@ -11,7 +12,7 @@ def test_issue4313():
beam_width = 16
beam_density = 0.0001
nlp = English()
ner = EntityRecognizer(nlp.vocab)
ner = EntityRecognizer(nlp.vocab, default_ner())
ner.add_label("SOME_LABEL")
ner.begin_training([])
nlp.add_pipe(ner)

View File

@ -0,0 +1,126 @@
from thinc.api import Config
import spacy
from spacy import util
from spacy.lang.en import English
from spacy.util import registry
from ..util import make_tempdir
from ...ml.models import build_Tok2Vec_model, build_tb_parser_model
nlp_config_string = """
[nlp]
lang = "en"
[nlp.pipeline.tok2vec]
factory = "tok2vec"
[nlp.pipeline.tok2vec.model]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 342
depth = 4
window_size = 1
embed_size = 2000
maxout_pieces = 3
subword_features = true
[nlp.pipeline.tagger]
factory = "tagger"
[nlp.pipeline.tagger.model]
@architectures = "spacy.Tagger.v1"
[nlp.pipeline.tagger.model.tok2vec]
@architectures = "spacy.Tok2VecTensors.v1"
width = ${nlp.pipeline.tok2vec.model:width}
"""
parser_config_string = """
[model]
@architectures = "spacy.TransitionBasedParser.v1"
nr_feature_tokens = 99
hidden_width = 66
maxout_pieces = 2
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 333
depth = 4
embed_size = 5555
window_size = 1
maxout_pieces = 7
subword_features = false
"""
@registry.architectures.register("my_test_parser")
def my_parser():
tok2vec = build_Tok2Vec_model(width=321, embed_size=5432, pretrained_vectors=None, window_size=3,
maxout_pieces=4, subword_features=True, char_embed=True, nM=64, nC=8,
conv_depth=2, bilstm_depth=0)
parser = build_tb_parser_model(tok2vec=tok2vec, nr_feature_tokens=7, hidden_width=65, maxout_pieces=5)
return parser
def test_serialize_nlp():
""" Create a custom nlp pipeline from config and ensure it serializes it correctly """
nlp_config = Config().from_str(nlp_config_string)
nlp = util.load_model_from_config(nlp_config["nlp"])
nlp.begin_training()
assert "tok2vec" in nlp.pipe_names
assert "tagger" in nlp.pipe_names
assert "parser" not in nlp.pipe_names
assert nlp.get_pipe("tagger").model.get_ref("tok2vec").get_dim("nO") == 342
with make_tempdir() as d:
nlp.to_disk(d)
nlp2 = spacy.load(d)
assert "tok2vec" in nlp2.pipe_names
assert "tagger" in nlp2.pipe_names
assert "parser" not in nlp2.pipe_names
assert nlp2.get_pipe("tagger").model.get_ref("tok2vec").get_dim("nO") == 342
def test_serialize_custom_nlp():
""" Create a custom nlp pipeline and ensure it serializes it correctly"""
nlp = English()
parser_cfg = dict()
parser_cfg["model"] = {'@architectures': "my_test_parser"}
parser = nlp.create_pipe("parser", parser_cfg)
nlp.add_pipe(parser)
nlp.begin_training()
with make_tempdir() as d:
nlp.to_disk(d)
nlp2 = spacy.load(d)
model = nlp2.get_pipe("parser").model
tok2vec = model.get_ref("tok2vec")
upper = model.upper
# check that we have the correct settings, not the default ones
assert tok2vec.get_dim("nO") == 321
assert upper.get_dim("nI") == 65
def test_serialize_parser():
""" Create a non-default parser config to check nlp serializes it correctly """
nlp = English()
model_config = Config().from_str(parser_config_string)
parser = nlp.create_pipe("parser", config=model_config)
parser.add_label("nsubj")
nlp.add_pipe(parser)
nlp.begin_training()
with make_tempdir() as d:
nlp.to_disk(d)
nlp2 = spacy.load(d)
model = nlp2.get_pipe("parser").model
tok2vec = model.get_ref("tok2vec")
upper = model.upper
# check that we have the correct settings, not the default ones
assert upper.get_dim("nI") == 66
assert tok2vec.get_dim("nO") == 333

View File

@ -1,5 +1,6 @@
import pytest
import re
from spacy.language import Language
from spacy.tokenizer import Tokenizer
@ -56,7 +57,7 @@ def test_serialize_language_exclude(meta_data):
nlp = Language(meta=meta_data)
assert nlp.meta["name"] == name
new_nlp = Language().from_bytes(nlp.to_bytes())
assert nlp.meta["name"] == name
assert new_nlp.meta["name"] == name
new_nlp = Language().from_bytes(nlp.to_bytes(), exclude=["meta"])
assert not new_nlp.meta["name"] == name
new_nlp = Language().from_bytes(nlp.to_bytes(exclude=["meta"]))

View File

@ -1,6 +1,7 @@
import pytest
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
from spacy.pipeline import Tensorizer, TextCategorizer, SentenceRecognizer
from spacy.ml.models.defaults import default_parser, default_tensorizer, default_tagger, default_textcat, default_sentrec
from ..util import make_tempdir
@ -10,58 +11,58 @@ test_parsers = [DependencyParser, EntityRecognizer]
@pytest.fixture
def parser(en_vocab):
parser = DependencyParser(en_vocab)
parser = DependencyParser(en_vocab, default_parser())
parser.add_label("nsubj")
parser.model, cfg = parser.Model(parser.moves.n_moves)
parser.cfg.update(cfg)
return parser
@pytest.fixture
def blank_parser(en_vocab):
parser = DependencyParser(en_vocab)
parser = DependencyParser(en_vocab, default_parser())
return parser
@pytest.fixture
def taggers(en_vocab):
tagger1 = Tagger(en_vocab)
tagger2 = Tagger(en_vocab)
tagger1.model = tagger1.Model(8)
tagger2.model = tagger1.model
return (tagger1, tagger2)
model = default_tagger()
tagger1 = Tagger(en_vocab, model)
tagger2 = Tagger(en_vocab, model)
return tagger1, tagger2
@pytest.mark.parametrize("Parser", test_parsers)
def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
parser = Parser(en_vocab)
parser.model, _ = parser.Model(10)
new_parser = Parser(en_vocab)
new_parser.model, _ = new_parser.Model(10)
parser = Parser(en_vocab, default_parser())
new_parser = Parser(en_vocab, default_parser())
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
assert new_parser.to_bytes(exclude=["vocab"]) == parser.to_bytes(exclude=["vocab"])
bytes_2 = new_parser.to_bytes(exclude=["vocab"])
bytes_3 = parser.to_bytes(exclude=["vocab"])
assert len(bytes_2) == len(bytes_3)
assert bytes_2 == bytes_3
@pytest.mark.parametrize("Parser", test_parsers)
def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
parser = Parser(en_vocab)
parser.model, _ = parser.Model(0)
parser = Parser(en_vocab, default_parser())
with make_tempdir() as d:
file_path = d / "parser"
parser.to_disk(file_path)
parser_d = Parser(en_vocab)
parser_d.model, _ = parser_d.Model(0)
parser_d = Parser(en_vocab, default_parser())
parser_d = parser_d.from_disk(file_path)
parser_bytes = parser.to_bytes(exclude=["model", "vocab"])
parser_d_bytes = parser_d.to_bytes(exclude=["model", "vocab"])
assert len(parser_bytes) == len(parser_d_bytes)
assert parser_bytes == parser_d_bytes
def test_to_from_bytes(parser, blank_parser):
assert parser.model is not True
assert blank_parser.model is True
assert blank_parser.model is not True
assert blank_parser.moves.n_moves != parser.moves.n_moves
bytes_data = parser.to_bytes(exclude=["vocab"])
# the blank parser needs to be resized before we can call from_bytes
blank_parser.model.resize_output(parser.moves.n_moves)
blank_parser.from_bytes(bytes_data)
assert blank_parser.model is not True
assert blank_parser.moves.n_moves == parser.moves.n_moves
@ -75,8 +76,10 @@ def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers):
tagger1_b = tagger1.to_bytes()
tagger1 = tagger1.from_bytes(tagger1_b)
assert tagger1.to_bytes() == tagger1_b
new_tagger1 = Tagger(en_vocab).from_bytes(tagger1_b)
assert new_tagger1.to_bytes() == tagger1_b
new_tagger1 = Tagger(en_vocab, default_tagger()).from_bytes(tagger1_b)
new_tagger1_b = new_tagger1.to_bytes()
assert len(new_tagger1_b) == len(tagger1_b)
assert new_tagger1_b == tagger1_b
def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
@ -86,26 +89,24 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
file_path2 = d / "tagger2"
tagger1.to_disk(file_path1)
tagger2.to_disk(file_path2)
tagger1_d = Tagger(en_vocab).from_disk(file_path1)
tagger2_d = Tagger(en_vocab).from_disk(file_path2)
tagger1_d = Tagger(en_vocab, default_tagger()).from_disk(file_path1)
tagger2_d = Tagger(en_vocab, default_tagger()).from_disk(file_path2)
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
def test_serialize_tensorizer_roundtrip_bytes(en_vocab):
tensorizer = Tensorizer(en_vocab)
tensorizer.model = tensorizer.Model()
tensorizer = Tensorizer(en_vocab, default_tensorizer())
tensorizer_b = tensorizer.to_bytes(exclude=["vocab"])
new_tensorizer = Tensorizer(en_vocab).from_bytes(tensorizer_b)
new_tensorizer = Tensorizer(en_vocab, default_tensorizer()).from_bytes(tensorizer_b)
assert new_tensorizer.to_bytes(exclude=["vocab"]) == tensorizer_b
def test_serialize_tensorizer_roundtrip_disk(en_vocab):
tensorizer = Tensorizer(en_vocab)
tensorizer.model = tensorizer.Model()
tensorizer = Tensorizer(en_vocab, default_tensorizer())
with make_tempdir() as d:
file_path = d / "tensorizer"
tensorizer.to_disk(file_path)
tensorizer_d = Tensorizer(en_vocab).from_disk(file_path)
tensorizer_d = Tensorizer(en_vocab, default_tensorizer()).from_disk(file_path)
assert tensorizer.to_bytes(exclude=["vocab"]) == tensorizer_d.to_bytes(
exclude=["vocab"]
)
@ -113,19 +114,17 @@ def test_serialize_tensorizer_roundtrip_disk(en_vocab):
def test_serialize_textcat_empty(en_vocab):
# See issue #1105
textcat = TextCategorizer(en_vocab, labels=["ENTITY", "ACTION", "MODIFIER"])
textcat = TextCategorizer(en_vocab, default_textcat(), labels=["ENTITY", "ACTION", "MODIFIER"])
textcat.to_bytes(exclude=["vocab"])
@pytest.mark.parametrize("Parser", test_parsers)
def test_serialize_pipe_exclude(en_vocab, Parser):
def get_new_parser():
new_parser = Parser(en_vocab)
new_parser.model, _ = new_parser.Model(0)
new_parser = Parser(en_vocab, default_parser())
return new_parser
parser = Parser(en_vocab)
parser.model, _ = parser.Model(0)
parser = Parser(en_vocab, default_parser())
parser.cfg["foo"] = "bar"
new_parser = get_new_parser().from_bytes(parser.to_bytes(exclude=["vocab"]))
assert "foo" in new_parser.cfg
@ -144,7 +143,7 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
def test_serialize_sentencerecognizer(en_vocab):
sr = SentenceRecognizer(en_vocab)
sr = SentenceRecognizer(en_vocab, default_sentrec())
sr_b = sr.to_bytes()
sr_d = SentenceRecognizer(en_vocab).from_bytes(sr_b)
sr_d = SentenceRecognizer(en_vocab, default_sentrec()).from_bytes(sr_b)
assert sr.to_bytes() == sr_d.to_bytes()

View File

@ -1,6 +1,6 @@
import pytest
from spacy.ml.component_models import Tok2Vec
from spacy.ml.models.tok2vec import build_Tok2Vec_model
from spacy.vocab import Vocab
from spacy.tokens import Doc
@ -25,7 +25,8 @@ def test_empty_doc():
embed_size = 2000
vocab = Vocab()
doc = Doc(vocab, words=[])
tok2vec = Tok2Vec(width, embed_size)
# TODO: fix tok2vec arguments
tok2vec = build_Tok2Vec_model(width, embed_size)
vectors, backprop = tok2vec.begin_update([doc])
assert len(vectors) == 1
assert vectors[0].shape == (0, width)
@ -36,7 +37,19 @@ def test_empty_doc():
)
def test_tok2vec_batch_sizes(batch_size, width, embed_size):
batch = get_batch(batch_size)
tok2vec = Tok2Vec(width, embed_size)
tok2vec = build_Tok2Vec_model(
width,
embed_size,
pretrained_vectors=None,
conv_depth=4,
bilstm_depth=0,
window_size=1,
maxout_pieces=3,
subword_features=True,
char_embed=False,
nM=64,
nC=8,
)
tok2vec.initialize()
vectors, backprop = tok2vec.begin_update(batch)
assert len(vectors) == len(batch)
@ -44,19 +57,24 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
assert doc_vec.shape == (len(doc), width)
# fmt: off
@pytest.mark.parametrize(
"tok2vec_config",
[
{"width": 8, "embed_size": 100, "char_embed": False},
{"width": 8, "embed_size": 100, "char_embed": True},
{"width": 8, "embed_size": 100, "conv_depth": 6},
{"width": 8, "embed_size": 100, "conv_depth": 6},
{"width": 8, "embed_size": 100, "subword_features": False},
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True},
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True},
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True},
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 6, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": True},
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 1, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False},
{"width": 8, "embed_size": 100, "char_embed": False, "nM": 64, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False},
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 8, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False},
{"width": 8, "embed_size": 100, "char_embed": True, "nM": 81, "nC": 9, "pretrained_vectors": None, "window_size": 3, "conv_depth": 2, "bilstm_depth": 0, "maxout_pieces": 3, "subword_features": False},
],
)
# fmt: on
def test_tok2vec_configs(tok2vec_config):
docs = get_batch(3)
tok2vec = Tok2Vec(**tok2vec_config)
tok2vec = build_Tok2Vec_model(**tok2vec_config)
tok2vec.initialize()
vectors, backprop = tok2vec.begin_update(docs)
assert len(vectors) == len(docs)

View File

@ -6,8 +6,7 @@ from pathlib import Path
import random
from typing import List
import thinc
import thinc.config
from thinc.api import NumpyOps, get_current_ops, Adam, require_gpu
from thinc.api import NumpyOps, get_current_ops, Adam, require_gpu, Config
import functools
import itertools
import numpy.random
@ -146,6 +145,10 @@ def load_model_from_path(model_path, meta=False, **overrides):
pipeline from meta.json and then calls from_disk() with path."""
if not meta:
meta = get_model_meta(model_path)
nlp_config = get_model_config(model_path)
if nlp_config.get("nlp", None):
return load_model_from_config(nlp_config["nlp"])
# Support language factories registered via entry points (e.g. custom
# language subclass) while keeping top-level language identifier "lang"
lang = meta.get("lang_factory", meta["lang"])
@ -162,11 +165,30 @@ def load_model_from_path(model_path, meta=False, **overrides):
if name not in disable:
config = meta.get("pipeline_args", {}).get(name, {})
factory = factories.get(name, name)
if nlp_config.get(name, None):
model_config = nlp_config[name]["model"]
config["model"] = model_config
component = nlp.create_pipe(factory, config=config)
nlp.add_pipe(component, name=name)
return nlp.from_disk(model_path, exclude=disable)
def load_model_from_config(nlp_config):
if "name" in nlp_config:
nlp = load_model(**nlp_config)
elif "lang" in nlp_config:
lang_class = get_lang_class(nlp_config["lang"])
nlp = lang_class()
else:
raise ValueError(Errors.E993)
if "pipeline" in nlp_config:
for name, component_cfg in nlp_config["pipeline"].items():
factory = component_cfg.pop("factory")
component = nlp.create_pipe(factory, config=component_cfg)
nlp.add_pipe(component, name=name)
return nlp
def load_model_from_init_py(init_file, **overrides):
"""Helper function to use in the `load()` method of a model package's
__init__.py.
@ -184,7 +206,7 @@ def load_model_from_init_py(init_file, **overrides):
return load_model_from_path(data_path, meta, **overrides)
def load_from_config(path, create_objects=False):
def load_config(path, create_objects=False):
"""Load a Thinc-formatted config file, optionally filling in objects where
the config references registry entries. See "Thinc config files" for details.
@ -212,7 +234,7 @@ def get_model_meta(path):
raise IOError(Errors.E052.format(path=model_path))
meta_path = model_path / "meta.json"
if not meta_path.is_file():
raise IOError(Errors.E053.format(path=meta_path))
raise IOError(Errors.E053.format(path=meta_path, name="meta.json"))
meta = srsly.read_json(meta_path)
for setting in ["lang", "name", "version"]:
if setting not in meta or not meta[setting]:
@ -220,6 +242,23 @@ def get_model_meta(path):
return meta
def get_model_config(path):
"""Get the model's config from a directory path.
path (unicode or Path): Path to model directory.
RETURNS (Config): The model's config data.
"""
model_path = ensure_path(path)
if not model_path.exists():
raise IOError(Errors.E052.format(path=model_path))
config_path = model_path / "config.cfg"
# model directories are allowed not to have config files ?
if not config_path.is_file():
return Config({})
# raise IOError(Errors.E053.format(path=config_path, name="config.cfg"))
return Config().from_disk(config_path)
def is_package(name):
"""Check if string maps to a package installed via pip.