mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Fix small issues, resolve_dot_names and debug model
This commit is contained in:
parent
a4da3120b4
commit
2be80379ec
|
@ -7,6 +7,8 @@ import typer
|
|||
|
||||
from ._util import Arg, Opt, show_validation_error, parse_config_overrides
|
||||
from ._util import import_code, debug_cli
|
||||
from ..schemas import ConfigSchemaTraining
|
||||
from ..util import registry
|
||||
from .. import util
|
||||
|
||||
|
||||
|
@ -52,8 +54,10 @@ def debug_config(
|
|||
with show_validation_error(config_path):
|
||||
config = util.load_config(config_path, overrides=overrides)
|
||||
nlp = util.load_model_from_config(config)
|
||||
dot_names = ["training.dev_corpus", "training.train_corpus"]
|
||||
util.resolve_dot_names(nlp.config, dot_names)
|
||||
config = nlp.config.interpolate()
|
||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
||||
util.resolve_dot_names(config, dot_names)
|
||||
msg.good("Config is valid")
|
||||
if show_vars:
|
||||
variables = get_variables(config)
|
||||
|
|
|
@ -2,7 +2,7 @@ from typing import Dict, Any, Optional, Iterable
|
|||
from pathlib import Path
|
||||
|
||||
from spacy.training import Example
|
||||
from spacy.util import dot_to_object
|
||||
from spacy.util import resolve_dot_names
|
||||
from wasabi import msg
|
||||
from thinc.api import fix_random_seed, set_dropout_rate, Adam
|
||||
from thinc.api import Model, data_validation, set_gpu_allocator
|
||||
|
@ -15,7 +15,10 @@ from ..util import registry
|
|||
from .. import util
|
||||
|
||||
|
||||
@debug_cli.command("model")
|
||||
@debug_cli.command(
|
||||
"model",
|
||||
context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
||||
)
|
||||
def debug_model_cli(
|
||||
# fmt: off
|
||||
ctx: typer.Context, # This is only used to read additional arguments
|
||||
|
@ -57,15 +60,14 @@ def debug_model_cli(
|
|||
raw_config = util.load_config(
|
||||
config_path, overrides=config_overrides, interpolate=False
|
||||
)
|
||||
config = raw_config.iterpolate()
|
||||
config = raw_config.interpolate()
|
||||
allocator = config["training"]["gpu_allocator"]
|
||||
if use_gpu >= 0 and allocator:
|
||||
set_gpu_allocator(allocator)
|
||||
with show_validation_error(config_path):
|
||||
nlp = util.load_model_from_config(raw_config)
|
||||
T = registry.resolve(
|
||||
nlp.config.interpolate()["training"], schema=ConfigSchemaTraining
|
||||
)
|
||||
config = nlp.config.interpolate()
|
||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||
seed = T["seed"]
|
||||
if seed is not None:
|
||||
msg.info(f"Fixing random seed: {seed}")
|
||||
|
@ -77,11 +79,16 @@ def debug_model_cli(
|
|||
exits=1,
|
||||
)
|
||||
model = pipe.model
|
||||
debug_model(T, nlp, model, print_settings=print_settings)
|
||||
debug_model(config, T, nlp, model, print_settings=print_settings)
|
||||
|
||||
|
||||
def debug_model(
|
||||
config, nlp, model: Model, *, print_settings: Optional[Dict[str, Any]] = None
|
||||
config,
|
||||
resolved_train_config,
|
||||
nlp,
|
||||
model: Model,
|
||||
*,
|
||||
print_settings: Optional[Dict[str, Any]] = None,
|
||||
):
|
||||
if not isinstance(model, Model):
|
||||
msg.fail(
|
||||
|
@ -102,12 +109,15 @@ def debug_model(
|
|||
# The output vector might differ from the official type of the output layer
|
||||
with data_validation(False):
|
||||
try:
|
||||
train_corpus = dot_to_object(config, config["training"]["train_corpus"])
|
||||
dot_names = [resolved_train_config["train_corpus"]]
|
||||
with show_validation_error():
|
||||
(train_corpus,) = resolve_dot_names(config, dot_names)
|
||||
nlp.initialize(lambda: train_corpus(nlp))
|
||||
msg.info("Initialized the model with the training corpus.")
|
||||
except ValueError:
|
||||
try:
|
||||
_set_output_dim(nO=7, model=model)
|
||||
with show_validation_error():
|
||||
nlp.initialize(lambda: [Example.from_dict(x, {}) for x in X])
|
||||
msg.info("Initialized the model with dummy data.")
|
||||
except Exception:
|
||||
|
|
|
@ -389,14 +389,12 @@ class ConfigSchema(BaseModel):
|
|||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class TrainingSchema(BaseModel):
|
||||
training: ConfigSchemaTraining
|
||||
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
|
||||
corpora: Dict[str, Reader]
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
CONFIG_SCHEMAS = {
|
||||
"nlp": ConfigSchemaNlp,
|
||||
"training": ConfigSchemaTraining,
|
||||
"pretraining": ConfigSchemaPretrain,
|
||||
"initialize": ConfigSchemaInit,
|
||||
}
|
||||
|
||||
|
||||
# Project config Schema
|
||||
|
|
|
@ -128,10 +128,10 @@ def test_resolve_dot_names():
|
|||
"training": {"optimizer": {"@optimizers": "Adam.v1"}},
|
||||
"foo": {"bar": "training.optimizer", "baz": "training.xyz"},
|
||||
}
|
||||
result = util.resolve_dot_names(config, ["foo.bar"])
|
||||
result = util.resolve_dot_names(config, ["training.optimizer"])
|
||||
assert isinstance(result[0], Optimizer)
|
||||
with pytest.raises(ConfigValidationError) as e:
|
||||
util.resolve_dot_names(config, ["foo.baz", "foo.bar"])
|
||||
util.resolve_dot_names(config, ["training.xyz", "training.optimizer"])
|
||||
errors = e.value.errors
|
||||
assert len(errors) == 1
|
||||
assert errors[0]["loc"] == ["training", "xyz"]
|
||||
|
|
|
@ -39,12 +39,12 @@ def test_readers():
|
|||
|
||||
config = Config().from_str(config_string)
|
||||
nlp = load_model_from_config(config, auto_fill=True)
|
||||
dot_names = ["training.train_corpus", "training.dev_corpus"]
|
||||
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
|
||||
assert isinstance(train_corpus, Callable)
|
||||
T = registry.resolve(
|
||||
nlp.config.interpolate()["training"], schema=ConfigSchemaTraining
|
||||
)
|
||||
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
||||
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
|
||||
assert isinstance(train_corpus, Callable)
|
||||
optimizer = T["optimizer"]
|
||||
# simulate a training loop
|
||||
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
|
||||
|
@ -92,11 +92,11 @@ def test_cat_readers(reader, additional_config):
|
|||
config["corpora"]["@readers"] = reader
|
||||
config["corpora"].update(additional_config)
|
||||
nlp = load_model_from_config(config, auto_fill=True)
|
||||
dot_names = ["training.train_corpus", "training.dev_corpus"]
|
||||
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
|
||||
T = registry.resolve(
|
||||
nlp.config["training"].interpolate(), schema=ConfigSchemaTraining
|
||||
)
|
||||
dot_names = [T["train_corpus"], T["dev_corpus"]]
|
||||
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
|
||||
optimizer = T["optimizer"]
|
||||
# simulate a training loop
|
||||
nlp.initialize(lambda: train_corpus(nlp), sgd=optimizer)
|
||||
|
|
|
@ -130,12 +130,12 @@ def init_tok2vec(
|
|||
init_tok2vec = ensure_path(I["init_tok2vec"])
|
||||
if init_tok2vec is not None:
|
||||
if P["objective"].get("type") == "vectors" and not I["vectors"]:
|
||||
err = 'need initialize.vocab.vectors if pretraining.objective.type is "vectors"'
|
||||
errors = [{"loc": ["initialize", "vocab"], "msg": err}]
|
||||
err = 'need initialize.vectors if pretraining.objective.type is "vectors"'
|
||||
errors = [{"loc": ["initialize"], "msg": err}]
|
||||
raise ConfigValidationError(config=nlp.config, errors=errors)
|
||||
if not init_tok2vec.exists():
|
||||
err = f"can't find pretrained tok2vec: {init_tok2vec}"
|
||||
errors = [{"loc": ["initialize", "vocab", "init_tok2vec"], "msg": err}]
|
||||
errors = [{"loc": ["initialize", "init_tok2vec"], "msg": err}]
|
||||
raise ConfigValidationError(config=nlp.config, errors=errors)
|
||||
with init_tok2vec.open("rb") as file_:
|
||||
weights_data = file_.read()
|
||||
|
|
|
@ -29,9 +29,7 @@ def train(
|
|||
output_path (Path): Optional output path to save trained model to.
|
||||
use_gpu (int): Whether to train on GPU. Make sure to call require_gpu
|
||||
before calling this function.
|
||||
logger (Callable[[Any], Any]): Optional logger exposing the methods info,
|
||||
error, debug and warn. Defaults to regular spaCy logger but can be
|
||||
swapped for CLI logger.
|
||||
silent (bool): Whether to pretty-print outputs.
|
||||
RETURNS (Path / None): The path to the final exported model.
|
||||
"""
|
||||
msg = Printer(no_print=silent)
|
||||
|
|
|
@ -392,7 +392,6 @@ def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[A
|
|||
we could find the lowest part of the tree.
|
||||
"""
|
||||
# TODO: include schema?
|
||||
# TODO: clean this up and avoid duplication
|
||||
resolved = {}
|
||||
output = []
|
||||
errors = []
|
||||
|
@ -402,20 +401,6 @@ def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[A
|
|||
else:
|
||||
section = name.split(".")[0]
|
||||
# We want to avoid resolving the same thing twice
|
||||
if section not in resolved:
|
||||
resolved[section] = registry.resolve(config[section])
|
||||
try:
|
||||
output.append(dot_to_object(resolved, name))
|
||||
except KeyError:
|
||||
msg = f"not a valid section reference: {name}"
|
||||
errors.append({"loc": name.split("."), "msg": msg})
|
||||
objects = []
|
||||
for ref in output:
|
||||
if not isinstance(ref, str):
|
||||
objects.append(ref)
|
||||
continue
|
||||
section = ref.split(".")[0]
|
||||
# We want to avoid resolving the same thing twice
|
||||
if section not in resolved:
|
||||
if registry.is_promise(config[section]):
|
||||
# Otherwise we can't resolve [corpus] if it's a promise
|
||||
|
@ -424,13 +409,13 @@ def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[A
|
|||
result = registry.resolve(config[section])
|
||||
resolved[section] = result
|
||||
try:
|
||||
objects.append(dot_to_object(resolved, ref))
|
||||
output.append(dot_to_object(resolved, name))
|
||||
except KeyError:
|
||||
msg = f"not a valid section reference: {name}"
|
||||
errors.append({"loc": ref.split("."), "msg": msg})
|
||||
errors.append({"loc": name.split("."), "msg": msg})
|
||||
if errors:
|
||||
raise ConfigValidationError(config=config, errors=errors)
|
||||
return tuple(objects)
|
||||
return tuple(output)
|
||||
|
||||
|
||||
def load_model_from_init_py(
|
||||
|
|
Loading…
Reference in New Issue
Block a user