Add the configuration schema for distillation (#12201)

* Add the configuration schema for distillation

This also adds the default configuration and some tests. The schema will
be used by the training loop and `distill` subcommand.

* Format

* Change distillation shortopt to -d

* Fix descripion of max_epochs

* Rename distillation flag to -dt

* Rename `pipe_map` to `student_to_teacher`
This commit is contained in:
Daniël de Kok 2023-01-31 13:06:02 +01:00 committed by GitHub
parent 1b5aba9e22
commit fb7f018ded
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 137 additions and 3 deletions

View File

@ -8,7 +8,7 @@ import re
from jinja2 import Template
from .. import util
from ..language import DEFAULT_CONFIG_PRETRAIN_PATH
from ..language import DEFAULT_CONFIG_DISTILL_PATH, DEFAULT_CONFIG_PRETRAIN_PATH
from ..schemas import RecommendationSchema
from ..util import SimpleFrozenList
from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND
@ -83,6 +83,7 @@ def init_fill_config_cli(
# fmt: off
base_path: Path = Arg(..., help="Path to base config to fill", exists=True, dir_okay=False),
output_file: Path = Arg("-", help="Path to output .cfg file (or - for stdout)", allow_dash=True),
distillation: bool = Opt(False, "--distillation", "-dt", help="Include config for distillation (with 'spacy distill')"),
pretraining: bool = Opt(False, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"),
diff: bool = Opt(False, "--diff", "-D", help="Print a visual diff highlighting the changes"),
code_path: Optional[Path] = Opt(None, "--code-path", "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
@ -98,13 +99,20 @@ def init_fill_config_cli(
DOCS: https://spacy.io/api/cli#init-fill-config
"""
import_code(code_path)
fill_config(output_file, base_path, pretraining=pretraining, diff=diff)
fill_config(
output_file,
base_path,
distillation=distillation,
pretraining=pretraining,
diff=diff,
)
def fill_config(
output_file: Path,
base_path: Path,
*,
distillation: bool = False,
pretraining: bool = False,
diff: bool = False,
silent: bool = False,
@ -123,6 +131,9 @@ def fill_config(
# replaced with their actual config after loading, so we have to re-add them
sourced = util.get_sourced_components(config)
filled["components"].update(sourced)
if distillation:
distillation_config = util.load_config(DEFAULT_CONFIG_DISTILL_PATH)
filled = distillation_config.merge(filled)
if pretraining:
validate_config_for_pretrain(filled, msg)
pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH)

View File

@ -0,0 +1,34 @@
[paths]
raw_text = null
[distillation]
corpus = "corpora.distillation"
dropout = 0.1
max_epochs = 1
max_steps = 0
student_to_teacher = {}
[distillation.batcher]
@batchers = "spacy.batch_by_words.v1"
size = 3000
discard_oversize = false
tolerance = 0.2
[distillation.optimizer]
@optimizers = "Adam.v1"
beta1 = 0.9
beta2 = 0.999
L2_is_weight_decay = true
L2 = 0.01
grad_clip = 1.0
use_averages = true
eps = 1e-8
learn_rate = 1e-4
[corpora]
[corpora.distillation]
@readers = "spacy.PlainTextCorpus.v1"
path = ${paths.raw_text}
min_length = 0
max_length = 0

View File

@ -48,6 +48,9 @@ PipeCallable = Callable[[Doc], Doc]
# This is the base config will all settings (training etc.)
DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg"
DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
# This is the base config for the [distillation] block and currently not included
# in the main config and only added via the 'init fill-config' command
DEFAULT_CONFIG_DISTILL_PATH = Path(__file__).parent / "default_config_distillation.cfg"
# This is the base config for the [pretraining] block and currently not included
# in the main config and only added via the 'init fill-config' command
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"

View File

@ -422,6 +422,27 @@ class ConfigSchemaInit(BaseModel):
arbitrary_types_allowed = True
class ConfigSchemaDistillEmpty(BaseModel):
class Config:
extra = "forbid"
class ConfigSchemaDistill(BaseModel):
# fmt: off
batcher: Batcher = Field(..., title="Batcher for the training data")
corpus: StrictStr = Field(..., title="Path in the config to the distillation data")
dropout: StrictFloat = Field(..., title="Dropout rate")
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to distill for")
max_steps: StrictInt = Field(..., title="Maximum number of steps to distill for")
optimizer: Optimizer = Field(..., title="The optimizer to use")
student_to_teacher: Dict[str, str] = Field(..., title="Mapping from student to teacher pipe")
# fmt: on
class Config:
extra = "forbid"
arbitrary_types_allowed = True
class ConfigSchema(BaseModel):
training: ConfigSchemaTraining
nlp: ConfigSchemaNlp
@ -429,6 +450,7 @@ class ConfigSchema(BaseModel):
components: Dict[str, Dict[str, Any]]
corpora: Dict[str, Reader]
initialize: ConfigSchemaInit
distillation: Union[ConfigSchemaDistill, ConfigSchemaDistillEmpty] = {} # type: ignore[assignment]
class Config:
extra = "allow"
@ -440,6 +462,7 @@ CONFIG_SCHEMAS = {
"training": ConfigSchemaTraining,
"pretraining": ConfigSchemaPretrain,
"initialize": ConfigSchemaInit,
"distill": ConfigSchemaDistill,
}

View File

@ -6,10 +6,11 @@ import spacy
from spacy.lang.de import German
from spacy.lang.en import English
from spacy.language import DEFAULT_CONFIG, DEFAULT_CONFIG_PRETRAIN_PATH
from spacy.language import DEFAULT_CONFIG_DISTILL_PATH
from spacy.language import Language
from spacy.ml.models import MaxoutWindowEncoder, MultiHashEmbed
from spacy.ml.models import build_tb_parser_model, build_Tok2Vec_model
from spacy.schemas import ConfigSchema, ConfigSchemaPretrain
from spacy.schemas import ConfigSchema, ConfigSchemaDistill, ConfigSchemaPretrain
from spacy.util import load_config, load_config_from_str
from spacy.util import load_model_from_config, registry
@ -66,6 +67,60 @@ factory = "tagger"
width = ${components.tok2vec.model.width}
"""
distill_config_string = """
[paths]
train = null
dev = null
[corpora]
[corpora.train]
@readers = "spacy.Corpus.v1"
path = ${paths.train}
[corpora.dev]
@readers = "spacy.Corpus.v1"
path = ${paths.dev}
[training]
[training.batcher]
@batchers = "spacy.batch_by_words.v1"
size = 666
[nlp]
lang = "en"
pipeline = ["tok2vec", "tagger"]
[components]
[components.tok2vec]
factory = "tok2vec"
[components.tok2vec.model]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 342
depth = 4
window_size = 1
embed_size = 2000
maxout_pieces = 3
subword_features = true
[components.tagger]
factory = "tagger"
[components.tagger.model]
@architectures = "spacy.Tagger.v2"
[components.tagger.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.width}
[distill]
"""
pretrain_config_string = """
[paths]
train = null
@ -201,6 +256,14 @@ def test_create_nlp_from_config():
load_model_from_config(Config(bad_cfg), auto_fill=True)
def test_nlp_from_distillation_config():
"""Test that the default distillation config validates properly"""
config = Config().from_str(distill_config_string)
distill_config = load_config(DEFAULT_CONFIG_DISTILL_PATH)
filled = config.merge(distill_config)
registry.resolve(filled["distillation"], schema=ConfigSchemaDistill)
def test_create_nlp_from_pretraining_config():
"""Test that the default pretraining config validates properly"""
config = Config().from_str(pretrain_config_string)