mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Merge branch 'develop' of https://github.com/explosion/spaCy into develop
This commit is contained in:
		
						commit
						f25cff1e38
					
				| 
						 | 
				
			
			@ -51,7 +51,7 @@ def debug_model_cli(
 | 
			
		|||
    with show_validation_error(config_path):
 | 
			
		||||
        config = util.load_config(config_path, overrides=config_overrides)
 | 
			
		||||
        nlp, config = util.load_model_from_config(config_path)
 | 
			
		||||
    seed = config["pretraining"]["seed"]
 | 
			
		||||
    seed = config["training"]["seed"]
 | 
			
		||||
    if seed is not None:
 | 
			
		||||
        msg.info(f"Fixing random seed: {seed}")
 | 
			
		||||
        fix_random_seed(seed)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -7,6 +7,7 @@ import srsly
 | 
			
		|||
import re
 | 
			
		||||
 | 
			
		||||
from .. import util
 | 
			
		||||
from ..language import DEFAULT_CONFIG_PRETRAIN_PATH
 | 
			
		||||
from ..schemas import RecommendationSchema
 | 
			
		||||
from ._util import init_cli, Arg, Opt, show_validation_error, COMMAND
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -48,6 +49,7 @@ def init_fill_config_cli(
 | 
			
		|||
    # fmt: off
 | 
			
		||||
    base_path: Path = Arg(..., help="Base config to fill", exists=True, dir_okay=False),
 | 
			
		||||
    output_file: Path = Arg("-", help="File to save config.cfg to (or - for stdout)", allow_dash=True),
 | 
			
		||||
    pretraining: bool = Opt(False, "--pretraining", "-p", help="Include config for pretraining (with 'spacy pretrain')"),
 | 
			
		||||
    diff: bool = Opt(False, "--diff", "-D", help="Print a visual diff highlighting the changes")
 | 
			
		||||
    # fmt: on
 | 
			
		||||
):
 | 
			
		||||
| 
						 | 
				
			
			@ -58,19 +60,24 @@ def init_fill_config_cli(
 | 
			
		|||
    can be used with a config generated via the training quickstart widget:
 | 
			
		||||
    https://nightly.spacy.io/usage/training#quickstart
 | 
			
		||||
    """
 | 
			
		||||
    fill_config(output_file, base_path, diff=diff)
 | 
			
		||||
    fill_config(output_file, base_path, pretraining=pretraining, diff=diff)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def fill_config(
 | 
			
		||||
    output_file: Path, base_path: Path, *, diff: bool = False
 | 
			
		||||
    output_file: Path, base_path: Path, *, pretraining: bool = False, diff: bool = False
 | 
			
		||||
) -> Tuple[Config, Config]:
 | 
			
		||||
    is_stdout = str(output_file) == "-"
 | 
			
		||||
    msg = Printer(no_print=is_stdout)
 | 
			
		||||
    with show_validation_error(hint_fill=False):
 | 
			
		||||
        config = util.load_config(base_path)
 | 
			
		||||
        nlp, _ = util.load_model_from_config(config, auto_fill=True)
 | 
			
		||||
    filled = nlp.config
 | 
			
		||||
    if pretraining:
 | 
			
		||||
        validate_config_for_pretrain(filled, msg)
 | 
			
		||||
        pretrain_config = util.load_config(DEFAULT_CONFIG_PRETRAIN_PATH)
 | 
			
		||||
        filled = pretrain_config.merge(filled)
 | 
			
		||||
    before = config.to_str()
 | 
			
		||||
    after = nlp.config.to_str()
 | 
			
		||||
    after = filled.to_str()
 | 
			
		||||
    if before == after:
 | 
			
		||||
        msg.warn("Nothing to auto-fill: base config is already complete")
 | 
			
		||||
    else:
 | 
			
		||||
| 
						 | 
				
			
			@ -84,7 +91,7 @@ def fill_config(
 | 
			
		|||
            print(diff_strings(before, after))
 | 
			
		||||
            msg.divider("END CONFIG DIFF")
 | 
			
		||||
            print("")
 | 
			
		||||
    save_config(nlp.config, output_file, is_stdout=is_stdout)
 | 
			
		||||
    save_config(filled, output_file, is_stdout=is_stdout)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_config(
 | 
			
		||||
| 
						 | 
				
			
			@ -132,12 +139,9 @@ def init_config(
 | 
			
		|||
    msg.info("Generated template specific for your use case")
 | 
			
		||||
    for label, value in use_case.items():
 | 
			
		||||
        msg.text(f"- {label}: {value}")
 | 
			
		||||
    use_transformer = bool(template_vars.use_transformer)
 | 
			
		||||
    with show_validation_error(hint_fill=False):
 | 
			
		||||
        config = util.load_config_from_str(base_template)
 | 
			
		||||
        nlp, _ = util.load_model_from_config(config, auto_fill=True)
 | 
			
		||||
    if use_transformer:
 | 
			
		||||
        nlp.config.pop("pretraining", {})  # TODO: solve this better
 | 
			
		||||
    msg.good("Auto-filled config with all values")
 | 
			
		||||
    save_config(nlp.config, output_file, is_stdout=is_stdout)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -161,3 +165,15 @@ def has_spacy_transformers() -> bool:
 | 
			
		|||
        return True
 | 
			
		||||
    except ImportError:
 | 
			
		||||
        return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def validate_config_for_pretrain(config: Config, msg: Printer) -> None:
 | 
			
		||||
    if "tok2vec" not in config["nlp"]["pipeline"]:
 | 
			
		||||
        msg.warn(
 | 
			
		||||
            "No tok2vec component found in the pipeline. If your tok2vec "
 | 
			
		||||
            "component has a different name, you may need to adjust the "
 | 
			
		||||
            "tok2vec_model reference in the [pretraining] block. If you don't "
 | 
			
		||||
            "have a tok2vec component, make sure to add it to your [components] "
 | 
			
		||||
            "and the pipeline specified in the [nlp] block, so you can pretrain "
 | 
			
		||||
            "weights for it."
 | 
			
		||||
        )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -90,19 +90,20 @@ def pretrain(
 | 
			
		|||
    with show_validation_error(config_path):
 | 
			
		||||
        config = util.load_config(config_path, overrides=config_overrides)
 | 
			
		||||
        nlp, config = util.load_model_from_config(config)
 | 
			
		||||
    # TODO: validate that [pretraining] block exists
 | 
			
		||||
    pretrain_config = config["pretraining"]
 | 
			
		||||
    if not pretrain_config:
 | 
			
		||||
        # TODO: What's the solution here? How do we handle optional blocks?
 | 
			
		||||
        msg.fail("The [pretraining] block in your config is empty", exits=1)
 | 
			
		||||
    if not output_dir.exists():
 | 
			
		||||
        output_dir.mkdir()
 | 
			
		||||
        msg.good(f"Created output directory: {output_dir}")
 | 
			
		||||
    seed = config["pretraining"]["seed"]
 | 
			
		||||
    seed = pretrain_config["seed"]
 | 
			
		||||
    if seed is not None:
 | 
			
		||||
        fix_random_seed(seed)
 | 
			
		||||
    if use_gpu >= 0 and config["pretraining"]["use_pytorch_for_gpu_memory"]:
 | 
			
		||||
    if use_gpu >= 0 and pretrain_config["use_pytorch_for_gpu_memory"]:
 | 
			
		||||
        use_pytorch_for_gpu_memory()
 | 
			
		||||
    config.to_disk(output_dir / "config.cfg")
 | 
			
		||||
    msg.good("Saved config file in the output directory")
 | 
			
		||||
    pretrain_config = config["pretraining"]
 | 
			
		||||
 | 
			
		||||
    if texts_loc != "-":  # reading from a file
 | 
			
		||||
        with msg.loading("Loading input texts..."):
 | 
			
		||||
            texts = list(srsly.read_jsonl(texts_loc))
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -117,7 +117,7 @@ def train(
 | 
			
		|||
 | 
			
		||||
    # Load a pretrained tok2vec model - cf. CLI command 'pretrain'
 | 
			
		||||
    if weights_data is not None:
 | 
			
		||||
        tok2vec_path = config.get("pretraining", {}).get("tok2vec_model", None)
 | 
			
		||||
        tok2vec_path = config["pretraining"].get("tok2vec_model", None)
 | 
			
		||||
        if tok2vec_path is None:
 | 
			
		||||
            msg.fail(
 | 
			
		||||
                f"To use a pretrained tok2vec model, the config needs to specify which "
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -90,29 +90,3 @@ eps = 1e-8
 | 
			
		|||
warmup_steps = 250
 | 
			
		||||
total_steps = 20000
 | 
			
		||||
initial_rate = 0.001
 | 
			
		||||
 | 
			
		||||
[pretraining]
 | 
			
		||||
max_epochs = 1000
 | 
			
		||||
min_length = 5
 | 
			
		||||
max_length = 500
 | 
			
		||||
dropout = 0.2
 | 
			
		||||
n_save_every = null
 | 
			
		||||
batch_size = 3000
 | 
			
		||||
seed = ${system.seed}
 | 
			
		||||
use_pytorch_for_gpu_memory = ${system.use_pytorch_for_gpu_memory}
 | 
			
		||||
tok2vec_model = "components.tok2vec.model"
 | 
			
		||||
 | 
			
		||||
[pretraining.objective]
 | 
			
		||||
type = "characters"
 | 
			
		||||
n_characters = 4
 | 
			
		||||
 | 
			
		||||
[pretraining.optimizer]
 | 
			
		||||
@optimizers = "Adam.v1"
 | 
			
		||||
beta1 = 0.9
 | 
			
		||||
beta2 = 0.999
 | 
			
		||||
L2_is_weight_decay = true
 | 
			
		||||
L2 = 0.01
 | 
			
		||||
grad_clip = 1.0
 | 
			
		||||
use_averages = true
 | 
			
		||||
eps = 1e-8
 | 
			
		||||
learn_rate = 0.001
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
							
								
								
									
										25
									
								
								spacy/default_config_pretraining.cfg
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										25
									
								
								spacy/default_config_pretraining.cfg
									
									
									
									
									
										Normal file
									
								
							| 
						 | 
				
			
			@ -0,0 +1,25 @@
 | 
			
		|||
[pretraining]
 | 
			
		||||
max_epochs = 1000
 | 
			
		||||
min_length = 5
 | 
			
		||||
max_length = 500
 | 
			
		||||
dropout = 0.2
 | 
			
		||||
n_save_every = null
 | 
			
		||||
batch_size = 3000
 | 
			
		||||
seed = ${system.seed}
 | 
			
		||||
use_pytorch_for_gpu_memory = ${system.use_pytorch_for_gpu_memory}
 | 
			
		||||
tok2vec_model = "components.tok2vec.model"
 | 
			
		||||
 | 
			
		||||
[pretraining.objective]
 | 
			
		||||
type = "characters"
 | 
			
		||||
n_characters = 4
 | 
			
		||||
 | 
			
		||||
[pretraining.optimizer]
 | 
			
		||||
@optimizers = "Adam.v1"
 | 
			
		||||
beta1 = 0.9
 | 
			
		||||
beta2 = 0.999
 | 
			
		||||
L2_is_weight_decay = true
 | 
			
		||||
L2 = 0.01
 | 
			
		||||
grad_clip = 1.0
 | 
			
		||||
use_averages = true
 | 
			
		||||
eps = 1e-8
 | 
			
		||||
learn_rate = 0.001
 | 
			
		||||
| 
						 | 
				
			
			@ -37,6 +37,9 @@ from . import about
 | 
			
		|||
# This is the base config will all settings (training etc.)
 | 
			
		||||
DEFAULT_CONFIG_PATH = Path(__file__).parent / "default_config.cfg"
 | 
			
		||||
DEFAULT_CONFIG = util.load_config(DEFAULT_CONFIG_PATH)
 | 
			
		||||
# This is the base config for the [pretraining] block and currently not included
 | 
			
		||||
# in the main config and only added via the 'init fill-config' command
 | 
			
		||||
DEFAULT_CONFIG_PRETRAIN_PATH = Path(__file__).parent / "default_config_pretraining.cfg"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseDefaults:
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -233,6 +233,11 @@ class ConfigSchemaNlp(BaseModel):
 | 
			
		|||
        arbitrary_types_allowed = True
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConfigSchemaPretrainEmpty(BaseModel):
 | 
			
		||||
    class Config:
 | 
			
		||||
        extra = "forbid"
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ConfigSchemaPretrain(BaseModel):
 | 
			
		||||
    # fmt: off
 | 
			
		||||
    max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
 | 
			
		||||
| 
						 | 
				
			
			@ -257,14 +262,15 @@ class ConfigSchemaPretrain(BaseModel):
 | 
			
		|||
class ConfigSchema(BaseModel):
 | 
			
		||||
    training: ConfigSchemaTraining
 | 
			
		||||
    nlp: ConfigSchemaNlp
 | 
			
		||||
    pretraining: Optional[ConfigSchemaPretrain]
 | 
			
		||||
    pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
 | 
			
		||||
    components: Dict[str, Dict[str, Any]]
 | 
			
		||||
 | 
			
		||||
    @root_validator
 | 
			
		||||
    def validate_config(cls, values):
 | 
			
		||||
        """Perform additional validation for settings with dependencies."""
 | 
			
		||||
        pt = values.get("pretraining")
 | 
			
		||||
        if pt and pt.objective.get("type") == "vectors" and not values["nlp"].vectors:
 | 
			
		||||
        if pt and not isinstance(pt, ConfigSchemaPretrainEmpty):
 | 
			
		||||
            if pt.objective.get("type") == "vectors" and not values["nlp"].vectors:
 | 
			
		||||
                err = "Need nlp.vectors if pretraining.objective.type is vectors"
 | 
			
		||||
                raise ValueError(err)
 | 
			
		||||
        return values
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -3,10 +3,11 @@ from thinc.config import Config, ConfigValidationError
 | 
			
		|||
import spacy
 | 
			
		||||
from spacy.lang.en import English
 | 
			
		||||
from spacy.lang.de import German
 | 
			
		||||
from spacy.language import Language
 | 
			
		||||
from spacy.language import Language, DEFAULT_CONFIG
 | 
			
		||||
from spacy.util import registry, load_model_from_config
 | 
			
		||||
from spacy.ml.models import build_Tok2Vec_model, build_tb_parser_model
 | 
			
		||||
from spacy.ml.models import MultiHashEmbed, MaxoutWindowEncoder
 | 
			
		||||
from spacy.schemas import ConfigSchema
 | 
			
		||||
 | 
			
		||||
from ..util import make_tempdir
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -299,3 +300,16 @@ def test_config_interpolation():
 | 
			
		|||
    nlp2 = English.from_config(interpolated)
 | 
			
		||||
    assert nlp2.config["training"]["train_corpus"]["path"] == ""
 | 
			
		||||
    assert nlp2.config["components"]["tagger"]["model"]["tok2vec"]["width"] == 342
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_config_optional_sections():
 | 
			
		||||
    config = Config().from_str(nlp_config_string)
 | 
			
		||||
    config = DEFAULT_CONFIG.merge(config)
 | 
			
		||||
    assert "pretraining" not in config
 | 
			
		||||
    filled = registry.fill_config(config, schema=ConfigSchema, validate=False)
 | 
			
		||||
    # Make sure that optional "pretraining" block doesn't default to None,
 | 
			
		||||
    # which would (rightly) cause error because it'd result in a top-level
 | 
			
		||||
    # key that's not a section (dict). Note that the following roundtrip is
 | 
			
		||||
    # also how Config.interpolate works under the hood.
 | 
			
		||||
    new_config = Config().from_str(filled.to_str())
 | 
			
		||||
    assert new_config["pretraining"] == {}
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user