Fix up configuration keys

This commit is contained in:
Daniël de Kok 2023-04-18 17:37:14 +02:00
parent 9a72ea0b91
commit b9324505d8
4 changed files with 11 additions and 11 deletions

View File

@ -1837,7 +1837,7 @@ class Language:
# using the nlp.config with all defaults. # using the nlp.config with all defaults.
config = util.copy_config(config) config = util.copy_config(config)
orig_pipeline = config.pop("components", {}) orig_pipeline = config.pop("components", {})
orig_distill = config.pop("distill", None) orig_distill = config.pop("distillation", None)
orig_pretraining = config.pop("pretraining", None) orig_pretraining = config.pop("pretraining", None)
config["components"] = {} config["components"] = {}
if auto_fill: if auto_fill:
@ -1847,8 +1847,8 @@ class Language:
filled["components"] = orig_pipeline filled["components"] = orig_pipeline
config["components"] = orig_pipeline config["components"] = orig_pipeline
if orig_distill is not None: if orig_distill is not None:
filled["distill"] = orig_distill filled["distillation"] = orig_distill
config["distill"] = orig_distill config["distillation"] = orig_distill
if orig_pretraining is not None: if orig_pretraining is not None:
filled["pretraining"] = orig_pretraining filled["pretraining"] = orig_pretraining
config["pretraining"] = orig_pretraining config["pretraining"] = orig_pretraining

View File

@ -462,7 +462,7 @@ CONFIG_SCHEMAS = {
"training": ConfigSchemaTraining, "training": ConfigSchemaTraining,
"pretraining": ConfigSchemaPretrain, "pretraining": ConfigSchemaPretrain,
"initialize": ConfigSchemaInit, "initialize": ConfigSchemaInit,
"distill": ConfigSchemaDistill, "distillation": ConfigSchemaDistill,
} }

View File

@ -124,13 +124,11 @@ def init_nlp_distill(
config = nlp.config.interpolate() config = nlp.config.interpolate()
# Resolve all training-relevant sections using the filled nlp config # Resolve all training-relevant sections using the filled nlp config
T = registry.resolve(config["training"], schema=ConfigSchemaTraining) T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill) D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill)
dot_names = [D["corpus"], T["dev_corpus"]] dot_names = [D["corpus"], T["dev_corpus"]]
if not isinstance(D["corpus"], str): if not isinstance(D["corpus"], str):
raise ConfigValidationError( raise ConfigValidationError(
desc=Errors.E897.format( desc=Errors.E897.format(field="distillation.corpus", type=type(D["corpus"]))
field="distill.corpus", type=type(D["corpus"])
)
) )
if not isinstance(T["dev_corpus"], str): if not isinstance(T["dev_corpus"], str):
raise ConfigValidationError( raise ConfigValidationError(
@ -158,7 +156,9 @@ def init_nlp_distill(
labels = {} labels = {}
for name, pipe in nlp.pipeline: for name, pipe in nlp.pipeline:
# Copy teacher labels. # Copy teacher labels.
teacher_pipe_name = student_to_teacher[name] if name in student_to_teacher else name teacher_pipe_name = (
student_to_teacher[name] if name in student_to_teacher else name
)
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None) teacher_pipe = teacher_pipes.get(teacher_pipe_name, None)
if ( if (
teacher_pipe is not None teacher_pipe is not None

View File

@ -57,7 +57,7 @@ def distill(
if use_gpu >= 0 and allocator: if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator) set_gpu_allocator(allocator)
T = registry.resolve(config["training"], schema=ConfigSchemaTraining) T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill) D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill)
dot_names = [D["corpus"], T["dev_corpus"]] dot_names = [D["corpus"], T["dev_corpus"]]
distill_corpus, dev_corpus = resolve_dot_names(config, dot_names) distill_corpus, dev_corpus = resolve_dot_names(config, dot_names)
optimizer = D["optimizer"] optimizer = D["optimizer"]