mirror of
https://github.com/explosion/spaCy.git
synced 2024-09-21 03:19:13 +03:00
Add the configuration schema for distillation (#12201)
* Add the configuration schema for distillation This also adds the default configuration and some tests. The schema will be used by the training loop and `distill` subcommand. * Format * Change distillation shortopt to -d * Fix descripion of max_epochs * Rename distillation flag to -dt * Rename `pipe_map` to `student_to_teacher`
This commit is contained in:
parent
1b5aba9e22
commit
fb7f018ded
|
@ -8,7 +8,7 @@ import re
|
||||||
from jinja2 import Template
|
from jinja2 import Template
|
||||||
|
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..language import DEFAULT_CONFIG_PRETRAIN_PATH
|
from ..language import DEFAULT_CONFIG_DISTILL_PATH, DEFAULT_CONFIG_PRETRAIN_PATH
|
||||||
from ..schemas import RecommendationSchema
|
from ..schemas import RecommendationSchema
|
||||||
from ..util import SimpleFrozenList
|
from ..util import SimpleFrozenList
|
||||||
from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND
|
from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND
|
||||||
|
@ -83,6 +83,7 @@ def init_fill_config_cli(
|
||||||
# fmt: off
|
# fmt: off
|
||||||
base_path: Path = Arg(..., help="Path to base config to fill", exists=True, dir_okay=False),
|
base_path: Path = Arg(..., help="Path to base config to fill", exists=True, dir_okay=False),
|
||||||
output_file: Path = Arg("-", help="Path to output .cfg file (or - for stdout)", allow_dash=True),
|
output_file: Path = Arg("-", help="Path to output .cfg file (or - for stdout)", allow_dash=True),
|
||||||
|
distillation: bool = Opt(False, "--distillation", "-dt", help="Include config for distillation (with 'spacy distill')"),
|
||||||
pretraining: bool = Opt(False, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"),
|
pretraining: bool = Opt(False, "--pretraining", "-pt", help="Include config for pretraining (with 'spacy pretrain')"),
|
||||||
diff: bool = Opt(False, "--diff", "-D", help="Print a visual diff highlighting the changes"),
|
diff: bool = Opt(False, "--diff", "-D", help="Print a visual diff highlighting the changes"),
|
||||||
code_path: Optional[Path] = Opt(None, "--code-path", "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
code_path: Optional[Path] = Opt(None, "--code-path", "--code", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||||
|
@ -98,13 +99,20 @@ def init_fill_config_cli(
|
||||||
DOCS: https://spacy.io/api/cli#init-fill-config
|
DOCS: https://spacy.io/api/cli#init-fill-config
|
||||||
"""
|
"""
|
||||||
import_code(code_path)
|
import_code(code_path)
|
||||||
fill_config(output_file, base_path, pretraining=pretraining, diff=diff)
|
fill_config(
|
||||||
|
output_file,
|
||||||
|
base_path,
|
||||||
|
distillation=distillation,
|
||||||
|
pretraining=pretraining,
|
||||||
|
diff=diff,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def fill_config(
|
def fill_config(
|
||||||
output_file: Path,
|
output_file: Path,
|
||||||
base_path: Path,
|
base_path: Path,
|
||||||
*,
|
*,
|
||||||
|
distillation: bool = False,
|
||||||
pretraining: bool = False,
|
pretraining: bool = False,
|
||||||
diff: bool = False,
|
diff: bool = False,
|
||||||
silent: bool = False,
|
silent: bool = False,
|
||||||
|
@ -123,6 +131,9 @@ def fill_config(
|
||||||
# replaced with their actual config after loading, so we have to re-add them
|
# replaced with their actual config after loading, so we have to re-add them
|
||||||
sourced = util.get_sourced_components(config)
|
sourced = util.get_sourced_components(config)
|
||||||
filled["components"].update(sourced)
|
filled["components"].update(sourced)
|
||||||
|
if distillation:
|
||||||
|
distillation_config = util.load_config(DEFAULT_CONFIG_DISTILL_PATH)
|
||||||
|
filled = distillation_config.merge(filled)
|
||||||
if pretraining:
|
if pretraining:
|
||||||
validate_config_for_pretrain(filled, msg)
|
validate_config_for_pretrain(filled, msg)
|
||||||
pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH)
|
pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH)
|
||||||
|
|
34
spacy/default_config_distillation.cfg
Normal file
34
spacy/default_config_distillation.cfg
Normal file
|
@ -0,0 +1,34 @@
|
||||||
|
[paths]
|
||||||
|
raw_text = null
|
||||||
|
|
||||||
|
[distillation]
|
||||||
|
corpus = "corpora.distillation"
|
||||||
|
dropout = 0.1
|
||||||
|
max_epochs = 1
|
||||||
|
max_steps = 0
|
||||||
|
student_to_teacher = {}
|
||||||
|
|
||||||
|
[distillation.batcher]
|
||||||
|
@batchers = "spacy.batch_by_words.v1"
|
||||||
|
size = 3000
|
||||||
|
discard_oversize = false
|
||||||
|
tolerance = 0.2
|
||||||
|
|
||||||
|
[distillation.optimizer]
|
||||||
|
@optimizers = "Adam.v1"
|
||||||
|
beta1 = 0.9
|
||||||
|
beta2 = 0.999
|
||||||
|
L2_is_weight_decay = true
|
||||||
|
L2 = 0.01
|
||||||
|
grad_clip = 1.0
|
||||||
|
use_averages = true
|
||||||
|
eps = 1e-8
|
||||||
|
learn_rate = 1e-4
|
||||||
|
|
||||||
|
[corpora]
|
||||||
|
|
||||||
|
[corpora.distillation]
|
||||||
|
@readers = "spacy.PlainTextCorpus.v1"
|
||||||
|
path = ${paths.raw_text}
|
||||||
|
min_length = 0
|
||||||
|
max_length = 0
|
|
@ -48,6 +48,9 @@ PipeCallable = Callable[[Doc], Doc]
|
||||||
# This is the base config will all settings (training etc.)
|
# This is the base config will all settings (training etc.)
|
||||||
DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg"
|
DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg"
|
||||||
DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
|
DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
|
||||||
|
# This is the base config for the [distillation] block and currently not included
|
||||||
|
# in the main config and only added via the 'init fill-config' command
|
||||||
|
DEFAULT_CONFIG_DISTILL_PATH = Path(__file__).parent / "default_config_distillation.cfg"
|
||||||
# This is the base config for the [pretraining] block and currently not included
|
# This is the base config for the [pretraining] block and currently not included
|
||||||
# in the main config and only added via the 'init fill-config' command
|
# in the main config and only added via the 'init fill-config' command
|
||||||
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
|
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
|
||||||
|
|
|
@ -422,6 +422,27 @@ class ConfigSchemaInit(BaseModel):
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigSchemaDistillEmpty(BaseModel):
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigSchemaDistill(BaseModel):
|
||||||
|
# fmt: off
|
||||||
|
batcher: Batcher = Field(..., title="Batcher for the training data")
|
||||||
|
corpus: StrictStr = Field(..., title="Path in the config to the distillation data")
|
||||||
|
dropout: StrictFloat = Field(..., title="Dropout rate")
|
||||||
|
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to distill for")
|
||||||
|
max_steps: StrictInt = Field(..., title="Maximum number of steps to distill for")
|
||||||
|
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||||
|
student_to_teacher: Dict[str, str] = Field(..., title="Mapping from student to teacher pipe")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
class ConfigSchema(BaseModel):
|
class ConfigSchema(BaseModel):
|
||||||
training: ConfigSchemaTraining
|
training: ConfigSchemaTraining
|
||||||
nlp: ConfigSchemaNlp
|
nlp: ConfigSchemaNlp
|
||||||
|
@ -429,6 +450,7 @@ class ConfigSchema(BaseModel):
|
||||||
components: Dict[str, Dict[str, Any]]
|
components: Dict[str, Dict[str, Any]]
|
||||||
corpora: Dict[str, Reader]
|
corpora: Dict[str, Reader]
|
||||||
initialize: ConfigSchemaInit
|
initialize: ConfigSchemaInit
|
||||||
|
distillation: Union[ConfigSchemaDistill, ConfigSchemaDistillEmpty] = {} # type: ignore[assignment]
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
extra = "allow"
|
extra = "allow"
|
||||||
|
@ -440,6 +462,7 @@ CONFIG_SCHEMAS = {
|
||||||
"training": ConfigSchemaTraining,
|
"training": ConfigSchemaTraining,
|
||||||
"pretraining": ConfigSchemaPretrain,
|
"pretraining": ConfigSchemaPretrain,
|
||||||
"initialize": ConfigSchemaInit,
|
"initialize": ConfigSchemaInit,
|
||||||
|
"distill": ConfigSchemaDistill,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,11 @@ import spacy
|
||||||
from spacy.lang.de import German
|
from spacy.lang.de import German
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.language import DEFAULT_CONFIG, DEFAULT_CONFIG_PRETRAIN_PATH
|
from spacy.language import DEFAULT_CONFIG, DEFAULT_CONFIG_PRETRAIN_PATH
|
||||||
|
from spacy.language import DEFAULT_CONFIG_DISTILL_PATH
|
||||||
from spacy.language import Language
|
from spacy.language import Language
|
||||||
from spacy.ml.models import MaxoutWindowEncoder, MultiHashEmbed
|
from spacy.ml.models import MaxoutWindowEncoder, MultiHashEmbed
|
||||||
from spacy.ml.models import build_tb_parser_model, build_Tok2Vec_model
|
from spacy.ml.models import build_tb_parser_model, build_Tok2Vec_model
|
||||||
from spacy.schemas import ConfigSchema, ConfigSchemaPretrain
|
from spacy.schemas import ConfigSchema, ConfigSchemaDistill, ConfigSchemaPretrain
|
||||||
from spacy.util import load_config, load_config_from_str
|
from spacy.util import load_config, load_config_from_str
|
||||||
from spacy.util import load_model_from_config, registry
|
from spacy.util import load_model_from_config, registry
|
||||||
|
|
||||||
|
@ -66,6 +67,60 @@ factory = "tagger"
|
||||||
width = ${components.tok2vec.model.width}
|
width = ${components.tok2vec.model.width}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
distill_config_string = """
|
||||||
|
[paths]
|
||||||
|
train = null
|
||||||
|
dev = null
|
||||||
|
|
||||||
|
[corpora]
|
||||||
|
|
||||||
|
[corpora.train]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths.train}
|
||||||
|
|
||||||
|
[corpora.dev]
|
||||||
|
@readers = "spacy.Corpus.v1"
|
||||||
|
path = ${paths.dev}
|
||||||
|
|
||||||
|
[training]
|
||||||
|
|
||||||
|
[training.batcher]
|
||||||
|
@batchers = "spacy.batch_by_words.v1"
|
||||||
|
size = 666
|
||||||
|
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["tok2vec", "tagger"]
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.tok2vec]
|
||||||
|
factory = "tok2vec"
|
||||||
|
|
||||||
|
[components.tok2vec.model]
|
||||||
|
@architectures = "spacy.HashEmbedCNN.v1"
|
||||||
|
pretrained_vectors = null
|
||||||
|
width = 342
|
||||||
|
depth = 4
|
||||||
|
window_size = 1
|
||||||
|
embed_size = 2000
|
||||||
|
maxout_pieces = 3
|
||||||
|
subword_features = true
|
||||||
|
|
||||||
|
[components.tagger]
|
||||||
|
factory = "tagger"
|
||||||
|
|
||||||
|
[components.tagger.model]
|
||||||
|
@architectures = "spacy.Tagger.v2"
|
||||||
|
|
||||||
|
[components.tagger.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = ${components.tok2vec.model.width}
|
||||||
|
|
||||||
|
[distill]
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
pretrain_config_string = """
|
pretrain_config_string = """
|
||||||
[paths]
|
[paths]
|
||||||
train = null
|
train = null
|
||||||
|
@ -201,6 +256,14 @@ def test_create_nlp_from_config():
|
||||||
load_model_from_config(Config(bad_cfg), auto_fill=True)
|
load_model_from_config(Config(bad_cfg), auto_fill=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_nlp_from_distillation_config():
|
||||||
|
"""Test that the default distillation config validates properly"""
|
||||||
|
config = Config().from_str(distill_config_string)
|
||||||
|
distill_config = load_config(DEFAULT_CONFIG_DISTILL_PATH)
|
||||||
|
filled = config.merge(distill_config)
|
||||||
|
registry.resolve(filled["distillation"], schema=ConfigSchemaDistill)
|
||||||
|
|
||||||
|
|
||||||
def test_create_nlp_from_pretraining_config():
|
def test_create_nlp_from_pretraining_config():
|
||||||
"""Test that the default pretraining config validates properly"""
|
"""Test that the default pretraining config validates properly"""
|
||||||
config = Config().from_str(pretrain_config_string)
|
config = Config().from_str(pretrain_config_string)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user