mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-21 17:41:59 +03:00
Fix up configuration keys
This commit is contained in:
parent
9a72ea0b91
commit
b9324505d8
|
@ -1837,7 +1837,7 @@ class Language:
|
|||
# using the nlp.config with all defaults.
|
||||
config = util.copy_config(config)
|
||||
orig_pipeline = config.pop("components", {})
|
||||
orig_distill = config.pop("distill", None)
|
||||
orig_distill = config.pop("distillation", None)
|
||||
orig_pretraining = config.pop("pretraining", None)
|
||||
config["components"] = {}
|
||||
if auto_fill:
|
||||
|
@ -1847,8 +1847,8 @@ class Language:
|
|||
filled["components"] = orig_pipeline
|
||||
config["components"] = orig_pipeline
|
||||
if orig_distill is not None:
|
||||
filled["distill"] = orig_distill
|
||||
config["distill"] = orig_distill
|
||||
filled["distillation"] = orig_distill
|
||||
config["distillation"] = orig_distill
|
||||
if orig_pretraining is not None:
|
||||
filled["pretraining"] = orig_pretraining
|
||||
config["pretraining"] = orig_pretraining
|
||||
|
|
|
@ -462,7 +462,7 @@ CONFIG_SCHEMAS = {
|
|||
"training": ConfigSchemaTraining,
|
||||
"pretraining": ConfigSchemaPretrain,
|
||||
"initialize": ConfigSchemaInit,
|
||||
"distill": ConfigSchemaDistill,
|
||||
"distillation": ConfigSchemaDistill,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -124,13 +124,11 @@ def init_nlp_distill(
|
|||
config = nlp.config.interpolate()
|
||||
# Resolve all training-relevant sections using the filled nlp config
|
||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill)
|
||||
D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill)
|
||||
dot_names = [D["corpus"], T["dev_corpus"]]
|
||||
if not isinstance(D["corpus"], str):
|
||||
raise ConfigValidationError(
|
||||
desc=Errors.E897.format(
|
||||
field="distill.corpus", type=type(D["corpus"])
|
||||
)
|
||||
desc=Errors.E897.format(field="distillation.corpus", type=type(D["corpus"]))
|
||||
)
|
||||
if not isinstance(T["dev_corpus"], str):
|
||||
raise ConfigValidationError(
|
||||
|
@ -158,7 +156,9 @@ def init_nlp_distill(
|
|||
labels = {}
|
||||
for name, pipe in nlp.pipeline:
|
||||
# Copy teacher labels.
|
||||
teacher_pipe_name = student_to_teacher[name] if name in student_to_teacher else name
|
||||
teacher_pipe_name = (
|
||||
student_to_teacher[name] if name in student_to_teacher else name
|
||||
)
|
||||
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None)
|
||||
if (
|
||||
teacher_pipe is not None
|
||||
|
|
|
@ -57,7 +57,7 @@ def distill(
|
|||
if use_gpu >= 0 and allocator:
|
||||
set_gpu_allocator(allocator)
|
||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill)
|
||||
D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill)
|
||||
dot_names = [D["corpus"], T["dev_corpus"]]
|
||||
distill_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
||||
optimizer = D["optimizer"]
|
||||
|
@ -333,7 +333,7 @@ def _distill_loop(
|
|||
if before_update:
|
||||
before_update_args = {"step": step, "epoch": epoch}
|
||||
before_update(student, before_update_args)
|
||||
dropout = dropouts(optimizer.step)
|
||||
dropout = dropouts(optimizer.step)
|
||||
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||
student.distill(
|
||||
teacher,
|
||||
|
|
Loading…
Reference in New Issue
Block a user