mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Fix commands
This commit is contained in:
		
							parent
							
								
									44bad1474c
								
							
						
					
					
						commit
						553bfea641
					
				| 
						 | 
				
			
			@ -16,6 +16,7 @@ import os
 | 
			
		|||
 | 
			
		||||
from ..schemas import ProjectConfigSchema, validate
 | 
			
		||||
from ..util import import_file, run_command, make_tempdir, registry, logger
 | 
			
		||||
from ..util import ensure_path
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from pathy import Pathy  # noqa: F401
 | 
			
		||||
| 
						 | 
				
			
			@ -458,3 +459,24 @@ def string_to_list(value: str, intify: bool = False) -> Union[List[str], List[in
 | 
			
		|||
            p = int(p)
 | 
			
		||||
        result.append(p)
 | 
			
		||||
    return result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def load_from_paths(
 | 
			
		||||
    config: Config,
 | 
			
		||||
) -> Tuple[List[Dict[str, str]], Dict[str, dict], bytes]:
 | 
			
		||||
    # TODO: separate checks from loading
 | 
			
		||||
    raw_text = ensure_path(config["training"]["raw_text"])
 | 
			
		||||
    if raw_text is not None:
 | 
			
		||||
        if not raw_text.exists():
 | 
			
		||||
            msg.fail("Can't find raw text", raw_text, exits=1)
 | 
			
		||||
        raw_text = list(srsly.read_jsonl(config["training"]["raw_text"]))
 | 
			
		||||
    tag_map = {}
 | 
			
		||||
    morph_rules = {}
 | 
			
		||||
    weights_data = None
 | 
			
		||||
    init_tok2vec = ensure_path(config["training"]["init_tok2vec"])
 | 
			
		||||
    if init_tok2vec is not None:
 | 
			
		||||
        if not init_tok2vec.exists():
 | 
			
		||||
            msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
 | 
			
		||||
        with init_tok2vec.open("rb") as file_:
 | 
			
		||||
            weights_data = file_.read()
 | 
			
		||||
    return raw_text, tag_map, morph_rules, weights_data
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,68 +0,0 @@
 | 
			
		|||
from typing import Optional, List, Dict, Any, Union, IO
 | 
			
		||||
import math
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
import numpy
 | 
			
		||||
from ast import literal_eval
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from preshed.counter import PreshCounter
 | 
			
		||||
import tarfile
 | 
			
		||||
import gzip
 | 
			
		||||
import zipfile
 | 
			
		||||
import srsly
 | 
			
		||||
import warnings
 | 
			
		||||
from wasabi import msg, Printer
 | 
			
		||||
import typer
 | 
			
		||||
from ._util import init_cli, Arg, Opt, parse_config_overrides, show_validation_error
 | 
			
		||||
 | 
			
		||||
DEFAULT_OOV_PROB = -20
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
#@init_cli.command("vocab")
 | 
			
		||||
#@app.command(
 | 
			
		||||
#    "init-model",
 | 
			
		||||
#    context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
 | 
			
		||||
#    hidden=True,  # hide this from main CLI help but still allow it to work with warning
 | 
			
		||||
#)
 | 
			
		||||
def init_model_cli(
 | 
			
		||||
    # fmt: off
 | 
			
		||||
    ctx: typer.Context,  # This is only used to read additional arguments
 | 
			
		||||
    lang: str = Arg(..., help="Pipeline language"),
 | 
			
		||||
    output_dir: Path = Arg(..., help="Pipeline output directory"),
 | 
			
		||||
    freqs_loc: Optional[Path] = Arg(None, help="Location of words frequencies file", exists=True),
 | 
			
		||||
    clusters_loc: Optional[Path] = Opt(None, "--clusters-loc", "-c", help="Optional location of brown clusters data", exists=True),
 | 
			
		||||
    jsonl_loc: Optional[Path] = Opt(None, "--jsonl-loc", "-j", help="Location of JSONL-formatted attributes file", exists=True),
 | 
			
		||||
    vectors_loc: Optional[Path] = Opt(None, "--vectors-loc", "-v", help="Optional vectors file in Word2Vec format", exists=True),
 | 
			
		||||
    prune_vectors: int = Opt(-1, "--prune-vectors", "-V", help="Optional number of vectors to prune to"),
 | 
			
		||||
    truncate_vectors: int = Opt(0, "--truncate-vectors", "-t", help="Optional number of vectors to truncate to when reading in vectors file"),
 | 
			
		||||
    vectors_name: Optional[str] = Opt(None, "--vectors-name", "-vn", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
 | 
			
		||||
    model_name: Optional[str] = Opt(None, "--meta-name", "-mn", help="Optional name of the package for the pipeline meta"),
 | 
			
		||||
    base_model: Optional[str] = Opt(None, "--base", "-b", help="Name of or path to base pipeline to start with (mostly relevant for pipelines with custom tokenizers)")
 | 
			
		||||
    # fmt: on
 | 
			
		||||
):
 | 
			
		||||
    """
 | 
			
		||||
    Create a new blank pipeline directory with vocab and vectors from raw data.
 | 
			
		||||
    If vectors are provided in Word2Vec format, they can be either a .txt or
 | 
			
		||||
    zipped as a .zip or .tar.gz.
 | 
			
		||||
 | 
			
		||||
    DOCS: https://nightly.spacy.io/api/cli#init-vocab
 | 
			
		||||
    """
 | 
			
		||||
    if ctx.command.name == "init-model":
 | 
			
		||||
        msg.warn(
 | 
			
		||||
            "The init-model command is now called 'init vocab'. You can run "
 | 
			
		||||
            "'python -m spacy init --help' for an overview of the other "
 | 
			
		||||
            "available initialization commands."
 | 
			
		||||
        )
 | 
			
		||||
    init_vocab(
 | 
			
		||||
        lang,
 | 
			
		||||
        output_dir,
 | 
			
		||||
        freqs_loc=freqs_loc,
 | 
			
		||||
        clusters_loc=clusters_loc,
 | 
			
		||||
        jsonl_loc=jsonl_loc,
 | 
			
		||||
        vectors_loc=vectors_loc,
 | 
			
		||||
        prune_vectors=prune_vectors,
 | 
			
		||||
        truncate_vectors=truncate_vectors,
 | 
			
		||||
        vectors_name=vectors_name,
 | 
			
		||||
        model_name=model_name,
 | 
			
		||||
        base_model=base_model,
 | 
			
		||||
        silent=False,
 | 
			
		||||
    )
 | 
			
		||||
| 
						 | 
				
			
			@ -1,18 +1,17 @@
 | 
			
		|||
from typing import Optional, Dict, Any, Tuple, Union, Callable, List
 | 
			
		||||
from typing import Optional, Dict, Callable, Any
 | 
			
		||||
import logging
 | 
			
		||||
import srsly
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from wasabi import msg
 | 
			
		||||
import typer
 | 
			
		||||
from thinc.api import Config, fix_random_seed
 | 
			
		||||
from thinc.api import Config, fix_random_seed, set_gpu_allocator
 | 
			
		||||
 | 
			
		||||
from .train import create_before_to_disk_callback
 | 
			
		||||
from .. import util
 | 
			
		||||
from ..util import registry
 | 
			
		||||
from ..schemas import ConfigSchemaTraining
 | 
			
		||||
from ..util import registry, resolve_dot_names
 | 
			
		||||
from ..schemas import ConfigSchemaTraining, ConfigSchemaPretrain
 | 
			
		||||
from ..language import Language
 | 
			
		||||
from ..errors import Errors
 | 
			
		||||
from ._util import init_cli, Arg, Opt, parse_config_overrides, show_validation_error
 | 
			
		||||
from ._util import import_code, get_sourced_components
 | 
			
		||||
from ..util import resolve_dot_names
 | 
			
		||||
from ._util import import_code, get_sourced_components, load_from_paths
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@init_cli.command(
 | 
			
		||||
| 
						 | 
				
			
			@ -31,10 +30,12 @@ def init_pipeline_cli(
 | 
			
		|||
    util.logger.setLevel(logging.DEBUG if verbose else logging.ERROR)
 | 
			
		||||
    overrides = parse_config_overrides(ctx.args)
 | 
			
		||||
    import_code(code_path)
 | 
			
		||||
    config = util.load_config(config_path, overrides=overrides)
 | 
			
		||||
    with show_validation_error(config_path):
 | 
			
		||||
        config = util.load_config(config_path, overrides=overrides)
 | 
			
		||||
    nlp = init_pipeline(config)
 | 
			
		||||
    nlp.to_disk(output_path)
 | 
			
		||||
    # TODO: add more instructions
 | 
			
		||||
    msg.good(f"Saved initialized pipeline to {output_path}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def must_initialize(init_path: Path, config_path: Path, overrides: Dict) -> bool:
 | 
			
		||||
| 
						 | 
				
			
			@ -51,7 +52,7 @@ def must_initialize(init_path: Path, config_path: Path, overrides: Dict) -> bool
 | 
			
		|||
            return False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def init_pipeline(config: Config, use_gpu=-1):
 | 
			
		||||
def init_pipeline(config: Config, use_gpu: int = -1) -> Language:
 | 
			
		||||
    raw_config = config
 | 
			
		||||
    config = raw_config.interpolate()
 | 
			
		||||
    if config["training"]["seed"] is not None:
 | 
			
		||||
| 
						 | 
				
			
			@ -61,22 +62,19 @@ def init_pipeline(config: Config, use_gpu=-1):
 | 
			
		|||
        set_gpu_allocator(allocator)
 | 
			
		||||
    # Use original config here before it's resolved to functions
 | 
			
		||||
    sourced_components = get_sourced_components(config)
 | 
			
		||||
    with show_validation_error():
 | 
			
		||||
        nlp = util.load_model_from_config(raw_config)
 | 
			
		||||
    msg.good("Set up nlp object from config")
 | 
			
		||||
    # Resolve all training-relevant sections using the filled nlp config
 | 
			
		||||
    T = registry.resolve(
 | 
			
		||||
        config["training"],
 | 
			
		||||
        schema=ConfigSchemaTraining,
 | 
			
		||||
        validate=True,
 | 
			
		||||
    )
 | 
			
		||||
    T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
 | 
			
		||||
    dot_names = [T["train_corpus"], T["dev_corpus"], T["raw_text"]]
 | 
			
		||||
    train_corpus, dev_corpus, raw_text = resolve_dot_names(config, dot_names)
 | 
			
		||||
    util.load_vocab_data_into_model(nlp, lookups=T["lookups"])
 | 
			
		||||
    msg.good("Created vocabulary")
 | 
			
		||||
    if T["vectors"] is not None:
 | 
			
		||||
        add_vectors(nlp, T["vectors"])
 | 
			
		||||
    score_weights = T["score_weights"]
 | 
			
		||||
        msg.good(f"Added vectors: {T['vectors']}")
 | 
			
		||||
    optimizer = T["optimizer"]
 | 
			
		||||
    batcher = T["batcher"]
 | 
			
		||||
    train_logger = T["logger"]
 | 
			
		||||
    before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
 | 
			
		||||
    # Components that shouldn't be updated during training
 | 
			
		||||
    frozen_components = T["frozen_components"]
 | 
			
		||||
| 
						 | 
				
			
			@ -89,13 +87,23 @@ def init_pipeline(config: Config, use_gpu=-1):
 | 
			
		|||
            nlp.resume_training(sgd=optimizer)
 | 
			
		||||
    with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
 | 
			
		||||
        nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
 | 
			
		||||
        msg.good(f"Initialized pipeline components")
 | 
			
		||||
    # Verify the config after calling 'begin_training' to ensure labels
 | 
			
		||||
    # are properly initialized
 | 
			
		||||
    verify_config(nlp)
 | 
			
		||||
    if "pretraining" in config and config["pretraining"]:
 | 
			
		||||
        P = registry.resolve(config["pretraining"], schema=ConfigSchemaPretrain)
 | 
			
		||||
        add_tok2vec_weights({"training": T, "pretraining": P}, nlp)
 | 
			
		||||
    # TODO: this should be handled better?
 | 
			
		||||
    nlp = before_to_disk(nlp)
 | 
			
		||||
    return nlp
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def add_tok2vec_weights(config: Config, nlp: Language) -> None:
 | 
			
		||||
    # Load pretrained tok2vec weights - cf. CLI command 'pretrain'
 | 
			
		||||
    weights_data = load_from_paths(config)
 | 
			
		||||
    if weights_data is not None:
 | 
			
		||||
        tok2vec_component = C["pretraining"]["component"]
 | 
			
		||||
        tok2vec_component = config["pretraining"]["component"]
 | 
			
		||||
        if tok2vec_component is None:
 | 
			
		||||
            msg.fail(
 | 
			
		||||
                f"To use pretrained tok2vec weights, [pretraining.component] "
 | 
			
		||||
| 
						 | 
				
			
			@ -103,9 +111,63 @@ def init_pipeline(config: Config, use_gpu=-1):
 | 
			
		|||
                exits=1,
 | 
			
		||||
            )
 | 
			
		||||
        layer = nlp.get_pipe(tok2vec_component).model
 | 
			
		||||
        tok2vec_layer = C["pretraining"]["layer"]
 | 
			
		||||
        tok2vec_layer = config["pretraining"]["layer"]
 | 
			
		||||
        if tok2vec_layer:
 | 
			
		||||
            layer = layer.get_ref(tok2vec_layer)
 | 
			
		||||
        layer.from_bytes(weights_data)
 | 
			
		||||
        msg.info(f"Loaded pretrained weights into component '{tok2vec_component}'")
 | 
			
		||||
        msg.good(f"Loaded pretrained weights into component '{tok2vec_component}'")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def add_vectors(nlp: Language, vectors: str) -> None:
 | 
			
		||||
    title = f"Config validation error for vectors {vectors}"
 | 
			
		||||
    desc = (
 | 
			
		||||
        "This typically means that there's a problem in the config.cfg included "
 | 
			
		||||
        "with the packaged vectors. Make sure that the vectors package you're "
 | 
			
		||||
        "loading is compatible with the current version of spaCy."
 | 
			
		||||
    )
 | 
			
		||||
    with show_validation_error(
 | 
			
		||||
        title=title, desc=desc, hint_fill=False, show_config=False
 | 
			
		||||
    ):
 | 
			
		||||
        util.load_vectors_into_model(nlp, vectors)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def verify_config(nlp: Language) -> None:
 | 
			
		||||
    """Perform additional checks based on the config, loaded nlp object and training data."""
 | 
			
		||||
    # TODO: maybe we should validate based on the actual components, the list
 | 
			
		||||
    # in config["nlp"]["pipeline"] instead?
 | 
			
		||||
    for pipe_config in nlp.config["components"].values():
 | 
			
		||||
        # We can't assume that the component name == the factory
 | 
			
		||||
        factory = pipe_config["factory"]
 | 
			
		||||
        if factory == "textcat":
 | 
			
		||||
            verify_textcat_config(nlp, pipe_config)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def verify_textcat_config(nlp: Language, pipe_config: Dict[str, Any]) -> None:
 | 
			
		||||
    # if 'positive_label' is provided: double check whether it's in the data and
 | 
			
		||||
    # the task is binary
 | 
			
		||||
    if pipe_config.get("positive_label"):
 | 
			
		||||
        textcat_labels = nlp.get_pipe("textcat").labels
 | 
			
		||||
        pos_label = pipe_config.get("positive_label")
 | 
			
		||||
        if pos_label not in textcat_labels:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                Errors.E920.format(pos_label=pos_label, labels=textcat_labels)
 | 
			
		||||
            )
 | 
			
		||||
        if len(list(textcat_labels)) != 2:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                Errors.E919.format(pos_label=pos_label, labels=textcat_labels)
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_before_to_disk_callback(
 | 
			
		||||
    callback: Optional[Callable[[Language], Language]]
 | 
			
		||||
) -> Callable[[Language], Language]:
 | 
			
		||||
    def before_to_disk(nlp: Language) -> Language:
 | 
			
		||||
        if not callback:
 | 
			
		||||
            return nlp
 | 
			
		||||
        modified_nlp = callback(nlp)
 | 
			
		||||
        if not isinstance(modified_nlp, Language):
 | 
			
		||||
            err = Errors.E914.format(name="before_to_disk", value=type(modified_nlp))
 | 
			
		||||
            raise ValueError(err)
 | 
			
		||||
        return modified_nlp
 | 
			
		||||
 | 
			
		||||
    return before_to_disk
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -1,6 +1,5 @@
 | 
			
		|||
from typing import Optional, Dict, Any, Tuple, Union, Callable, List
 | 
			
		||||
from timeit import default_timer as timer
 | 
			
		||||
import srsly
 | 
			
		||||
import tqdm
 | 
			
		||||
from pathlib import Path
 | 
			
		||||
from wasabi import msg
 | 
			
		||||
| 
						 | 
				
			
			@ -11,13 +10,17 @@ import random
 | 
			
		|||
import typer
 | 
			
		||||
import logging
 | 
			
		||||
 | 
			
		||||
from .init_pipeline import init_pipeline, must_initialize
 | 
			
		||||
from .init_pipeline import create_before_to_disk_callback
 | 
			
		||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
 | 
			
		||||
from ._util import import_code, get_sourced_components
 | 
			
		||||
from ._util import import_code
 | 
			
		||||
from ._util import load_from_paths  # noqa: F401 (needed for Ray extension for now)
 | 
			
		||||
from ..language import Language
 | 
			
		||||
from .. import util
 | 
			
		||||
from ..training.example import Example
 | 
			
		||||
from ..errors import Errors
 | 
			
		||||
from ..util import resolve_dot_names
 | 
			
		||||
from ..util import resolve_dot_names, registry
 | 
			
		||||
from ..schemas import ConfigSchemaTraining
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@app.command(
 | 
			
		||||
| 
						 | 
				
			
			@ -56,25 +59,35 @@ def train_cli(
 | 
			
		|||
        require_gpu(use_gpu)
 | 
			
		||||
    else:
 | 
			
		||||
        msg.info("Using CPU")
 | 
			
		||||
    config = util.load_config(
 | 
			
		||||
        config_path, overrides=config_overrides, interpolate=False
 | 
			
		||||
    )
 | 
			
		||||
    config = util.load_config(config_path, overrides=overrides, interpolate=False)
 | 
			
		||||
    msg.divider("Initializing pipeline")
 | 
			
		||||
    # TODO: add warnings / --initialize (?) argument
 | 
			
		||||
    if output_path is None:
 | 
			
		||||
        nlp = init_pipeline(config)
 | 
			
		||||
    else:
 | 
			
		||||
        init_path = output_path / "model-initial"
 | 
			
		||||
        if must_reinitialize(config, init_path):
 | 
			
		||||
        if must_initialize(config, init_path):
 | 
			
		||||
            nlp = init_pipeline(config)
 | 
			
		||||
            nlp.to_disk(init_path)
 | 
			
		||||
            msg.good(f"Saved initialized pipeline to {init_path}")
 | 
			
		||||
        else:
 | 
			
		||||
            nlp = spacy.load(output_path / "model-initial")
 | 
			
		||||
    msg.info("Start training")
 | 
			
		||||
    train(nlp, config, output_path)
 | 
			
		||||
            nlp = util.load_model(init_path)
 | 
			
		||||
            msg.good(f"Loaded initialized pipeline from {init_path}")
 | 
			
		||||
    msg.divider("Training pipeline")
 | 
			
		||||
    train(nlp, output_path, use_gpu=use_gpu)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train(nlp: Language, output_path: Optional[Path]=None) -> None:
 | 
			
		||||
def train(
 | 
			
		||||
    nlp: Language, output_path: Optional[Path] = None, *, use_gpu: int = -1
 | 
			
		||||
) -> None:
 | 
			
		||||
    # TODO: random seed, GPU allocator
 | 
			
		||||
    # Create iterator, which yields out info after each optimization step.
 | 
			
		||||
    config = nlp.config.interpolate()
 | 
			
		||||
    if config["training"]["seed"] is not None:
 | 
			
		||||
        fix_random_seed(config["training"]["seed"])
 | 
			
		||||
    allocator = config["training"]["gpu_allocator"]
 | 
			
		||||
    if use_gpu >= 0 and allocator:
 | 
			
		||||
        set_gpu_allocator(allocator)
 | 
			
		||||
    T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
 | 
			
		||||
    dot_names = [T["train_corpus"], T["dev_corpus"], T["raw_text"]]
 | 
			
		||||
    train_corpus, dev_corpus, raw_text = resolve_dot_names(config, dot_names)
 | 
			
		||||
| 
						 | 
				
			
			@ -85,9 +98,7 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
 | 
			
		|||
    before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
 | 
			
		||||
    # Components that shouldn't be updated during training
 | 
			
		||||
    frozen_components = T["frozen_components"]
 | 
			
		||||
 
 | 
			
		||||
    # Create iterator, which yields out info after each optimization step.
 | 
			
		||||
    msg.info("Start training")
 | 
			
		||||
    training_step_iterator = train_while_improving(
 | 
			
		||||
        nlp,
 | 
			
		||||
        optimizer,
 | 
			
		||||
| 
						 | 
				
			
			@ -101,7 +112,7 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
 | 
			
		|||
        raw_text=raw_text,
 | 
			
		||||
        exclude=frozen_components,
 | 
			
		||||
    )
 | 
			
		||||
    msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
 | 
			
		||||
    msg.info(f"Initial learn rate: {optimizer.learn_rate}")
 | 
			
		||||
    with nlp.select_pipes(disable=frozen_components):
 | 
			
		||||
        print_row, finalize_logger = train_logger(nlp)
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -145,7 +156,6 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
 | 
			
		|||
            msg.good(f"Saved pipeline to output directory {final_model_path}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def add_vectors(nlp: Language, vectors: str) -> None:
 | 
			
		||||
    title = f"Config validation error for vectors {vectors}"
 | 
			
		||||
    desc = (
 | 
			
		||||
| 
						 | 
				
			
			@ -199,21 +209,6 @@ def create_evaluation_callback(
 | 
			
		|||
    return evaluate
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_before_to_disk_callback(
 | 
			
		||||
    callback: Optional[Callable[[Language], Language]]
 | 
			
		||||
) -> Callable[[Language], Language]:
 | 
			
		||||
    def before_to_disk(nlp: Language) -> Language:
 | 
			
		||||
        if not callback:
 | 
			
		||||
            return nlp
 | 
			
		||||
        modified_nlp = callback(nlp)
 | 
			
		||||
        if not isinstance(modified_nlp, Language):
 | 
			
		||||
            err = Errors.E914.format(name="before_to_disk", value=type(modified_nlp))
 | 
			
		||||
            raise ValueError(err)
 | 
			
		||||
        return modified_nlp
 | 
			
		||||
 | 
			
		||||
    return before_to_disk
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def train_while_improving(
 | 
			
		||||
    nlp: Language,
 | 
			
		||||
    optimizer: Optimizer,
 | 
			
		||||
| 
						 | 
				
			
			@ -370,30 +365,3 @@ def verify_cli_args(config_path: Path, output_path: Optional[Path] = None) -> No
 | 
			
		|||
        if not output_path.exists():
 | 
			
		||||
            output_path.mkdir()
 | 
			
		||||
            msg.good(f"Created output directory: {output_path}")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def verify_config(nlp: Language) -> None:
 | 
			
		||||
    """Perform additional checks based on the config, loaded nlp object and training data."""
 | 
			
		||||
    # TODO: maybe we should validate based on the actual components, the list
 | 
			
		||||
    # in config["nlp"]["pipeline"] instead?
 | 
			
		||||
    for pipe_config in nlp.config["components"].values():
 | 
			
		||||
        # We can't assume that the component name == the factory
 | 
			
		||||
        factory = pipe_config["factory"]
 | 
			
		||||
        if factory == "textcat":
 | 
			
		||||
            verify_textcat_config(nlp, pipe_config)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def verify_textcat_config(nlp: Language, pipe_config: Dict[str, Any]) -> None:
 | 
			
		||||
    # if 'positive_label' is provided: double check whether it's in the data and
 | 
			
		||||
    # the task is binary
 | 
			
		||||
    if pipe_config.get("positive_label"):
 | 
			
		||||
        textcat_labels = nlp.get_pipe("textcat").labels
 | 
			
		||||
        pos_label = pipe_config.get("positive_label")
 | 
			
		||||
        if pos_label not in textcat_labels:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                Errors.E920.format(pos_label=pos_label, labels=textcat_labels)
 | 
			
		||||
            )
 | 
			
		||||
        if len(list(textcat_labels)) != 2:
 | 
			
		||||
            raise ValueError(
 | 
			
		||||
                Errors.E919.format(pos_label=pos_label, labels=textcat_labels)
 | 
			
		||||
            )
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user