From 3ce1d65dc35389bc044799ad2d7b2ea2ff7e7eec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Mon, 30 Jan 2023 12:33:03 +0100 Subject: [PATCH] Add the configuration schema for distillation This also adds the default configuration and some tests. The schema will be used by the training loop and `distill` subcommand. --- spacy/cli/init_config.py | 9 ++- spacy/default_config_distillation.cfg | 34 ++++++++++ spacy/language.py | 3 + spacy/schemas.py | 23 +++++++ .../tests/serialize/test_serialize_config.py | 65 ++++++++++++++++++- 5 files changed, 131 insertions(+), 3 deletions(-) create mode 100644 spacy/default_config_distillation.cfg diff --git a/spacy/cli/init_config.py b/spacy/cli/init_config.py index b634caa4c..987fa8977 100644 --- a/spacy/cli/init_config.py +++ b/spacy/cli/init_config.py @@ -8,7 +8,7 @@ import re from jinja2 import Template from .. import util -from ..language import DEFAULT_CONFIG_PRETRAIN_PATH +from ..language import DEFAULT_CONFIG_DISTILL_PATH, DEFAULT_CONFIG_PRETRAIN_PATH from ..schemas import RecommendationSchema from ..util import SimpleFrozenList from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND @@ -83,6 +83,7 @@ def init_fill_config_cli( # fmt: off base_path: Path = Arg(..., help="Path to base config to fill", exists=True, dir_okay=False), output_file: Path = Arg("-", help="Path to output .cfg file (or - for stdout)", allow_dash=True), + distillation: bool = Opt(False, "--distillation", "-pt", help="Include config for distillation (with 'spacy distill')"), pretraining: bool = Opt(False, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"), diff: bool = Opt(False, "--diff", "-D", help="Print a visual diff highlighting the changes"), code_path: Optional[Path] = Opt(None, "--code-path", "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"), @@ -98,13 +99,14 @@ def init_fill_config_cli( DOCS: https://spacy.io/api/cli#init-fill-config """ import_code(code_path) - fill_config(output_file, base_path, pretraining=pretraining, diff=diff) + fill_config(output_file, base_path, distillation=distillation, pretraining=pretraining, diff=diff) def fill_config( output_file: Path, base_path: Path, *, + distillation: bool = False, pretraining: bool = False, diff: bool = False, silent: bool = False, @@ -123,6 +125,9 @@ def fill_config( # replaced with their actual config after loading, so we have to re-add them sourced = util.get_sourced_components(config) filled["components"].update(sourced) + if distillation: + distillation_config = util.load_config(DEFAULT_CONFIG_DISTILL_PATH) + filled = distillation_config.merge(filled) if pretraining: validate_config_for_pretrain(filled, msg) pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH) diff --git a/spacy/default_config_distillation.cfg b/spacy/default_config_distillation.cfg new file mode 100644 index 000000000..2993f8e48 --- /dev/null +++ b/spacy/default_config_distillation.cfg @@ -0,0 +1,34 @@ +[paths] +raw_text = null + +[distillation] +corpus = "corpora.distillation" +dropout = 0.1 +max_epochs = 1 +max_steps = 0 +pipe_map = {} + +[distillation.batcher] +@batchers = "spacy.batch_by_words.v1" +size = 3000 +discard_oversize = false +tolerance = 0.2 + +[distillation.optimizer] +@optimizers = "Adam.v1" +beta1 = 0.9 +beta2 = 0.999 +L2_is_weight_decay = true +L2 = 0.01 +grad_clip = 1.0 +use_averages = true +eps = 1e-8 +learn_rate = 1e-4 + +[corpora] + +[corpora.distillation] +@readers = "spacy.PlainTextCorpus.v1" +path = ${paths.raw_text} +min_length = 0 +max_length = 0 diff --git a/spacy/language.py b/spacy/language.py index dcb62aef0..1616ea678 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -49,6 +49,9 @@ PipeCallable = Callable[[Doc], Doc] # This is the base config will all settings (training etc.) DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg" DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH) +# This is the base config for the [distillation] block and currently not included +# in the main config and only added via the 'init fill-config' command +DEFAULT_CONFIG_DISTILL_PATH = Path(__file__).parent / "default_config_distillation.cfg" # This is the base config for the [pretraining] block and currently not included # in the main config and only added via the 'init fill-config' command DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg" diff --git a/spacy/schemas.py b/spacy/schemas.py index 15f7a499b..ea113fd7b 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -405,6 +405,27 @@ class ConfigSchemaInit(BaseModel): arbitrary_types_allowed = True +class ConfigSchemaDistillEmpty(BaseModel): + class Config: + extra = "forbid" + + +class ConfigSchemaDistill(BaseModel): + # fmt: off + batcher: Batcher = Field(..., title="Batcher for the training data") + corpus: StrictStr = Field(..., title="Path in the config to the distillation data") + dropout: StrictFloat = Field(..., title="Dropout rate") + max_epochs: StrictInt = Field(..., title="Maximum number of steps to distill for") + max_steps: StrictInt = Field(..., title="Maximum number of steps to distill for") + optimizer: Optimizer = Field(..., title="The optimizer to use") + pipe_map: Dict[str, str] = Field(..., title="Mapping from teacher to student pipe") + # fmt: on + + class Config: + extra = "forbid" + arbitrary_types_allowed = True + + class ConfigSchema(BaseModel): training: ConfigSchemaTraining nlp: ConfigSchemaNlp @@ -412,6 +433,7 @@ class ConfigSchema(BaseModel): components: Dict[str, Dict[str, Any]] corpora: Dict[str, Reader] initialize: ConfigSchemaInit + distillation: Union[ConfigSchemaDistill, ConfigSchemaDistillEmpty] = {} # type: ignore[assignment] class Config: extra = "allow" @@ -423,6 +445,7 @@ CONFIG_SCHEMAS = { "training": ConfigSchemaTraining, "pretraining": ConfigSchemaPretrain, "initialize": ConfigSchemaInit, + "distill": ConfigSchemaDistill, } diff --git a/spacy/tests/serialize/test_serialize_config.py b/spacy/tests/serialize/test_serialize_config.py index 82f01dcc2..6eb95001a 100644 --- a/spacy/tests/serialize/test_serialize_config.py +++ b/spacy/tests/serialize/test_serialize_config.py @@ -6,10 +6,11 @@ import spacy from spacy.lang.de import German from spacy.lang.en import English from spacy.language import DEFAULT_CONFIG, DEFAULT_CONFIG_PRETRAIN_PATH +from spacy.language import DEFAULT_CONFIG_DISTILL_PATH from spacy.language import Language from spacy.ml.models import MaxoutWindowEncoder, MultiHashEmbed from spacy.ml.models import build_tb_parser_model, build_Tok2Vec_model -from spacy.schemas import ConfigSchema, ConfigSchemaPretrain +from spacy.schemas import ConfigSchema, ConfigSchemaDistill, ConfigSchemaPretrain from spacy.util import load_config, load_config_from_str from spacy.util import load_model_from_config, registry @@ -66,6 +67,60 @@ factory = "tagger" width = ${components.tok2vec.model.width} """ +distill_config_string = """ +[paths] +train = null +dev = null + +[corpora] + +[corpora.train] +@readers = "spacy.Corpus.v1" +path = ${paths.train} + +[corpora.dev] +@readers = "spacy.Corpus.v1" +path = ${paths.dev} + +[training] + +[training.batcher] +@batchers = "spacy.batch_by_words.v1" +size = 666 + +[nlp] +lang = "en" +pipeline = ["tok2vec", "tagger"] + +[components] + +[components.tok2vec] +factory = "tok2vec" + +[components.tok2vec.model] +@architectures = "spacy.HashEmbedCNN.v1" +pretrained_vectors = null +width = 342 +depth = 4 +window_size = 1 +embed_size = 2000 +maxout_pieces = 3 +subword_features = true + +[components.tagger] +factory = "tagger" + +[components.tagger.model] +@architectures = "spacy.Tagger.v2" + +[components.tagger.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = ${components.tok2vec.model.width} + +[distill] +""" + + pretrain_config_string = """ [paths] train = null @@ -201,6 +256,14 @@ def test_create_nlp_from_config(): load_model_from_config(Config(bad_cfg), auto_fill=True) +def test_nlp_from_distillation_config(): + """Test that the default distillation config validates properly""" + config = Config().from_str(distill_config_string) + distill_config = load_config(DEFAULT_CONFIG_DISTILL_PATH) + filled = config.merge(distill_config) + registry.resolve(filled["distillation"], schema=ConfigSchemaDistill) + + def test_create_nlp_from_pretraining_config(): """Test that the default pretraining config validates properly""" config = Config().from_str(pretrain_config_string)