Update config resolution to use new Thinc

This commit is contained in:
Ines Montani 2020-09-27 22:21:31 +02:00
parent f29d5b9b89
commit 7e938ed63e
24 changed files with 163 additions and 122 deletions

View File

@ -6,7 +6,7 @@ requires = [
"cymem>=2.0.2,<2.1.0",
"preshed>=3.0.2,<3.1.0",
"murmurhash>=0.28.0,<1.1.0",
"thinc>=8.0.0a36,<8.0.0a40",
"thinc>=8.0.0a40,<8.0.0a50",
"blis>=0.4.0,<0.5.0",
"pytokenizations",
"pathy"

View File

@ -1,7 +1,7 @@
# Our libraries
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.0.0a36,<8.0.0a40
thinc>=8.0.0a40,<8.0.0a50
blis>=0.4.0,<0.5.0
ml_datasets==0.2.0a0
murmurhash>=0.28.0,<1.1.0

View File

@ -34,13 +34,13 @@ setup_requires =
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
murmurhash>=0.28.0,<1.1.0
thinc>=8.0.0a36,<8.0.0a40
thinc>=8.0.0a40,<8.0.0a50
install_requires =
# Our libraries
murmurhash>=0.28.0,<1.1.0
cymem>=2.0.2,<2.1.0
preshed>=3.0.2,<3.1.0
thinc>=8.0.0a36,<8.0.0a40
thinc>=8.0.0a40,<8.0.0a50
blis>=0.4.0,<0.5.0
wasabi>=0.8.0,<1.1.0
srsly>=2.1.0,<3.0.0

View File

@ -51,9 +51,10 @@ def debug_config(
msg.divider("Config validation")
with show_validation_error(config_path):
config = util.load_config(config_path, overrides=overrides)
nlp, resolved = util.load_model_from_config(config)
nlp = util.load_model_from_config(config)
# Use the resolved config here in case user has one function returning
# a dict of corpora etc.
resolved = util.resolve_training_config(nlp.config)
check_section_refs(resolved, ["training.dev_corpus", "training.train_corpus"])
msg.good("Config is valid")
if show_vars:

View File

@ -93,18 +93,19 @@ def debug_data(
msg.fail("Config file not found", config_path, exists=1)
with show_validation_error(config_path):
cfg = util.load_config(config_path, overrides=config_overrides)
nlp, config = util.load_model_from_config(cfg)
nlp = util.load_model_from_config(cfg)
C = util.resolve_training_config(nlp.config)
# Use original config here, not resolved version
sourced_components = get_sourced_components(cfg)
frozen_components = config["training"]["frozen_components"]
frozen_components = C["training"]["frozen_components"]
resume_components = [p for p in sourced_components if p not in frozen_components]
pipeline = nlp.pipe_names
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
tag_map_path = util.ensure_path(config["training"]["tag_map"])
tag_map_path = util.ensure_path(C["training"]["tag_map"])
tag_map = {}
if tag_map_path is not None:
tag_map = srsly.read_json(tag_map_path)
morph_rules_path = util.ensure_path(config["training"]["morph_rules"])
morph_rules_path = util.ensure_path(C["training"]["morph_rules"])
morph_rules = {}
if morph_rules_path is not None:
morph_rules = srsly.read_json(morph_rules_path)
@ -144,10 +145,10 @@ def debug_data(
train_texts = gold_train_data["texts"]
dev_texts = gold_dev_data["texts"]
frozen_components = config["training"]["frozen_components"]
frozen_components = C["training"]["frozen_components"]
msg.divider("Training stats")
msg.text(f"Language: {config['nlp']['lang']}")
msg.text(f"Language: {C['nlp']['lang']}")
msg.text(f"Training pipeline: {', '.join(pipeline)}")
if resume_components:
msg.text(f"Components from other pipelines: {', '.join(resume_components)}")

View File

@ -1,4 +1,3 @@
import warnings
from typing import Dict, Any, Optional, Iterable
from pathlib import Path
@ -57,14 +56,17 @@ def debug_model_cli(
}
config_overrides = parse_config_overrides(ctx.args)
with show_validation_error(config_path):
config = util.load_config(
config_path, overrides=config_overrides, interpolate=True
raw_config = util.load_config(
config_path, overrides=config_overrides, interpolate=False
)
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
nlp, config = util.load_model_from_config(config)
seed = config["training"]["seed"]
config = raw_config.iterpolate()
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
with show_validation_error(config_path):
nlp = util.load_model_from_config(raw_config)
C = util.resolve_training_config(nlp.config)
seed = C["training"]["seed"]
if seed is not None:
msg.info(f"Fixing random seed: {seed}")
fix_random_seed(seed)
@ -75,7 +77,7 @@ def debug_model_cli(
exits=1,
)
model = pipe.model
debug_model(config, nlp, model, print_settings=print_settings)
debug_model(C, nlp, model, print_settings=print_settings)
def debug_model(
@ -108,7 +110,7 @@ def debug_model(
_set_output_dim(nO=7, model=model)
nlp.begin_training(lambda: [Example.from_dict(x, {}) for x in X])
msg.info("Initialized the model with dummy data.")
except:
except Exception:
msg.fail(
"Could not initialize the model: you'll have to provide a valid train_corpus argument in the config file.",
exits=1,

View File

@ -88,10 +88,10 @@ def fill_config(
msg = Printer(no_print=no_print)
with show_validation_error(hint_fill=False):
config = util.load_config(base_path)
nlp, _ = util.load_model_from_config(config, auto_fill=True, validate=False)
nlp = util.load_model_from_config(config, auto_fill=True, validate=False)
# Load a second time with validation to be extra sure that the produced
# config result is a valid config
nlp, _ = util.load_model_from_config(nlp.config)
nlp = util.load_model_from_config(nlp.config)
filled = nlp.config
if pretraining:
validate_config_for_pretrain(filled, msg)
@ -169,7 +169,7 @@ def init_config(
msg.text(f"- {label}: {value}")
with show_validation_error(hint_fill=False):
config = util.load_config_from_str(base_template)
nlp, _ = util.load_model_from_config(config, auto_fill=True)
nlp = util.load_model_from_config(config, auto_fill=True)
config = nlp.config
if pretraining:
validate_config_for_pretrain(config, msg)

View File

@ -69,17 +69,18 @@ def pretrain_cli(
msg.info(f"Loading config from: {config_path}")
with show_validation_error(config_path):
config = util.load_config(
config_path, overrides=config_overrides, interpolate=True
raw_config = util.load_config(
config_path, overrides=config_overrides, interpolate=False
)
config = raw_config.interpolate()
if not config.get("pretraining"):
# TODO: What's the solution here? How do we handle optional blocks?
msg.fail("The [pretraining] block in your config is empty", exits=1)
if not output_dir.exists():
output_dir.mkdir()
msg.good(f"Created output directory: {output_dir}")
config.to_disk(output_dir / "config.cfg")
# Save non-interpolated config
raw_config.to_disk(output_dir / "config.cfg")
msg.good("Saved config file in the output directory")
pretrain(
@ -103,14 +104,13 @@ def pretrain(
allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
nlp, config = util.load_model_from_config(config)
P_cfg = config["pretraining"]
corpus = dot_to_object(config, P_cfg["corpus"])
nlp = util.load_model_from_config(config)
C = util.resolve_training_config(nlp.config)
P_cfg = C["pretraining"]
corpus = dot_to_object(C, P_cfg["corpus"])
batcher = P_cfg["batcher"]
model = create_pretraining_model(nlp, config["pretraining"])
optimizer = config["pretraining"]["optimizer"]
model = create_pretraining_model(nlp, C["pretraining"])
optimizer = C["pretraining"]["optimizer"]
# Load in pretrained weights to resume from
if resume_path is not None:
_resume_model(model, resume_path, epoch_resume)

View File

@ -75,12 +75,12 @@ def train(
msg.info("Using CPU")
msg.info(f"Loading config and nlp from: {config_path}")
with show_validation_error(config_path):
config = util.load_config(
config_path, overrides=config_overrides, interpolate=True
)
# Keep a second un-interpolated config so we can preserve variables in
# Keep an un-interpolated config so we can preserve variables in
# the final nlp object we train and serialize
raw_config = util.load_config(config_path, overrides=config_overrides)
raw_config = util.load_config(
config_path, overrides=config_overrides, interpolate=False
)
config = raw_config.interpolate()
if config["training"]["seed"] is not None:
fix_random_seed(config["training"]["seed"])
allocator = config["training"]["gpu_allocator"]
@ -89,15 +89,17 @@ def train(
# Use original config here before it's resolved to functions
sourced_components = get_sourced_components(config)
with show_validation_error(config_path):
nlp, config = util.load_model_from_config(raw_config)
util.load_vocab_data_into_model(nlp, lookups=config["training"]["lookups"])
if config["training"]["vectors"] is not None:
add_vectors(nlp, config["training"]["vectors"])
raw_text, tag_map, morph_rules, weights_data = load_from_paths(config)
T_cfg = config["training"]
nlp = util.load_model_from_config(raw_config)
# Resolve all training-relevant sections using the filled nlp config
C = util.resolve_training_config(nlp.config)
util.load_vocab_data_into_model(nlp, lookups=C["training"]["lookups"])
if C["training"]["vectors"] is not None:
add_vectors(nlp, C["training"]["vectors"])
raw_text, tag_map, morph_rules, weights_data = load_from_paths(C)
T_cfg = C["training"]
optimizer = T_cfg["optimizer"]
train_corpus = dot_to_object(config, T_cfg["train_corpus"])
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
train_corpus = dot_to_object(C, T_cfg["train_corpus"])
dev_corpus = dot_to_object(C, T_cfg["dev_corpus"])
batcher = T_cfg["batcher"]
train_logger = T_cfg["logger"]
before_to_disk = create_before_to_disk_callback(T_cfg["before_to_disk"])
@ -124,7 +126,7 @@ def train(
# Load pretrained tok2vec weights - cf. CLI command 'pretrain'
if weights_data is not None:
tok2vec_component = config["pretraining"]["component"]
tok2vec_component = C["pretraining"]["component"]
if tok2vec_component is None:
msg.fail(
f"To use pretrained tok2vec weights, [pretraining.component] "
@ -132,7 +134,7 @@ def train(
exits=1,
)
layer = nlp.get_pipe(tok2vec_component).model
tok2vec_layer = config["pretraining"]["layer"]
tok2vec_layer = C["pretraining"]["layer"]
if tok2vec_layer:
layer = layer.get_ref(tok2vec_layer)
layer.from_bytes(weights_data)

View File

@ -166,11 +166,10 @@ class Language:
self._components = []
self._disabled = set()
self.max_length = max_length
self.resolved = {}
# Create the default tokenizer from the default config
if not create_tokenizer:
tokenizer_cfg = {"tokenizer": self._config["nlp"]["tokenizer"]}
create_tokenizer = registry.make_from_config(tokenizer_cfg)["tokenizer"]
create_tokenizer = registry.resolve(tokenizer_cfg)["tokenizer"]
self.tokenizer = create_tokenizer(self)
def __init_subclass__(cls, **kwargs):
@ -467,7 +466,7 @@ class Language:
if "nlp" not in arg_names or "name" not in arg_names:
raise ValueError(Errors.E964.format(name=name))
# Officially register the factory so we can later call
# registry.make_from_config and refer to it in the config as
# registry.resolve and refer to it in the config as
# @factories = "spacy.Language.xyz". We use the class name here so
# different classes can have different factories.
registry.factories.register(internal_name, func=factory_func)
@ -650,8 +649,9 @@ class Language:
cfg = {factory_name: config}
# We're calling the internal _fill here to avoid constructing the
# registered functions twice
resolved, filled = registry.resolve(cfg, validate=validate)
filled = Config(filled[factory_name])
resolved = registry.resolve(cfg, validate=validate)
filled = registry.fill({"cfg": cfg[factory_name]}, validate=validate)["cfg"]
filled = Config(filled)
filled["factory"] = factory_name
filled.pop("@factories", None)
# Remove the extra values we added because we don't want to keep passing
@ -1518,15 +1518,14 @@ class Language:
config = util.copy_config(config)
orig_pipeline = config.pop("components", {})
config["components"] = {}
resolved, filled = registry.resolve(
config, validate=validate, schema=ConfigSchema
)
filled = registry.fill(config, validate=validate, schema=ConfigSchema)
filled["components"] = orig_pipeline
config["components"] = orig_pipeline
create_tokenizer = resolved["nlp"]["tokenizer"]
before_creation = resolved["nlp"]["before_creation"]
after_creation = resolved["nlp"]["after_creation"]
after_pipeline_creation = resolved["nlp"]["after_pipeline_creation"]
resolved_nlp = registry.resolve(filled["nlp"], validate=validate)
create_tokenizer = resolved_nlp["tokenizer"]
before_creation = resolved_nlp["before_creation"]
after_creation = resolved_nlp["after_creation"]
after_pipeline_creation = resolved_nlp["after_pipeline_creation"]
lang_cls = cls
if before_creation is not None:
lang_cls = before_creation(cls)
@ -1587,7 +1586,6 @@ class Language:
disabled_pipes = [*config["nlp"]["disabled"], *disable]
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
nlp.config = filled if auto_fill else config
nlp.resolved = resolved
if after_pipeline_creation is not None:
nlp = after_pipeline_creation(nlp)
if not isinstance(nlp, cls):

View File

@ -4,6 +4,7 @@ from enum import Enum
from pydantic import BaseModel, Field, ValidationError, validator
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
from pydantic import root_validator
from thinc.config import Promise
from collections import defaultdict
from thinc.api import Optimizer
@ -16,10 +17,12 @@ if TYPE_CHECKING:
from .training import Example # noqa: F401
# fmt: off
ItemT = TypeVar("ItemT")
Batcher = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
Reader = Callable[["Language", str], Iterable["Example"]]
Logger = Callable[["Language"], Tuple[Callable[[Dict[str, Any]], None], Callable]]
Batcher = Union[Callable[[Iterable[ItemT]], Iterable[List[ItemT]]], Promise]
Reader = Union[Callable[["Language", str], Iterable["Example"]], Promise]
Logger = Union[Callable[["Language"], Tuple[Callable[[Dict[str, Any]], None], Callable]], Promise]
# fmt: on
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
@ -292,6 +295,20 @@ class ConfigSchema(BaseModel):
arbitrary_types_allowed = True
class NlpSchema(BaseModel):
nlp: ConfigSchemaNlp
class TrainingSchema(BaseModel):
training: ConfigSchemaTraining
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
corpora: Dict[str, Reader]
class Config:
extra = "allow"
arbitrary_types_allowed = True
# Project config Schema

View File

@ -24,7 +24,7 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
ner = EntityRecognizer(en_vocab, model, **config)
ner.begin_training(lambda: [_ner_example(ner)])
ner(doc)
@ -46,7 +46,7 @@ def test_ents_reset(en_vocab):
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
ner = EntityRecognizer(en_vocab, model, **config)
ner.begin_training(lambda: [_ner_example(ner)])
ner(doc)

View File

@ -23,7 +23,7 @@ def parser(vocab):
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(vocab, model, **config)
return parser
@ -82,7 +82,7 @@ def test_add_label_deserializes_correctly():
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
ner1 = EntityRecognizer(Vocab(), model, **config)
ner1.add_label("C")
ner1.add_label("B")
@ -111,7 +111,7 @@ def test_add_label_get_label(pipe_cls, n_moves, model_config):
splitting the move names.
"""
labels = ["A", "B", "C"]
model = registry.make_from_config({"model": model_config}, validate=True)["model"]
model = registry.resolve({"model": model_config}, validate=True)["model"]
config = {
"learn_tokens": False,
"min_action_freq": 30,

View File

@ -127,7 +127,7 @@ def test_get_oracle_actions():
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(doc.vocab, model, **config)
parser.moves.add_action(0, "")
parser.moves.add_action(1, "")

View File

@ -25,7 +25,7 @@ def arc_eager(vocab):
@pytest.fixture
def tok2vec():
cfg = {"model": DEFAULT_TOK2VEC_MODEL}
tok2vec = registry.make_from_config(cfg, validate=True)["model"]
tok2vec = registry.resolve(cfg, validate=True)["model"]
tok2vec.initialize()
return tok2vec
@ -38,14 +38,14 @@ def parser(vocab, arc_eager):
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
return Parser(vocab, model, moves=arc_eager, **config)
@pytest.fixture
def model(arc_eager, tok2vec, vocab):
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
model.attrs["resize_output"](model, arc_eager.n_moves)
model.initialize()
return model
@ -72,7 +72,7 @@ def test_build_model(parser, vocab):
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model
assert parser.model is not None

View File

@ -28,7 +28,7 @@ def parser(vocab):
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(vocab, model, **config)
parser.cfg["token_vector_width"] = 4
parser.cfg["hidden_width"] = 32

View File

@ -139,7 +139,7 @@ TRAIN_DATA = [
def test_tok2vec_listener():
orig_config = Config().from_str(cfg_string)
nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
assert nlp.pipe_names == ["tok2vec", "tagger"]
tagger = nlp.get_pipe("tagger")
tok2vec = nlp.get_pipe("tok2vec")
@ -173,7 +173,7 @@ def test_tok2vec_listener():
def test_tok2vec_listener_callback():
orig_config = Config().from_str(cfg_string)
nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
assert nlp.pipe_names == ["tok2vec", "tagger"]
tagger = nlp.get_pipe("tagger")
tok2vec = nlp.get_pipe("tok2vec")

View File

@ -195,7 +195,7 @@ def test_issue3345():
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_NER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
ner = EntityRecognizer(doc.vocab, model, **config)
# Add the OUT action. I wouldn't have thought this would be necessary...
ner.moves.add_action(5, "")

View File

@ -264,9 +264,7 @@ def test_issue3830_no_subtok():
"min_action_freq": 30,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)[
"model"
]
model = registry.resolve({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
parser = DependencyParser(Vocab(), model, **config)
parser.add_label("nsubj")
assert "subtok" not in parser.labels
@ -281,9 +279,7 @@ def test_issue3830_with_subtok():
"min_action_freq": 30,
"update_with_oracle_cut_size": 100,
}
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)[
"model"
]
model = registry.resolve({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
parser = DependencyParser(Vocab(), model, **config)
parser.add_label("nsubj")
assert "subtok" not in parser.labels

View File

@ -108,8 +108,8 @@ def my_parser():
def test_create_nlp_from_config():
config = Config().from_str(nlp_config_string)
with pytest.raises(ConfigValidationError):
nlp, _ = load_model_from_config(config, auto_fill=False)
nlp, resolved = load_model_from_config(config, auto_fill=True)
load_model_from_config(config, auto_fill=False)
nlp = load_model_from_config(config, auto_fill=True)
assert nlp.config["training"]["batcher"]["size"] == 666
assert len(nlp.config["training"]) > 1
assert nlp.pipe_names == ["tok2vec", "tagger"]
@ -136,7 +136,7 @@ def test_create_nlp_from_config_multiple_instances():
"tagger2": config["components"]["tagger"],
}
config["nlp"]["pipeline"] = list(config["components"].keys())
nlp, _ = load_model_from_config(config, auto_fill=True)
nlp = load_model_from_config(config, auto_fill=True)
assert nlp.pipe_names == ["t2v", "tagger1", "tagger2"]
assert nlp.get_pipe_meta("t2v").factory == "tok2vec"
assert nlp.get_pipe_meta("tagger1").factory == "tagger"
@ -150,7 +150,7 @@ def test_create_nlp_from_config_multiple_instances():
def test_serialize_nlp():
""" Create a custom nlp pipeline from config and ensure it serializes it correctly """
nlp_config = Config().from_str(nlp_config_string)
nlp, _ = load_model_from_config(nlp_config, auto_fill=True)
nlp = load_model_from_config(nlp_config, auto_fill=True)
nlp.get_pipe("tagger").add_label("A")
nlp.begin_training()
assert "tok2vec" in nlp.pipe_names
@ -209,7 +209,7 @@ def test_config_nlp_roundtrip():
nlp = English()
nlp.add_pipe("entity_ruler")
nlp.add_pipe("ner")
new_nlp, new_config = load_model_from_config(nlp.config, auto_fill=False)
new_nlp = load_model_from_config(nlp.config, auto_fill=False)
assert new_nlp.config == nlp.config
assert new_nlp.pipe_names == nlp.pipe_names
assert new_nlp._pipe_configs == nlp._pipe_configs
@ -280,12 +280,12 @@ def test_config_overrides():
overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]}
# load_model from config with overrides passed directly to Config
config = Config().from_str(nlp_config_string, overrides=overrides_dot)
nlp, _ = load_model_from_config(config, auto_fill=True)
nlp = load_model_from_config(config, auto_fill=True)
assert isinstance(nlp, German)
assert nlp.pipe_names == ["tagger"]
# Serialized roundtrip with config passed in
base_config = Config().from_str(nlp_config_string)
base_nlp, _ = load_model_from_config(base_config, auto_fill=True)
base_nlp = load_model_from_config(base_config, auto_fill=True)
assert isinstance(base_nlp, English)
assert base_nlp.pipe_names == ["tok2vec", "tagger"]
with make_tempdir() as d:
@ -328,7 +328,7 @@ def test_config_optional_sections():
config = Config().from_str(nlp_config_string)
config = DEFAULT_CONFIG.merge(config)
assert "pretraining" not in config
filled = registry.fill_config(config, schema=ConfigSchema, validate=False)
filled = registry.fill(config, schema=ConfigSchema, validate=False)
# Make sure that optional "pretraining" block doesn't default to None,
# which would (rightly) cause error because it'd result in a top-level
# key that's not a section (dict). Note that the following roundtrip is
@ -341,7 +341,7 @@ def test_config_auto_fill_extra_fields():
config = Config({"nlp": {"lang": "en"}, "training": {}})
assert load_model_from_config(config, auto_fill=True)
config = Config({"nlp": {"lang": "en"}, "training": {"extra": "hello"}})
nlp, _ = load_model_from_config(config, auto_fill=True, validate=False)
nlp = load_model_from_config(config, auto_fill=True, validate=False)
assert "extra" not in nlp.config["training"]
# Make sure the config generated is valid
load_model_from_config(nlp.config)

View File

@ -23,7 +23,7 @@ def parser(en_vocab):
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(en_vocab, model, **config)
parser.add_label("nsubj")
return parser
@ -37,7 +37,7 @@ def blank_parser(en_vocab):
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
parser = DependencyParser(en_vocab, model, **config)
return parser
@ -45,7 +45,7 @@ def blank_parser(en_vocab):
@pytest.fixture
def taggers(en_vocab):
cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
tagger1 = Tagger(en_vocab, model)
tagger2 = Tagger(en_vocab, model)
return tagger1, tagger2
@ -59,7 +59,7 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
parser = Parser(en_vocab, model, **config)
new_parser = Parser(en_vocab, model, **config)
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
@ -77,7 +77,7 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
"update_with_oracle_cut_size": 100,
}
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
parser = Parser(en_vocab, model, **config)
with make_tempdir() as d:
file_path = d / "parser"
@ -111,7 +111,7 @@ def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers):
tagger1 = tagger1.from_bytes(tagger1_b)
assert tagger1.to_bytes() == tagger1_b
cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b)
new_tagger1_b = new_tagger1.to_bytes()
assert len(new_tagger1_b) == len(tagger1_b)
@ -126,7 +126,7 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
tagger1.to_disk(file_path1)
tagger2.to_disk(file_path2)
cfg = {"model": DEFAULT_TAGGER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
tagger1_d = Tagger(en_vocab, model).from_disk(file_path1)
tagger2_d = Tagger(en_vocab, model).from_disk(file_path2)
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
@ -135,7 +135,7 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
def test_serialize_textcat_empty(en_vocab):
# See issue #1105
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
textcat = TextCategorizer(
en_vocab,
model,
@ -149,7 +149,7 @@ def test_serialize_textcat_empty(en_vocab):
@pytest.mark.parametrize("Parser", test_parsers)
def test_serialize_pipe_exclude(en_vocab, Parser):
cfg = {"model": DEFAULT_PARSER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
config = {
"learn_tokens": False,
"min_action_freq": 0,
@ -176,7 +176,7 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
def test_serialize_sentencerecognizer(en_vocab):
cfg = {"model": DEFAULT_SENTER_MODEL}
model = registry.make_from_config(cfg, validate=True)["model"]
model = registry.resolve(cfg, validate=True)["model"]
sr = SentenceRecognizer(en_vocab, model)
sr_b = sr.to_bytes()
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)

View File

@ -82,10 +82,10 @@ def test_util_dot_section():
no_output_layer = false
"""
nlp_config = Config().from_str(cfg_string)
en_nlp, en_config = util.load_model_from_config(nlp_config, auto_fill=True)
en_nlp = util.load_model_from_config(nlp_config, auto_fill=True)
default_config = Config().from_disk(DEFAULT_CONFIG_PATH)
default_config["nlp"]["lang"] = "nl"
nl_nlp, nl_config = util.load_model_from_config(default_config, auto_fill=True)
nl_nlp = util.load_model_from_config(default_config, auto_fill=True)
# Test that creation went OK
assert isinstance(en_nlp, English)
assert isinstance(nl_nlp, Dutch)
@ -94,14 +94,15 @@ def test_util_dot_section():
# not exclusive_classes
assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] is False
# Test that default values got overwritten
assert en_config["nlp"]["pipeline"] == ["textcat"]
assert nl_config["nlp"]["pipeline"] == [] # default value []
assert en_nlp.config["nlp"]["pipeline"] == ["textcat"]
assert nl_nlp.config["nlp"]["pipeline"] == [] # default value []
# Test proper functioning of 'dot_to_object'
with pytest.raises(KeyError):
dot_to_object(en_config, "nlp.pipeline.tagger")
dot_to_object(en_nlp.config, "nlp.pipeline.tagger")
with pytest.raises(KeyError):
dot_to_object(en_config, "nlp.unknownattribute")
assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer)
dot_to_object(en_nlp.config, "nlp.unknownattribute")
resolved = util.resolve_training_config(nl_nlp.config)
assert isinstance(dot_to_object(resolved, "training.optimizer"), Optimizer)
def test_simple_frozen_list():

View File

@ -3,6 +3,7 @@ import pytest
from thinc.api import Config
from spacy import Language
from spacy.util import load_model_from_config, registry, dot_to_object
from spacy.util import resolve_training_config
from spacy.training import Example
@ -37,8 +38,8 @@ def test_readers():
return {"train": reader, "dev": reader, "extra": reader, "something": reader}
config = Config().from_str(config_string)
nlp, resolved = load_model_from_config(config, auto_fill=True)
nlp = load_model_from_config(config, auto_fill=True)
resolved = resolve_training_config(nlp.config)
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
assert isinstance(train_corpus, Callable)
optimizer = resolved["training"]["optimizer"]
@ -87,8 +88,8 @@ def test_cat_readers(reader, additional_config):
config = Config().from_str(nlp_config_string)
config["corpora"]["@readers"] = reader
config["corpora"].update(additional_config)
nlp, resolved = load_model_from_config(config, auto_fill=True)
nlp = load_model_from_config(config, auto_fill=True)
resolved = resolve_training_config(nlp.config)
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
optimizer = resolved["training"]["optimizer"]
# simulate a training loop

View File

@ -86,7 +86,7 @@ class registry(thinc.registry):
# spacy_factories entry point. This registry only exists so we can easily
# load them via the entry points. The "true" factories are added via the
# Language.factory decorator (in the spaCy code base and user code) and those
# are the factories used to initialize components via registry.make_from_config.
# are the factories used to initialize components via registry.resolve.
_entry_point_factories = catalogue.create("spacy", "factories", entry_points=True)
factories = catalogue.create("spacy", "internal_factories")
# This is mostly used to get a list of all installed models in the current
@ -351,9 +351,7 @@ def load_model_from_path(
meta = get_model_meta(model_path)
config_path = model_path / "config.cfg"
config = load_config(config_path, overrides=dict_to_dot(config))
nlp, _ = load_model_from_config(
config, vocab=vocab, disable=disable, exclude=exclude
)
nlp = load_model_from_config(config, vocab=vocab, disable=disable, exclude=exclude)
return nlp.from_disk(model_path, exclude=exclude)
@ -365,7 +363,7 @@ def load_model_from_config(
exclude: Iterable[str] = SimpleFrozenList(),
auto_fill: bool = False,
validate: bool = True,
) -> Tuple["Language", Config]:
) -> "Language":
"""Create an nlp object from a config. Expects the full config file including
a section "nlp" containing the settings for the nlp object.
@ -398,7 +396,31 @@ def load_model_from_config(
auto_fill=auto_fill,
validate=validate,
)
return nlp, nlp.resolved
return nlp
def resolve_training_config(
config: Config,
exclude: Iterable[str] = ("nlp", "components"),
validate: bool = True,
) -> Dict[str, Any]:
"""Resolve the config sections relevant for trainig and create all objects.
Mostly used in the CLI to separate training config (not resolved by default
because not runtime-relevant an nlp object should load fine even if it's
[training] block refers to functions that are not available etc.).
config (Config): The config to resolve.
exclude (Iterable[str]): The config blocks to exclude. Those blocks won't
be available in the final resolved config.
validate (bool): Whether to validate the config.
RETURNS (Dict[str, Any]): The resolved config.
"""
config = config.copy()
excluded = {}
for key in exclude:
if key in config:
excluded.pop(key, None)
return registry.resolve(config, validate=validate)
def load_model_from_init_py(