2020-07-25 16:01:15 +03:00
from typing import Dict , List , Union , Optional , Sequence , Any , Callable , Type
2020-08-05 17:00:59 +03:00
from typing import Iterable , TypeVar , TYPE_CHECKING
2019-12-25 14:39:49 +03:00
from enum import Enum
from pydantic import BaseModel , Field , ValidationError , validator
2020-07-10 14:31:27 +03:00
from pydantic import StrictStr , StrictInt , StrictFloat , StrictBool
2020-07-12 13:32:08 +03:00
from pydantic import root_validator
2019-12-25 14:39:49 +03:00
from collections import defaultdict
2020-07-22 14:42:59 +03:00
from thinc . api import Optimizer
2019-12-25 14:39:49 +03:00
from . attrs import NAMES
2020-08-05 17:00:59 +03:00
if TYPE_CHECKING :
# This lets us add type hints for mypy etc. without causing circular imports
from . language import Language # noqa: F401
from . gold import Example # noqa: F401
2020-08-04 16:09:37 +03:00
ItemT = TypeVar ( " ItemT " )
Batcher = Callable [ [ Iterable [ ItemT ] ] , Iterable [ List [ ItemT ] ] ]
2020-08-05 17:00:59 +03:00
Reader = Callable [ [ " Language " , str ] , Iterable [ " Example " ] ]
2020-08-04 16:09:37 +03:00
2019-12-25 14:39:49 +03:00
2020-07-25 16:01:15 +03:00
def validate ( schema : Type [ BaseModel ] , obj : Dict [ str , Any ] ) - > List [ str ] :
2019-12-25 14:39:49 +03:00
""" Validate data against a given pydantic schema.
2020-07-25 16:01:15 +03:00
obj ( Dict [ str , Any ] ) : JSON - serializable data to validate .
2019-12-25 14:39:49 +03:00
schema ( pydantic . BaseModel ) : The schema to validate against .
2020-07-25 16:01:15 +03:00
RETURNS ( List [ str ] ) : A list of error messages , if available .
2019-12-25 14:39:49 +03:00
"""
try :
schema ( * * obj )
return [ ]
except ValidationError as e :
errors = e . errors ( )
data = defaultdict ( list )
for error in errors :
err_loc = " -> " . join ( [ str ( p ) for p in error . get ( " loc " , [ ] ) ] )
data [ err_loc ] . append ( error . get ( " msg " ) )
return [ f " [ { loc } ] { ' , ' . join ( msg ) } " for loc , msg in data . items ( ) ]
# Matcher token patterns
2020-07-25 16:01:15 +03:00
def validate_token_pattern ( obj : list ) - > List [ str ] :
2019-12-25 14:39:49 +03:00
# Try to convert non-string keys (e.g. {ORTH: "foo"} -> {"ORTH": "foo"})
get_key = lambda k : NAMES [ k ] if isinstance ( k , int ) and k < len ( NAMES ) else k
if isinstance ( obj , list ) :
converted = [ ]
for pattern in obj :
if isinstance ( pattern , dict ) :
pattern = { get_key ( k ) : v for k , v in pattern . items ( ) }
converted . append ( pattern )
obj = converted
return validate ( TokenPatternSchema , { " pattern " : obj } )
class TokenPatternString ( BaseModel ) :
REGEX : Optional [ StrictStr ]
IN : Optional [ List [ StrictStr ] ]
NOT_IN : Optional [ List [ StrictStr ] ]
class Config :
extra = " forbid "
2020-07-22 14:42:59 +03:00
@validator ( " * " , pre = True , each_item = True )
2019-12-25 14:39:49 +03:00
def raise_for_none ( cls , v ) :
if v is None :
raise ValueError ( " None / null is not allowed " )
return v
class TokenPatternNumber ( BaseModel ) :
REGEX : Optional [ StrictStr ] = None
IN : Optional [ List [ StrictInt ] ] = None
NOT_IN : Optional [ List [ StrictInt ] ] = None
EQ : Union [ StrictInt , StrictFloat ] = Field ( None , alias = " == " )
2020-05-21 20:01:02 +03:00
NEQ : Union [ StrictInt , StrictFloat ] = Field ( None , alias = " != " )
2019-12-25 14:39:49 +03:00
GEQ : Union [ StrictInt , StrictFloat ] = Field ( None , alias = " >= " )
LEQ : Union [ StrictInt , StrictFloat ] = Field ( None , alias = " <= " )
GT : Union [ StrictInt , StrictFloat ] = Field ( None , alias = " > " )
LT : Union [ StrictInt , StrictFloat ] = Field ( None , alias = " < " )
class Config :
extra = " forbid "
2020-07-22 14:42:59 +03:00
@validator ( " * " , pre = True , each_item = True )
2019-12-25 14:39:49 +03:00
def raise_for_none ( cls , v ) :
if v is None :
raise ValueError ( " None / null is not allowed " )
return v
class TokenPatternOperator ( str , Enum ) :
plus : StrictStr = " + "
start : StrictStr = " * "
question : StrictStr = " ? "
exclamation : StrictStr = " ! "
StringValue = Union [ TokenPatternString , StrictStr ]
NumberValue = Union [ TokenPatternNumber , StrictInt , StrictFloat ]
UnderscoreValue = Union [
TokenPatternString , TokenPatternNumber , str , int , float , list , bool ,
]
class TokenPattern ( BaseModel ) :
orth : Optional [ StringValue ] = None
text : Optional [ StringValue ] = None
lower : Optional [ StringValue ] = None
pos : Optional [ StringValue ] = None
tag : Optional [ StringValue ] = None
dep : Optional [ StringValue ] = None
lemma : Optional [ StringValue ] = None
shape : Optional [ StringValue ] = None
ent_type : Optional [ StringValue ] = None
norm : Optional [ StringValue ] = None
length : Optional [ NumberValue ] = None
2020-02-18 16:32:53 +03:00
spacy : Optional [ StrictBool ] = None
2019-12-25 14:39:49 +03:00
is_alpha : Optional [ StrictBool ] = None
is_ascii : Optional [ StrictBool ] = None
is_digit : Optional [ StrictBool ] = None
is_lower : Optional [ StrictBool ] = None
is_upper : Optional [ StrictBool ] = None
is_title : Optional [ StrictBool ] = None
is_punct : Optional [ StrictBool ] = None
is_space : Optional [ StrictBool ] = None
is_bracket : Optional [ StrictBool ] = None
is_quote : Optional [ StrictBool ] = None
is_left_punct : Optional [ StrictBool ] = None
is_right_punct : Optional [ StrictBool ] = None
is_currency : Optional [ StrictBool ] = None
is_stop : Optional [ StrictBool ] = None
is_sent_start : Optional [ StrictBool ] = None
2020-03-26 16:05:40 +03:00
sent_start : Optional [ StrictBool ] = None
2019-12-25 14:39:49 +03:00
like_num : Optional [ StrictBool ] = None
like_url : Optional [ StrictBool ] = None
like_email : Optional [ StrictBool ] = None
op : Optional [ TokenPatternOperator ] = None
underscore : Optional [ Dict [ StrictStr , UnderscoreValue ] ] = Field ( None , alias = " _ " )
class Config :
extra = " forbid "
allow_population_by_field_name = True
alias_generator = lambda value : value . upper ( )
@validator ( " * " , pre = True )
def raise_for_none ( cls , v ) :
if v is None :
raise ValueError ( " None / null is not allowed " )
return v
class TokenPatternSchema ( BaseModel ) :
pattern : List [ TokenPattern ] = Field ( . . . , minItems = 1 )
class Config :
extra = " forbid "
# Model meta
class ModelMetaSchema ( BaseModel ) :
# fmt: off
lang : StrictStr = Field ( . . . , title = " Two-letter language code, e.g. ' en ' " )
name : StrictStr = Field ( . . . , title = " Model name " )
version : StrictStr = Field ( . . . , title = " Model version " )
2020-08-18 15:39:40 +03:00
spacy_version : StrictStr = Field ( " " , title = " Compatible spaCy version identifier " )
parent_package : StrictStr = Field ( " spacy " , title = " Name of parent spaCy package, e.g. spacy or spacy-nightly " )
pipeline : List [ StrictStr ] = Field ( [ ] , title = " Names of pipeline components " )
description : StrictStr = Field ( " " , title = " Model description " )
license : StrictStr = Field ( " " , title = " Model license " )
author : StrictStr = Field ( " " , title = " Model author name " )
email : StrictStr = Field ( " " , title = " Model author email " )
url : StrictStr = Field ( " " , title = " Model author URL " )
sources : Optional [ Union [ List [ StrictStr ] , List [ Dict [ str , str ] ] ] ] = Field ( None , title = " Training data sources " )
vectors : Dict [ str , Any ] = Field ( { } , title = " Included word vectors " )
labels : Dict [ str , Dict [ str , List [ str ] ] ] = Field ( { } , title = " Component labels, keyed by component name " )
accuracy : Dict [ str , Union [ float , Dict [ str , float ] ] ] = Field ( { } , title = " Accuracy numbers " )
speed : Dict [ str , Union [ float , int ] ] = Field ( { } , title = " Speed evaluation numbers " )
spacy_git_version : StrictStr = Field ( " " , title = " Commit of spaCy version used " )
2019-12-25 14:39:49 +03:00
# fmt: on
2020-07-10 14:31:27 +03:00
# Config schema
# We're not setting any defaults here (which is too messy) and are making all
# fields required, so we can raise validation errors for missing values. To
# provide a default, we include a separate .cfg file with all values and
# check that against this schema in the test suite to make sure it's always
# up to date.
class ConfigSchemaTraining ( BaseModel ) :
# fmt: off
2020-07-22 14:42:59 +03:00
vectors : Optional [ StrictStr ] = Field ( . . . , title = " Path to vectors " )
2020-08-04 16:09:37 +03:00
train_corpus : Reader = Field ( . . . , title = " Reader for the training data " )
dev_corpus : Reader = Field ( . . . , title = " Reader for the dev data " )
batcher : Batcher = Field ( . . . , title = " Batcher for the training data " )
2020-07-10 14:31:27 +03:00
dropout : StrictFloat = Field ( . . . , title = " Dropout rate " )
patience : StrictInt = Field ( . . . , title = " How many steps to continue without improvement in evaluation score " )
max_epochs : StrictInt = Field ( . . . , title = " Maximum number of epochs to train for " )
max_steps : StrictInt = Field ( . . . , title = " Maximum number of update steps to train for " )
eval_frequency : StrictInt = Field ( . . . , title = " How often to evaluate during training (steps) " )
2020-07-10 21:52:00 +03:00
seed : Optional [ StrictInt ] = Field ( . . . , title = " Random seed " )
2020-07-10 14:31:27 +03:00
accumulate_gradient : StrictInt = Field ( . . . , title = " Whether to divide the batch up into substeps " )
2020-07-28 12:22:24 +03:00
score_weights : Dict [ StrictStr , Union [ StrictFloat , StrictInt ] ] = Field ( . . . , title = " Scores to report and their weights for selecting final model " )
2020-07-12 13:32:08 +03:00
init_tok2vec : Optional [ StrictStr ] = Field ( . . . , title = " Path to pretrained tok2vec weights " )
2020-08-04 16:09:37 +03:00
raw_text : Optional [ StrictStr ] = Field ( default = None , title = " Raw text " )
2020-07-10 14:31:27 +03:00
optimizer : Optimizer = Field ( . . . , title = " The optimizer to use " )
2020-08-05 00:39:19 +03:00
frozen_components : List [ str ] = Field ( . . . , title = " Pipeline components that shouldn ' t be updated during training " )
2020-07-10 14:31:27 +03:00
# fmt: on
class Config :
extra = " forbid "
2020-06-21 14:44:00 +03:00
arbitrary_types_allowed = True
2020-07-10 14:31:27 +03:00
class ConfigSchemaNlp ( BaseModel ) :
2020-07-22 14:42:59 +03:00
# fmt: off
2020-07-10 14:31:27 +03:00
lang : StrictStr = Field ( . . . , title = " The base language to use " )
2020-07-22 14:42:59 +03:00
pipeline : List [ StrictStr ] = Field ( . . . , title = " The pipeline component names in order " )
tokenizer : Callable = Field ( . . . , title = " The tokenizer to use " )
2020-07-25 13:14:28 +03:00
load_vocab_data : StrictBool = Field ( . . . , title = " Whether to load additional vocab data from spacy-lookups-data " )
2020-08-05 20:47:54 +03:00
before_creation : Optional [ Callable [ [ Type [ " Language " ] ] , Type [ " Language " ] ] ] = Field ( . . . , title = " Optional callback to modify Language class before initialization " )
after_creation : Optional [ Callable [ [ " Language " ] , " Language " ] ] = Field ( . . . , title = " Optional callback to modify nlp object after creation and before the pipeline is constructed " )
after_pipeline_creation : Optional [ Callable [ [ " Language " ] , " Language " ] ] = Field ( . . . , title = " Optional callback to modify nlp object after the pipeline is constructed " )
2020-07-22 14:42:59 +03:00
# fmt: on
2019-12-25 14:39:49 +03:00
class Config :
extra = " forbid "
2020-07-10 14:31:27 +03:00
arbitrary_types_allowed = True
2020-07-11 14:03:53 +03:00
class ConfigSchemaPretrain ( BaseModel ) :
# fmt: off
max_epochs : StrictInt = Field ( . . . , title = " Maximum number of epochs to train for " )
min_length : StrictInt = Field ( . . . , title = " Minimum length of examples " )
max_length : StrictInt = Field ( . . . , title = " Maximum length of examples " )
dropout : StrictFloat = Field ( . . . , title = " Dropout rate " )
n_save_every : Optional [ StrictInt ] = Field ( . . . , title = " Saving frequency " )
batch_size : Union [ Sequence [ int ] , int ] = Field ( . . . , title = " The batch size or batch size schedule " )
seed : Optional [ StrictInt ] = Field ( . . . , title = " Random seed " )
use_pytorch_for_gpu_memory : StrictBool = Field ( . . . , title = " Allocate memory via PyTorch " )
2020-07-22 14:42:59 +03:00
tok2vec_model : StrictStr = Field ( . . . , title = " tok2vec model in config, e.g. components.tok2vec.model " )
2020-07-11 14:03:53 +03:00
optimizer : Optimizer = Field ( . . . , title = " The optimizer to use " )
# TODO: use a more detailed schema for this?
objective : Dict [ str , Any ] = Field ( . . . , title = " Pretraining objective " )
# fmt: on
class Config :
extra = " forbid "
arbitrary_types_allowed = True
2020-07-10 14:31:27 +03:00
class ConfigSchema ( BaseModel ) :
training : ConfigSchemaTraining
nlp : ConfigSchemaNlp
2020-07-11 14:03:53 +03:00
pretraining : Optional [ ConfigSchemaPretrain ]
2020-07-22 14:42:59 +03:00
components : Dict [ str , Dict [ str , Any ] ]
2020-07-11 14:03:53 +03:00
@root_validator
def validate_config ( cls , values ) :
""" Perform additional validation for settings with dependencies. """
pt = values . get ( " pretraining " )
if pt and pt . objective . get ( " type " ) == " vectors " and not values [ " nlp " ] . vectors :
err = " Need nlp.vectors if pretraining.objective.type is vectors "
raise ValueError ( err )
return values
2020-07-10 14:31:27 +03:00
class Config :
extra = " allow "
arbitrary_types_allowed = True
2020-06-21 14:44:00 +03:00
# Project config Schema
class ProjectConfigAsset ( BaseModel ) :
2020-06-27 15:15:41 +03:00
# fmt: off
2020-06-21 14:44:00 +03:00
dest : StrictStr = Field ( . . . , title = " Destination of downloaded asset " )
2020-07-07 21:51:50 +03:00
url : Optional [ StrictStr ] = Field ( None , title = " URL of asset " )
2020-06-27 15:15:41 +03:00
checksum : str = Field ( None , title = " MD5 hash of file " , regex = r " ([a-fA-F \ d] {32} ) " )
# fmt: on
2020-06-21 14:44:00 +03:00
class ProjectConfigCommand ( BaseModel ) :
# fmt: off
name : StrictStr = Field ( . . . , title = " Name of command " )
help : Optional [ StrictStr ] = Field ( None , title = " Command description " )
script : List [ StrictStr ] = Field ( [ ] , title = " List of CLI commands to run, in order " )
2020-07-10 00:51:18 +03:00
deps : List [ StrictStr ] = Field ( [ ] , title = " File dependencies required by this command " )
outputs : List [ StrictStr ] = Field ( [ ] , title = " Outputs produced by this command " )
outputs_no_cache : List [ StrictStr ] = Field ( [ ] , title = " Outputs not tracked by DVC (DVC only) " )
no_skip : bool = Field ( False , title = " Never skip this command, even if nothing changed " )
2020-06-21 14:44:00 +03:00
# fmt: on
2020-06-22 15:53:31 +03:00
class Config :
title = " A single named command specified in a project config "
extra = " forbid "
2020-06-21 14:44:00 +03:00
class ProjectConfigSchema ( BaseModel ) :
# fmt: off
variables : Dict [ StrictStr , Union [ str , int , float , bool ] ] = Field ( { } , title = " Optional variables to substitute in commands " )
assets : List [ ProjectConfigAsset ] = Field ( [ ] , title = " Data assets " )
2020-07-09 02:42:51 +03:00
workflows : Dict [ StrictStr , List [ StrictStr ] ] = Field ( { } , title = " Named workflows, mapped to list of project commands to run in order " )
2020-06-21 14:44:00 +03:00
commands : List [ ProjectConfigCommand ] = Field ( [ ] , title = " Project command shortucts " )
# fmt: on
class Config :
title = " Schema for project configuration file "