Upd train

This commit is contained in:
Matthew Honnibal 2020-09-28 01:08:30 +02:00
commit a3e1791c9c
32 changed files with 210 additions and 126 deletions

View File

@ -6,7 +6,7 @@ requires = [
"cymem>=2.0.2,<2.1.0", "cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0", "preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0", "murmurhash>=0.28.0,<1.1.0",
"thinc>=8.0.0a36,<8.0.0a40", "thinc>=8.0.0a41,<8.0.0a50",
"blis>=0.4.0,<0.5.0", "blis>=0.4.0,<0.5.0",
"pytokenizations", "pytokenizations",
"pathy" "pathy"

View File

@ -1,7 +1,7 @@
# Our libraries # Our libraries
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=8.0.0a36,<8.0.0a40 thinc>=8.0.0a41,<8.0.0a50
blis>=0.4.0,<0.5.0 blis>=0.4.0,<0.5.0
ml_datasets==0.2.0a0 ml_datasets==0.2.0a0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0

View File

@ -34,13 +34,13 @@ setup_requires =
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
thinc>=8.0.0a36,<8.0.0a40 thinc>=8.0.0a41,<8.0.0a50
install_requires = install_requires =
# Our libraries # Our libraries
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=8.0.0a36,<8.0.0a40 thinc>=8.0.0a41,<8.0.0a50
blis>=0.4.0,<0.5.0 blis>=0.4.0,<0.5.0
wasabi>=0.8.0,<1.1.0 wasabi>=0.8.0,<1.1.0
srsly>=2.1.0,<3.0.0 srsly>=2.1.0,<3.0.0

View File

@ -243,6 +243,8 @@ def show_validation_error(
yield yield
except ConfigValidationError as e: except ConfigValidationError as e:
title = title if title is not None else e.title title = title if title is not None else e.title
if e.desc:
desc = f"{e.desc}" if not desc else f"{e.desc}\n\n{desc}"
# Re-generate a new error object with overrides # Re-generate a new error object with overrides
err = e.from_error(e, title="", desc=desc, show_config=show_config) err = e.from_error(e, title="", desc=desc, show_config=show_config)
msg.fail(title) msg.fail(title)

View File

@ -51,9 +51,10 @@ def debug_config(
msg.divider("Config validation") msg.divider("Config validation")
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config(config_path, overrides=overrides) config = util.load_config(config_path, overrides=overrides)
nlp, resolved = util.load_model_from_config(config) nlp = util.load_model_from_config(config)
# Use the resolved config here in case user has one function returning # Use the resolved config here in case user has one function returning
# a dict of corpora etc. # a dict of corpora etc.
resolved = util.resolve_training_config(nlp.config)
check_section_refs(resolved, ["training.dev_corpus", "training.train_corpus"]) check_section_refs(resolved, ["training.dev_corpus", "training.train_corpus"])
msg.good("Config is valid") msg.good("Config is valid")
if show_vars: if show_vars:

View File

@ -93,18 +93,19 @@ def debug_data(
msg.fail("Config file not found", config_path, exists=1) msg.fail("Config file not found", config_path, exists=1)
with show_validation_error(config_path): with show_validation_error(config_path):
cfg = util.load_config(config_path, overrides=config_overrides) cfg = util.load_config(config_path, overrides=config_overrides)
nlp, config = util.load_model_from_config(cfg) nlp = util.load_model_from_config(cfg)
C = util.resolve_training_config(nlp.config)
# Use original config here, not resolved version # Use original config here, not resolved version
sourced_components = get_sourced_components(cfg) sourced_components = get_sourced_components(cfg)
frozen_components = config["training"]["frozen_components"] frozen_components = C["training"]["frozen_components"]
resume_components = [p for p in sourced_components if p not in frozen_components] resume_components = [p for p in sourced_components if p not in frozen_components]
pipeline = nlp.pipe_names pipeline = nlp.pipe_names
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names] factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
tag_map_path = util.ensure_path(config["training"]["tag_map"]) tag_map_path = util.ensure_path(C["training"]["tag_map"])
tag_map = {} tag_map = {}
if tag_map_path is not None: if tag_map_path is not None:
tag_map = srsly.read_json(tag_map_path) tag_map = srsly.read_json(tag_map_path)
morph_rules_path = util.ensure_path(config["training"]["morph_rules"]) morph_rules_path = util.ensure_path(C["training"]["morph_rules"])
morph_rules = {} morph_rules = {}
if morph_rules_path is not None: if morph_rules_path is not None:
morph_rules = srsly.read_json(morph_rules_path) morph_rules = srsly.read_json(morph_rules_path)
@ -144,10 +145,10 @@ def debug_data(
train_texts = gold_train_data["texts"] train_texts = gold_train_data["texts"]
dev_texts = gold_dev_data["texts"] dev_texts = gold_dev_data["texts"]
frozen_components = config["training"]["frozen_components"] frozen_components = C["training"]["frozen_components"]
msg.divider("Training stats") msg.divider("Training stats")
msg.text(f"Language: {config['nlp']['lang']}") msg.text(f"Language: {C['nlp']['lang']}")
msg.text(f"Training pipeline: {', '.join(pipeline)}") msg.text(f"Training pipeline: {', '.join(pipeline)}")
if resume_components: if resume_components:
msg.text(f"Components from other pipelines: {', '.join(resume_components)}") msg.text(f"Components from other pipelines: {', '.join(resume_components)}")

View File

@ -1,4 +1,3 @@
import warnings
from typing import Dict, Any, Optional, Iterable from typing import Dict, Any, Optional, Iterable
from pathlib import Path from pathlib import Path
@ -57,14 +56,17 @@ def debug_model_cli(
} }
config_overrides = parse_config_overrides(ctx.args) config_overrides = parse_config_overrides(ctx.args)
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config( raw_config = util.load_config(
config_path, overrides=config_overrides, interpolate=True config_path, overrides=config_overrides, interpolate=False
) )
allocator = config["training"]["gpu_allocator"] config = raw_config.iterpolate()
if use_gpu >= 0 and allocator: allocator = config["training"]["gpu_allocator"]
set_gpu_allocator(allocator) if use_gpu >= 0 and allocator:
nlp, config = util.load_model_from_config(config) set_gpu_allocator(allocator)
seed = config["training"]["seed"] with show_validation_error(config_path):
nlp = util.load_model_from_config(raw_config)
C = util.resolve_training_config(nlp.config)
seed = C["training"]["seed"]
if seed is not None: if seed is not None:
msg.info(f"Fixing random seed: {seed}") msg.info(f"Fixing random seed: {seed}")
fix_random_seed(seed) fix_random_seed(seed)
@ -75,7 +77,7 @@ def debug_model_cli(
exits=1, exits=1,
) )
model = pipe.model model = pipe.model
debug_model(config, nlp, model, print_settings=print_settings) debug_model(C, nlp, model, print_settings=print_settings)
def debug_model( def debug_model(
@ -108,7 +110,7 @@ def debug_model(
_set_output_dim(nO=7, model=model) _set_output_dim(nO=7, model=model)
nlp.begin_training(lambda: [Example.from_dict(x, {}) for x in X]) nlp.begin_training(lambda: [Example.from_dict(x, {}) for x in X])
msg.info("Initialized the model with dummy data.") msg.info("Initialized the model with dummy data.")
except: except Exception:
msg.fail( msg.fail(
"Could not initialize the model: you'll have to provide a valid train_corpus argument in the config file.", "Could not initialize the model: you'll have to provide a valid train_corpus argument in the config file.",
exits=1, exits=1,

View File

@ -88,10 +88,10 @@ def fill_config(
msg = Printer(no_print=no_print) msg = Printer(no_print=no_print)
with show_validation_error(hint_fill=False): with show_validation_error(hint_fill=False):
config = util.load_config(base_path) config = util.load_config(base_path)
nlp, _ = util.load_model_from_config(config, auto_fill=True, validate=False) nlp = util.load_model_from_config(config, auto_fill=True, validate=False)
# Load a second time with validation to be extra sure that the produced # Load a second time with validation to be extra sure that the produced
# config result is a valid config # config result is a valid config
nlp, _ = util.load_model_from_config(nlp.config) nlp = util.load_model_from_config(nlp.config)
filled = nlp.config filled = nlp.config
if pretraining: if pretraining:
validate_config_for_pretrain(filled, msg) validate_config_for_pretrain(filled, msg)
@ -169,7 +169,7 @@ def init_config(
msg.text(f"- {label}: {value}") msg.text(f"- {label}: {value}")
with show_validation_error(hint_fill=False): with show_validation_error(hint_fill=False):
config = util.load_config_from_str(base_template) config = util.load_config_from_str(base_template)
nlp, _ = util.load_model_from_config(config, auto_fill=True) nlp = util.load_model_from_config(config, auto_fill=True)
config = nlp.config config = nlp.config
if pretraining: if pretraining:
validate_config_for_pretrain(config, msg) validate_config_for_pretrain(config, msg)

View File

@ -69,17 +69,18 @@ def pretrain_cli(
msg.info(f"Loading config from: {config_path}") msg.info(f"Loading config from: {config_path}")
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config( raw_config = util.load_config(
config_path, overrides=config_overrides, interpolate=True config_path, overrides=config_overrides, interpolate=False
) )
config = raw_config.interpolate()
if not config.get("pretraining"): if not config.get("pretraining"):
# TODO: What's the solution here? How do we handle optional blocks? # TODO: What's the solution here? How do we handle optional blocks?
msg.fail("The [pretraining] block in your config is empty", exits=1) msg.fail("The [pretraining] block in your config is empty", exits=1)
if not output_dir.exists(): if not output_dir.exists():
output_dir.mkdir() output_dir.mkdir()
msg.good(f"Created output directory: {output_dir}") msg.good(f"Created output directory: {output_dir}")
# Save non-interpolated config
config.to_disk(output_dir / "config.cfg") raw_config.to_disk(output_dir / "config.cfg")
msg.good("Saved config file in the output directory") msg.good("Saved config file in the output directory")
pretrain( pretrain(
@ -103,14 +104,13 @@ def pretrain(
allocator = config["training"]["gpu_allocator"] allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator: if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator) set_gpu_allocator(allocator)
nlp = util.load_model_from_config(config)
nlp, config = util.load_model_from_config(config) C = util.resolve_training_config(nlp.config)
P_cfg = config["pretraining"] P_cfg = C["pretraining"]
corpus = dot_to_object(config, P_cfg["corpus"]) corpus = dot_to_object(C, P_cfg["corpus"])
batcher = P_cfg["batcher"] batcher = P_cfg["batcher"]
model = create_pretraining_model(nlp, config["pretraining"]) model = create_pretraining_model(nlp, C["pretraining"])
optimizer = config["pretraining"]["optimizer"] optimizer = C["pretraining"]["optimizer"]
# Load in pretrained weights to resume from # Load in pretrained weights to resume from
if resume_path is not None: if resume_path is not None:
_resume_model(model, resume_path, epoch_resume) _resume_model(model, resume_path, epoch_resume)

View File

@ -58,7 +58,7 @@ def train_cli(
else: else:
msg.info("Using CPU") msg.info("Using CPU")
config = util.load_config( config = util.load_config(
config_path, overrides=config_overrides, interpolate=True config_path, overrides=config_overrides, interpolate=False
) )
if output_path is None: if output_path is None:
nlp = init_pipeline(config) nlp = init_pipeline(config)
@ -75,24 +75,32 @@ def train_cli(
def train(nlp: Language, output_path: Optional[Path]=None) -> None: def train(nlp: Language, output_path: Optional[Path]=None) -> None:
# Create iterator, which yields out info after each optimization step. # Create iterator, which yields out info after each optimization step.
config = nlp.config config = nlp.config.interpolate()
T_cfg = config["training"] T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
score_weights = T_cfg["score_weights"] optimizer T["optimizer"]
optimizer = T_cfg["optimizer"] score_weights = T["score_weights"]
train_corpus = dot_to_object(config, T_cfg["train_corpus"]) # TODO: This might not be called corpora
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"]) corpora = registry.resolve(config["corpora"], schema=ConfigSchemaCorpora)
batcher = T_cfg["batcher"] train_corpus = dot_to_object({"corpora": corpora}, T["train_corpus"])
dev_corpus = dot_to_object({"corpora": corpora}, T["dev_corpus"])
batcher = T["batcher"]
train_logger = T["logger"]
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
# Components that shouldn't be updated during training
frozen_components = T["frozen_components"]
# Create iterator, which yields out info after each optimization step.
msg.info("Start training")
training_step_iterator = train_while_improving( training_step_iterator = train_while_improving(
nlp, nlp,
optimizer, optimizer,
create_train_batches(train_corpus(nlp), batcher, T_cfg["max_epochs"]), create_train_batches(train_corpus(nlp), batcher, T["max_epochs"]),
create_evaluation_callback(nlp, dev_corpus, score_weights), create_evaluation_callback(nlp, dev_corpus, score_weights),
dropout=T_cfg["dropout"], dropout=T["dropout"],
accumulate_gradient=T_cfg["accumulate_gradient"], accumulate_gradient=T["accumulate_gradient"],
patience=T_cfg["patience"], patience=T["patience"],
max_steps=T_cfg["max_steps"], max_steps=T["max_steps"],
eval_frequency=T_cfg["eval_frequency"], eval_frequency=T["eval_frequency"],
raw_text=None, raw_text=None,
exclude=frozen_components, exclude=frozen_components,
) )
@ -101,7 +109,7 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
print_row, finalize_logger = train_logger(nlp) print_row, finalize_logger = train_logger(nlp)
try: try:
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False) progress = tqdm.tqdm(total=T["eval_frequency"], leave=False)
progress.set_description(f"Epoch 1") progress.set_description(f"Epoch 1")
for batch, info, is_best_checkpoint in training_step_iterator: for batch, info, is_best_checkpoint in training_step_iterator:
progress.update(1) progress.update(1)
@ -110,11 +118,11 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
print_row(info) print_row(info)
if is_best_checkpoint and output_path is not None: if is_best_checkpoint and output_path is not None:
with nlp.select_pipes(disable=frozen_components): with nlp.select_pipes(disable=frozen_components):
update_meta(T_cfg, nlp, info) update_meta(T, nlp, info)
with nlp.use_params(optimizer.averages): with nlp.use_params(optimizer.averages):
nlp = before_to_disk(nlp) nlp = before_to_disk(nlp)
nlp.to_disk(output_path / "model-best") nlp.to_disk(output_path / "model-best")
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False) progress = tqdm.tqdm(total=T["eval_frequency"], leave=False)
progress.set_description(f"Epoch {info['epoch']}") progress.set_description(f"Epoch {info['epoch']}")
except Exception as e: except Exception as e:
finalize_logger() finalize_logger()

View File

@ -12,8 +12,10 @@ from .tag_bigram_map import TAG_BIGRAM_MAP
from ...compat import copy_reg from ...compat import copy_reg
from ...errors import Errors from ...errors import Errors
from ...language import Language from ...language import Language
from ...scorer import Scorer
from ...symbols import POS from ...symbols import POS
from ...tokens import Doc from ...tokens import Doc
from ...training import validate_examples
from ...util import DummyTokenizer, registry from ...util import DummyTokenizer, registry
from ... import util from ... import util
@ -130,6 +132,10 @@ class JapaneseTokenizer(DummyTokenizer):
) )
return sub_tokens_list return sub_tokens_list
def score(self, examples):
validate_examples(examples, "JapaneseTokenizer.score")
return Scorer.score_tokenization(examples)
def _get_config(self) -> Dict[str, Any]: def _get_config(self) -> Dict[str, Any]:
return {"split_mode": self.split_mode} return {"split_mode": self.split_mode}

View File

@ -7,7 +7,9 @@ from .lex_attrs import LEX_ATTRS
from ...language import Language from ...language import Language
from ...tokens import Doc from ...tokens import Doc
from ...compat import copy_reg from ...compat import copy_reg
from ...scorer import Scorer
from ...symbols import POS from ...symbols import POS
from ...training import validate_examples
from ...util import DummyTokenizer, registry from ...util import DummyTokenizer, registry
@ -62,6 +64,10 @@ class KoreanTokenizer(DummyTokenizer):
lemma = surface lemma = surface
yield {"surface": surface, "lemma": lemma, "tag": tag} yield {"surface": surface, "lemma": lemma, "tag": tag}
def score(self, examples):
validate_examples(examples, "KoreanTokenizer.score")
return Scorer.score_tokenization(examples)
class KoreanDefaults(Language.Defaults): class KoreanDefaults(Language.Defaults):
config = Config().from_str(DEFAULT_CONFIG) config = Config().from_str(DEFAULT_CONFIG)

View File

@ -8,7 +8,9 @@ from thinc.api import Config
from ...errors import Warnings, Errors from ...errors import Warnings, Errors
from ...language import Language from ...language import Language
from ...scorer import Scorer
from ...tokens import Doc from ...tokens import Doc
from ...training import validate_examples
from ...util import DummyTokenizer, registry from ...util import DummyTokenizer, registry
from .lex_attrs import LEX_ATTRS from .lex_attrs import LEX_ATTRS
from .stop_words import STOP_WORDS from .stop_words import STOP_WORDS
@ -136,6 +138,10 @@ class ChineseTokenizer(DummyTokenizer):
warn_msg = Warnings.W104.format(target="pkuseg", current=self.segmenter) warn_msg = Warnings.W104.format(target="pkuseg", current=self.segmenter)
warnings.warn(warn_msg) warnings.warn(warn_msg)
def score(self, examples):
validate_examples(examples, "ChineseTokenizer.score")
return Scorer.score_tokenization(examples)
def _get_config(self) -> Dict[str, Any]: def _get_config(self) -> Dict[str, Any]:
return { return {
"segmenter": self.segmenter, "segmenter": self.segmenter,

View File

@ -27,7 +27,7 @@ from .lang.punctuation import TOKENIZER_INFIXES
from .tokens import Doc from .tokens import Doc
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .errors import Errors, Warnings from .errors import Errors, Warnings
from .schemas import ConfigSchema from .schemas import ConfigSchema, ConfigSchemaNlp
from .git_info import GIT_VERSION from .git_info import GIT_VERSION
from . import util from . import util
from . import about from . import about
@ -166,11 +166,10 @@ class Language:
self._components = [] self._components = []
self._disabled = set() self._disabled = set()
self.max_length = max_length self.max_length = max_length
self.resolved = {}
# Create the default tokenizer from the default config # Create the default tokenizer from the default config
if not create_tokenizer: if not create_tokenizer:
tokenizer_cfg = {"tokenizer": self._config["nlp"]["tokenizer"]} tokenizer_cfg = {"tokenizer": self._config["nlp"]["tokenizer"]}
create_tokenizer = registry.make_from_config(tokenizer_cfg)["tokenizer"] create_tokenizer = registry.resolve(tokenizer_cfg)["tokenizer"]
self.tokenizer = create_tokenizer(self) self.tokenizer = create_tokenizer(self)
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
@ -467,7 +466,7 @@ class Language:
if "nlp" not in arg_names or "name" not in arg_names: if "nlp" not in arg_names or "name" not in arg_names:
raise ValueError(Errors.E964.format(name=name)) raise ValueError(Errors.E964.format(name=name))
# Officially register the factory so we can later call # Officially register the factory so we can later call
# registry.make_from_config and refer to it in the config as # registry.resolve and refer to it in the config as
# @factories = "spacy.Language.xyz". We use the class name here so # @factories = "spacy.Language.xyz". We use the class name here so
# different classes can have different factories. # different classes can have different factories.
registry.factories.register(internal_name, func=factory_func) registry.factories.register(internal_name, func=factory_func)
@ -650,8 +649,9 @@ class Language:
cfg = {factory_name: config} cfg = {factory_name: config}
# We're calling the internal _fill here to avoid constructing the # We're calling the internal _fill here to avoid constructing the
# registered functions twice # registered functions twice
resolved, filled = registry.resolve(cfg, validate=validate) resolved = registry.resolve(cfg, validate=validate)
filled = Config(filled[factory_name]) filled = registry.fill({"cfg": cfg[factory_name]}, validate=validate)["cfg"]
filled = Config(filled)
filled["factory"] = factory_name filled["factory"] = factory_name
filled.pop("@factories", None) filled.pop("@factories", None)
# Remove the extra values we added because we don't want to keep passing # Remove the extra values we added because we don't want to keep passing
@ -1518,15 +1518,19 @@ class Language:
config = util.copy_config(config) config = util.copy_config(config)
orig_pipeline = config.pop("components", {}) orig_pipeline = config.pop("components", {})
config["components"] = {} config["components"] = {}
resolved, filled = registry.resolve( if auto_fill:
config, validate=validate, schema=ConfigSchema filled = registry.fill(config, validate=validate, schema=ConfigSchema)
) else:
filled = config
filled["components"] = orig_pipeline filled["components"] = orig_pipeline
config["components"] = orig_pipeline config["components"] = orig_pipeline
create_tokenizer = resolved["nlp"]["tokenizer"] resolved_nlp = registry.resolve(
before_creation = resolved["nlp"]["before_creation"] filled["nlp"], validate=validate, schema=ConfigSchemaNlp
after_creation = resolved["nlp"]["after_creation"] )
after_pipeline_creation = resolved["nlp"]["after_pipeline_creation"] create_tokenizer = resolved_nlp["tokenizer"]
before_creation = resolved_nlp["before_creation"]
after_creation = resolved_nlp["after_creation"]
after_pipeline_creation = resolved_nlp["after_pipeline_creation"]
lang_cls = cls lang_cls = cls
if before_creation is not None: if before_creation is not None:
lang_cls = before_creation(cls) lang_cls = before_creation(cls)
@ -1587,7 +1591,6 @@ class Language:
disabled_pipes = [*config["nlp"]["disabled"], *disable] disabled_pipes = [*config["nlp"]["disabled"], *disable]
nlp._disabled = set(p for p in disabled_pipes if p not in exclude) nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
nlp.config = filled if auto_fill else config nlp.config = filled if auto_fill else config
nlp.resolved = resolved
if after_pipeline_creation is not None: if after_pipeline_creation is not None:
nlp = after_pipeline_creation(nlp) nlp = after_pipeline_creation(nlp)
if not isinstance(nlp, cls): if not isinstance(nlp, cls):

View File

@ -29,7 +29,8 @@ cdef class Morphology:
FEATURE_SEP = "|" FEATURE_SEP = "|"
FIELD_SEP = "=" FIELD_SEP = "="
VALUE_SEP = "," VALUE_SEP = ","
EMPTY_MORPH = "_" # not an empty string so that the PreshMap key is not 0 # not an empty string so that the PreshMap key is not 0
EMPTY_MORPH = symbols.NAMES[symbols._]
def __init__(self, StringStore strings): def __init__(self, StringStore strings):
self.mem = Pool() self.mem = Pool()

View File

@ -4,6 +4,7 @@ from enum import Enum
from pydantic import BaseModel, Field, ValidationError, validator from pydantic import BaseModel, Field, ValidationError, validator
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
from pydantic import root_validator from pydantic import root_validator
from thinc.config import Promise
from collections import defaultdict from collections import defaultdict
from thinc.api import Optimizer from thinc.api import Optimizer
@ -16,10 +17,12 @@ if TYPE_CHECKING:
from .training import Example # noqa: F401 from .training import Example # noqa: F401
# fmt: off
ItemT = TypeVar("ItemT") ItemT = TypeVar("ItemT")
Batcher = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]] Batcher = Union[Callable[[Iterable[ItemT]], Iterable[List[ItemT]]], Promise]
Reader = Callable[["Language", str], Iterable["Example"]] Reader = Union[Callable[["Language", str], Iterable["Example"]], Promise]
Logger = Callable[["Language"], Tuple[Callable[[Dict[str, Any]], None], Callable]] Logger = Union[Callable[["Language"], Tuple[Callable[[Dict[str, Any]], None], Callable]], Promise]
# fmt: on
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]: def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
@ -292,6 +295,16 @@ class ConfigSchema(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
class TrainingSchema(BaseModel):
training: ConfigSchemaTraining
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
corpora: Dict[str, Reader]
class Config:
extra = "allow"
arbitrary_types_allowed = True
# Project config Schema # Project config Schema

View File

@ -466,3 +466,4 @@ cdef enum symbol_t:
ENT_ID ENT_ID
IDX IDX
_

View File

@ -465,6 +465,7 @@ IDS = {
"acl": acl, "acl": acl,
"LAW": LAW, "LAW": LAW,
"MORPH": MORPH, "MORPH": MORPH,
"_": _,
} }

View File

@ -24,7 +24,7 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_NER_MODEL} cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
ner = EntityRecognizer(en_vocab, model, **config) ner = EntityRecognizer(en_vocab, model, **config)
ner.begin_training(lambda: [_ner_example(ner)]) ner.begin_training(lambda: [_ner_example(ner)])
ner(doc) ner(doc)
@ -46,7 +46,7 @@ def test_ents_reset(en_vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_NER_MODEL} cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
ner = EntityRecognizer(en_vocab, model, **config) ner = EntityRecognizer(en_vocab, model, **config)
ner.begin_training(lambda: [_ner_example(ner)]) ner.begin_training(lambda: [_ner_example(ner)])
ner(doc) ner(doc)

View File

@ -23,7 +23,7 @@ def parser(vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(vocab, model, **config) parser = DependencyParser(vocab, model, **config)
return parser return parser
@ -82,7 +82,7 @@ def test_add_label_deserializes_correctly():
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_NER_MODEL} cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
ner1 = EntityRecognizer(Vocab(), model, **config) ner1 = EntityRecognizer(Vocab(), model, **config)
ner1.add_label("C") ner1.add_label("C")
ner1.add_label("B") ner1.add_label("B")
@ -111,7 +111,7 @@ def test_add_label_get_label(pipe_cls, n_moves, model_config):
splitting the move names. splitting the move names.
""" """
labels = ["A", "B", "C"] labels = ["A", "B", "C"]
model = registry.make_from_config({"model": model_config}, validate=True)["model"] model = registry.resolve({"model": model_config}, validate=True)["model"]
config = { config = {
"learn_tokens": False, "learn_tokens": False,
"min_action_freq": 30, "min_action_freq": 30,

View File

@ -127,7 +127,7 @@ def test_get_oracle_actions():
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(doc.vocab, model, **config) parser = DependencyParser(doc.vocab, model, **config)
parser.moves.add_action(0, "") parser.moves.add_action(0, "")
parser.moves.add_action(1, "") parser.moves.add_action(1, "")

View File

@ -25,7 +25,7 @@ def arc_eager(vocab):
@pytest.fixture @pytest.fixture
def tok2vec(): def tok2vec():
cfg = {"model": DEFAULT_TOK2VEC_MODEL} cfg = {"model": DEFAULT_TOK2VEC_MODEL}
tok2vec = registry.make_from_config(cfg, validate=True)["model"] tok2vec = registry.resolve(cfg, validate=True)["model"]
tok2vec.initialize() tok2vec.initialize()
return tok2vec return tok2vec
@ -38,14 +38,14 @@ def parser(vocab, arc_eager):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
return Parser(vocab, model, moves=arc_eager, **config) return Parser(vocab, model, moves=arc_eager, **config)
@pytest.fixture @pytest.fixture
def model(arc_eager, tok2vec, vocab): def model(arc_eager, tok2vec, vocab):
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
model.attrs["resize_output"](model, arc_eager.n_moves) model.attrs["resize_output"](model, arc_eager.n_moves)
model.initialize() model.initialize()
return model return model
@ -72,7 +72,7 @@ def test_build_model(parser, vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model
assert parser.model is not None assert parser.model is not None

View File

@ -28,7 +28,7 @@ def parser(vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(vocab, model, **config) parser = DependencyParser(vocab, model, **config)
parser.cfg["token_vector_width"] = 4 parser.cfg["token_vector_width"] = 4
parser.cfg["hidden_width"] = 32 parser.cfg["hidden_width"] = 32

View File

@ -139,7 +139,7 @@ TRAIN_DATA = [
def test_tok2vec_listener(): def test_tok2vec_listener():
orig_config = Config().from_str(cfg_string) orig_config = Config().from_str(cfg_string)
nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True) nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
assert nlp.pipe_names == ["tok2vec", "tagger"] assert nlp.pipe_names == ["tok2vec", "tagger"]
tagger = nlp.get_pipe("tagger") tagger = nlp.get_pipe("tagger")
tok2vec = nlp.get_pipe("tok2vec") tok2vec = nlp.get_pipe("tok2vec")
@ -173,7 +173,7 @@ def test_tok2vec_listener():
def test_tok2vec_listener_callback(): def test_tok2vec_listener_callback():
orig_config = Config().from_str(cfg_string) orig_config = Config().from_str(cfg_string)
nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True) nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
assert nlp.pipe_names == ["tok2vec", "tagger"] assert nlp.pipe_names == ["tok2vec", "tagger"]
tagger = nlp.get_pipe("tagger") tagger = nlp.get_pipe("tagger")
tok2vec = nlp.get_pipe("tok2vec") tok2vec = nlp.get_pipe("tok2vec")

View File

@ -195,7 +195,7 @@ def test_issue3345():
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_NER_MODEL} cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
ner = EntityRecognizer(doc.vocab, model, **config) ner = EntityRecognizer(doc.vocab, model, **config)
# Add the OUT action. I wouldn't have thought this would be necessary... # Add the OUT action. I wouldn't have thought this would be necessary...
ner.moves.add_action(5, "") ner.moves.add_action(5, "")

View File

@ -264,9 +264,7 @@ def test_issue3830_no_subtok():
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)[ model = registry.resolve({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
"model"
]
parser = DependencyParser(Vocab(), model, **config) parser = DependencyParser(Vocab(), model, **config)
parser.add_label("nsubj") parser.add_label("nsubj")
assert "subtok" not in parser.labels assert "subtok" not in parser.labels
@ -281,9 +279,7 @@ def test_issue3830_with_subtok():
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)[ model = registry.resolve({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
"model"
]
parser = DependencyParser(Vocab(), model, **config) parser = DependencyParser(Vocab(), model, **config)
parser.add_label("nsubj") parser.add_label("nsubj")
assert "subtok" not in parser.labels assert "subtok" not in parser.labels

View File

@ -108,8 +108,8 @@ def my_parser():
def test_create_nlp_from_config(): def test_create_nlp_from_config():
config = Config().from_str(nlp_config_string) config = Config().from_str(nlp_config_string)
with pytest.raises(ConfigValidationError): with pytest.raises(ConfigValidationError):
nlp, _ = load_model_from_config(config, auto_fill=False) load_model_from_config(config, auto_fill=False)
nlp, resolved = load_model_from_config(config, auto_fill=True) nlp = load_model_from_config(config, auto_fill=True)
assert nlp.config["training"]["batcher"]["size"] == 666 assert nlp.config["training"]["batcher"]["size"] == 666
assert len(nlp.config["training"]) > 1 assert len(nlp.config["training"]) > 1
assert nlp.pipe_names == ["tok2vec", "tagger"] assert nlp.pipe_names == ["tok2vec", "tagger"]
@ -136,7 +136,7 @@ def test_create_nlp_from_config_multiple_instances():
"tagger2": config["components"]["tagger"], "tagger2": config["components"]["tagger"],
} }
config["nlp"]["pipeline"] = list(config["components"].keys()) config["nlp"]["pipeline"] = list(config["components"].keys())
nlp, _ = load_model_from_config(config, auto_fill=True) nlp = load_model_from_config(config, auto_fill=True)
assert nlp.pipe_names == ["t2v", "tagger1", "tagger2"] assert nlp.pipe_names == ["t2v", "tagger1", "tagger2"]
assert nlp.get_pipe_meta("t2v").factory == "tok2vec" assert nlp.get_pipe_meta("t2v").factory == "tok2vec"
assert nlp.get_pipe_meta("tagger1").factory == "tagger" assert nlp.get_pipe_meta("tagger1").factory == "tagger"
@ -150,7 +150,7 @@ def test_create_nlp_from_config_multiple_instances():
def test_serialize_nlp(): def test_serialize_nlp():
""" Create a custom nlp pipeline from config and ensure it serializes it correctly """ """ Create a custom nlp pipeline from config and ensure it serializes it correctly """
nlp_config = Config().from_str(nlp_config_string) nlp_config = Config().from_str(nlp_config_string)
nlp, _ = load_model_from_config(nlp_config, auto_fill=True) nlp = load_model_from_config(nlp_config, auto_fill=True)
nlp.get_pipe("tagger").add_label("A") nlp.get_pipe("tagger").add_label("A")
nlp.begin_training() nlp.begin_training()
assert "tok2vec" in nlp.pipe_names assert "tok2vec" in nlp.pipe_names
@ -209,7 +209,7 @@ def test_config_nlp_roundtrip():
nlp = English() nlp = English()
nlp.add_pipe("entity_ruler") nlp.add_pipe("entity_ruler")
nlp.add_pipe("ner") nlp.add_pipe("ner")
new_nlp, new_config = load_model_from_config(nlp.config, auto_fill=False) new_nlp = load_model_from_config(nlp.config, auto_fill=False)
assert new_nlp.config == nlp.config assert new_nlp.config == nlp.config
assert new_nlp.pipe_names == nlp.pipe_names assert new_nlp.pipe_names == nlp.pipe_names
assert new_nlp._pipe_configs == nlp._pipe_configs assert new_nlp._pipe_configs == nlp._pipe_configs
@ -280,12 +280,12 @@ def test_config_overrides():
overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]} overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]}
# load_model from config with overrides passed directly to Config # load_model from config with overrides passed directly to Config
config = Config().from_str(nlp_config_string, overrides=overrides_dot) config = Config().from_str(nlp_config_string, overrides=overrides_dot)
nlp, _ = load_model_from_config(config, auto_fill=True) nlp = load_model_from_config(config, auto_fill=True)
assert isinstance(nlp, German) assert isinstance(nlp, German)
assert nlp.pipe_names == ["tagger"] assert nlp.pipe_names == ["tagger"]
# Serialized roundtrip with config passed in # Serialized roundtrip with config passed in
base_config = Config().from_str(nlp_config_string) base_config = Config().from_str(nlp_config_string)
base_nlp, _ = load_model_from_config(base_config, auto_fill=True) base_nlp = load_model_from_config(base_config, auto_fill=True)
assert isinstance(base_nlp, English) assert isinstance(base_nlp, English)
assert base_nlp.pipe_names == ["tok2vec", "tagger"] assert base_nlp.pipe_names == ["tok2vec", "tagger"]
with make_tempdir() as d: with make_tempdir() as d:
@ -328,7 +328,7 @@ def test_config_optional_sections():
config = Config().from_str(nlp_config_string) config = Config().from_str(nlp_config_string)
config = DEFAULT_CONFIG.merge(config) config = DEFAULT_CONFIG.merge(config)
assert "pretraining" not in config assert "pretraining" not in config
filled = registry.fill_config(config, schema=ConfigSchema, validate=False) filled = registry.fill(config, schema=ConfigSchema, validate=False)
# Make sure that optional "pretraining" block doesn't default to None, # Make sure that optional "pretraining" block doesn't default to None,
# which would (rightly) cause error because it'd result in a top-level # which would (rightly) cause error because it'd result in a top-level
# key that's not a section (dict). Note that the following roundtrip is # key that's not a section (dict). Note that the following roundtrip is
@ -341,7 +341,7 @@ def test_config_auto_fill_extra_fields():
config = Config({"nlp": {"lang": "en"}, "training": {}}) config = Config({"nlp": {"lang": "en"}, "training": {}})
assert load_model_from_config(config, auto_fill=True) assert load_model_from_config(config, auto_fill=True)
config = Config({"nlp": {"lang": "en"}, "training": {"extra": "hello"}}) config = Config({"nlp": {"lang": "en"}, "training": {"extra": "hello"}})
nlp, _ = load_model_from_config(config, auto_fill=True, validate=False) nlp = load_model_from_config(config, auto_fill=True, validate=False)
assert "extra" not in nlp.config["training"] assert "extra" not in nlp.config["training"]
# Make sure the config generated is valid # Make sure the config generated is valid
load_model_from_config(nlp.config) load_model_from_config(nlp.config)

View File

@ -23,7 +23,7 @@ def parser(en_vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(en_vocab, model, **config) parser = DependencyParser(en_vocab, model, **config)
parser.add_label("nsubj") parser.add_label("nsubj")
return parser return parser
@ -37,7 +37,7 @@ def blank_parser(en_vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(en_vocab, model, **config) parser = DependencyParser(en_vocab, model, **config)
return parser return parser
@ -45,7 +45,7 @@ def blank_parser(en_vocab):
@pytest.fixture @pytest.fixture
def taggers(en_vocab): def taggers(en_vocab):
cfg = {"model": DEFAULT_TAGGER_MODEL} cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
tagger1 = Tagger(en_vocab, model) tagger1 = Tagger(en_vocab, model)
tagger2 = Tagger(en_vocab, model) tagger2 = Tagger(en_vocab, model)
return tagger1, tagger2 return tagger1, tagger2
@ -59,7 +59,7 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = Parser(en_vocab, model, **config) parser = Parser(en_vocab, model, **config)
new_parser = Parser(en_vocab, model, **config) new_parser = Parser(en_vocab, model, **config)
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"])) new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
@ -77,7 +77,7 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = Parser(en_vocab, model, **config) parser = Parser(en_vocab, model, **config)
with make_tempdir() as d: with make_tempdir() as d:
file_path = d / "parser" file_path = d / "parser"
@ -111,7 +111,7 @@ def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers):
tagger1 = tagger1.from_bytes(tagger1_b) tagger1 = tagger1.from_bytes(tagger1_b)
assert tagger1.to_bytes() == tagger1_b assert tagger1.to_bytes() == tagger1_b
cfg = {"model": DEFAULT_TAGGER_MODEL} cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b) new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b)
new_tagger1_b = new_tagger1.to_bytes() new_tagger1_b = new_tagger1.to_bytes()
assert len(new_tagger1_b) == len(tagger1_b) assert len(new_tagger1_b) == len(tagger1_b)
@ -126,7 +126,7 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
tagger1.to_disk(file_path1) tagger1.to_disk(file_path1)
tagger2.to_disk(file_path2) tagger2.to_disk(file_path2)
cfg = {"model": DEFAULT_TAGGER_MODEL} cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
tagger1_d = Tagger(en_vocab, model).from_disk(file_path1) tagger1_d = Tagger(en_vocab, model).from_disk(file_path1)
tagger2_d = Tagger(en_vocab, model).from_disk(file_path2) tagger2_d = Tagger(en_vocab, model).from_disk(file_path2)
assert tagger1_d.to_bytes() == tagger2_d.to_bytes() assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
@ -135,7 +135,7 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
def test_serialize_textcat_empty(en_vocab): def test_serialize_textcat_empty(en_vocab):
# See issue #1105 # See issue #1105
cfg = {"model": DEFAULT_TEXTCAT_MODEL} cfg = {"model": DEFAULT_TEXTCAT_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
textcat = TextCategorizer( textcat = TextCategorizer(
en_vocab, en_vocab,
model, model,
@ -149,7 +149,7 @@ def test_serialize_textcat_empty(en_vocab):
@pytest.mark.parametrize("Parser", test_parsers) @pytest.mark.parametrize("Parser", test_parsers)
def test_serialize_pipe_exclude(en_vocab, Parser): def test_serialize_pipe_exclude(en_vocab, Parser):
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
config = { config = {
"learn_tokens": False, "learn_tokens": False,
"min_action_freq": 0, "min_action_freq": 0,
@ -176,7 +176,7 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
def test_serialize_sentencerecognizer(en_vocab): def test_serialize_sentencerecognizer(en_vocab):
cfg = {"model": DEFAULT_SENTER_MODEL} cfg = {"model": DEFAULT_SENTER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
sr = SentenceRecognizer(en_vocab, model) sr = SentenceRecognizer(en_vocab, model)
sr_b = sr.to_bytes() sr_b = sr.to_bytes()
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b) sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)

View File

@ -7,6 +7,7 @@ from spacy import util
from spacy import prefer_gpu, require_gpu from spacy import prefer_gpu, require_gpu
from spacy.ml._precomputable_affine import PrecomputableAffine from spacy.ml._precomputable_affine import PrecomputableAffine
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
from thinc.api import Optimizer
@pytest.fixture @pytest.fixture
@ -157,3 +158,16 @@ def test_dot_to_dict(dot_notation, expected):
result = util.dot_to_dict(dot_notation) result = util.dot_to_dict(dot_notation)
assert result == expected assert result == expected
assert util.dict_to_dot(result) == dot_notation assert util.dict_to_dot(result) == dot_notation
def test_resolve_training_config():
config = {
"nlp": {"lang": "en", "disabled": []},
"training": {"dropout": 0.1, "optimizer": {"@optimizers": "Adam.v1"}},
"corpora": {},
}
resolved = util.resolve_training_config(config)
assert resolved["training"]["dropout"] == 0.1
assert isinstance(resolved["training"]["optimizer"], Optimizer)
assert resolved["corpora"] == {}
assert "nlp" not in resolved

View File

@ -82,10 +82,10 @@ def test_util_dot_section():
no_output_layer = false no_output_layer = false
""" """
nlp_config = Config().from_str(cfg_string) nlp_config = Config().from_str(cfg_string)
en_nlp, en_config = util.load_model_from_config(nlp_config, auto_fill=True) en_nlp = util.load_model_from_config(nlp_config, auto_fill=True)
default_config = Config().from_disk(DEFAULT_CONFIG_PATH) default_config = Config().from_disk(DEFAULT_CONFIG_PATH)
default_config["nlp"]["lang"] = "nl" default_config["nlp"]["lang"] = "nl"
nl_nlp, nl_config = util.load_model_from_config(default_config, auto_fill=True) nl_nlp = util.load_model_from_config(default_config, auto_fill=True)
# Test that creation went OK # Test that creation went OK
assert isinstance(en_nlp, English) assert isinstance(en_nlp, English)
assert isinstance(nl_nlp, Dutch) assert isinstance(nl_nlp, Dutch)
@ -94,14 +94,15 @@ def test_util_dot_section():
# not exclusive_classes # not exclusive_classes
assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] is False assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] is False
# Test that default values got overwritten # Test that default values got overwritten
assert en_config["nlp"]["pipeline"] == ["textcat"] assert en_nlp.config["nlp"]["pipeline"] == ["textcat"]
assert nl_config["nlp"]["pipeline"] == [] # default value [] assert nl_nlp.config["nlp"]["pipeline"] == [] # default value []
# Test proper functioning of 'dot_to_object' # Test proper functioning of 'dot_to_object'
with pytest.raises(KeyError): with pytest.raises(KeyError):
dot_to_object(en_config, "nlp.pipeline.tagger") dot_to_object(en_nlp.config, "nlp.pipeline.tagger")
with pytest.raises(KeyError): with pytest.raises(KeyError):
dot_to_object(en_config, "nlp.unknownattribute") dot_to_object(en_nlp.config, "nlp.unknownattribute")
assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer) resolved = util.resolve_training_config(nl_nlp.config)
assert isinstance(dot_to_object(resolved, "training.optimizer"), Optimizer)
def test_simple_frozen_list(): def test_simple_frozen_list():

View File

@ -3,6 +3,7 @@ import pytest
from thinc.api import Config from thinc.api import Config
from spacy import Language from spacy import Language
from spacy.util import load_model_from_config, registry, dot_to_object from spacy.util import load_model_from_config, registry, dot_to_object
from spacy.util import resolve_training_config
from spacy.training import Example from spacy.training import Example
@ -37,8 +38,8 @@ def test_readers():
return {"train": reader, "dev": reader, "extra": reader, "something": reader} return {"train": reader, "dev": reader, "extra": reader, "something": reader}
config = Config().from_str(config_string) config = Config().from_str(config_string)
nlp, resolved = load_model_from_config(config, auto_fill=True) nlp = load_model_from_config(config, auto_fill=True)
resolved = resolve_training_config(nlp.config)
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"]) train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
assert isinstance(train_corpus, Callable) assert isinstance(train_corpus, Callable)
optimizer = resolved["training"]["optimizer"] optimizer = resolved["training"]["optimizer"]
@ -87,8 +88,8 @@ def test_cat_readers(reader, additional_config):
config = Config().from_str(nlp_config_string) config = Config().from_str(nlp_config_string)
config["corpora"]["@readers"] = reader config["corpora"]["@readers"] = reader
config["corpora"].update(additional_config) config["corpora"].update(additional_config)
nlp, resolved = load_model_from_config(config, auto_fill=True) nlp = load_model_from_config(config, auto_fill=True)
resolved = resolve_training_config(nlp.config)
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"]) train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
optimizer = resolved["training"]["optimizer"] optimizer = resolved["training"]["optimizer"]
# simulate a training loop # simulate a training loop

View File

@ -86,7 +86,7 @@ class registry(thinc.registry):
# spacy_factories entry point. This registry only exists so we can easily # spacy_factories entry point. This registry only exists so we can easily
# load them via the entry points. The "true" factories are added via the # load them via the entry points. The "true" factories are added via the
# Language.factory decorator (in the spaCy code base and user code) and those # Language.factory decorator (in the spaCy code base and user code) and those
# are the factories used to initialize components via registry.make_from_config. # are the factories used to initialize components via registry.resolve.
_entry_point_factories = catalogue.create("spacy", "factories", entry_points=True) _entry_point_factories = catalogue.create("spacy", "factories", entry_points=True)
factories = catalogue.create("spacy", "internal_factories") factories = catalogue.create("spacy", "internal_factories")
# This is mostly used to get a list of all installed models in the current # This is mostly used to get a list of all installed models in the current
@ -351,9 +351,7 @@ def load_model_from_path(
meta = get_model_meta(model_path) meta = get_model_meta(model_path)
config_path = model_path / "config.cfg" config_path = model_path / "config.cfg"
config = load_config(config_path, overrides=dict_to_dot(config)) config = load_config(config_path, overrides=dict_to_dot(config))
nlp, _ = load_model_from_config( nlp = load_model_from_config(config, vocab=vocab, disable=disable, exclude=exclude)
config, vocab=vocab, disable=disable, exclude=exclude
)
return nlp.from_disk(model_path, exclude=exclude) return nlp.from_disk(model_path, exclude=exclude)
@ -365,7 +363,7 @@ def load_model_from_config(
exclude: Iterable[str] = SimpleFrozenList(), exclude: Iterable[str] = SimpleFrozenList(),
auto_fill: bool = False, auto_fill: bool = False,
validate: bool = True, validate: bool = True,
) -> Tuple["Language", Config]: ) -> "Language":
"""Create an nlp object from a config. Expects the full config file including """Create an nlp object from a config. Expects the full config file including
a section "nlp" containing the settings for the nlp object. a section "nlp" containing the settings for the nlp object.
@ -398,7 +396,30 @@ def load_model_from_config(
auto_fill=auto_fill, auto_fill=auto_fill,
validate=validate, validate=validate,
) )
return nlp, nlp.resolved return nlp
def resolve_training_config(
config: Config,
exclude: Iterable[str] = ("nlp", "components"),
validate: bool = True,
) -> Dict[str, Any]:
"""Resolve the config sections relevant for trainig and create all objects.
Mostly used in the CLI to separate training config (not resolved by default
because not runtime-relevant an nlp object should load fine even if it's
[training] block refers to functions that are not available etc.).
config (Config): The config to resolve.
exclude (Iterable[str]): The config blocks to exclude. Those blocks won't
be available in the final resolved config.
validate (bool): Whether to validate the config.
RETURNS (Dict[str, Any]): The resolved config.
"""
config = config.copy()
for key in exclude:
if key in config:
config.pop(key)
return registry.resolve(config, validate=validate)
def load_model_from_init_py( def load_model_from_init_py(