mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-27 17:54:39 +03:00
Update with WIP
This commit is contained in:
parent
a60562f208
commit
240e0a62ca
|
@ -1,13 +1,12 @@
|
||||||
from typing import Optional, Dict, List, Union, Sequence
|
from typing import Optional, Dict
|
||||||
from timeit import default_timer as timer
|
from timeit import default_timer as timer
|
||||||
import srsly
|
import srsly
|
||||||
import tqdm
|
import tqdm
|
||||||
from pydantic import BaseModel, FilePath
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
import thinc
|
import thinc
|
||||||
import thinc.schedules
|
import thinc.schedules
|
||||||
from thinc.api import Model, use_pytorch_for_gpu_memory, require_gpu, fix_random_seed
|
from thinc.api import use_pytorch_for_gpu_memory, require_gpu, fix_random_seed
|
||||||
import random
|
import random
|
||||||
|
|
||||||
from ._app import app, Arg, Opt
|
from ._app import app, Arg, Opt
|
||||||
|
@ -15,108 +14,15 @@ from ..gold import Corpus, Example
|
||||||
from ..lookups import Lookups
|
from ..lookups import Lookups
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
from ..schemas import ConfigSchema
|
||||||
|
|
||||||
|
|
||||||
# Don't remove - required to load the built-in architectures
|
# Don't remove - required to load the built-in architectures
|
||||||
from ..ml import models # noqa: F401
|
from ..ml import models # noqa: F401
|
||||||
|
|
||||||
# from ..schemas import ConfigSchema # TODO: include?
|
|
||||||
|
|
||||||
|
|
||||||
registry = util.registry
|
registry = util.registry
|
||||||
|
|
||||||
CONFIG_STR = """
|
|
||||||
[training]
|
|
||||||
patience = 10
|
|
||||||
eval_frequency = 10
|
|
||||||
dropout = 0.2
|
|
||||||
init_tok2vec = null
|
|
||||||
max_epochs = 100
|
|
||||||
orth_variant_level = 0.0
|
|
||||||
gold_preproc = false
|
|
||||||
max_length = 0
|
|
||||||
use_gpu = 0
|
|
||||||
scores = ["ents_p", "ents_r", "ents_f"]
|
|
||||||
score_weights = {"ents_f": 1.0}
|
|
||||||
limit = 0
|
|
||||||
|
|
||||||
[training.batch_size]
|
|
||||||
@schedules = "compounding.v1"
|
|
||||||
start = 100
|
|
||||||
stop = 1000
|
|
||||||
compound = 1.001
|
|
||||||
|
|
||||||
[optimizer]
|
|
||||||
@optimizers = "Adam.v1"
|
|
||||||
learn_rate = 0.001
|
|
||||||
beta1 = 0.9
|
|
||||||
beta2 = 0.999
|
|
||||||
|
|
||||||
[nlp]
|
|
||||||
lang = "en"
|
|
||||||
vectors = null
|
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec]
|
|
||||||
factory = "tok2vec"
|
|
||||||
|
|
||||||
[nlp.pipeline.ner]
|
|
||||||
factory = "ner"
|
|
||||||
|
|
||||||
[nlp.pipeline.ner.model]
|
|
||||||
@architectures = "spacy.TransitionBasedParser.v1"
|
|
||||||
nr_feature_tokens = 3
|
|
||||||
hidden_width = 64
|
|
||||||
maxout_pieces = 3
|
|
||||||
|
|
||||||
[nlp.pipeline.ner.model.tok2vec]
|
|
||||||
@architectures = "spacy.Tok2VecTensors.v1"
|
|
||||||
width = ${nlp.pipeline.tok2vec.model:width}
|
|
||||||
|
|
||||||
[nlp.pipeline.tok2vec.model]
|
|
||||||
@architectures = "spacy.HashEmbedCNN.v1"
|
|
||||||
pretrained_vectors = ${nlp:vectors}
|
|
||||||
width = 128
|
|
||||||
depth = 4
|
|
||||||
window_size = 1
|
|
||||||
embed_size = 10000
|
|
||||||
maxout_pieces = 3
|
|
||||||
subword_features = true
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
class PipelineComponent(BaseModel):
|
|
||||||
factory: str
|
|
||||||
model: Model
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigSchema(BaseModel):
|
|
||||||
optimizer: Optional["Optimizer"]
|
|
||||||
|
|
||||||
class training(BaseModel):
|
|
||||||
patience: int = 10
|
|
||||||
eval_frequency: int = 100
|
|
||||||
dropout: float = 0.2
|
|
||||||
init_tok2vec: Optional[FilePath] = None
|
|
||||||
max_epochs: int = 100
|
|
||||||
orth_variant_level: float = 0.0
|
|
||||||
gold_preproc: bool = False
|
|
||||||
max_length: int = 0
|
|
||||||
use_gpu: int = 0
|
|
||||||
scores: List[str] = ["ents_p", "ents_r", "ents_f"]
|
|
||||||
score_weights: Dict[str, Union[int, float]] = {"ents_f": 1.0}
|
|
||||||
limit: int = 0
|
|
||||||
batch_size: Union[Sequence[int], int]
|
|
||||||
|
|
||||||
class nlp(BaseModel):
|
|
||||||
lang: str
|
|
||||||
vectors: Optional[str]
|
|
||||||
pipeline: Optional[Dict[str, PipelineComponent]]
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
extra = "allow"
|
|
||||||
|
|
||||||
|
|
||||||
@app.command("train")
|
@app.command("train")
|
||||||
def train_cli(
|
def train_cli(
|
||||||
|
@ -126,12 +32,7 @@ def train_cli(
|
||||||
config_path: Path = Arg(..., help="Path to config file", exists=True),
|
config_path: Path = Arg(..., help="Path to config file", exists=True),
|
||||||
output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store model in"),
|
output_path: Optional[Path] = Opt(None, "--output", "--output-path", "-o", help="Output directory to store model in"),
|
||||||
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
code_path: Optional[Path] = Opt(None, "--code-path", "-c", help="Path to Python file with additional code (registered functions) to be imported"),
|
||||||
init_tok2vec: Optional[Path] = Opt(None, "--init-tok2vec", "-t2v", help="Path to pretrained weights for the tok2vec components. See 'spacy pretrain'. Experimental."),
|
|
||||||
raw_text: Optional[Path] = Opt(None, "--raw-text", "-rt", help="Path to jsonl file with unlabelled text documents."),
|
|
||||||
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||||
use_gpu: int = Opt(-1, "--use-gpu", "-g", help="Use GPU"),
|
|
||||||
tag_map_path: Optional[Path] = Opt(None, "--tag-map-path", "-tm", help="Location of JSON-formatted tag map"),
|
|
||||||
omit_extra_lookups: bool = Opt(False, "--omit-extra-lookups", "-OEL", help="Don't include extra lookups in model"),
|
|
||||||
# fmt: on
|
# fmt: on
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
|
@ -141,33 +42,11 @@ def train_cli(
|
||||||
"""
|
"""
|
||||||
util.set_env_log(verbose)
|
util.set_env_log(verbose)
|
||||||
verify_cli_args(**locals())
|
verify_cli_args(**locals())
|
||||||
|
try:
|
||||||
if raw_text is not None:
|
util.import_file("python_code", code_path)
|
||||||
raw_text = list(srsly.read_jsonl(raw_text))
|
except Exception as e:
|
||||||
tag_map = {}
|
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
||||||
if tag_map_path is not None:
|
train(config_path, {"train": train_path, "dev": dev_path}, output_path=output_path)
|
||||||
tag_map = srsly.read_json(tag_map_path)
|
|
||||||
|
|
||||||
weights_data = None
|
|
||||||
if init_tok2vec is not None:
|
|
||||||
with init_tok2vec.open("rb") as file_:
|
|
||||||
weights_data = file_.read()
|
|
||||||
|
|
||||||
if use_gpu >= 0:
|
|
||||||
msg.info("Using GPU: {use_gpu}")
|
|
||||||
require_gpu(use_gpu)
|
|
||||||
else:
|
|
||||||
msg.info("Using CPU")
|
|
||||||
|
|
||||||
train(
|
|
||||||
config_path,
|
|
||||||
{"train": train_path, "dev": dev_path},
|
|
||||||
output_path=output_path,
|
|
||||||
raw_text=raw_text,
|
|
||||||
tag_map=tag_map,
|
|
||||||
weights_data=weights_data,
|
|
||||||
omit_extra_lookups=omit_extra_lookups,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def train(
|
def train(
|
||||||
|
@ -175,19 +54,24 @@ def train(
|
||||||
data_paths: Dict[str, Path],
|
data_paths: Dict[str, Path],
|
||||||
raw_text: Optional[Path] = None,
|
raw_text: Optional[Path] = None,
|
||||||
output_path: Optional[Path] = None,
|
output_path: Optional[Path] = None,
|
||||||
tag_map: Optional[Path] = None,
|
|
||||||
weights_data: Optional[bytes] = None,
|
weights_data: Optional[bytes] = None,
|
||||||
omit_extra_lookups: bool = False,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
msg.info(f"Loading config from: {config_path}")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
# Read the config first without creating objects, to get to the original nlp_config
|
# Read the config first without creating objects, to get to the original nlp_config
|
||||||
config = util.load_config(config_path, create_objects=False)
|
config = util.load_config(config_path, create_objects=False, schema=ConfigSchema)
|
||||||
|
use_gpu = config["training"]["use_gpu"]
|
||||||
|
if use_gpu >= 0:
|
||||||
|
msg.info(f"Using GPU: {use_gpu}")
|
||||||
|
require_gpu(use_gpu)
|
||||||
|
else:
|
||||||
|
msg.info("Using CPU")
|
||||||
|
raw_text, tag_map, weights_data = load_from_paths(config)
|
||||||
fix_random_seed(config["training"]["seed"])
|
fix_random_seed(config["training"]["seed"])
|
||||||
if config["training"].get("use_pytorch_for_gpu_memory"):
|
if config["training"].get("use_pytorch_for_gpu_memory"):
|
||||||
# It feels kind of weird to not have a default for this.
|
# It feels kind of weird to not have a default for this.
|
||||||
use_pytorch_for_gpu_memory()
|
use_pytorch_for_gpu_memory()
|
||||||
nlp_config = config["nlp"]
|
nlp_config = config["nlp"]
|
||||||
config = util.load_config(config_path, create_objects=True)
|
config = util.load_config(config_path, create_objects=True, schema=ConfigSchema)
|
||||||
training = config["training"]
|
training = config["training"]
|
||||||
msg.info("Creating nlp from config")
|
msg.info("Creating nlp from config")
|
||||||
nlp = util.load_model_from_config(nlp_config)
|
nlp = util.load_model_from_config(nlp_config)
|
||||||
|
@ -216,7 +100,7 @@ def train(
|
||||||
|
|
||||||
# Create empty extra lexeme tables so the data from spacy-lookups-data
|
# Create empty extra lexeme tables so the data from spacy-lookups-data
|
||||||
# isn't loaded if these features are accessed
|
# isn't loaded if these features are accessed
|
||||||
if omit_extra_lookups:
|
if config["omit_extra_lookups"]:
|
||||||
nlp.vocab.lookups_extra = Lookups()
|
nlp.vocab.lookups_extra = Lookups()
|
||||||
nlp.vocab.lookups_extra.add_table("lexeme_cluster")
|
nlp.vocab.lookups_extra.add_table("lexeme_cluster")
|
||||||
nlp.vocab.lookups_extra.add_table("lexeme_prob")
|
nlp.vocab.lookups_extra.add_table("lexeme_prob")
|
||||||
|
@ -556,18 +440,36 @@ def update_meta(training, nlp, info):
|
||||||
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
|
nlp.meta["performance"][f"{pipe_name}_loss"] = info["losses"][pipe_name]
|
||||||
|
|
||||||
|
|
||||||
|
def load_from_paths(config):
|
||||||
|
# TODO: separate checks from loading
|
||||||
|
raw_text = util.ensure_path(config["training"]["raw_text"])
|
||||||
|
if raw_text is not None:
|
||||||
|
if not raw_text.exists():
|
||||||
|
msg.fail("Can't find raw text", raw_text, exits=1)
|
||||||
|
raw_text = list(srsly.read_jsonl(config["training"]["raw_text"]))
|
||||||
|
tag_map = {}
|
||||||
|
tag_map_path = util.ensure_path(config["training"]["tag_map"])
|
||||||
|
if tag_map_path is not None:
|
||||||
|
if not tag_map_path.exists():
|
||||||
|
msg.fail("Can't find tag map path", tag_map_path, exits=1)
|
||||||
|
tag_map = srsly.read_json(config["training"]["tag_map"])
|
||||||
|
weights_data = None
|
||||||
|
init_tok2vec = util.ensure_path(config["training"]["init_tok2vec"])
|
||||||
|
if init_tok2vec is not None:
|
||||||
|
if not init_tok2vec.exists():
|
||||||
|
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
||||||
|
with init_tok2vec.open("rb") as file_:
|
||||||
|
weights_data = file_.read()
|
||||||
|
return raw_text, tag_map, weights_data
|
||||||
|
|
||||||
|
|
||||||
def verify_cli_args(
|
def verify_cli_args(
|
||||||
train_path,
|
train_path: Path,
|
||||||
dev_path,
|
dev_path: Path,
|
||||||
config_path,
|
config_path: Path,
|
||||||
output_path=None,
|
output_path: Optional[Path] = None,
|
||||||
code_path=None,
|
code_path: Optional[Path] = None,
|
||||||
init_tok2vec=None,
|
verbose: bool = False,
|
||||||
raw_text=None,
|
|
||||||
verbose=False,
|
|
||||||
use_gpu=-1,
|
|
||||||
tag_map_path=None,
|
|
||||||
omit_extra_lookups=False,
|
|
||||||
):
|
):
|
||||||
# Make sure all files and paths exists if they are needed
|
# Make sure all files and paths exists if they are needed
|
||||||
if not config_path or not config_path.exists():
|
if not config_path or not config_path.exists():
|
||||||
|
@ -591,12 +493,6 @@ def verify_cli_args(
|
||||||
if code_path is not None:
|
if code_path is not None:
|
||||||
if not code_path.exists():
|
if not code_path.exists():
|
||||||
msg.fail("Path to Python code not found", code_path, exits=1)
|
msg.fail("Path to Python code not found", code_path, exits=1)
|
||||||
try:
|
|
||||||
util.import_file("python_code", code_path)
|
|
||||||
except Exception as e:
|
|
||||||
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
|
||||||
if init_tok2vec is not None and not init_tok2vec.exists():
|
|
||||||
msg.fail("Can't find pretrained tok2vec", init_tok2vec, exits=1)
|
|
||||||
|
|
||||||
|
|
||||||
def verify_textcat_config(nlp, nlp_config):
|
def verify_textcat_config(nlp, nlp_config):
|
||||||
|
|
110
spacy/schemas.py
110
spacy/schemas.py
|
@ -1,9 +1,10 @@
|
||||||
from typing import Dict, List, Union, Optional, Sequence, Any
|
from typing import Dict, List, Union, Optional, Sequence, Any
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, Field, ValidationError, validator
|
from pydantic import BaseModel, Field, ValidationError, validator
|
||||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool, FilePath
|
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
||||||
|
from pydantic import FilePath, DirectoryPath
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from thinc.api import Model
|
from thinc.api import Model, Optimizer
|
||||||
|
|
||||||
from .attrs import NAMES
|
from .attrs import NAMES
|
||||||
|
|
||||||
|
@ -173,41 +174,6 @@ class ModelMetaSchema(BaseModel):
|
||||||
# JSON training format
|
# JSON training format
|
||||||
|
|
||||||
|
|
||||||
class PipelineComponent(BaseModel):
|
|
||||||
factory: str
|
|
||||||
model: Model
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
arbitrary_types_allowed = True
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigSchema(BaseModel):
|
|
||||||
optimizer: Optional["Optimizer"]
|
|
||||||
|
|
||||||
class training(BaseModel):
|
|
||||||
patience: int = 10
|
|
||||||
eval_frequency: int = 100
|
|
||||||
dropout: float = 0.2
|
|
||||||
init_tok2vec: Optional[FilePath] = None
|
|
||||||
max_epochs: int = 100
|
|
||||||
orth_variant_level: float = 0.0
|
|
||||||
gold_preproc: bool = False
|
|
||||||
max_length: int = 0
|
|
||||||
use_gpu: int = 0
|
|
||||||
scores: List[str] = ["ents_p", "ents_r", "ents_f"]
|
|
||||||
score_weights: Dict[str, Union[int, float]] = {"ents_f": 1.0}
|
|
||||||
limit: int = 0
|
|
||||||
batch_size: Union[Sequence[int], int]
|
|
||||||
|
|
||||||
class nlp(BaseModel):
|
|
||||||
lang: str
|
|
||||||
vectors: Optional[str]
|
|
||||||
pipeline: Optional[Dict[str, PipelineComponent]]
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
extra = "allow"
|
|
||||||
|
|
||||||
|
|
||||||
class TrainingSchema(BaseModel):
|
class TrainingSchema(BaseModel):
|
||||||
# TODO: write
|
# TODO: write
|
||||||
|
|
||||||
|
@ -216,6 +182,76 @@ class TrainingSchema(BaseModel):
|
||||||
extra = "forbid"
|
extra = "forbid"
|
||||||
|
|
||||||
|
|
||||||
|
# Config schema
|
||||||
|
# We're not setting any defaults here (which is too messy) and are making all
|
||||||
|
# fields required, so we can raise validation errors for missing values. To
|
||||||
|
# provide a default, we include a separate .cfg file with all values and
|
||||||
|
# check that against this schema in the test suite to make sure it's always
|
||||||
|
# up to date.
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigSchemaTraining(BaseModel):
|
||||||
|
# fmt: off
|
||||||
|
gold_preproc: StrictBool = Field(..., title="Whether to train on gold-standard sentences and tokens")
|
||||||
|
max_length: StrictInt = Field(..., title="Maximum length of examples (longer examples are divided into sentences if possible)")
|
||||||
|
limit: StrictInt = Field(..., title="Number of examples to use (0 for all)")
|
||||||
|
orth_variant_level: StrictFloat = Field(..., title="Orth variants for data augmentation")
|
||||||
|
dropout: StrictFloat = Field(..., title="Dropout rate")
|
||||||
|
patience: StrictInt = Field(..., title="How many steps to continue without improvement in evaluation score")
|
||||||
|
max_epochs: StrictInt = Field(..., title="Maximum number of epochs to train for")
|
||||||
|
max_steps: StrictInt = Field(..., title="Maximum number of update steps to train for")
|
||||||
|
eval_frequency: StrictInt = Field(..., title="How often to evaluate during training (steps)")
|
||||||
|
seed: StrictInt = Field(..., title="Random seed")
|
||||||
|
accumulate_gradient: StrictInt = Field(..., title="Whether to divide the batch up into substeps")
|
||||||
|
use_pytorch_for_gpu_memory: StrictBool = Field(..., title="Allocate memory via PyTorch")
|
||||||
|
use_gpu: StrictInt = Field(..., title="GPU ID or -1 for CPU")
|
||||||
|
scores: List[StrictStr] = Field(..., title="Score types to be printed in overview")
|
||||||
|
score_weights: Dict[StrictStr, Union[StrictFloat, StrictInt]] = Field(..., title="Weights of each score type for selecting final model")
|
||||||
|
init_tok2vec: Optional[FilePath] = Field(..., title="Path to pretrained tok2vec weights")
|
||||||
|
discard_oversize: StrictBool = Field(..., title="Whether to skip examples longer than batch size")
|
||||||
|
omit_extra_lookups: StrictBool = Field(..., title="Don't include extra lookups in model")
|
||||||
|
batch_by: StrictStr = Field(..., title="Batch examples by type")
|
||||||
|
raw_text: Optional[FilePath] = Field(..., title="Raw text")
|
||||||
|
tag_map: Optional[FilePath] = Field(..., title="Path to JSON-formatted tag map")
|
||||||
|
batch_size: Union[Sequence[int], int] = Field(..., title="The batch size or batch size schedule")
|
||||||
|
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigSchemaNlpComponent(BaseModel):
|
||||||
|
factory: StrictStr = Field(..., title="Component factory name")
|
||||||
|
model: Model = Field(..., title="Component model")
|
||||||
|
# TODO: add config schema / types for components so we can fill and validate
|
||||||
|
# component options like learn_tokens, min_action_freq etc.
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigSchemaNlp(BaseModel):
|
||||||
|
lang: StrictStr = Field(..., title="The base language to use")
|
||||||
|
vectors: Optional[DirectoryPath] = Field(..., title="Path to vectors")
|
||||||
|
pipeline: Optional[Dict[str, ConfigSchemaNlpComponent]]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "forbid"
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigSchema(BaseModel):
|
||||||
|
training: ConfigSchemaTraining
|
||||||
|
nlp: ConfigSchemaNlp
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
# Project config Schema
|
# Project config Schema
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import List, Union
|
from typing import List, Union, Type, Dict, Any
|
||||||
import os
|
import os
|
||||||
import importlib
|
import importlib
|
||||||
import importlib.util
|
import importlib.util
|
||||||
|
@ -6,6 +6,8 @@ import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import thinc
|
import thinc
|
||||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config
|
from thinc.api import NumpyOps, get_current_ops, Adam, Config
|
||||||
|
from thinc.config import EmptySchema
|
||||||
|
from pydantic import BaseModel
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import numpy.random
|
import numpy.random
|
||||||
|
@ -20,6 +22,7 @@ import subprocess
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
import tempfile
|
import tempfile
|
||||||
import shutil
|
import shutil
|
||||||
|
import hashlib
|
||||||
import shlex
|
import shlex
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
@ -326,20 +329,29 @@ def get_base_version(version):
|
||||||
return Version(version).base_version
|
return Version(version).base_version
|
||||||
|
|
||||||
|
|
||||||
def load_config(path, create_objects=False):
|
def load_config(
|
||||||
|
path: Union[Path, str],
|
||||||
|
*,
|
||||||
|
create_objects: bool = False,
|
||||||
|
schema: Type[BaseModel] = EmptySchema,
|
||||||
|
validate: bool = True,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Load a Thinc-formatted config file, optionally filling in objects where
|
"""Load a Thinc-formatted config file, optionally filling in objects where
|
||||||
the config references registry entries. See "Thinc config files" for details.
|
the config references registry entries. See "Thinc config files" for details.
|
||||||
|
|
||||||
path (str / Path): Path to the config file
|
path (str / Path): Path to the config file
|
||||||
create_objects (bool): Whether to automatically create objects when the config
|
create_objects (bool): Whether to automatically create objects when the config
|
||||||
references registry entries. Defaults to False.
|
references registry entries. Defaults to False.
|
||||||
|
schema (BaseModel): Optional pydantic base schema to use for validation.
|
||||||
RETURNS (dict): The objects from the config file.
|
RETURNS (dict): The objects from the config file.
|
||||||
"""
|
"""
|
||||||
config = thinc.config.Config().from_disk(path)
|
config = thinc.config.Config().from_disk(path)
|
||||||
if create_objects:
|
if create_objects:
|
||||||
return registry.make_from_config(config, validate=True)
|
return registry.make_from_config(config, validate=validate, schema=schema)
|
||||||
else:
|
else:
|
||||||
|
# Just fill config here so we can validate and fail early
|
||||||
|
if validate and schema:
|
||||||
|
registry.fill_config(config, validate=validate, schema=schema)
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user