Merge pull request #6156 from explosion/feature/new-thinc-config-resolution

This commit is contained in:
Ines Montani 2020-09-27 23:57:52 +02:00 committed by GitHub
commit cad4dbddaa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
25 changed files with 178 additions and 123 deletions

View File

@ -6,7 +6,7 @@ requires = [
"cymem>=2.0.2,<2.1.0", "cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0", "preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0", "murmurhash>=0.28.0,<1.1.0",
"thinc>=8.0.0a36,<8.0.0a40", "thinc>=8.0.0a41,<8.0.0a50",
"blis>=0.4.0,<0.5.0", "blis>=0.4.0,<0.5.0",
"pytokenizations", "pytokenizations",
"pathy" "pathy"

View File

@ -1,7 +1,7 @@
# Our libraries # Our libraries
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=8.0.0a36,<8.0.0a40 thinc>=8.0.0a41,<8.0.0a50
blis>=0.4.0,<0.5.0 blis>=0.4.0,<0.5.0
ml_datasets==0.2.0a0 ml_datasets==0.2.0a0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0

View File

@ -34,13 +34,13 @@ setup_requires =
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
thinc>=8.0.0a36,<8.0.0a40 thinc>=8.0.0a41,<8.0.0a50
install_requires = install_requires =
# Our libraries # Our libraries
murmurhash>=0.28.0,<1.1.0 murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0 cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0 preshed>=3.0.2,<3.1.0
thinc>=8.0.0a36,<8.0.0a40 thinc>=8.0.0a41,<8.0.0a50
blis>=0.4.0,<0.5.0 blis>=0.4.0,<0.5.0
wasabi>=0.8.0,<1.1.0 wasabi>=0.8.0,<1.1.0
srsly>=2.1.0,<3.0.0 srsly>=2.1.0,<3.0.0

View File

@ -51,9 +51,10 @@ def debug_config(
msg.divider("Config validation") msg.divider("Config validation")
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config(config_path, overrides=overrides) config = util.load_config(config_path, overrides=overrides)
nlp, resolved = util.load_model_from_config(config) nlp = util.load_model_from_config(config)
# Use the resolved config here in case user has one function returning # Use the resolved config here in case user has one function returning
# a dict of corpora etc. # a dict of corpora etc.
resolved = util.resolve_training_config(nlp.config)
check_section_refs(resolved, ["training.dev_corpus", "training.train_corpus"]) check_section_refs(resolved, ["training.dev_corpus", "training.train_corpus"])
msg.good("Config is valid") msg.good("Config is valid")
if show_vars: if show_vars:

View File

@ -93,18 +93,19 @@ def debug_data(
msg.fail("Config file not found", config_path, exists=1) msg.fail("Config file not found", config_path, exists=1)
with show_validation_error(config_path): with show_validation_error(config_path):
cfg = util.load_config(config_path, overrides=config_overrides) cfg = util.load_config(config_path, overrides=config_overrides)
nlp, config = util.load_model_from_config(cfg) nlp = util.load_model_from_config(cfg)
C = util.resolve_training_config(nlp.config)
# Use original config here, not resolved version # Use original config here, not resolved version
sourced_components = get_sourced_components(cfg) sourced_components = get_sourced_components(cfg)
frozen_components = config["training"]["frozen_components"] frozen_components = C["training"]["frozen_components"]
resume_components = [p for p in sourced_components if p not in frozen_components] resume_components = [p for p in sourced_components if p not in frozen_components]
pipeline = nlp.pipe_names pipeline = nlp.pipe_names
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names] factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
tag_map_path = util.ensure_path(config["training"]["tag_map"]) tag_map_path = util.ensure_path(C["training"]["tag_map"])
tag_map = {} tag_map = {}
if tag_map_path is not None: if tag_map_path is not None:
tag_map = srsly.read_json(tag_map_path) tag_map = srsly.read_json(tag_map_path)
morph_rules_path = util.ensure_path(config["training"]["morph_rules"]) morph_rules_path = util.ensure_path(C["training"]["morph_rules"])
morph_rules = {} morph_rules = {}
if morph_rules_path is not None: if morph_rules_path is not None:
morph_rules = srsly.read_json(morph_rules_path) morph_rules = srsly.read_json(morph_rules_path)
@ -144,10 +145,10 @@ def debug_data(
train_texts = gold_train_data["texts"] train_texts = gold_train_data["texts"]
dev_texts = gold_dev_data["texts"] dev_texts = gold_dev_data["texts"]
frozen_components = config["training"]["frozen_components"] frozen_components = C["training"]["frozen_components"]
msg.divider("Training stats") msg.divider("Training stats")
msg.text(f"Language: {config['nlp']['lang']}") msg.text(f"Language: {C['nlp']['lang']}")
msg.text(f"Training pipeline: {', '.join(pipeline)}") msg.text(f"Training pipeline: {', '.join(pipeline)}")
if resume_components: if resume_components:
msg.text(f"Components from other pipelines: {', '.join(resume_components)}") msg.text(f"Components from other pipelines: {', '.join(resume_components)}")

View File

@ -1,4 +1,3 @@
import warnings
from typing import Dict, Any, Optional, Iterable from typing import Dict, Any, Optional, Iterable
from pathlib import Path from pathlib import Path
@ -57,14 +56,17 @@ def debug_model_cli(
} }
config_overrides = parse_config_overrides(ctx.args) config_overrides = parse_config_overrides(ctx.args)
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config( raw_config = util.load_config(
config_path, overrides=config_overrides, interpolate=True config_path, overrides=config_overrides, interpolate=False
) )
config = raw_config.iterpolate()
allocator = config["training"]["gpu_allocator"] allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator: if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator) set_gpu_allocator(allocator)
nlp, config = util.load_model_from_config(config) with show_validation_error(config_path):
seed = config["training"]["seed"] nlp = util.load_model_from_config(raw_config)
C = util.resolve_training_config(nlp.config)
seed = C["training"]["seed"]
if seed is not None: if seed is not None:
msg.info(f"Fixing random seed: {seed}") msg.info(f"Fixing random seed: {seed}")
fix_random_seed(seed) fix_random_seed(seed)
@ -75,7 +77,7 @@ def debug_model_cli(
exits=1, exits=1,
) )
model = pipe.model model = pipe.model
debug_model(config, nlp, model, print_settings=print_settings) debug_model(C, nlp, model, print_settings=print_settings)
def debug_model( def debug_model(
@ -108,7 +110,7 @@ def debug_model(
_set_output_dim(nO=7, model=model) _set_output_dim(nO=7, model=model)
nlp.begin_training(lambda: [Example.from_dict(x, {}) for x in X]) nlp.begin_training(lambda: [Example.from_dict(x, {}) for x in X])
msg.info("Initialized the model with dummy data.") msg.info("Initialized the model with dummy data.")
except: except Exception:
msg.fail( msg.fail(
"Could not initialize the model: you'll have to provide a valid train_corpus argument in the config file.", "Could not initialize the model: you'll have to provide a valid train_corpus argument in the config file.",
exits=1, exits=1,

View File

@ -88,10 +88,10 @@ def fill_config(
msg = Printer(no_print=no_print) msg = Printer(no_print=no_print)
with show_validation_error(hint_fill=False): with show_validation_error(hint_fill=False):
config = util.load_config(base_path) config = util.load_config(base_path)
nlp, _ = util.load_model_from_config(config, auto_fill=True, validate=False) nlp = util.load_model_from_config(config, auto_fill=True, validate=False)
# Load a second time with validation to be extra sure that the produced # Load a second time with validation to be extra sure that the produced
# config result is a valid config # config result is a valid config
nlp, _ = util.load_model_from_config(nlp.config) nlp = util.load_model_from_config(nlp.config)
filled = nlp.config filled = nlp.config
if pretraining: if pretraining:
validate_config_for_pretrain(filled, msg) validate_config_for_pretrain(filled, msg)
@ -169,7 +169,7 @@ def init_config(
msg.text(f"- {label}: {value}") msg.text(f"- {label}: {value}")
with show_validation_error(hint_fill=False): with show_validation_error(hint_fill=False):
config = util.load_config_from_str(base_template) config = util.load_config_from_str(base_template)
nlp, _ = util.load_model_from_config(config, auto_fill=True) nlp = util.load_model_from_config(config, auto_fill=True)
config = nlp.config config = nlp.config
if pretraining: if pretraining:
validate_config_for_pretrain(config, msg) validate_config_for_pretrain(config, msg)

View File

@ -69,17 +69,18 @@ def pretrain_cli(
msg.info(f"Loading config from: {config_path}") msg.info(f"Loading config from: {config_path}")
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config( raw_config = util.load_config(
config_path, overrides=config_overrides, interpolate=True config_path, overrides=config_overrides, interpolate=False
) )
config = raw_config.interpolate()
if not config.get("pretraining"): if not config.get("pretraining"):
# TODO: What's the solution here? How do we handle optional blocks? # TODO: What's the solution here? How do we handle optional blocks?
msg.fail("The [pretraining] block in your config is empty", exits=1) msg.fail("The [pretraining] block in your config is empty", exits=1)
if not output_dir.exists(): if not output_dir.exists():
output_dir.mkdir() output_dir.mkdir()
msg.good(f"Created output directory: {output_dir}") msg.good(f"Created output directory: {output_dir}")
# Save non-interpolated config
config.to_disk(output_dir / "config.cfg") raw_config.to_disk(output_dir / "config.cfg")
msg.good("Saved config file in the output directory") msg.good("Saved config file in the output directory")
pretrain( pretrain(
@ -103,14 +104,13 @@ def pretrain(
allocator = config["training"]["gpu_allocator"] allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator: if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator) set_gpu_allocator(allocator)
nlp = util.load_model_from_config(config)
nlp, config = util.load_model_from_config(config) C = util.resolve_training_config(nlp.config)
P_cfg = config["pretraining"] P_cfg = C["pretraining"]
corpus = dot_to_object(config, P_cfg["corpus"]) corpus = dot_to_object(C, P_cfg["corpus"])
batcher = P_cfg["batcher"] batcher = P_cfg["batcher"]
model = create_pretraining_model(nlp, config["pretraining"]) model = create_pretraining_model(nlp, C["pretraining"])
optimizer = config["pretraining"]["optimizer"] optimizer = C["pretraining"]["optimizer"]
# Load in pretrained weights to resume from # Load in pretrained weights to resume from
if resume_path is not None: if resume_path is not None:
_resume_model(model, resume_path, epoch_resume) _resume_model(model, resume_path, epoch_resume)

View File

@ -75,12 +75,12 @@ def train(
msg.info("Using CPU") msg.info("Using CPU")
msg.info(f"Loading config and nlp from: {config_path}") msg.info(f"Loading config and nlp from: {config_path}")
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config( # Keep an un-interpolated config so we can preserve variables in
config_path, overrides=config_overrides, interpolate=True
)
# Keep a second un-interpolated config so we can preserve variables in
# the final nlp object we train and serialize # the final nlp object we train and serialize
raw_config = util.load_config(config_path, overrides=config_overrides) raw_config = util.load_config(
config_path, overrides=config_overrides, interpolate=False
)
config = raw_config.interpolate()
if config["training"]["seed"] is not None: if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"]) fix_random_seed(config["training"]["seed"])
allocator = config["training"]["gpu_allocator"] allocator = config["training"]["gpu_allocator"]
@ -89,15 +89,17 @@ def train(
# Use original config here before it's resolved to functions # Use original config here before it's resolved to functions
sourced_components = get_sourced_components(config) sourced_components = get_sourced_components(config)
with show_validation_error(config_path): with show_validation_error(config_path):
nlp, config = util.load_model_from_config(raw_config) nlp = util.load_model_from_config(raw_config)
util.load_vocab_data_into_model(nlp, lookups=config["training"]["lookups"]) # Resolve all training-relevant sections using the filled nlp config
if config["training"]["vectors"] is not None: C = util.resolve_training_config(nlp.config)
add_vectors(nlp, config["training"]["vectors"]) util.load_vocab_data_into_model(nlp, lookups=C["training"]["lookups"])
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config) if C["training"]["vectors"] is not None:
T_cfg = config["training"] add_vectors(nlp, C["training"]["vectors"])
raw_text, tag_map, morph_rules, weights_data = load_from_paths(C)
T_cfg = C["training"]
optimizer = T_cfg["optimizer"] optimizer = T_cfg["optimizer"]
train_corpus = dot_to_object(config, T_cfg["train_corpus"]) train_corpus = dot_to_object(C, T_cfg["train_corpus"])
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"]) dev_corpus = dot_to_object(C, T_cfg["dev_corpus"])
batcher = T_cfg["batcher"] batcher = T_cfg["batcher"]
train_logger = T_cfg["logger"] train_logger = T_cfg["logger"]
before_to_disk = create_before_to_disk_callback(T_cfg["before_to_disk"]) before_to_disk = create_before_to_disk_callback(T_cfg["before_to_disk"])
@ -124,7 +126,7 @@ def train(
# Load pretrained tok2vec weights - cf. CLI command 'pretrain' # Load pretrained tok2vec weights - cf. CLI command 'pretrain'
if weights_data is not None: if weights_data is not None:
tok2vec_component = config["pretraining"]["component"] tok2vec_component = C["pretraining"]["component"]
if tok2vec_component is None: if tok2vec_component is None:
msg.fail( msg.fail(
f"To use pretrained tok2vec weights, [pretraining.component] " f"To use pretrained tok2vec weights, [pretraining.component] "
@ -132,7 +134,7 @@ def train(
exits=1, exits=1,
) )
layer = nlp.get_pipe(tok2vec_component).model layer = nlp.get_pipe(tok2vec_component).model
tok2vec_layer = config["pretraining"]["layer"] tok2vec_layer = C["pretraining"]["layer"]
if tok2vec_layer: if tok2vec_layer:
layer = layer.get_ref(tok2vec_layer) layer = layer.get_ref(tok2vec_layer)
layer.from_bytes(weights_data) layer.from_bytes(weights_data)

View File

@ -27,7 +27,7 @@ from .lang.punctuation import TOKENIZER_INFIXES
from .tokens import Doc from .tokens import Doc
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .errors import Errors, Warnings from .errors import Errors, Warnings
from .schemas import ConfigSchema from .schemas import ConfigSchema, ConfigSchemaNlp
from .git_info import GIT_VERSION from .git_info import GIT_VERSION
from . import util from . import util
from . import about from . import about
@ -166,11 +166,10 @@ class Language:
self._components = [] self._components = []
self._disabled = set() self._disabled = set()
self.max_length = max_length self.max_length = max_length
self.resolved = {}
# Create the default tokenizer from the default config # Create the default tokenizer from the default config
if not create_tokenizer: if not create_tokenizer:
tokenizer_cfg = {"tokenizer": self._config["nlp"]["tokenizer"]} tokenizer_cfg = {"tokenizer": self._config["nlp"]["tokenizer"]}
create_tokenizer = registry.make_from_config(tokenizer_cfg)["tokenizer"] create_tokenizer = registry.resolve(tokenizer_cfg)["tokenizer"]
self.tokenizer = create_tokenizer(self) self.tokenizer = create_tokenizer(self)
def __init_subclass__(cls, **kwargs): def __init_subclass__(cls, **kwargs):
@ -467,7 +466,7 @@ class Language:
if "nlp" not in arg_names or "name" not in arg_names: if "nlp" not in arg_names or "name" not in arg_names:
raise ValueError(Errors.E964.format(name=name)) raise ValueError(Errors.E964.format(name=name))
# Officially register the factory so we can later call # Officially register the factory so we can later call
# registry.make_from_config and refer to it in the config as # registry.resolve and refer to it in the config as
# @factories = "spacy.Language.xyz". We use the class name here so # @factories = "spacy.Language.xyz". We use the class name here so
# different classes can have different factories. # different classes can have different factories.
registry.factories.register(internal_name, func=factory_func) registry.factories.register(internal_name, func=factory_func)
@ -650,8 +649,9 @@ class Language:
cfg = {factory_name: config} cfg = {factory_name: config}
# We're calling the internal _fill here to avoid constructing the # We're calling the internal _fill here to avoid constructing the
# registered functions twice # registered functions twice
resolved, filled = registry.resolve(cfg, validate=validate) resolved = registry.resolve(cfg, validate=validate)
filled = Config(filled[factory_name]) filled = registry.fill({"cfg": cfg[factory_name]}, validate=validate)["cfg"]
filled = Config(filled)
filled["factory"] = factory_name filled["factory"] = factory_name
filled.pop("@factories", None) filled.pop("@factories", None)
# Remove the extra values we added because we don't want to keep passing # Remove the extra values we added because we don't want to keep passing
@ -1518,15 +1518,19 @@ class Language:
config = util.copy_config(config) config = util.copy_config(config)
orig_pipeline = config.pop("components", {}) orig_pipeline = config.pop("components", {})
config["components"] = {} config["components"] = {}
resolved, filled = registry.resolve( if auto_fill:
config, validate=validate, schema=ConfigSchema filled = registry.fill(config, validate=validate, schema=ConfigSchema)
) else:
filled = config
filled["components"] = orig_pipeline filled["components"] = orig_pipeline
config["components"] = orig_pipeline config["components"] = orig_pipeline
create_tokenizer = resolved["nlp"]["tokenizer"] resolved_nlp = registry.resolve(
before_creation = resolved["nlp"]["before_creation"] filled["nlp"], validate=validate, schema=ConfigSchemaNlp
after_creation = resolved["nlp"]["after_creation"] )
after_pipeline_creation = resolved["nlp"]["after_pipeline_creation"] create_tokenizer = resolved_nlp["tokenizer"]
before_creation = resolved_nlp["before_creation"]
after_creation = resolved_nlp["after_creation"]
after_pipeline_creation = resolved_nlp["after_pipeline_creation"]
lang_cls = cls lang_cls = cls
if before_creation is not None: if before_creation is not None:
lang_cls = before_creation(cls) lang_cls = before_creation(cls)
@ -1587,7 +1591,6 @@ class Language:
disabled_pipes = [*config["nlp"]["disabled"], *disable] disabled_pipes = [*config["nlp"]["disabled"], *disable]
nlp._disabled = set(p for p in disabled_pipes if p not in exclude) nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
nlp.config = filled if auto_fill else config nlp.config = filled if auto_fill else config
nlp.resolved = resolved
if after_pipeline_creation is not None: if after_pipeline_creation is not None:
nlp = after_pipeline_creation(nlp) nlp = after_pipeline_creation(nlp)
if not isinstance(nlp, cls): if not isinstance(nlp, cls):

View File

@ -4,6 +4,7 @@ from enum import Enum
from pydantic import BaseModel, Field, ValidationError, validator from pydantic import BaseModel, Field, ValidationError, validator
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
from pydantic import root_validator from pydantic import root_validator
from thinc.config import Promise
from collections import defaultdict from collections import defaultdict
from thinc.api import Optimizer from thinc.api import Optimizer
@ -16,10 +17,12 @@ if TYPE_CHECKING:
from .training import Example # noqa: F401 from .training import Example # noqa: F401
# fmt: off
ItemT = TypeVar("ItemT") ItemT = TypeVar("ItemT")
Batcher = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]] Batcher = Union[Callable[[Iterable[ItemT]], Iterable[List[ItemT]]], Promise]
Reader = Callable[["Language", str], Iterable["Example"]] Reader = Union[Callable[["Language", str], Iterable["Example"]], Promise]
Logger = Callable[["Language"], Tuple[Callable[[Dict[str, Any]], None], Callable]] Logger = Union[Callable[["Language"], Tuple[Callable[[Dict[str, Any]], None], Callable]], Promise]
# fmt: on
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]: def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
@ -292,6 +295,16 @@ class ConfigSchema(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
class TrainingSchema(BaseModel):
training: ConfigSchemaTraining
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
corpora: Dict[str, Reader]
class Config:
extra = "allow"
arbitrary_types_allowed = True
# Project config Schema # Project config Schema

View File

@ -24,7 +24,7 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_NER_MODEL} cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
ner = EntityRecognizer(en_vocab, model, **config) ner = EntityRecognizer(en_vocab, model, **config)
ner.begin_training(lambda: [_ner_example(ner)]) ner.begin_training(lambda: [_ner_example(ner)])
ner(doc) ner(doc)
@ -46,7 +46,7 @@ def test_ents_reset(en_vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_NER_MODEL} cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
ner = EntityRecognizer(en_vocab, model, **config) ner = EntityRecognizer(en_vocab, model, **config)
ner.begin_training(lambda: [_ner_example(ner)]) ner.begin_training(lambda: [_ner_example(ner)])
ner(doc) ner(doc)

View File

@ -23,7 +23,7 @@ def parser(vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(vocab, model, **config) parser = DependencyParser(vocab, model, **config)
return parser return parser
@ -82,7 +82,7 @@ def test_add_label_deserializes_correctly():
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_NER_MODEL} cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
ner1 = EntityRecognizer(Vocab(), model, **config) ner1 = EntityRecognizer(Vocab(), model, **config)
ner1.add_label("C") ner1.add_label("C")
ner1.add_label("B") ner1.add_label("B")
@ -111,7 +111,7 @@ def test_add_label_get_label(pipe_cls, n_moves, model_config):
splitting the move names. splitting the move names.
""" """
labels = ["A", "B", "C"] labels = ["A", "B", "C"]
model = registry.make_from_config({"model": model_config}, validate=True)["model"] model = registry.resolve({"model": model_config}, validate=True)["model"]
config = { config = {
"learn_tokens": False, "learn_tokens": False,
"min_action_freq": 30, "min_action_freq": 30,

View File

@ -127,7 +127,7 @@ def test_get_oracle_actions():
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(doc.vocab, model, **config) parser = DependencyParser(doc.vocab, model, **config)
parser.moves.add_action(0, "") parser.moves.add_action(0, "")
parser.moves.add_action(1, "") parser.moves.add_action(1, "")

View File

@ -25,7 +25,7 @@ def arc_eager(vocab):
@pytest.fixture @pytest.fixture
def tok2vec(): def tok2vec():
cfg = {"model": DEFAULT_TOK2VEC_MODEL} cfg = {"model": DEFAULT_TOK2VEC_MODEL}
tok2vec = registry.make_from_config(cfg, validate=True)["model"] tok2vec = registry.resolve(cfg, validate=True)["model"]
tok2vec.initialize() tok2vec.initialize()
return tok2vec return tok2vec
@ -38,14 +38,14 @@ def parser(vocab, arc_eager):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
return Parser(vocab, model, moves=arc_eager, **config) return Parser(vocab, model, moves=arc_eager, **config)
@pytest.fixture @pytest.fixture
def model(arc_eager, tok2vec, vocab): def model(arc_eager, tok2vec, vocab):
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
model.attrs["resize_output"](model, arc_eager.n_moves) model.attrs["resize_output"](model, arc_eager.n_moves)
model.initialize() model.initialize()
return model return model
@ -72,7 +72,7 @@ def test_build_model(parser, vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model
assert parser.model is not None assert parser.model is not None

View File

@ -28,7 +28,7 @@ def parser(vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(vocab, model, **config) parser = DependencyParser(vocab, model, **config)
parser.cfg["token_vector_width"] = 4 parser.cfg["token_vector_width"] = 4
parser.cfg["hidden_width"] = 32 parser.cfg["hidden_width"] = 32

View File

@ -139,7 +139,7 @@ TRAIN_DATA = [
def test_tok2vec_listener(): def test_tok2vec_listener():
orig_config = Config().from_str(cfg_string) orig_config = Config().from_str(cfg_string)
nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True) nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
assert nlp.pipe_names == ["tok2vec", "tagger"] assert nlp.pipe_names == ["tok2vec", "tagger"]
tagger = nlp.get_pipe("tagger") tagger = nlp.get_pipe("tagger")
tok2vec = nlp.get_pipe("tok2vec") tok2vec = nlp.get_pipe("tok2vec")
@ -173,7 +173,7 @@ def test_tok2vec_listener():
def test_tok2vec_listener_callback(): def test_tok2vec_listener_callback():
orig_config = Config().from_str(cfg_string) orig_config = Config().from_str(cfg_string)
nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True) nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
assert nlp.pipe_names == ["tok2vec", "tagger"] assert nlp.pipe_names == ["tok2vec", "tagger"]
tagger = nlp.get_pipe("tagger") tagger = nlp.get_pipe("tagger")
tok2vec = nlp.get_pipe("tok2vec") tok2vec = nlp.get_pipe("tok2vec")

View File

@ -195,7 +195,7 @@ def test_issue3345():
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_NER_MODEL} cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
ner = EntityRecognizer(doc.vocab, model, **config) ner = EntityRecognizer(doc.vocab, model, **config)
# Add the OUT action. I wouldn't have thought this would be necessary... # Add the OUT action. I wouldn't have thought this would be necessary...
ner.moves.add_action(5, "") ner.moves.add_action(5, "")

View File

@ -264,9 +264,7 @@ def test_issue3830_no_subtok():
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)[ model = registry.resolve({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
"model"
]
parser = DependencyParser(Vocab(), model, **config) parser = DependencyParser(Vocab(), model, **config)
parser.add_label("nsubj") parser.add_label("nsubj")
assert "subtok" not in parser.labels assert "subtok" not in parser.labels
@ -281,9 +279,7 @@ def test_issue3830_with_subtok():
"min_action_freq": 30, "min_action_freq": 30,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)[ model = registry.resolve({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
"model"
]
parser = DependencyParser(Vocab(), model, **config) parser = DependencyParser(Vocab(), model, **config)
parser.add_label("nsubj") parser.add_label("nsubj")
assert "subtok" not in parser.labels assert "subtok" not in parser.labels

View File

@ -108,8 +108,8 @@ def my_parser():
def test_create_nlp_from_config(): def test_create_nlp_from_config():
config = Config().from_str(nlp_config_string) config = Config().from_str(nlp_config_string)
with pytest.raises(ConfigValidationError): with pytest.raises(ConfigValidationError):
nlp, _ = load_model_from_config(config, auto_fill=False) load_model_from_config(config, auto_fill=False)
nlp, resolved = load_model_from_config(config, auto_fill=True) nlp = load_model_from_config(config, auto_fill=True)
assert nlp.config["training"]["batcher"]["size"] == 666 assert nlp.config["training"]["batcher"]["size"] == 666
assert len(nlp.config["training"]) > 1 assert len(nlp.config["training"]) > 1
assert nlp.pipe_names == ["tok2vec", "tagger"] assert nlp.pipe_names == ["tok2vec", "tagger"]
@ -136,7 +136,7 @@ def test_create_nlp_from_config_multiple_instances():
"tagger2": config["components"]["tagger"], "tagger2": config["components"]["tagger"],
} }
config["nlp"]["pipeline"] = list(config["components"].keys()) config["nlp"]["pipeline"] = list(config["components"].keys())
nlp, _ = load_model_from_config(config, auto_fill=True) nlp = load_model_from_config(config, auto_fill=True)
assert nlp.pipe_names == ["t2v", "tagger1", "tagger2"] assert nlp.pipe_names == ["t2v", "tagger1", "tagger2"]
assert nlp.get_pipe_meta("t2v").factory == "tok2vec" assert nlp.get_pipe_meta("t2v").factory == "tok2vec"
assert nlp.get_pipe_meta("tagger1").factory == "tagger" assert nlp.get_pipe_meta("tagger1").factory == "tagger"
@ -150,7 +150,7 @@ def test_create_nlp_from_config_multiple_instances():
def test_serialize_nlp(): def test_serialize_nlp():
""" Create a custom nlp pipeline from config and ensure it serializes it correctly """ """ Create a custom nlp pipeline from config and ensure it serializes it correctly """
nlp_config = Config().from_str(nlp_config_string) nlp_config = Config().from_str(nlp_config_string)
nlp, _ = load_model_from_config(nlp_config, auto_fill=True) nlp = load_model_from_config(nlp_config, auto_fill=True)
nlp.get_pipe("tagger").add_label("A") nlp.get_pipe("tagger").add_label("A")
nlp.begin_training() nlp.begin_training()
assert "tok2vec" in nlp.pipe_names assert "tok2vec" in nlp.pipe_names
@ -209,7 +209,7 @@ def test_config_nlp_roundtrip():
nlp = English() nlp = English()
nlp.add_pipe("entity_ruler") nlp.add_pipe("entity_ruler")
nlp.add_pipe("ner") nlp.add_pipe("ner")
new_nlp, new_config = load_model_from_config(nlp.config, auto_fill=False) new_nlp = load_model_from_config(nlp.config, auto_fill=False)
assert new_nlp.config == nlp.config assert new_nlp.config == nlp.config
assert new_nlp.pipe_names == nlp.pipe_names assert new_nlp.pipe_names == nlp.pipe_names
assert new_nlp._pipe_configs == nlp._pipe_configs assert new_nlp._pipe_configs == nlp._pipe_configs
@ -280,12 +280,12 @@ def test_config_overrides():
overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]} overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]}
# load_model from config with overrides passed directly to Config # load_model from config with overrides passed directly to Config
config = Config().from_str(nlp_config_string, overrides=overrides_dot) config = Config().from_str(nlp_config_string, overrides=overrides_dot)
nlp, _ = load_model_from_config(config, auto_fill=True) nlp = load_model_from_config(config, auto_fill=True)
assert isinstance(nlp, German) assert isinstance(nlp, German)
assert nlp.pipe_names == ["tagger"] assert nlp.pipe_names == ["tagger"]
# Serialized roundtrip with config passed in # Serialized roundtrip with config passed in
base_config = Config().from_str(nlp_config_string) base_config = Config().from_str(nlp_config_string)
base_nlp, _ = load_model_from_config(base_config, auto_fill=True) base_nlp = load_model_from_config(base_config, auto_fill=True)
assert isinstance(base_nlp, English) assert isinstance(base_nlp, English)
assert base_nlp.pipe_names == ["tok2vec", "tagger"] assert base_nlp.pipe_names == ["tok2vec", "tagger"]
with make_tempdir() as d: with make_tempdir() as d:
@ -328,7 +328,7 @@ def test_config_optional_sections():
config = Config().from_str(nlp_config_string) config = Config().from_str(nlp_config_string)
config = DEFAULT_CONFIG.merge(config) config = DEFAULT_CONFIG.merge(config)
assert "pretraining" not in config assert "pretraining" not in config
filled = registry.fill_config(config, schema=ConfigSchema, validate=False) filled = registry.fill(config, schema=ConfigSchema, validate=False)
# Make sure that optional "pretraining" block doesn't default to None, # Make sure that optional "pretraining" block doesn't default to None,
# which would (rightly) cause error because it'd result in a top-level # which would (rightly) cause error because it'd result in a top-level
# key that's not a section (dict). Note that the following roundtrip is # key that's not a section (dict). Note that the following roundtrip is
@ -341,7 +341,7 @@ def test_config_auto_fill_extra_fields():
config = Config({"nlp": {"lang": "en"}, "training": {}}) config = Config({"nlp": {"lang": "en"}, "training": {}})
assert load_model_from_config(config, auto_fill=True) assert load_model_from_config(config, auto_fill=True)
config = Config({"nlp": {"lang": "en"}, "training": {"extra": "hello"}}) config = Config({"nlp": {"lang": "en"}, "training": {"extra": "hello"}})
nlp, _ = load_model_from_config(config, auto_fill=True, validate=False) nlp = load_model_from_config(config, auto_fill=True, validate=False)
assert "extra" not in nlp.config["training"] assert "extra" not in nlp.config["training"]
# Make sure the config generated is valid # Make sure the config generated is valid
load_model_from_config(nlp.config) load_model_from_config(nlp.config)

View File

@ -23,7 +23,7 @@ def parser(en_vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(en_vocab, model, **config) parser = DependencyParser(en_vocab, model, **config)
parser.add_label("nsubj") parser.add_label("nsubj")
return parser return parser
@ -37,7 +37,7 @@ def blank_parser(en_vocab):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(en_vocab, model, **config) parser = DependencyParser(en_vocab, model, **config)
return parser return parser
@ -45,7 +45,7 @@ def blank_parser(en_vocab):
@pytest.fixture @pytest.fixture
def taggers(en_vocab): def taggers(en_vocab):
cfg = {"model": DEFAULT_TAGGER_MODEL} cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
tagger1 = Tagger(en_vocab, model) tagger1 = Tagger(en_vocab, model)
tagger2 = Tagger(en_vocab, model) tagger2 = Tagger(en_vocab, model)
return tagger1, tagger2 return tagger1, tagger2
@ -59,7 +59,7 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = Parser(en_vocab, model, **config) parser = Parser(en_vocab, model, **config)
new_parser = Parser(en_vocab, model, **config) new_parser = Parser(en_vocab, model, **config)
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"])) new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
@ -77,7 +77,7 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
} }
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
parser = Parser(en_vocab, model, **config) parser = Parser(en_vocab, model, **config)
with make_tempdir() as d: with make_tempdir() as d:
file_path = d / "parser" file_path = d / "parser"
@ -111,7 +111,7 @@ def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers):
tagger1 = tagger1.from_bytes(tagger1_b) tagger1 = tagger1.from_bytes(tagger1_b)
assert tagger1.to_bytes() == tagger1_b assert tagger1.to_bytes() == tagger1_b
cfg = {"model": DEFAULT_TAGGER_MODEL} cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b) new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b)
new_tagger1_b = new_tagger1.to_bytes() new_tagger1_b = new_tagger1.to_bytes()
assert len(new_tagger1_b) == len(tagger1_b) assert len(new_tagger1_b) == len(tagger1_b)
@ -126,7 +126,7 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
tagger1.to_disk(file_path1) tagger1.to_disk(file_path1)
tagger2.to_disk(file_path2) tagger2.to_disk(file_path2)
cfg = {"model": DEFAULT_TAGGER_MODEL} cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
tagger1_d = Tagger(en_vocab, model).from_disk(file_path1) tagger1_d = Tagger(en_vocab, model).from_disk(file_path1)
tagger2_d = Tagger(en_vocab, model).from_disk(file_path2) tagger2_d = Tagger(en_vocab, model).from_disk(file_path2)
assert tagger1_d.to_bytes() == tagger2_d.to_bytes() assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
@ -135,7 +135,7 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
def test_serialize_textcat_empty(en_vocab): def test_serialize_textcat_empty(en_vocab):
# See issue #1105 # See issue #1105
cfg = {"model": DEFAULT_TEXTCAT_MODEL} cfg = {"model": DEFAULT_TEXTCAT_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
textcat = TextCategorizer( textcat = TextCategorizer(
en_vocab, en_vocab,
model, model,
@ -149,7 +149,7 @@ def test_serialize_textcat_empty(en_vocab):
@pytest.mark.parametrize("Parser", test_parsers) @pytest.mark.parametrize("Parser", test_parsers)
def test_serialize_pipe_exclude(en_vocab, Parser): def test_serialize_pipe_exclude(en_vocab, Parser):
cfg = {"model": DEFAULT_PARSER_MODEL} cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
config = { config = {
"learn_tokens": False, "learn_tokens": False,
"min_action_freq": 0, "min_action_freq": 0,
@ -176,7 +176,7 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
def test_serialize_sentencerecognizer(en_vocab): def test_serialize_sentencerecognizer(en_vocab):
cfg = {"model": DEFAULT_SENTER_MODEL} cfg = {"model": DEFAULT_SENTER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"] model = registry.resolve(cfg, validate=True)["model"]
sr = SentenceRecognizer(en_vocab, model) sr = SentenceRecognizer(en_vocab, model)
sr_b = sr.to_bytes() sr_b = sr.to_bytes()
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b) sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)

View File

@ -7,6 +7,7 @@ from spacy import util
from spacy import prefer_gpu, require_gpu from spacy import prefer_gpu, require_gpu
from spacy.ml._precomputable_affine import PrecomputableAffine from spacy.ml._precomputable_affine import PrecomputableAffine
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
from thinc.api import Optimizer
@pytest.fixture @pytest.fixture
@ -157,3 +158,16 @@ def test_dot_to_dict(dot_notation, expected):
result = util.dot_to_dict(dot_notation) result = util.dot_to_dict(dot_notation)
assert result == expected assert result == expected
assert util.dict_to_dot(result) == dot_notation assert util.dict_to_dot(result) == dot_notation
def test_resolve_training_config():
config = {
"nlp": {"lang": "en", "disabled": []},
"training": {"dropout": 0.1, "optimizer": {"@optimizers": "Adam.v1"}},
"corpora": {},
}
resolved = util.resolve_training_config(config)
assert resolved["training"]["dropout"] == 0.1
assert isinstance(resolved["training"]["optimizer"], Optimizer)
assert resolved["corpora"] == {}
assert "nlp" not in resolved

View File

@ -82,10 +82,10 @@ def test_util_dot_section():
no_output_layer = false no_output_layer = false
""" """
nlp_config = Config().from_str(cfg_string) nlp_config = Config().from_str(cfg_string)
en_nlp, en_config = util.load_model_from_config(nlp_config, auto_fill=True) en_nlp = util.load_model_from_config(nlp_config, auto_fill=True)
default_config = Config().from_disk(DEFAULT_CONFIG_PATH) default_config = Config().from_disk(DEFAULT_CONFIG_PATH)
default_config["nlp"]["lang"] = "nl" default_config["nlp"]["lang"] = "nl"
nl_nlp, nl_config = util.load_model_from_config(default_config, auto_fill=True) nl_nlp = util.load_model_from_config(default_config, auto_fill=True)
# Test that creation went OK # Test that creation went OK
assert isinstance(en_nlp, English) assert isinstance(en_nlp, English)
assert isinstance(nl_nlp, Dutch) assert isinstance(nl_nlp, Dutch)
@ -94,14 +94,15 @@ def test_util_dot_section():
# not exclusive_classes # not exclusive_classes
assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] is False assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] is False
# Test that default values got overwritten # Test that default values got overwritten
assert en_config["nlp"]["pipeline"] == ["textcat"] assert en_nlp.config["nlp"]["pipeline"] == ["textcat"]
assert nl_config["nlp"]["pipeline"] == [] # default value [] assert nl_nlp.config["nlp"]["pipeline"] == [] # default value []
# Test proper functioning of 'dot_to_object' # Test proper functioning of 'dot_to_object'
with pytest.raises(KeyError): with pytest.raises(KeyError):
dot_to_object(en_config, "nlp.pipeline.tagger") dot_to_object(en_nlp.config, "nlp.pipeline.tagger")
with pytest.raises(KeyError): with pytest.raises(KeyError):
dot_to_object(en_config, "nlp.unknownattribute") dot_to_object(en_nlp.config, "nlp.unknownattribute")
assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer) resolved = util.resolve_training_config(nl_nlp.config)
assert isinstance(dot_to_object(resolved, "training.optimizer"), Optimizer)
def test_simple_frozen_list(): def test_simple_frozen_list():

View File

@ -3,6 +3,7 @@ import pytest
from thinc.api import Config from thinc.api import Config
from spacy import Language from spacy import Language
from spacy.util import load_model_from_config, registry, dot_to_object from spacy.util import load_model_from_config, registry, dot_to_object
from spacy.util import resolve_training_config
from spacy.training import Example from spacy.training import Example
@ -37,8 +38,8 @@ def test_readers():
return {"train": reader, "dev": reader, "extra": reader, "something": reader} return {"train": reader, "dev": reader, "extra": reader, "something": reader}
config = Config().from_str(config_string) config = Config().from_str(config_string)
nlp, resolved = load_model_from_config(config, auto_fill=True) nlp = load_model_from_config(config, auto_fill=True)
resolved = resolve_training_config(nlp.config)
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"]) train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
assert isinstance(train_corpus, Callable) assert isinstance(train_corpus, Callable)
optimizer = resolved["training"]["optimizer"] optimizer = resolved["training"]["optimizer"]
@ -87,8 +88,8 @@ def test_cat_readers(reader, additional_config):
config = Config().from_str(nlp_config_string) config = Config().from_str(nlp_config_string)
config["corpora"]["@readers"] = reader config["corpora"]["@readers"] = reader
config["corpora"].update(additional_config) config["corpora"].update(additional_config)
nlp, resolved = load_model_from_config(config, auto_fill=True) nlp = load_model_from_config(config, auto_fill=True)
resolved = resolve_training_config(nlp.config)
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"]) train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
optimizer = resolved["training"]["optimizer"] optimizer = resolved["training"]["optimizer"]
# simulate a training loop # simulate a training loop

View File

@ -86,7 +86,7 @@ class registry(thinc.registry):
# spacy_factories entry point. This registry only exists so we can easily # spacy_factories entry point. This registry only exists so we can easily
# load them via the entry points. The "true" factories are added via the # load them via the entry points. The "true" factories are added via the
# Language.factory decorator (in the spaCy code base and user code) and those # Language.factory decorator (in the spaCy code base and user code) and those
# are the factories used to initialize components via registry.make_from_config. # are the factories used to initialize components via registry.resolve.
_entry_point_factories = catalogue.create("spacy", "factories", entry_points=True) _entry_point_factories = catalogue.create("spacy", "factories", entry_points=True)
factories = catalogue.create("spacy", "internal_factories") factories = catalogue.create("spacy", "internal_factories")
# This is mostly used to get a list of all installed models in the current # This is mostly used to get a list of all installed models in the current
@ -351,9 +351,7 @@ def load_model_from_path(
meta = get_model_meta(model_path) meta = get_model_meta(model_path)
config_path = model_path / "config.cfg" config_path = model_path / "config.cfg"
config = load_config(config_path, overrides=dict_to_dot(config)) config = load_config(config_path, overrides=dict_to_dot(config))
nlp, _ = load_model_from_config( nlp = load_model_from_config(config, vocab=vocab, disable=disable, exclude=exclude)
config, vocab=vocab, disable=disable, exclude=exclude
)
return nlp.from_disk(model_path, exclude=exclude) return nlp.from_disk(model_path, exclude=exclude)
@ -365,7 +363,7 @@ def load_model_from_config(
exclude: Iterable[str] = SimpleFrozenList(), exclude: Iterable[str] = SimpleFrozenList(),
auto_fill: bool = False, auto_fill: bool = False,
validate: bool = True, validate: bool = True,
) -> Tuple["Language", Config]: ) -> "Language":
"""Create an nlp object from a config. Expects the full config file including """Create an nlp object from a config. Expects the full config file including
a section "nlp" containing the settings for the nlp object. a section "nlp" containing the settings for the nlp object.
@ -398,7 +396,30 @@ def load_model_from_config(
auto_fill=auto_fill, auto_fill=auto_fill,
validate=validate, validate=validate,
) )
return nlp, nlp.resolved return nlp
def resolve_training_config(
config: Config,
exclude: Iterable[str] = ("nlp", "components"),
validate: bool = True,
) -> Dict[str, Any]:
"""Resolve the config sections relevant for trainig and create all objects.
Mostly used in the CLI to separate training config (not resolved by default
because not runtime-relevant an nlp object should load fine even if it's
[training] block refers to functions that are not available etc.).
config (Config): The config to resolve.
exclude (Iterable[str]): The config blocks to exclude. Those blocks won't
be available in the final resolved config.
validate (bool): Whether to validate the config.
RETURNS (Dict[str, Any]): The resolved config.
"""
config = config.copy()
for key in exclude:
if key in config:
config.pop(key)
return registry.resolve(config, validate=validate)
def load_model_from_init_py( def load_model_from_init_py(