mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-13 09:42:26 +03:00
Fix up configuration keys
This commit is contained in:
parent
9a72ea0b91
commit
b9324505d8
|
@ -1837,7 +1837,7 @@ class Language:
|
||||||
# using the nlp.config with all defaults.
|
# using the nlp.config with all defaults.
|
||||||
config = util.copy_config(config)
|
config = util.copy_config(config)
|
||||||
orig_pipeline = config.pop("components", {})
|
orig_pipeline = config.pop("components", {})
|
||||||
orig_distill = config.pop("distill", None)
|
orig_distill = config.pop("distillation", None)
|
||||||
orig_pretraining = config.pop("pretraining", None)
|
orig_pretraining = config.pop("pretraining", None)
|
||||||
config["components"] = {}
|
config["components"] = {}
|
||||||
if auto_fill:
|
if auto_fill:
|
||||||
|
@ -1847,8 +1847,8 @@ class Language:
|
||||||
filled["components"] = orig_pipeline
|
filled["components"] = orig_pipeline
|
||||||
config["components"] = orig_pipeline
|
config["components"] = orig_pipeline
|
||||||
if orig_distill is not None:
|
if orig_distill is not None:
|
||||||
filled["distill"] = orig_distill
|
filled["distillation"] = orig_distill
|
||||||
config["distill"] = orig_distill
|
config["distillation"] = orig_distill
|
||||||
if orig_pretraining is not None:
|
if orig_pretraining is not None:
|
||||||
filled["pretraining"] = orig_pretraining
|
filled["pretraining"] = orig_pretraining
|
||||||
config["pretraining"] = orig_pretraining
|
config["pretraining"] = orig_pretraining
|
||||||
|
|
|
@ -462,7 +462,7 @@ CONFIG_SCHEMAS = {
|
||||||
"training": ConfigSchemaTraining,
|
"training": ConfigSchemaTraining,
|
||||||
"pretraining": ConfigSchemaPretrain,
|
"pretraining": ConfigSchemaPretrain,
|
||||||
"initialize": ConfigSchemaInit,
|
"initialize": ConfigSchemaInit,
|
||||||
"distill": ConfigSchemaDistill,
|
"distillation": ConfigSchemaDistill,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -124,13 +124,11 @@ def init_nlp_distill(
|
||||||
config = nlp.config.interpolate()
|
config = nlp.config.interpolate()
|
||||||
# Resolve all training-relevant sections using the filled nlp config
|
# Resolve all training-relevant sections using the filled nlp config
|
||||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||||
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill)
|
D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill)
|
||||||
dot_names = [D["corpus"], T["dev_corpus"]]
|
dot_names = [D["corpus"], T["dev_corpus"]]
|
||||||
if not isinstance(D["corpus"], str):
|
if not isinstance(D["corpus"], str):
|
||||||
raise ConfigValidationError(
|
raise ConfigValidationError(
|
||||||
desc=Errors.E897.format(
|
desc=Errors.E897.format(field="distillation.corpus", type=type(D["corpus"]))
|
||||||
field="distill.corpus", type=type(D["corpus"])
|
|
||||||
)
|
|
||||||
)
|
)
|
||||||
if not isinstance(T["dev_corpus"], str):
|
if not isinstance(T["dev_corpus"], str):
|
||||||
raise ConfigValidationError(
|
raise ConfigValidationError(
|
||||||
|
@ -158,7 +156,9 @@ def init_nlp_distill(
|
||||||
labels = {}
|
labels = {}
|
||||||
for name, pipe in nlp.pipeline:
|
for name, pipe in nlp.pipeline:
|
||||||
# Copy teacher labels.
|
# Copy teacher labels.
|
||||||
teacher_pipe_name = student_to_teacher[name] if name in student_to_teacher else name
|
teacher_pipe_name = (
|
||||||
|
student_to_teacher[name] if name in student_to_teacher else name
|
||||||
|
)
|
||||||
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None)
|
teacher_pipe = teacher_pipes.get(teacher_pipe_name, None)
|
||||||
if (
|
if (
|
||||||
teacher_pipe is not None
|
teacher_pipe is not None
|
||||||
|
|
|
@ -57,7 +57,7 @@ def distill(
|
||||||
if use_gpu >= 0 and allocator:
|
if use_gpu >= 0 and allocator:
|
||||||
set_gpu_allocator(allocator)
|
set_gpu_allocator(allocator)
|
||||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||||
D = registry.resolve(config["distill"], schema=ConfigSchemaDistill)
|
D = registry.resolve(config["distillation"], schema=ConfigSchemaDistill)
|
||||||
dot_names = [D["corpus"], T["dev_corpus"]]
|
dot_names = [D["corpus"], T["dev_corpus"]]
|
||||||
distill_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
distill_corpus, dev_corpus = resolve_dot_names(config, dot_names)
|
||||||
optimizer = D["optimizer"]
|
optimizer = D["optimizer"]
|
||||||
|
@ -333,7 +333,7 @@ def _distill_loop(
|
||||||
if before_update:
|
if before_update:
|
||||||
before_update_args = {"step": step, "epoch": epoch}
|
before_update_args = {"step": step, "epoch": epoch}
|
||||||
before_update(student, before_update_args)
|
before_update(student, before_update_args)
|
||||||
dropout = dropouts(optimizer.step)
|
dropout = dropouts(optimizer.step)
|
||||||
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||||
student.distill(
|
student.distill(
|
||||||
teacher,
|
teacher,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user