mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Add the configuration schema for distillation (#12201)
* Add the configuration schema for distillation This also adds the default configuration and some tests. The schema will be used by the training loop and `distill` subcommand. * Format * Change distillation shortopt to -d * Fix descripion of max_epochs * Rename distillation flag to -dt * Rename `pipe_map` to `student_to_teacher`
This commit is contained in:
		
							parent
							
								
									1b5aba9e22
								
							
						
					
					
						commit
						fb7f018ded
					
				| 
						 | 
					@ -8,7 +8,7 @@ import re
 | 
				
			||||||
from jinja2 import Template
 | 
					from jinja2 import Template
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .. import util
 | 
					from .. import util
 | 
				
			||||||
from ..language import DEFAULT_CONFIG_PRETRAIN_PATH
 | 
					from ..language import DEFAULT_CONFIG_DISTILL_PATH, DEFAULT_CONFIG_PRETRAIN_PATH
 | 
				
			||||||
from ..schemas import RecommendationSchema
 | 
					from ..schemas import RecommendationSchema
 | 
				
			||||||
from ..util import SimpleFrozenList
 | 
					from ..util import SimpleFrozenList
 | 
				
			||||||
from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND
 | 
					from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND
 | 
				
			||||||
| 
						 | 
					@ -83,6 +83,7 @@ def init_fill_config_cli(
 | 
				
			||||||
    # fmt: off
 | 
					    # fmt: off
 | 
				
			||||||
    base_path: Path = Arg(..., help="Path to base config to fill", exists=True, dir_okay=False),
 | 
					    base_path: Path = Arg(..., help="Path to base config to fill", exists=True, dir_okay=False),
 | 
				
			||||||
    output_file: Path = Arg("-", help="Path to output .cfg file (or - for stdout)", allow_dash=True),
 | 
					    output_file: Path = Arg("-", help="Path to output .cfg file (or - for stdout)", allow_dash=True),
 | 
				
			||||||
 | 
					    distillation: bool = Opt(False, "--distillation", "-dt", help="Include config for distillation (with 'spacy distill')"),
 | 
				
			||||||
    pretraining: bool = Opt(False, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"),
 | 
					    pretraining: bool = Opt(False, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"),
 | 
				
			||||||
    diff: bool = Opt(False, "--diff", "-D", help="Print a visual diff highlighting the changes"),
 | 
					    diff: bool = Opt(False, "--diff", "-D", help="Print a visual diff highlighting the changes"),
 | 
				
			||||||
    code_path: Optional[Path] = Opt(None, "--code-path", "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
 | 
					    code_path: Optional[Path] = Opt(None, "--code-path", "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
 | 
				
			||||||
| 
						 | 
					@ -98,13 +99,20 @@ def init_fill_config_cli(
 | 
				
			||||||
    DOCS: https://spacy.io/api/cli#init-fill-config
 | 
					    DOCS: https://spacy.io/api/cli#init-fill-config
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    import_code(code_path)
 | 
					    import_code(code_path)
 | 
				
			||||||
    fill_config(output_file, base_path, pretraining=pretraining, diff=diff)
 | 
					    fill_config(
 | 
				
			||||||
 | 
					        output_file,
 | 
				
			||||||
 | 
					        base_path,
 | 
				
			||||||
 | 
					        distillation=distillation,
 | 
				
			||||||
 | 
					        pretraining=pretraining,
 | 
				
			||||||
 | 
					        diff=diff,
 | 
				
			||||||
 | 
					    )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def fill_config(
 | 
					def fill_config(
 | 
				
			||||||
    output_file: Path,
 | 
					    output_file: Path,
 | 
				
			||||||
    base_path: Path,
 | 
					    base_path: Path,
 | 
				
			||||||
    *,
 | 
					    *,
 | 
				
			||||||
 | 
					    distillation: bool = False,
 | 
				
			||||||
    pretraining: bool = False,
 | 
					    pretraining: bool = False,
 | 
				
			||||||
    diff: bool = False,
 | 
					    diff: bool = False,
 | 
				
			||||||
    silent: bool = False,
 | 
					    silent: bool = False,
 | 
				
			||||||
| 
						 | 
					@ -123,6 +131,9 @@ def fill_config(
 | 
				
			||||||
    # replaced with their actual config after loading, so we have to re-add them
 | 
					    # replaced with their actual config after loading, so we have to re-add them
 | 
				
			||||||
    sourced = util.get_sourced_components(config)
 | 
					    sourced = util.get_sourced_components(config)
 | 
				
			||||||
    filled["components"].update(sourced)
 | 
					    filled["components"].update(sourced)
 | 
				
			||||||
 | 
					    if distillation:
 | 
				
			||||||
 | 
					        distillation_config = util.load_config(DEFAULT_CONFIG_DISTILL_PATH)
 | 
				
			||||||
 | 
					        filled = distillation_config.merge(filled)
 | 
				
			||||||
    if pretraining:
 | 
					    if pretraining:
 | 
				
			||||||
        validate_config_for_pretrain(filled, msg)
 | 
					        validate_config_for_pretrain(filled, msg)
 | 
				
			||||||
        pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH)
 | 
					        pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
							
								
								
									
										34
									
								
								spacy/default_config_distillation.cfg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										34
									
								
								spacy/default_config_distillation.cfg
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
					@ -0,0 +1,34 @@
 | 
				
			||||||
 | 
					[paths]
 | 
				
			||||||
 | 
					raw_text = null
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[distillation]
 | 
				
			||||||
 | 
					corpus = "corpora.distillation"
 | 
				
			||||||
 | 
					dropout = 0.1
 | 
				
			||||||
 | 
					max_epochs = 1
 | 
				
			||||||
 | 
					max_steps = 0
 | 
				
			||||||
 | 
					student_to_teacher = {}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[distillation.batcher]
 | 
				
			||||||
 | 
					@batchers = "spacy.batch_by_words.v1"
 | 
				
			||||||
 | 
					size = 3000
 | 
				
			||||||
 | 
					discard_oversize = false
 | 
				
			||||||
 | 
					tolerance = 0.2
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[distillation.optimizer]
 | 
				
			||||||
 | 
					@optimizers = "Adam.v1"
 | 
				
			||||||
 | 
					beta1 = 0.9
 | 
				
			||||||
 | 
					beta2 = 0.999
 | 
				
			||||||
 | 
					L2_is_weight_decay = true
 | 
				
			||||||
 | 
					L2 = 0.01
 | 
				
			||||||
 | 
					grad_clip = 1.0
 | 
				
			||||||
 | 
					use_averages = true
 | 
				
			||||||
 | 
					eps = 1e-8
 | 
				
			||||||
 | 
					learn_rate = 1e-4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[corpora]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[corpora.distillation]
 | 
				
			||||||
 | 
					@readers = "spacy.PlainTextCorpus.v1"
 | 
				
			||||||
 | 
					path = ${paths.raw_text}
 | 
				
			||||||
 | 
					min_length = 0
 | 
				
			||||||
 | 
					max_length = 0
 | 
				
			||||||
| 
						 | 
					@ -48,6 +48,9 @@ PipeCallable = Callable[[Doc], Doc]
 | 
				
			||||||
# This is the base config will all settings (training etc.)
 | 
					# This is the base config will all settings (training etc.)
 | 
				
			||||||
DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg"
 | 
					DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg"
 | 
				
			||||||
DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
 | 
					DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
 | 
				
			||||||
 | 
					# This is the base config for the [distillation] block and currently not included
 | 
				
			||||||
 | 
					# in the main config and only added via the 'init fill-config' command
 | 
				
			||||||
 | 
					DEFAULT_CONFIG_DISTILL_PATH = Path(__file__).parent / "default_config_distillation.cfg"
 | 
				
			||||||
# This is the base config for the [pretraining] block and currently not included
 | 
					# This is the base config for the [pretraining] block and currently not included
 | 
				
			||||||
# in the main config and only added via the 'init fill-config' command
 | 
					# in the main config and only added via the 'init fill-config' command
 | 
				
			||||||
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
 | 
					DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -422,6 +422,27 @@ class ConfigSchemaInit(BaseModel):
 | 
				
			||||||
        arbitrary_types_allowed = True
 | 
					        arbitrary_types_allowed = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ConfigSchemaDistillEmpty(BaseModel):
 | 
				
			||||||
 | 
					    class Config:
 | 
				
			||||||
 | 
					        extra = "forbid"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					class ConfigSchemaDistill(BaseModel):
 | 
				
			||||||
 | 
					    # fmt: off
 | 
				
			||||||
 | 
					    batcher: Batcher = Field(..., title="Batcher for the training data")
 | 
				
			||||||
 | 
					    corpus: StrictStr = Field(..., title="Path in the config to the distillation data")
 | 
				
			||||||
 | 
					    dropout: StrictFloat = Field(..., title="Dropout rate")
 | 
				
			||||||
 | 
					    max_epochs: StrictInt = Field(..., title="Maximum number of epochs to distill for")
 | 
				
			||||||
 | 
					    max_steps: StrictInt = Field(..., title="Maximum number of steps to distill for")
 | 
				
			||||||
 | 
					    optimizer: Optimizer = Field(..., title="The optimizer to use")
 | 
				
			||||||
 | 
					    student_to_teacher: Dict[str, str] = Field(..., title="Mapping from student to teacher pipe")
 | 
				
			||||||
 | 
					    # fmt: on
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    class Config:
 | 
				
			||||||
 | 
					        extra = "forbid"
 | 
				
			||||||
 | 
					        arbitrary_types_allowed = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
class ConfigSchema(BaseModel):
 | 
					class ConfigSchema(BaseModel):
 | 
				
			||||||
    training: ConfigSchemaTraining
 | 
					    training: ConfigSchemaTraining
 | 
				
			||||||
    nlp: ConfigSchemaNlp
 | 
					    nlp: ConfigSchemaNlp
 | 
				
			||||||
| 
						 | 
					@ -429,6 +450,7 @@ class ConfigSchema(BaseModel):
 | 
				
			||||||
    components: Dict[str, Dict[str, Any]]
 | 
					    components: Dict[str, Dict[str, Any]]
 | 
				
			||||||
    corpora: Dict[str, Reader]
 | 
					    corpora: Dict[str, Reader]
 | 
				
			||||||
    initialize: ConfigSchemaInit
 | 
					    initialize: ConfigSchemaInit
 | 
				
			||||||
 | 
					    distillation: Union[ConfigSchemaDistill, ConfigSchemaDistillEmpty] = {}  # type: ignore[assignment]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    class Config:
 | 
					    class Config:
 | 
				
			||||||
        extra = "allow"
 | 
					        extra = "allow"
 | 
				
			||||||
| 
						 | 
					@ -440,6 +462,7 @@ CONFIG_SCHEMAS = {
 | 
				
			||||||
    "training": ConfigSchemaTraining,
 | 
					    "training": ConfigSchemaTraining,
 | 
				
			||||||
    "pretraining": ConfigSchemaPretrain,
 | 
					    "pretraining": ConfigSchemaPretrain,
 | 
				
			||||||
    "initialize": ConfigSchemaInit,
 | 
					    "initialize": ConfigSchemaInit,
 | 
				
			||||||
 | 
					    "distill": ConfigSchemaDistill,
 | 
				
			||||||
}
 | 
					}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -6,10 +6,11 @@ import spacy
 | 
				
			||||||
from spacy.lang.de import German
 | 
					from spacy.lang.de import German
 | 
				
			||||||
from spacy.lang.en import English
 | 
					from spacy.lang.en import English
 | 
				
			||||||
from spacy.language import DEFAULT_CONFIG, DEFAULT_CONFIG_PRETRAIN_PATH
 | 
					from spacy.language import DEFAULT_CONFIG, DEFAULT_CONFIG_PRETRAIN_PATH
 | 
				
			||||||
 | 
					from spacy.language import DEFAULT_CONFIG_DISTILL_PATH
 | 
				
			||||||
from spacy.language import Language
 | 
					from spacy.language import Language
 | 
				
			||||||
from spacy.ml.models import MaxoutWindowEncoder, MultiHashEmbed
 | 
					from spacy.ml.models import MaxoutWindowEncoder, MultiHashEmbed
 | 
				
			||||||
from spacy.ml.models import build_tb_parser_model, build_Tok2Vec_model
 | 
					from spacy.ml.models import build_tb_parser_model, build_Tok2Vec_model
 | 
				
			||||||
from spacy.schemas import ConfigSchema, ConfigSchemaPretrain
 | 
					from spacy.schemas import ConfigSchema, ConfigSchemaDistill, ConfigSchemaPretrain
 | 
				
			||||||
from spacy.util import load_config, load_config_from_str
 | 
					from spacy.util import load_config, load_config_from_str
 | 
				
			||||||
from spacy.util import load_model_from_config, registry
 | 
					from spacy.util import load_model_from_config, registry
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -66,6 +67,60 @@ factory = "tagger"
 | 
				
			||||||
width = ${components.tok2vec.model.width}
 | 
					width = ${components.tok2vec.model.width}
 | 
				
			||||||
"""
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					distill_config_string = """
 | 
				
			||||||
 | 
					[paths]
 | 
				
			||||||
 | 
					train = null
 | 
				
			||||||
 | 
					dev = null
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[corpora]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[corpora.train]
 | 
				
			||||||
 | 
					@readers = "spacy.Corpus.v1"
 | 
				
			||||||
 | 
					path = ${paths.train}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[corpora.dev]
 | 
				
			||||||
 | 
					@readers = "spacy.Corpus.v1"
 | 
				
			||||||
 | 
					path = ${paths.dev}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[training]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[training.batcher]
 | 
				
			||||||
 | 
					@batchers = "spacy.batch_by_words.v1"
 | 
				
			||||||
 | 
					size = 666
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[nlp]
 | 
				
			||||||
 | 
					lang = "en"
 | 
				
			||||||
 | 
					pipeline = ["tok2vec", "tagger"]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[components]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[components.tok2vec]
 | 
				
			||||||
 | 
					factory = "tok2vec"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[components.tok2vec.model]
 | 
				
			||||||
 | 
					@architectures = "spacy.HashEmbedCNN.v1"
 | 
				
			||||||
 | 
					pretrained_vectors = null
 | 
				
			||||||
 | 
					width = 342
 | 
				
			||||||
 | 
					depth = 4
 | 
				
			||||||
 | 
					window_size = 1
 | 
				
			||||||
 | 
					embed_size = 2000
 | 
				
			||||||
 | 
					maxout_pieces = 3
 | 
				
			||||||
 | 
					subword_features = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[components.tagger]
 | 
				
			||||||
 | 
					factory = "tagger"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[components.tagger.model]
 | 
				
			||||||
 | 
					@architectures = "spacy.Tagger.v2"
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[components.tagger.model.tok2vec]
 | 
				
			||||||
 | 
					@architectures = "spacy.Tok2VecListener.v1"
 | 
				
			||||||
 | 
					width = ${components.tok2vec.model.width}
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					[distill]
 | 
				
			||||||
 | 
					"""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
pretrain_config_string = """
 | 
					pretrain_config_string = """
 | 
				
			||||||
[paths]
 | 
					[paths]
 | 
				
			||||||
train = null
 | 
					train = null
 | 
				
			||||||
| 
						 | 
					@ -201,6 +256,14 @@ def test_create_nlp_from_config():
 | 
				
			||||||
        load_model_from_config(Config(bad_cfg), auto_fill=True)
 | 
					        load_model_from_config(Config(bad_cfg), auto_fill=True)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					def test_nlp_from_distillation_config():
 | 
				
			||||||
 | 
					    """Test that the default distillation config validates properly"""
 | 
				
			||||||
 | 
					    config = Config().from_str(distill_config_string)
 | 
				
			||||||
 | 
					    distill_config = load_config(DEFAULT_CONFIG_DISTILL_PATH)
 | 
				
			||||||
 | 
					    filled = config.merge(distill_config)
 | 
				
			||||||
 | 
					    registry.resolve(filled["distillation"], schema=ConfigSchemaDistill)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
def test_create_nlp_from_pretraining_config():
 | 
					def test_create_nlp_from_pretraining_config():
 | 
				
			||||||
    """Test that the default pretraining config validates properly"""
 | 
					    """Test that the default pretraining config validates properly"""
 | 
				
			||||||
    config = Config().from_str(pretrain_config_string)
 | 
					    config = Config().from_str(pretrain_config_string)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user