Fix up configuration keys

This commit is contained in:
Daniël de Kok 2023-04-18 17:37:14 +02:00
parent 9a72ea0b91
commit b9324505d8
4 changed files with 11 additions and 11 deletions

View File

@ -1837,7 +1837,7 @@ class Language:
# using the nlp.config with all defaults.
config = util.copy_config(config)
orig_pipeline = config.pop("components", {})
orig_distill = config.pop("distill", None)
orig_distill = config.pop("distillation", None)
orig_pretraining = config.pop("pretraining", None)
config["components"] = {}
if auto_fill:
@ -1847,8 +1847,8 @@ class Language:
filled["components"] = orig_pipeline
config["components"] = orig_pipeline
if orig_distill is not None:
filled["distill"] = orig_distill
config["distill"] = orig_distill
filled["distillation"] = orig_distill
config["distillation"] = orig_distill
if orig_pretraining is not None:
filled["pretraining"] = orig_pretraining
config["pretraining"] = orig_pretraining

View File

@ -462,7 +462,7 @@ CONFIG_SCHEMAS = {
"training": ConfigSchemaTraining,
"pretraining": ConfigSchemaPretrain,
"initialize": ConfigSchemaInit,
"distill": ConfigSchemaDistill,
"distillation": ConfigSchemaDistill,
}

View File

@ -124,13 +124,11 @@ def init_nlp_distill(
config = nlp.config.interpolate()
# Resolve all training-relevant sections using the filled nlp config
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill)
D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill)
dot_names = [D["corpus"], T["dev_corpus"]]
if not isinstance(D["corpus"], str):
raise ConfigValidationError(
desc=Errors.E897.format(
field="distill.corpus", type=type(D["corpus"])
)
desc=Errors.E897.format(field="distillation.corpus", type=type(D["corpus"]))
)
if not isinstance(T["dev_corpus"], str):
raise ConfigValidationError(
@ -158,7 +156,9 @@ def init_nlp_distill(
labels = {}
for name, pipe in nlp.pipeline:
# Copy teacher labels.
teacher_pipe_name = student_to_teacher[name] if name in student_to_teacher else name
teacher_pipe_name = (
student_to_teacher[name] if name in student_to_teacher else name
)
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None)
if (
teacher_pipe is not None

View File

@ -57,7 +57,7 @@ def distill(
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill)
D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill)
dot_names = [D["corpus"], T["dev_corpus"]]
distill_corpus, dev_corpus = resolve_dot_names(config, dot_names)
optimizer = D["optimizer"]
@ -333,7 +333,7 @@ def _distill_loop(
if before_update:
before_update_args = {"step": step, "epoch": epoch}
before_update(student, before_update_args)
dropout = dropouts(optimizer.step)
dropout = dropouts(optimizer.step)
for subbatch in subdivide_batch(batch, accumulate_gradient):
student.distill(
teacher,