mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix commands
This commit is contained in:
parent
44bad1474c
commit
553bfea641
|
@ -16,6 +16,7 @@ import os
|
||||||
|
|
||||||
from ..schemas import ProjectConfigSchema, validate
|
from ..schemas import ProjectConfigSchema, validate
|
||||||
from ..util import import_file, run_command, make_tempdir, registry, logger
|
from ..util import import_file, run_command, make_tempdir, registry, logger
|
||||||
|
from ..util import ensure_path
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from pathy import Pathy # noqa: F401
|
from pathy import Pathy # noqa: F401
|
||||||
|
@ -458,3 +459,24 @@ def string_to_list(value: str, intify: bool = False) -> Union[List[str], List[in
|
||||||
p = int(p)
|
p = int(p)
|
||||||
result.append(p)
|
result.append(p)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_paths(
|
||||||
|
config: Config,
|
||||||
|
) -> Tuple[List[Dict[str, str]], Dict[str, dict], bytes]:
|
||||||
|
# TODO: separate checks from loading
|
||||||
|
raw_text = ensure_path(config["training"]["raw_text"])
|
||||||
|
if raw_text is not None:
|
||||||
|
if not raw_text.exists():
|
||||||
|
msg.fail("Can't find raw text", raw_text, exits=1)
|
||||||
|
raw_text = list(srsly.read_jsonl(config["training"]["raw_text"]))
|
||||||
|
tag_map = {}
|
||||||
|
morph_rules = {}
|
||||||
|
weights_data = None
|
||||||
|
init_tok2vec = ensure_path(config["training"]["init_tok2vec"])
|
||||||
|
if init_tok2vec is not None:
|
||||||
|
if not init_tok2vec.exists():
|
||||||
|
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
||||||
|
with init_tok2vec.open("rb") as file_:
|
||||||
|
weights_data = file_.read()
|
||||||
|
return raw_text, tag_map, morph_rules, weights_data
|
||||||
|
|
|
@ -1,68 +0,0 @@
|
||||||
from typing import Optional, List, Dict, Any, Union, IO
|
|
||||||
import math
|
|
||||||
from tqdm import tqdm
|
|
||||||
import numpy
|
|
||||||
from ast import literal_eval
|
|
||||||
from pathlib import Path
|
|
||||||
from preshed.counter import PreshCounter
|
|
||||||
import tarfile
|
|
||||||
import gzip
|
|
||||||
import zipfile
|
|
||||||
import srsly
|
|
||||||
import warnings
|
|
||||||
from wasabi import msg, Printer
|
|
||||||
import typer
|
|
||||||
from ._util import init_cli, Arg, Opt, parse_config_overrides, show_validation_error
|
|
||||||
|
|
||||||
DEFAULT_OOV_PROB = -20
|
|
||||||
|
|
||||||
|
|
||||||
#@init_cli.command("vocab")
|
|
||||||
#@app.command(
|
|
||||||
# "init-model",
|
|
||||||
# context_settings={"allow_extra_args": True, "ignore_unknown_options": True},
|
|
||||||
# hidden=True, # hide this from main CLI help but still allow it to work with warning
|
|
||||||
#)
|
|
||||||
def init_model_cli(
|
|
||||||
# fmt: off
|
|
||||||
ctx: typer.Context, # This is only used to read additional arguments
|
|
||||||
lang: str = Arg(..., help="Pipeline language"),
|
|
||||||
output_dir: Path = Arg(..., help="Pipeline output directory"),
|
|
||||||
freqs_loc: Optional[Path] = Arg(None, help="Location of words frequencies file", exists=True),
|
|
||||||
clusters_loc: Optional[Path] = Opt(None, "--clusters-loc", "-c", help="Optional location of brown clusters data", exists=True),
|
|
||||||
jsonl_loc: Optional[Path] = Opt(None, "--jsonl-loc", "-j", help="Location of JSONL-formatted attributes file", exists=True),
|
|
||||||
vectors_loc: Optional[Path] = Opt(None, "--vectors-loc", "-v", help="Optional vectors file in Word2Vec format", exists=True),
|
|
||||||
prune_vectors: int = Opt(-1, "--prune-vectors", "-V", help="Optional number of vectors to prune to"),
|
|
||||||
truncate_vectors: int = Opt(0, "--truncate-vectors", "-t", help="Optional number of vectors to truncate to when reading in vectors file"),
|
|
||||||
vectors_name: Optional[str] = Opt(None, "--vectors-name", "-vn", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
|
|
||||||
model_name: Optional[str] = Opt(None, "--meta-name", "-mn", help="Optional name of the package for the pipeline meta"),
|
|
||||||
base_model: Optional[str] = Opt(None, "--base", "-b", help="Name of or path to base pipeline to start with (mostly relevant for pipelines with custom tokenizers)")
|
|
||||||
# fmt: on
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Create a new blank pipeline directory with vocab and vectors from raw data.
|
|
||||||
If vectors are provided in Word2Vec format, they can be either a .txt or
|
|
||||||
zipped as a .zip or .tar.gz.
|
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/cli#init-vocab
|
|
||||||
"""
|
|
||||||
if ctx.command.name == "init-model":
|
|
||||||
msg.warn(
|
|
||||||
"The init-model command is now called 'init vocab'. You can run "
|
|
||||||
"'python -m spacy init --help' for an overview of the other "
|
|
||||||
"available initialization commands."
|
|
||||||
)
|
|
||||||
init_vocab(
|
|
||||||
lang,
|
|
||||||
output_dir,
|
|
||||||
freqs_loc=freqs_loc,
|
|
||||||
clusters_loc=clusters_loc,
|
|
||||||
jsonl_loc=jsonl_loc,
|
|
||||||
vectors_loc=vectors_loc,
|
|
||||||
prune_vectors=prune_vectors,
|
|
||||||
truncate_vectors=truncate_vectors,
|
|
||||||
vectors_name=vectors_name,
|
|
||||||
model_name=model_name,
|
|
||||||
base_model=base_model,
|
|
||||||
silent=False,
|
|
||||||
)
|
|
|
@ -1,18 +1,17 @@
|
||||||
from typing import Optional, Dict, Any, Tuple, Union, Callable, List
|
from typing import Optional, Dict, Callable, Any
|
||||||
import logging
|
import logging
|
||||||
import srsly
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
import typer
|
import typer
|
||||||
from thinc.api import Config, fix_random_seed
|
from thinc.api import Config, fix_random_seed, set_gpu_allocator
|
||||||
|
|
||||||
from .train import create_before_to_disk_callback
|
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..util import registry
|
from ..util import registry, resolve_dot_names
|
||||||
from ..schemas import ConfigSchemaTraining
|
from ..schemas import ConfigSchemaTraining, ConfigSchemaPretrain
|
||||||
|
from ..language import Language
|
||||||
|
from ..errors import Errors
|
||||||
from ._util import init_cli, Arg, Opt, parse_config_overrides, show_validation_error
|
from ._util import init_cli, Arg, Opt, parse_config_overrides, show_validation_error
|
||||||
from ._util import import_code, get_sourced_components
|
from ._util import import_code, get_sourced_components, load_from_paths
|
||||||
from ..util import resolve_dot_names
|
|
||||||
|
|
||||||
|
|
||||||
@init_cli.command(
|
@init_cli.command(
|
||||||
|
@ -31,10 +30,12 @@ def init_pipeline_cli(
|
||||||
util.logger.setLevel(logging.DEBUG if verbose else logging.ERROR)
|
util.logger.setLevel(logging.DEBUG if verbose else logging.ERROR)
|
||||||
overrides = parse_config_overrides(ctx.args)
|
overrides = parse_config_overrides(ctx.args)
|
||||||
import_code(code_path)
|
import_code(code_path)
|
||||||
config = util.load_config(config_path, overrides=overrides)
|
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
nlp = init_pipeline(config)
|
config = util.load_config(config_path, overrides=overrides)
|
||||||
|
nlp = init_pipeline(config)
|
||||||
nlp.to_disk(output_path)
|
nlp.to_disk(output_path)
|
||||||
|
# TODO: add more instructions
|
||||||
|
msg.good(f"Saved initialized pipeline to {output_path}")
|
||||||
|
|
||||||
|
|
||||||
def must_initialize(init_path: Path, config_path: Path, overrides: Dict) -> bool:
|
def must_initialize(init_path: Path, config_path: Path, overrides: Dict) -> bool:
|
||||||
|
@ -51,7 +52,7 @@ def must_initialize(init_path: Path, config_path: Path, overrides: Dict) -> bool
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def init_pipeline(config: Config, use_gpu=-1):
|
def init_pipeline(config: Config, use_gpu: int = -1) -> Language:
|
||||||
raw_config = config
|
raw_config = config
|
||||||
config = raw_config.interpolate()
|
config = raw_config.interpolate()
|
||||||
if config["training"]["seed"] is not None:
|
if config["training"]["seed"] is not None:
|
||||||
|
@ -61,22 +62,19 @@ def init_pipeline(config: Config, use_gpu=-1):
|
||||||
set_gpu_allocator(allocator)
|
set_gpu_allocator(allocator)
|
||||||
# Use original config here before it's resolved to functions
|
# Use original config here before it's resolved to functions
|
||||||
sourced_components = get_sourced_components(config)
|
sourced_components = get_sourced_components(config)
|
||||||
nlp = util.load_model_from_config(raw_config)
|
with show_validation_error():
|
||||||
|
nlp = util.load_model_from_config(raw_config)
|
||||||
|
msg.good("Set up nlp object from config")
|
||||||
# Resolve all training-relevant sections using the filled nlp config
|
# Resolve all training-relevant sections using the filled nlp config
|
||||||
T = registry.resolve(
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||||
config["training"],
|
|
||||||
schema=ConfigSchemaTraining,
|
|
||||||
validate=True,
|
|
||||||
)
|
|
||||||
dot_names = [T["train_corpus"], T["dev_corpus"], T["raw_text"]]
|
dot_names = [T["train_corpus"], T["dev_corpus"], T["raw_text"]]
|
||||||
train_corpus, dev_corpus, raw_text = resolve_dot_names(config, dot_names)
|
train_corpus, dev_corpus, raw_text = resolve_dot_names(config, dot_names)
|
||||||
util.load_vocab_data_into_model(nlp, lookups=T["lookups"])
|
util.load_vocab_data_into_model(nlp, lookups=T["lookups"])
|
||||||
|
msg.good("Created vocabulary")
|
||||||
if T["vectors"] is not None:
|
if T["vectors"] is not None:
|
||||||
add_vectors(nlp, T["vectors"])
|
add_vectors(nlp, T["vectors"])
|
||||||
score_weights = T["score_weights"]
|
msg.good(f"Added vectors: {T['vectors']}")
|
||||||
optimizer = T["optimizer"]
|
optimizer = T["optimizer"]
|
||||||
batcher = T["batcher"]
|
|
||||||
train_logger = T["logger"]
|
|
||||||
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
|
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
|
||||||
# Components that shouldn't be updated during training
|
# Components that shouldn't be updated during training
|
||||||
frozen_components = T["frozen_components"]
|
frozen_components = T["frozen_components"]
|
||||||
|
@ -89,13 +87,23 @@ def init_pipeline(config: Config, use_gpu=-1):
|
||||||
nlp.resume_training(sgd=optimizer)
|
nlp.resume_training(sgd=optimizer)
|
||||||
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||||
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
|
nlp.begin_training(lambda: train_corpus(nlp), sgd=optimizer)
|
||||||
|
msg.good(f"Initialized pipeline components")
|
||||||
# Verify the config after calling 'begin_training' to ensure labels
|
# Verify the config after calling 'begin_training' to ensure labels
|
||||||
# are properly initialized
|
# are properly initialized
|
||||||
verify_config(nlp)
|
verify_config(nlp)
|
||||||
|
if "pretraining" in config and config["pretraining"]:
|
||||||
|
P = registry.resolve(config["pretraining"], schema=ConfigSchemaPretrain)
|
||||||
|
add_tok2vec_weights({"training": T, "pretraining": P}, nlp)
|
||||||
|
# TODO: this should be handled better?
|
||||||
|
nlp = before_to_disk(nlp)
|
||||||
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
|
def add_tok2vec_weights(config: Config, nlp: Language) -> None:
|
||||||
# Load pretrained tok2vec weights - cf. CLI command 'pretrain'
|
# Load pretrained tok2vec weights - cf. CLI command 'pretrain'
|
||||||
|
weights_data = load_from_paths(config)
|
||||||
if weights_data is not None:
|
if weights_data is not None:
|
||||||
tok2vec_component = C["pretraining"]["component"]
|
tok2vec_component = config["pretraining"]["component"]
|
||||||
if tok2vec_component is None:
|
if tok2vec_component is None:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
f"To use pretrained tok2vec weights, [pretraining.component] "
|
f"To use pretrained tok2vec weights, [pretraining.component] "
|
||||||
|
@ -103,9 +111,63 @@ def init_pipeline(config: Config, use_gpu=-1):
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
layer = nlp.get_pipe(tok2vec_component).model
|
layer = nlp.get_pipe(tok2vec_component).model
|
||||||
tok2vec_layer = C["pretraining"]["layer"]
|
tok2vec_layer = config["pretraining"]["layer"]
|
||||||
if tok2vec_layer:
|
if tok2vec_layer:
|
||||||
layer = layer.get_ref(tok2vec_layer)
|
layer = layer.get_ref(tok2vec_layer)
|
||||||
layer.from_bytes(weights_data)
|
layer.from_bytes(weights_data)
|
||||||
msg.info(f"Loaded pretrained weights into component '{tok2vec_component}'")
|
msg.good(f"Loaded pretrained weights into component '{tok2vec_component}'")
|
||||||
return nlp
|
|
||||||
|
|
||||||
|
def add_vectors(nlp: Language, vectors: str) -> None:
|
||||||
|
title = f"Config validation error for vectors {vectors}"
|
||||||
|
desc = (
|
||||||
|
"This typically means that there's a problem in the config.cfg included "
|
||||||
|
"with the packaged vectors. Make sure that the vectors package you're "
|
||||||
|
"loading is compatible with the current version of spaCy."
|
||||||
|
)
|
||||||
|
with show_validation_error(
|
||||||
|
title=title, desc=desc, hint_fill=False, show_config=False
|
||||||
|
):
|
||||||
|
util.load_vectors_into_model(nlp, vectors)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_config(nlp: Language) -> None:
|
||||||
|
"""Perform additional checks based on the config, loaded nlp object and training data."""
|
||||||
|
# TODO: maybe we should validate based on the actual components, the list
|
||||||
|
# in config["nlp"]["pipeline"] instead?
|
||||||
|
for pipe_config in nlp.config["components"].values():
|
||||||
|
# We can't assume that the component name == the factory
|
||||||
|
factory = pipe_config["factory"]
|
||||||
|
if factory == "textcat":
|
||||||
|
verify_textcat_config(nlp, pipe_config)
|
||||||
|
|
||||||
|
|
||||||
|
def verify_textcat_config(nlp: Language, pipe_config: Dict[str, Any]) -> None:
|
||||||
|
# if 'positive_label' is provided: double check whether it's in the data and
|
||||||
|
# the task is binary
|
||||||
|
if pipe_config.get("positive_label"):
|
||||||
|
textcat_labels = nlp.get_pipe("textcat").labels
|
||||||
|
pos_label = pipe_config.get("positive_label")
|
||||||
|
if pos_label not in textcat_labels:
|
||||||
|
raise ValueError(
|
||||||
|
Errors.E920.format(pos_label=pos_label, labels=textcat_labels)
|
||||||
|
)
|
||||||
|
if len(list(textcat_labels)) != 2:
|
||||||
|
raise ValueError(
|
||||||
|
Errors.E919.format(pos_label=pos_label, labels=textcat_labels)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def create_before_to_disk_callback(
|
||||||
|
callback: Optional[Callable[[Language], Language]]
|
||||||
|
) -> Callable[[Language], Language]:
|
||||||
|
def before_to_disk(nlp: Language) -> Language:
|
||||||
|
if not callback:
|
||||||
|
return nlp
|
||||||
|
modified_nlp = callback(nlp)
|
||||||
|
if not isinstance(modified_nlp, Language):
|
||||||
|
err = Errors.E914.format(name="before_to_disk", value=type(modified_nlp))
|
||||||
|
raise ValueError(err)
|
||||||
|
return modified_nlp
|
||||||
|
|
||||||
|
return before_to_disk
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
from typing import Optional, Dict, Any, Tuple, Union, Callable, List
|
from typing import Optional, Dict, Any, Tuple, Union, Callable, List
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
import srsly
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
|
@ -11,13 +10,17 @@ import random
|
||||||
import typer
|
import typer
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
from .init_pipeline import init_pipeline, must_initialize
|
||||||
|
from .init_pipeline import create_before_to_disk_callback
|
||||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||||
from ._util import import_code, get_sourced_components
|
from ._util import import_code
|
||||||
|
from ._util import load_from_paths # noqa: F401 (needed for Ray extension for now)
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..training.example import Example
|
from ..training.example import Example
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..util import resolve_dot_names
|
from ..util import resolve_dot_names, registry
|
||||||
|
from ..schemas import ConfigSchemaTraining
|
||||||
|
|
||||||
|
|
||||||
@app.command(
|
@app.command(
|
||||||
|
@ -56,25 +59,35 @@ def train_cli(
|
||||||
require_gpu(use_gpu)
|
require_gpu(use_gpu)
|
||||||
else:
|
else:
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
config = util.load_config(
|
config = util.load_config(config_path, overrides=overrides, interpolate=False)
|
||||||
config_path, overrides=config_overrides, interpolate=False
|
msg.divider("Initializing pipeline")
|
||||||
)
|
# TODO: add warnings / --initialize (?) argument
|
||||||
if output_path is None:
|
if output_path is None:
|
||||||
nlp = init_pipeline(config)
|
nlp = init_pipeline(config)
|
||||||
else:
|
else:
|
||||||
init_path = output_path / "model-initial"
|
init_path = output_path / "model-initial"
|
||||||
if must_reinitialize(config, init_path):
|
if must_initialize(config, init_path):
|
||||||
nlp = init_pipeline(config)
|
nlp = init_pipeline(config)
|
||||||
nlp.to_disk(init_path)
|
nlp.to_disk(init_path)
|
||||||
|
msg.good(f"Saved initialized pipeline to {init_path}")
|
||||||
else:
|
else:
|
||||||
nlp = spacy.load(output_path / "model-initial")
|
nlp = util.load_model(init_path)
|
||||||
msg.info("Start training")
|
msg.good(f"Loaded initialized pipeline from {init_path}")
|
||||||
train(nlp, config, output_path)
|
msg.divider("Training pipeline")
|
||||||
|
train(nlp, output_path, use_gpu=use_gpu)
|
||||||
|
|
||||||
|
|
||||||
def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
def train(
|
||||||
|
nlp: Language, output_path: Optional[Path] = None, *, use_gpu: int = -1
|
||||||
|
) -> None:
|
||||||
|
# TODO: random seed, GPU allocator
|
||||||
# Create iterator, which yields out info after each optimization step.
|
# Create iterator, which yields out info after each optimization step.
|
||||||
config = nlp.config.interpolate()
|
config = nlp.config.interpolate()
|
||||||
|
if config["training"]["seed"] is not None:
|
||||||
|
fix_random_seed(config["training"]["seed"])
|
||||||
|
allocator = config["training"]["gpu_allocator"]
|
||||||
|
if use_gpu >= 0 and allocator:
|
||||||
|
set_gpu_allocator(allocator)
|
||||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||||
dot_names = [T["train_corpus"], T["dev_corpus"], T["raw_text"]]
|
dot_names = [T["train_corpus"], T["dev_corpus"], T["raw_text"]]
|
||||||
train_corpus, dev_corpus, raw_text = resolve_dot_names(config, dot_names)
|
train_corpus, dev_corpus, raw_text = resolve_dot_names(config, dot_names)
|
||||||
|
@ -85,9 +98,7 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
||||||
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
|
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
|
||||||
# Components that shouldn't be updated during training
|
# Components that shouldn't be updated during training
|
||||||
frozen_components = T["frozen_components"]
|
frozen_components = T["frozen_components"]
|
||||||
|
|
||||||
# Create iterator, which yields out info after each optimization step.
|
# Create iterator, which yields out info after each optimization step.
|
||||||
msg.info("Start training")
|
|
||||||
training_step_iterator = train_while_improving(
|
training_step_iterator = train_while_improving(
|
||||||
nlp,
|
nlp,
|
||||||
optimizer,
|
optimizer,
|
||||||
|
@ -101,7 +112,7 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
||||||
raw_text=raw_text,
|
raw_text=raw_text,
|
||||||
exclude=frozen_components,
|
exclude=frozen_components,
|
||||||
)
|
)
|
||||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
msg.info(f"Initial learn rate: {optimizer.learn_rate}")
|
||||||
with nlp.select_pipes(disable=frozen_components):
|
with nlp.select_pipes(disable=frozen_components):
|
||||||
print_row, finalize_logger = train_logger(nlp)
|
print_row, finalize_logger = train_logger(nlp)
|
||||||
|
|
||||||
|
@ -145,7 +156,6 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
||||||
msg.good(f"Saved pipeline to output directory {final_model_path}")
|
msg.good(f"Saved pipeline to output directory {final_model_path}")
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def add_vectors(nlp: Language, vectors: str) -> None:
|
def add_vectors(nlp: Language, vectors: str) -> None:
|
||||||
title = f"Config validation error for vectors {vectors}"
|
title = f"Config validation error for vectors {vectors}"
|
||||||
desc = (
|
desc = (
|
||||||
|
@ -199,21 +209,6 @@ def create_evaluation_callback(
|
||||||
return evaluate
|
return evaluate
|
||||||
|
|
||||||
|
|
||||||
def create_before_to_disk_callback(
|
|
||||||
callback: Optional[Callable[[Language], Language]]
|
|
||||||
) -> Callable[[Language], Language]:
|
|
||||||
def before_to_disk(nlp: Language) -> Language:
|
|
||||||
if not callback:
|
|
||||||
return nlp
|
|
||||||
modified_nlp = callback(nlp)
|
|
||||||
if not isinstance(modified_nlp, Language):
|
|
||||||
err = Errors.E914.format(name="before_to_disk", value=type(modified_nlp))
|
|
||||||
raise ValueError(err)
|
|
||||||
return modified_nlp
|
|
||||||
|
|
||||||
return before_to_disk
|
|
||||||
|
|
||||||
|
|
||||||
def train_while_improving(
|
def train_while_improving(
|
||||||
nlp: Language,
|
nlp: Language,
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
|
@ -370,30 +365,3 @@ def verify_cli_args(config_path: Path, output_path: Optional[Path] = None) -> No
|
||||||
if not output_path.exists():
|
if not output_path.exists():
|
||||||
output_path.mkdir()
|
output_path.mkdir()
|
||||||
msg.good(f"Created output directory: {output_path}")
|
msg.good(f"Created output directory: {output_path}")
|
||||||
|
|
||||||
|
|
||||||
def verify_config(nlp: Language) -> None:
|
|
||||||
"""Perform additional checks based on the config, loaded nlp object and training data."""
|
|
||||||
# TODO: maybe we should validate based on the actual components, the list
|
|
||||||
# in config["nlp"]["pipeline"] instead?
|
|
||||||
for pipe_config in nlp.config["components"].values():
|
|
||||||
# We can't assume that the component name == the factory
|
|
||||||
factory = pipe_config["factory"]
|
|
||||||
if factory == "textcat":
|
|
||||||
verify_textcat_config(nlp, pipe_config)
|
|
||||||
|
|
||||||
|
|
||||||
def verify_textcat_config(nlp: Language, pipe_config: Dict[str, Any]) -> None:
|
|
||||||
# if 'positive_label' is provided: double check whether it's in the data and
|
|
||||||
# the task is binary
|
|
||||||
if pipe_config.get("positive_label"):
|
|
||||||
textcat_labels = nlp.get_pipe("textcat").labels
|
|
||||||
pos_label = pipe_config.get("positive_label")
|
|
||||||
if pos_label not in textcat_labels:
|
|
||||||
raise ValueError(
|
|
||||||
Errors.E920.format(pos_label=pos_label, labels=textcat_labels)
|
|
||||||
)
|
|
||||||
if len(list(textcat_labels)) != 2:
|
|
||||||
raise ValueError(
|
|
||||||
Errors.E919.format(pos_label=pos_label, labels=textcat_labels)
|
|
||||||
)
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user