mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 18:26:30 +03:00
Add the configuration schema for distillation (#12201)
* Add the configuration schema for distillation This also adds the default configuration and some tests. The schema will be used by the training loop and `distill` subcommand. * Format * Change distillation shortopt to -d * Fix descripion of max_epochs * Rename distillation flag to -dt * Rename `pipe_map` to `student_to_teacher`
This commit is contained in:
parent
1b5aba9e22
commit
fb7f018ded
|
@ -8,7 +8,7 @@ import re
|
|||
from jinja2 import Template
|
||||
|
||||
from .. import util
|
||||
from ..language import DEFAULT_CONFIG_PRETRAIN_PATH
|
||||
from ..language import DEFAULT_CONFIG_DISTILL_PATH, DEFAULT_CONFIG_PRETRAIN_PATH
|
||||
from ..schemas import RecommendationSchema
|
||||
from ..util import SimpleFrozenList
|
||||
from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND
|
||||
|
@ -83,6 +83,7 @@ def init_fill_config_cli(
|
|||
# fmt: off
|
||||
base_path: Path = Arg(..., help="Path to base config to fill", exists=True, dir_okay=False),
|
||||
output_file: Path = Arg("-", help="Path to output .cfg file (or - for stdout)", allow_dash=True),
|
||||
distillation: bool = Opt(False, "--distillation", "-dt", help="Include config for distillation (with 'spacy distill')"),
|
||||
pretraining: bool = Opt(False, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"),
|
||||
diff: bool = Opt(False, "--diff", "-D", help="Print a visual diff highlighting the changes"),
|
||||
code_path: Optional[Path] = Opt(None, "--code-path", "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||
|
@ -98,13 +99,20 @@ def init_fill_config_cli(
|
|||
DOCS: https://spacy.io/api/cli#init-fill-config
|
||||
"""
|
||||
import_code(code_path)
|
||||
fill_config(output_file, base_path, pretraining=pretraining, diff=diff)
|
||||
fill_config(
|
||||
output_file,
|
||||
base_path,
|
||||
distillation=distillation,
|
||||
pretraining=pretraining,
|
||||
diff=diff,
|
||||
)
|
||||
|
||||
|
||||
def fill_config(
|
||||
output_file: Path,
|
||||
base_path: Path,
|
||||
*,
|
||||
distillation: bool = False,
|
||||
pretraining: bool = False,
|
||||
diff: bool = False,
|
||||
silent: bool = False,
|
||||
|
@ -123,6 +131,9 @@ def fill_config(
|
|||
# replaced with their actual config after loading, so we have to re-add them
|
||||
sourced = util.get_sourced_components(config)
|
||||
filled["components"].update(sourced)
|
||||
if distillation:
|
||||
distillation_config = util.load_config(DEFAULT_CONFIG_DISTILL_PATH)
|
||||
filled = distillation_config.merge(filled)
|
||||
if pretraining:
|
||||
validate_config_for_pretrain(filled, msg)
|
||||
pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH)
|
||||
|
|
34
spacy/default_config_distillation.cfg
Normal file
34
spacy/default_config_distillation.cfg
Normal file
|
@ -0,0 +1,34 @@
|
|||
[paths]
|
||||
raw_text = null
|
||||
|
||||
[distillation]
|
||||
corpus = "corpora.distillation"
|
||||
dropout = 0.1
|
||||
max_epochs = 1
|
||||
max_steps = 0
|
||||
student_to_teacher = {}
|
||||
|
||||
[distillation.batcher]
|
||||
@batchers = "spacy.batch_by_words.v1"
|
||||
size = 3000
|
||||
discard_oversize = false
|
||||
tolerance = 0.2
|
||||
|
||||
[distillation.optimizer]
|
||||
@optimizers = "Adam.v1"
|
||||
beta1 = 0.9
|
||||
beta2 = 0.999
|
||||
L2_is_weight_decay = true
|
||||
L2 = 0.01
|
||||
grad_clip = 1.0
|
||||
use_averages = true
|
||||
eps = 1e-8
|
||||
learn_rate = 1e-4
|
||||
|
||||
[corpora]
|
||||
|
||||
[corpora.distillation]
|
||||
@readers = "spacy.PlainTextCorpus.v1"
|
||||
path = ${paths.raw_text}
|
||||
min_length = 0
|
||||
max_length = 0
|
|
@ -48,6 +48,9 @@ PipeCallable = Callable[[Doc], Doc]
|
|||
# This is the base config will all settings (training etc.)
|
||||
DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg"
|
||||
DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
|
||||
# This is the base config for the [distillation] block and currently not included
|
||||
# in the main config and only added via the 'init fill-config' command
|
||||
DEFAULT_CONFIG_DISTILL_PATH = Path(__file__).parent / "default_config_distillation.cfg"
|
||||
# This is the base config for the [pretraining] block and currently not included
|
||||
# in the main config and only added via the 'init fill-config' command
|
||||
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
|
||||
|
|
|
@ -422,6 +422,27 @@ class ConfigSchemaInit(BaseModel):
|
|||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ConfigSchemaDistillEmpty(BaseModel):
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
|
||||
|
||||
class ConfigSchemaDistill(BaseModel):
|
||||
# fmt: off
|
||||
batcher: Batcher = Field(..., title="Batcher for the training data")
|
||||
corpus: StrictStr = Field(..., title="Path in the config to the distillation data")
|
||||
dropout: StrictFloat = Field(..., title="Dropout rate")
|
||||
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to distill for")
|
||||
max_steps: StrictInt = Field(..., title="Maximum number of steps to distill for")
|
||||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||
student_to_teacher: Dict[str, str] = Field(..., title="Mapping from student to teacher pipe")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class ConfigSchema(BaseModel):
|
||||
training: ConfigSchemaTraining
|
||||
nlp: ConfigSchemaNlp
|
||||
|
@ -429,6 +450,7 @@ class ConfigSchema(BaseModel):
|
|||
components: Dict[str, Dict[str, Any]]
|
||||
corpora: Dict[str, Reader]
|
||||
initialize: ConfigSchemaInit
|
||||
distillation: Union[ConfigSchemaDistill, ConfigSchemaDistillEmpty] = {} # type: ignore[assignment]
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
|
@ -440,6 +462,7 @@ CONFIG_SCHEMAS = {
|
|||
"training": ConfigSchemaTraining,
|
||||
"pretraining": ConfigSchemaPretrain,
|
||||
"initialize": ConfigSchemaInit,
|
||||
"distill": ConfigSchemaDistill,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -6,10 +6,11 @@ import spacy
|
|||
from spacy.lang.de import German
|
||||
from spacy.lang.en import English
|
||||
from spacy.language import DEFAULT_CONFIG, DEFAULT_CONFIG_PRETRAIN_PATH
|
||||
from spacy.language import DEFAULT_CONFIG_DISTILL_PATH
|
||||
from spacy.language import Language
|
||||
from spacy.ml.models import MaxoutWindowEncoder, MultiHashEmbed
|
||||
from spacy.ml.models import build_tb_parser_model, build_Tok2Vec_model
|
||||
from spacy.schemas import ConfigSchema, ConfigSchemaPretrain
|
||||
from spacy.schemas import ConfigSchema, ConfigSchemaDistill, ConfigSchemaPretrain
|
||||
from spacy.util import load_config, load_config_from_str
|
||||
from spacy.util import load_model_from_config, registry
|
||||
|
||||
|
@ -66,6 +67,60 @@ factory = "tagger"
|
|||
width = ${components.tok2vec.model.width}
|
||||
"""
|
||||
|
||||
distill_config_string = """
|
||||
[paths]
|
||||
train = null
|
||||
dev = null
|
||||
|
||||
[corpora]
|
||||
|
||||
[corpora.train]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths.train}
|
||||
|
||||
[corpora.dev]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
path = ${paths.dev}
|
||||
|
||||
[training]
|
||||
|
||||
[training.batcher]
|
||||
@batchers = "spacy.batch_by_words.v1"
|
||||
size = 666
|
||||
|
||||
[nlp]
|
||||
lang = "en"
|
||||
pipeline = ["tok2vec", "tagger"]
|
||||
|
||||
[components]
|
||||
|
||||
[components.tok2vec]
|
||||
factory = "tok2vec"
|
||||
|
||||
[components.tok2vec.model]
|
||||
@architectures = "spacy.HashEmbedCNN.v1"
|
||||
pretrained_vectors = null
|
||||
width = 342
|
||||
depth = 4
|
||||
window_size = 1
|
||||
embed_size = 2000
|
||||
maxout_pieces = 3
|
||||
subword_features = true
|
||||
|
||||
[components.tagger]
|
||||
factory = "tagger"
|
||||
|
||||
[components.tagger.model]
|
||||
@architectures = "spacy.Tagger.v2"
|
||||
|
||||
[components.tagger.model.tok2vec]
|
||||
@architectures = "spacy.Tok2VecListener.v1"
|
||||
width = ${components.tok2vec.model.width}
|
||||
|
||||
[distill]
|
||||
"""
|
||||
|
||||
|
||||
pretrain_config_string = """
|
||||
[paths]
|
||||
train = null
|
||||
|
@ -201,6 +256,14 @@ def test_create_nlp_from_config():
|
|||
load_model_from_config(Config(bad_cfg), auto_fill=True)
|
||||
|
||||
|
||||
def test_nlp_from_distillation_config():
|
||||
"""Test that the default distillation config validates properly"""
|
||||
config = Config().from_str(distill_config_string)
|
||||
distill_config = load_config(DEFAULT_CONFIG_DISTILL_PATH)
|
||||
filled = config.merge(distill_config)
|
||||
registry.resolve(filled["distillation"], schema=ConfigSchemaDistill)
|
||||
|
||||
|
||||
def test_create_nlp_from_pretraining_config():
|
||||
"""Test that the default pretraining config validates properly"""
|
||||
config = Config().from_str(pretrain_config_string)
|
||||
|
|
Loading…
Reference in New Issue
Block a user