Fix small issues, resolve_dot_names and debug model

This commit is contained in:
Ines Montani 2020-09-29 20:38:35 +02:00
parent a4da3120b4
commit 2be80379ec
8 changed files with 51 additions and 56 deletions

View File

@ -7,6 +7,8 @@ import typer
from ._util import Arg, Opt, show_validation_error, parse_config_overrides from ._util import Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli from ._util import import_code, debug_cli
from ..schemas import ConfigSchemaTraining
from ..util import registry
from .. import util from .. import util
@ -52,8 +54,10 @@ def debug_config(
with show_validation_error(config_path): with show_validation_error(config_path):
config = util.load_config(config_path, overrides=overrides) config = util.load_config(config_path, overrides=overrides)
nlp = util.load_model_from_config(config) nlp = util.load_model_from_config(config)
dot_names = ["training.dev_corpus", "training.train_corpus"] config = nlp.config.interpolate()
util.resolve_dot_names(nlp.config, dot_names) T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
dot_names = [T["train_corpus"], T["dev_corpus"]]
util.resolve_dot_names(config, dot_names)
msg.good("Config is valid") msg.good("Config is valid")
if show_vars: if show_vars:
variables = get_variables(config) variables = get_variables(config)

View File

@ -2,7 +2,7 @@ from typing import Dict, Any, Optional, Iterable
from pathlib import Path from pathlib import Path
from spacy.training import Example from spacy.training import Example
from spacy.util import dot_to_object from spacy.util import resolve_dot_names
from wasabi import msg from wasabi import msg
from thinc.api import fix_random_seed, set_dropout_rate, Adam from thinc.api import fix_random_seed, set_dropout_rate, Adam
from thinc.api import Model, data_validation, set_gpu_allocator from thinc.api import Model, data_validation, set_gpu_allocator
@ -15,7 +15,10 @@ from ..util import registry
from .. import util from .. import util
@debug_cli.command("model") @debug_cli.command(
"model",
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
)
def debug_model_cli( def debug_model_cli(
# fmt: off # fmt: off
ctx: typer.Context, # This is only used to read additional arguments ctx: typer.Context, # This is only used to read additional arguments
@ -57,15 +60,14 @@ def debug_model_cli(
raw_config = util.load_config( raw_config = util.load_config(
config_path, overrides=config_overrides, interpolate=False config_path, overrides=config_overrides, interpolate=False
) )
config = raw_config.iterpolate() config = raw_config.interpolate()
allocator = config["training"]["gpu_allocator"] allocator = config["training"]["gpu_allocator"]
if use_gpu >= 0 and allocator: if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator) set_gpu_allocator(allocator)
with show_validation_error(config_path): with show_validation_error(config_path):
nlp = util.load_model_from_config(raw_config) nlp = util.load_model_from_config(raw_config)
T = registry.resolve( config = nlp.config.interpolate()
nlp.config.interpolate()["training"], schema=ConfigSchemaTraining T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
)
seed = T["seed"] seed = T["seed"]
if seed is not None: if seed is not None:
msg.info(f"Fixing random seed: {seed}") msg.info(f"Fixing random seed: {seed}")
@ -77,11 +79,16 @@ def debug_model_cli(
exits=1, exits=1,
) )
model = pipe.model model = pipe.model
debug_model(T, nlp, model, print_settings=print_settings) debug_model(config, T, nlp, model, print_settings=print_settings)
def debug_model( def debug_model(
config, nlp, model: Model, *, print_settings: Optional[Dict[str, Any]] = None config,
resolved_train_config,
nlp,
model: Model,
*,
print_settings: Optional[Dict[str, Any]] = None,
): ):
if not isinstance(model, Model): if not isinstance(model, Model):
msg.fail( msg.fail(
@ -102,13 +109,16 @@ def debug_model(
# The output vector might differ from the official type of the output layer # The output vector might differ from the official type of the output layer
with data_validation(False): with data_validation(False):
try: try:
train_corpus = dot_to_object(config, config["training"]["train_corpus"]) dot_names = [resolved_train_config["train_corpus"]]
nlp.initialize(lambda: train_corpus(nlp)) with show_validation_error():
(train_corpus,) = resolve_dot_names(config, dot_names)
nlp.initialize(lambda: train_corpus(nlp))
msg.info("Initialized the model with the training corpus.") msg.info("Initialized the model with the training corpus.")
except ValueError: except ValueError:
try: try:
_set_output_dim(nO=7, model=model) _set_output_dim(nO=7, model=model)
nlp.initialize(lambda: [Example.from_dict(x, {}) for x in X]) with show_validation_error():
nlp.initialize(lambda: [Example.from_dict(x, {}) for x in X])
msg.info("Initialized the model with dummy data.") msg.info("Initialized the model with dummy data.")
except Exception: except Exception:
msg.fail( msg.fail(

View File

@ -389,14 +389,12 @@ class ConfigSchema(BaseModel):
arbitrary_types_allowed = True arbitrary_types_allowed = True
class TrainingSchema(BaseModel): CONFIG_SCHEMAS = {
training: ConfigSchemaTraining "nlp": ConfigSchemaNlp,
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {} "training": ConfigSchemaTraining,
corpora: Dict[str, Reader] "pretraining": ConfigSchemaPretrain,
"initialize": ConfigSchemaInit,
class Config: }
extra = "allow"
arbitrary_types_allowed = True
# Project config Schema # Project config Schema

View File

@ -128,10 +128,10 @@ def test_resolve_dot_names():
"training": {"optimizer": {"@optimizers": "Adam.v1"}}, "training": {"optimizer": {"@optimizers": "Adam.v1"}},
"foo": {"bar": "training.optimizer", "baz": "training.xyz"}, "foo": {"bar": "training.optimizer", "baz": "training.xyz"},
} }
result = util.resolve_dot_names(config, ["foo.bar"]) result = util.resolve_dot_names(config, ["training.optimizer"])
assert isinstance(result[0], Optimizer) assert isinstance(result[0], Optimizer)
with pytest.raises(ConfigValidationError) as e: with pytest.raises(ConfigValidationError) as e:
util.resolve_dot_names(config, ["foo.baz", "foo.bar"]) util.resolve_dot_names(config, ["training.xyz", "training.optimizer"])
errors = e.value.errors errors = e.value.errors
assert len(errors) == 1 assert len(errors) == 1
assert errors[0]["loc"] == ["training", "xyz"] assert errors[0]["loc"] == ["training", "xyz"]

View File

@ -39,12 +39,12 @@ def test_readers():
config = Config().from_str(config_string) config = Config().from_str(config_string)
nlp = load_model_from_config(config, auto_fill=True) nlp = load_model_from_config(config, auto_fill=True)
dot_names = ["training.train_corpus", "training.dev_corpus"]
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
assert isinstance(train_corpus, Callable)
T = registry.resolve( T = registry.resolve(
nlp.config.interpolate()["training"], schema=ConfigSchemaTraining nlp.config.interpolate()["training"], schema=ConfigSchemaTraining
) )
dot_names = [T["train_corpus"], T["dev_corpus"]]
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
assert isinstance(train_corpus, Callable)
optimizer = T["optimizer"] optimizer = T["optimizer"]
# simulate a training loop # simulate a training loop
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer) nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
@ -92,11 +92,11 @@ def test_cat_readers(reader, additional_config):
config["corpora"]["@readers"] = reader config["corpora"]["@readers"] = reader
config["corpora"].update(additional_config) config["corpora"].update(additional_config)
nlp = load_model_from_config(config, auto_fill=True) nlp = load_model_from_config(config, auto_fill=True)
dot_names = ["training.train_corpus", "training.dev_corpus"]
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
T = registry.resolve( T = registry.resolve(
nlp.config["training"].interpolate(), schema=ConfigSchemaTraining nlp.config["training"].interpolate(), schema=ConfigSchemaTraining
) )
dot_names = [T["train_corpus"], T["dev_corpus"]]
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
optimizer = T["optimizer"] optimizer = T["optimizer"]
# simulate a training loop # simulate a training loop
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer) nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)

View File

@ -130,12 +130,12 @@ def init_tok2vec(
init_tok2vec = ensure_path(I["init_tok2vec"]) init_tok2vec = ensure_path(I["init_tok2vec"])
if init_tok2vec is not None: if init_tok2vec is not None:
if P["objective"].get("type") == "vectors" and not I["vectors"]: if P["objective"].get("type") == "vectors" and not I["vectors"]:
err = 'need initialize.vocab.vectors if pretraining.objective.type is "vectors"' err = 'need initialize.vectors if pretraining.objective.type is "vectors"'
errors = [{"loc": ["initialize", "vocab"], "msg": err}] errors = [{"loc": ["initialize"], "msg": err}]
raise ConfigValidationError(config=nlp.config, errors=errors) raise ConfigValidationError(config=nlp.config, errors=errors)
if not init_tok2vec.exists(): if not init_tok2vec.exists():
err = f"can't find pretrained tok2vec: {init_tok2vec}" err = f"can't find pretrained tok2vec: {init_tok2vec}"
errors = [{"loc": ["initialize", "vocab", "init_tok2vec"], "msg": err}] errors = [{"loc": ["initialize", "init_tok2vec"], "msg": err}]
raise ConfigValidationError(config=nlp.config, errors=errors) raise ConfigValidationError(config=nlp.config, errors=errors)
with init_tok2vec.open("rb") as file_: with init_tok2vec.open("rb") as file_:
weights_data = file_.read() weights_data = file_.read()

View File

@ -29,9 +29,7 @@ def train(
output_path (Path): Optional output path to save trained model to. output_path (Path): Optional output path to save trained model to.
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
before calling this function. before calling this function.
logger (Callable[[Any], Any]): Optional logger exposing the methods info, silent (bool): Whether to pretty-print outputs.
error, debug and warn. Defaults to regular spaCy logger but can be
swapped for CLI logger.
RETURNS (Path / None): The path to the final exported model. RETURNS (Path / None): The path to the final exported model.
""" """
msg = Printer(no_print=silent) msg = Printer(no_print=silent)

View File

@ -392,7 +392,6 @@ def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[A
we could find the lowest part of the tree. we could find the lowest part of the tree.
""" """
# TODO: include schema? # TODO: include schema?
# TODO: clean this up and avoid duplication
resolved = {} resolved = {}
output = [] output = []
errors = [] errors = []
@ -403,34 +402,20 @@ def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[A
section = name.split(".")[0] section = name.split(".")[0]
# We want to avoid resolving the same thing twice # We want to avoid resolving the same thing twice
if section not in resolved: if section not in resolved:
resolved[section] = registry.resolve(config[section]) if registry.is_promise(config[section]):
# Otherwise we can't resolve [corpus] if it's a promise
result = registry.resolve({"config": config[section]})["config"]
else:
result = registry.resolve(config[section])
resolved[section] = result
try: try:
output.append(dot_to_object(resolved, name)) output.append(dot_to_object(resolved, name))
except KeyError: except KeyError:
msg = f"not a valid section reference: {name}" msg = f"not a valid section reference: {name}"
errors.append({"loc": name.split("."), "msg": msg}) errors.append({"loc": name.split("."), "msg": msg})
objects = []
for ref in output:
if not isinstance(ref, str):
objects.append(ref)
continue
section = ref.split(".")[0]
# We want to avoid resolving the same thing twice
if section not in resolved:
if registry.is_promise(config[section]):
# Otherwise we can't resolve [corpus] if it's a promise
result = registry.resolve({"config": config[section]})["config"]
else:
result = registry.resolve(config[section])
resolved[section] = result
try:
objects.append(dot_to_object(resolved, ref))
except KeyError:
msg = f"not a valid section reference: {name}"
errors.append({"loc": ref.split("."), "msg": msg})
if errors: if errors:
raise ConfigValidationError(config=config, errors=errors) raise ConfigValidationError(config=config, errors=errors)
return tuple(objects) return tuple(output)
def load_model_from_init_py( def load_model_from_init_py(