Fix config resolution and interpolation

TODO: auto-interpolate in Thinc if config is dict (i.e. likely subsection)
This commit is contained in:
Ines Montani 2020-09-28 15:34:00 +02:00
parent 02838a1d47
commit 2e9c9e74af
5 changed files with 23 additions and 8 deletions

View File

@ -97,7 +97,9 @@ def debug_data(
with show_validation_error(config_path): with show_validation_error(config_path):
cfg = util.load_config(config_path, overrides=config_overrides) cfg = util.load_config(config_path, overrides=config_overrides)
nlp = util.load_model_from_config(cfg) nlp = util.load_model_from_config(cfg)
T = registry.resolve(nlp.config["training"], schema=ConfigSchemaTraining) T = registry.resolve(
nlp.config.interpolate()["training"], schema=ConfigSchemaTraining
)
# Use original config here, not resolved version # Use original config here, not resolved version
sourced_components = get_sourced_components(cfg) sourced_components = get_sourced_components(cfg)
frozen_components = T["frozen_components"] frozen_components = T["frozen_components"]

View File

@ -63,7 +63,9 @@ def debug_model_cli(
set_gpu_allocator(allocator) set_gpu_allocator(allocator)
with show_validation_error(config_path): with show_validation_error(config_path):
nlp = util.load_model_from_config(raw_config) nlp = util.load_model_from_config(raw_config)
T = registry.resolve(nlp.config["training"], schema=ConfigSchemaTraining) T = registry.resolve(
nlp.config.interpolate()["training"], schema=ConfigSchemaTraining
)
seed = T["seed"] seed = T["seed"]
if seed is not None: if seed is not None:
msg.info(f"Fixing random seed: {seed}") msg.info(f"Fixing random seed: {seed}")

View File

@ -42,7 +42,9 @@ def test_readers():
dot_names = ["training.train_corpus", "training.dev_corpus"] dot_names = ["training.train_corpus", "training.dev_corpus"]
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names) train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
assert isinstance(train_corpus, Callable) assert isinstance(train_corpus, Callable)
T = registry.resolve(nlp.config["training"], schema=ConfigSchemaTraining) T = registry.resolve(
nlp.config.interpolate()["training"], schema=ConfigSchemaTraining
)
optimizer = T["optimizer"] optimizer = T["optimizer"]
# simulate a training loop # simulate a training loop
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer) nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
@ -53,7 +55,8 @@ def test_readers():
# ensure the pipeline runs # ensure the pipeline runs
doc = nlp("Quick test") doc = nlp("Quick test")
assert doc.cats assert doc.cats
extra_corpus = registry.resolve(nlp.config["corpora"])["extra"] corpora = {"corpora": nlp.config.interpolate()["corpora"]}
extra_corpus = registry.resolve(corpora)["corpora"]["extra"]
assert isinstance(extra_corpus, Callable) assert isinstance(extra_corpus, Callable)
@ -91,7 +94,9 @@ def test_cat_readers(reader, additional_config):
nlp = load_model_from_config(config, auto_fill=True) nlp = load_model_from_config(config, auto_fill=True)
dot_names = ["training.train_corpus", "training.dev_corpus"] dot_names = ["training.train_corpus", "training.dev_corpus"]
train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names) train_corpus, dev_corpus = resolve_dot_names(nlp.config, dot_names)
T = registry.resolve(nlp.config["training"], schema=ConfigSchemaTraining) T = registry.resolve(
nlp.config["training"].interpolate(), schema=ConfigSchemaTraining
)
optimizer = T["optimizer"] optimizer = T["optimizer"]
# simulate a training loop # simulate a training loop
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer) nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)

View File

@ -33,8 +33,9 @@ def pretrain(
if use_gpu >= 0 and allocator: if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator) set_gpu_allocator(allocator)
nlp = load_model_from_config(config) nlp = load_model_from_config(config)
T = registry.resolve(nlp.config["training"], schema=ConfigSchemaTraining) _config = nlp.config.interpolate()
P = registry.resolve(nlp.config["pretraining"], schema=ConfigSchemaPretrain) T = registry.resolve(_config["training"], schema=ConfigSchemaTraining)
P = registry.resolve(_config["pretraining"], schema=ConfigSchemaPretrain)
corpus = dot_to_object(T, P["corpus"]) corpus = dot_to_object(T, P["corpus"])
batcher = P["batcher"] batcher = P["batcher"]
model = create_pretraining_model(nlp, P) model = create_pretraining_model(nlp, P)

View File

@ -413,7 +413,12 @@ def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[A
section = ref.split(".")[0] section = ref.split(".")[0]
# We want to avoid resolving the same thing twice # We want to avoid resolving the same thing twice
if section not in resolved: if section not in resolved:
resolved[section] = registry.resolve(config[section]) if registry.is_promise(config[section]):
# Otherwise we can't resolve [corpus] if it's a promise
result = registry.resolve({"config": config[section]})["config"]
else:
result = registry.resolve(config[section])
resolved[section] = result
try: try:
objects.append(dot_to_object(resolved, ref)) objects.append(dot_to_object(resolved, ref))
except KeyError: except KeyError: