mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Upd train
This commit is contained in:
commit
a3e1791c9c
|
@ -6,7 +6,7 @@ requires = [
|
|||
"cymem>=2.0.2,<2.1.0",
|
||||
"preshed>=3.0.2,<3.1.0",
|
||||
"murmurhash>=0.28.0,<1.1.0",
|
||||
"thinc>=8.0.0a36,<8.0.0a40",
|
||||
"thinc>=8.0.0a41,<8.0.0a50",
|
||||
"blis>=0.4.0,<0.5.0",
|
||||
"pytokenizations",
|
||||
"pathy"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# Our libraries
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.0.0a36,<8.0.0a40
|
||||
thinc>=8.0.0a41,<8.0.0a50
|
||||
blis>=0.4.0,<0.5.0
|
||||
ml_datasets==0.2.0a0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
|
|
|
@ -34,13 +34,13 @@ setup_requires =
|
|||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
thinc>=8.0.0a36,<8.0.0a40
|
||||
thinc>=8.0.0a41,<8.0.0a50
|
||||
install_requires =
|
||||
# Our libraries
|
||||
murmurhash>=0.28.0,<1.1.0
|
||||
cymem>=2.0.2,<2.1.0
|
||||
preshed>=3.0.2,<3.1.0
|
||||
thinc>=8.0.0a36,<8.0.0a40
|
||||
thinc>=8.0.0a41,<8.0.0a50
|
||||
blis>=0.4.0,<0.5.0
|
||||
wasabi>=0.8.0,<1.1.0
|
||||
srsly>=2.1.0,<3.0.0
|
||||
|
|
|
@ -243,6 +243,8 @@ def show_validation_error(
|
|||
yield
|
||||
except ConfigValidationError as e:
|
||||
title = title if title is not None else e.title
|
||||
if e.desc:
|
||||
desc = f"{e.desc}" if not desc else f"{e.desc}\n\n{desc}"
|
||||
# Re-generate a new error object with overrides
|
||||
err = e.from_error(e, title="", desc=desc, show_config=show_config)
|
||||
msg.fail(title)
|
||||
|
|
|
@ -51,9 +51,10 @@ def debug_config(
|
|||
msg.divider("Config validation")
|
||||
with show_validation_error(config_path):
|
||||
config = util.load_config(config_path, overrides=overrides)
|
||||
nlp, resolved = util.load_model_from_config(config)
|
||||
nlp = util.load_model_from_config(config)
|
||||
# Use the resolved config here in case user has one function returning
|
||||
# a dict of corpora etc.
|
||||
resolved = util.resolve_training_config(nlp.config)
|
||||
check_section_refs(resolved, ["training.dev_corpus", "training.train_corpus"])
|
||||
msg.good("Config is valid")
|
||||
if show_vars:
|
||||
|
|
|
@ -93,18 +93,19 @@ def debug_data(
|
|||
msg.fail("Config file not found", config_path, exists=1)
|
||||
with show_validation_error(config_path):
|
||||
cfg = util.load_config(config_path, overrides=config_overrides)
|
||||
nlp, config = util.load_model_from_config(cfg)
|
||||
nlp = util.load_model_from_config(cfg)
|
||||
C = util.resolve_training_config(nlp.config)
|
||||
# Use original config here, not resolved version
|
||||
sourced_components = get_sourced_components(cfg)
|
||||
frozen_components = config["training"]["frozen_components"]
|
||||
frozen_components = C["training"]["frozen_components"]
|
||||
resume_components = [p for p in sourced_components if p not in frozen_components]
|
||||
pipeline = nlp.pipe_names
|
||||
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
|
||||
tag_map_path = util.ensure_path(config["training"]["tag_map"])
|
||||
tag_map_path = util.ensure_path(C["training"]["tag_map"])
|
||||
tag_map = {}
|
||||
if tag_map_path is not None:
|
||||
tag_map = srsly.read_json(tag_map_path)
|
||||
morph_rules_path = util.ensure_path(config["training"]["morph_rules"])
|
||||
morph_rules_path = util.ensure_path(C["training"]["morph_rules"])
|
||||
morph_rules = {}
|
||||
if morph_rules_path is not None:
|
||||
morph_rules = srsly.read_json(morph_rules_path)
|
||||
|
@ -144,10 +145,10 @@ def debug_data(
|
|||
|
||||
train_texts = gold_train_data["texts"]
|
||||
dev_texts = gold_dev_data["texts"]
|
||||
frozen_components = config["training"]["frozen_components"]
|
||||
frozen_components = C["training"]["frozen_components"]
|
||||
|
||||
msg.divider("Training stats")
|
||||
msg.text(f"Language: {config['nlp']['lang']}")
|
||||
msg.text(f"Language: {C['nlp']['lang']}")
|
||||
msg.text(f"Training pipeline: {', '.join(pipeline)}")
|
||||
if resume_components:
|
||||
msg.text(f"Components from other pipelines: {', '.join(resume_components)}")
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
import warnings
|
||||
from typing import Dict, Any, Optional, Iterable
|
||||
from pathlib import Path
|
||||
|
||||
|
@ -57,14 +56,17 @@ def debug_model_cli(
|
|||
}
|
||||
config_overrides = parse_config_overrides(ctx.args)
|
||||
with show_validation_error(config_path):
|
||||
config = util.load_config(
|
||||
config_path, overrides=config_overrides, interpolate=True
|
||||
raw_config = util.load_config(
|
||||
config_path, overrides=config_overrides, interpolate=False
|
||||
)
|
||||
config = raw_config.iterpolate()
|
||||
allocator = config["training"]["gpu_allocator"]
|
||||
if use_gpu >= 0 and allocator:
|
||||
set_gpu_allocator(allocator)
|
||||
nlp, config = util.load_model_from_config(config)
|
||||
seed = config["training"]["seed"]
|
||||
with show_validation_error(config_path):
|
||||
nlp = util.load_model_from_config(raw_config)
|
||||
C = util.resolve_training_config(nlp.config)
|
||||
seed = C["training"]["seed"]
|
||||
if seed is not None:
|
||||
msg.info(f"Fixing random seed: {seed}")
|
||||
fix_random_seed(seed)
|
||||
|
@ -75,7 +77,7 @@ def debug_model_cli(
|
|||
exits=1,
|
||||
)
|
||||
model = pipe.model
|
||||
debug_model(config, nlp, model, print_settings=print_settings)
|
||||
debug_model(C, nlp, model, print_settings=print_settings)
|
||||
|
||||
|
||||
def debug_model(
|
||||
|
@ -108,7 +110,7 @@ def debug_model(
|
|||
_set_output_dim(nO=7, model=model)
|
||||
nlp.begin_training(lambda: [Example.from_dict(x, {}) for x in X])
|
||||
msg.info("Initialized the model with dummy data.")
|
||||
except:
|
||||
except Exception:
|
||||
msg.fail(
|
||||
"Could not initialize the model: you'll have to provide a valid train_corpus argument in the config file.",
|
||||
exits=1,
|
||||
|
|
|
@ -88,10 +88,10 @@ def fill_config(
|
|||
msg = Printer(no_print=no_print)
|
||||
with show_validation_error(hint_fill=False):
|
||||
config = util.load_config(base_path)
|
||||
nlp, _ = util.load_model_from_config(config, auto_fill=True, validate=False)
|
||||
nlp = util.load_model_from_config(config, auto_fill=True, validate=False)
|
||||
# Load a second time with validation to be extra sure that the produced
|
||||
# config result is a valid config
|
||||
nlp, _ = util.load_model_from_config(nlp.config)
|
||||
nlp = util.load_model_from_config(nlp.config)
|
||||
filled = nlp.config
|
||||
if pretraining:
|
||||
validate_config_for_pretrain(filled, msg)
|
||||
|
@ -169,7 +169,7 @@ def init_config(
|
|||
msg.text(f"- {label}: {value}")
|
||||
with show_validation_error(hint_fill=False):
|
||||
config = util.load_config_from_str(base_template)
|
||||
nlp, _ = util.load_model_from_config(config, auto_fill=True)
|
||||
nlp = util.load_model_from_config(config, auto_fill=True)
|
||||
config = nlp.config
|
||||
if pretraining:
|
||||
validate_config_for_pretrain(config, msg)
|
||||
|
|
|
@ -69,17 +69,18 @@ def pretrain_cli(
|
|||
msg.info(f"Loading config from: {config_path}")
|
||||
|
||||
with show_validation_error(config_path):
|
||||
config = util.load_config(
|
||||
config_path, overrides=config_overrides, interpolate=True
|
||||
raw_config = util.load_config(
|
||||
config_path, overrides=config_overrides, interpolate=False
|
||||
)
|
||||
config = raw_config.interpolate()
|
||||
if not config.get("pretraining"):
|
||||
# TODO: What's the solution here? How do we handle optional blocks?
|
||||
msg.fail("The [pretraining] block in your config is empty", exits=1)
|
||||
if not output_dir.exists():
|
||||
output_dir.mkdir()
|
||||
msg.good(f"Created output directory: {output_dir}")
|
||||
|
||||
config.to_disk(output_dir / "config.cfg")
|
||||
# Save non-interpolated config
|
||||
raw_config.to_disk(output_dir / "config.cfg")
|
||||
msg.good("Saved config file in the output directory")
|
||||
|
||||
pretrain(
|
||||
|
@ -103,14 +104,13 @@ def pretrain(
|
|||
allocator = config["training"]["gpu_allocator"]
|
||||
if use_gpu >= 0 and allocator:
|
||||
set_gpu_allocator(allocator)
|
||||
|
||||
nlp, config = util.load_model_from_config(config)
|
||||
P_cfg = config["pretraining"]
|
||||
corpus = dot_to_object(config, P_cfg["corpus"])
|
||||
nlp = util.load_model_from_config(config)
|
||||
C = util.resolve_training_config(nlp.config)
|
||||
P_cfg = C["pretraining"]
|
||||
corpus = dot_to_object(C, P_cfg["corpus"])
|
||||
batcher = P_cfg["batcher"]
|
||||
model = create_pretraining_model(nlp, config["pretraining"])
|
||||
optimizer = config["pretraining"]["optimizer"]
|
||||
|
||||
model = create_pretraining_model(nlp, C["pretraining"])
|
||||
optimizer = C["pretraining"]["optimizer"]
|
||||
# Load in pretrained weights to resume from
|
||||
if resume_path is not None:
|
||||
_resume_model(model, resume_path, epoch_resume)
|
||||
|
|
|
@ -58,7 +58,7 @@ def train_cli(
|
|||
else:
|
||||
msg.info("Using CPU")
|
||||
config = util.load_config(
|
||||
config_path, overrides=config_overrides, interpolate=True
|
||||
config_path, overrides=config_overrides, interpolate=False
|
||||
)
|
||||
if output_path is None:
|
||||
nlp = init_pipeline(config)
|
||||
|
@ -75,24 +75,32 @@ def train_cli(
|
|||
|
||||
def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
||||
# Create iterator, which yields out info after each optimization step.
|
||||
config = nlp.config
|
||||
T_cfg = config["training"]
|
||||
score_weights = T_cfg["score_weights"]
|
||||
optimizer = T_cfg["optimizer"]
|
||||
train_corpus = dot_to_object(config, T_cfg["train_corpus"])
|
||||
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
|
||||
batcher = T_cfg["batcher"]
|
||||
config = nlp.config.interpolate()
|
||||
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||
optimizer T["optimizer"]
|
||||
score_weights = T["score_weights"]
|
||||
# TODO: This might not be called corpora
|
||||
corpora = registry.resolve(config["corpora"], schema=ConfigSchemaCorpora)
|
||||
train_corpus = dot_to_object({"corpora": corpora}, T["train_corpus"])
|
||||
dev_corpus = dot_to_object({"corpora": corpora}, T["dev_corpus"])
|
||||
batcher = T["batcher"]
|
||||
train_logger = T["logger"]
|
||||
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
|
||||
# Components that shouldn't be updated during training
|
||||
frozen_components = T["frozen_components"]
|
||||
|
||||
# Create iterator, which yields out info after each optimization step.
|
||||
msg.info("Start training")
|
||||
training_step_iterator = train_while_improving(
|
||||
nlp,
|
||||
optimizer,
|
||||
create_train_batches(train_corpus(nlp), batcher, T_cfg["max_epochs"]),
|
||||
create_train_batches(train_corpus(nlp), batcher, T["max_epochs"]),
|
||||
create_evaluation_callback(nlp, dev_corpus, score_weights),
|
||||
dropout=T_cfg["dropout"],
|
||||
accumulate_gradient=T_cfg["accumulate_gradient"],
|
||||
patience=T_cfg["patience"],
|
||||
max_steps=T_cfg["max_steps"],
|
||||
eval_frequency=T_cfg["eval_frequency"],
|
||||
dropout=T["dropout"],
|
||||
accumulate_gradient=T["accumulate_gradient"],
|
||||
patience=T["patience"],
|
||||
max_steps=T["max_steps"],
|
||||
eval_frequency=T["eval_frequency"],
|
||||
raw_text=None,
|
||||
exclude=frozen_components,
|
||||
)
|
||||
|
@ -101,7 +109,7 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
|||
print_row, finalize_logger = train_logger(nlp)
|
||||
|
||||
try:
|
||||
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
||||
progress = tqdm.tqdm(total=T["eval_frequency"], leave=False)
|
||||
progress.set_description(f"Epoch 1")
|
||||
for batch, info, is_best_checkpoint in training_step_iterator:
|
||||
progress.update(1)
|
||||
|
@ -110,11 +118,11 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
|||
print_row(info)
|
||||
if is_best_checkpoint and output_path is not None:
|
||||
with nlp.select_pipes(disable=frozen_components):
|
||||
update_meta(T_cfg, nlp, info)
|
||||
update_meta(T, nlp, info)
|
||||
with nlp.use_params(optimizer.averages):
|
||||
nlp = before_to_disk(nlp)
|
||||
nlp.to_disk(output_path / "model-best")
|
||||
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
||||
progress = tqdm.tqdm(total=T["eval_frequency"], leave=False)
|
||||
progress.set_description(f"Epoch {info['epoch']}")
|
||||
except Exception as e:
|
||||
finalize_logger()
|
||||
|
|
|
@ -12,8 +12,10 @@ from .tag_bigram_map import TAG_BIGRAM_MAP
|
|||
from ...compat import copy_reg
|
||||
from ...errors import Errors
|
||||
from ...language import Language
|
||||
from ...scorer import Scorer
|
||||
from ...symbols import POS
|
||||
from ...tokens import Doc
|
||||
from ...training import validate_examples
|
||||
from ...util import DummyTokenizer, registry
|
||||
from ... import util
|
||||
|
||||
|
@ -130,6 +132,10 @@ class JapaneseTokenizer(DummyTokenizer):
|
|||
)
|
||||
return sub_tokens_list
|
||||
|
||||
def score(self, examples):
|
||||
validate_examples(examples, "JapaneseTokenizer.score")
|
||||
return Scorer.score_tokenization(examples)
|
||||
|
||||
def _get_config(self) -> Dict[str, Any]:
|
||||
return {"split_mode": self.split_mode}
|
||||
|
||||
|
|
|
@ -7,7 +7,9 @@ from .lex_attrs import LEX_ATTRS
|
|||
from ...language import Language
|
||||
from ...tokens import Doc
|
||||
from ...compat import copy_reg
|
||||
from ...scorer import Scorer
|
||||
from ...symbols import POS
|
||||
from ...training import validate_examples
|
||||
from ...util import DummyTokenizer, registry
|
||||
|
||||
|
||||
|
@ -62,6 +64,10 @@ class KoreanTokenizer(DummyTokenizer):
|
|||
lemma = surface
|
||||
yield {"surface": surface, "lemma": lemma, "tag": tag}
|
||||
|
||||
def score(self, examples):
|
||||
validate_examples(examples, "KoreanTokenizer.score")
|
||||
return Scorer.score_tokenization(examples)
|
||||
|
||||
|
||||
class KoreanDefaults(Language.Defaults):
|
||||
config = Config().from_str(DEFAULT_CONFIG)
|
||||
|
|
|
@ -8,7 +8,9 @@ from thinc.api import Config
|
|||
|
||||
from ...errors import Warnings, Errors
|
||||
from ...language import Language
|
||||
from ...scorer import Scorer
|
||||
from ...tokens import Doc
|
||||
from ...training import validate_examples
|
||||
from ...util import DummyTokenizer, registry
|
||||
from .lex_attrs import LEX_ATTRS
|
||||
from .stop_words import STOP_WORDS
|
||||
|
@ -136,6 +138,10 @@ class ChineseTokenizer(DummyTokenizer):
|
|||
warn_msg = Warnings.W104.format(target="pkuseg", current=self.segmenter)
|
||||
warnings.warn(warn_msg)
|
||||
|
||||
def score(self, examples):
|
||||
validate_examples(examples, "ChineseTokenizer.score")
|
||||
return Scorer.score_tokenization(examples)
|
||||
|
||||
def _get_config(self) -> Dict[str, Any]:
|
||||
return {
|
||||
"segmenter": self.segmenter,
|
||||
|
|
|
@ -27,7 +27,7 @@ from .lang.punctuation import TOKENIZER_INFIXES
|
|||
from .tokens import Doc
|
||||
from .tokenizer import Tokenizer
|
||||
from .errors import Errors, Warnings
|
||||
from .schemas import ConfigSchema
|
||||
from .schemas import ConfigSchema, ConfigSchemaNlp
|
||||
from .git_info import GIT_VERSION
|
||||
from . import util
|
||||
from . import about
|
||||
|
@ -166,11 +166,10 @@ class Language:
|
|||
self._components = []
|
||||
self._disabled = set()
|
||||
self.max_length = max_length
|
||||
self.resolved = {}
|
||||
# Create the default tokenizer from the default config
|
||||
if not create_tokenizer:
|
||||
tokenizer_cfg = {"tokenizer": self._config["nlp"]["tokenizer"]}
|
||||
create_tokenizer = registry.make_from_config(tokenizer_cfg)["tokenizer"]
|
||||
create_tokenizer = registry.resolve(tokenizer_cfg)["tokenizer"]
|
||||
self.tokenizer = create_tokenizer(self)
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
|
@ -467,7 +466,7 @@ class Language:
|
|||
if "nlp" not in arg_names or "name" not in arg_names:
|
||||
raise ValueError(Errors.E964.format(name=name))
|
||||
# Officially register the factory so we can later call
|
||||
# registry.make_from_config and refer to it in the config as
|
||||
# registry.resolve and refer to it in the config as
|
||||
# @factories = "spacy.Language.xyz". We use the class name here so
|
||||
# different classes can have different factories.
|
||||
registry.factories.register(internal_name, func=factory_func)
|
||||
|
@ -650,8 +649,9 @@ class Language:
|
|||
cfg = {factory_name: config}
|
||||
# We're calling the internal _fill here to avoid constructing the
|
||||
# registered functions twice
|
||||
resolved, filled = registry.resolve(cfg, validate=validate)
|
||||
filled = Config(filled[factory_name])
|
||||
resolved = registry.resolve(cfg, validate=validate)
|
||||
filled = registry.fill({"cfg": cfg[factory_name]}, validate=validate)["cfg"]
|
||||
filled = Config(filled)
|
||||
filled["factory"] = factory_name
|
||||
filled.pop("@factories", None)
|
||||
# Remove the extra values we added because we don't want to keep passing
|
||||
|
@ -1518,15 +1518,19 @@ class Language:
|
|||
config = util.copy_config(config)
|
||||
orig_pipeline = config.pop("components", {})
|
||||
config["components"] = {}
|
||||
resolved, filled = registry.resolve(
|
||||
config, validate=validate, schema=ConfigSchema
|
||||
)
|
||||
if auto_fill:
|
||||
filled = registry.fill(config, validate=validate, schema=ConfigSchema)
|
||||
else:
|
||||
filled = config
|
||||
filled["components"] = orig_pipeline
|
||||
config["components"] = orig_pipeline
|
||||
create_tokenizer = resolved["nlp"]["tokenizer"]
|
||||
before_creation = resolved["nlp"]["before_creation"]
|
||||
after_creation = resolved["nlp"]["after_creation"]
|
||||
after_pipeline_creation = resolved["nlp"]["after_pipeline_creation"]
|
||||
resolved_nlp = registry.resolve(
|
||||
filled["nlp"], validate=validate, schema=ConfigSchemaNlp
|
||||
)
|
||||
create_tokenizer = resolved_nlp["tokenizer"]
|
||||
before_creation = resolved_nlp["before_creation"]
|
||||
after_creation = resolved_nlp["after_creation"]
|
||||
after_pipeline_creation = resolved_nlp["after_pipeline_creation"]
|
||||
lang_cls = cls
|
||||
if before_creation is not None:
|
||||
lang_cls = before_creation(cls)
|
||||
|
@ -1587,7 +1591,6 @@ class Language:
|
|||
disabled_pipes = [*config["nlp"]["disabled"], *disable]
|
||||
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
|
||||
nlp.config = filled if auto_fill else config
|
||||
nlp.resolved = resolved
|
||||
if after_pipeline_creation is not None:
|
||||
nlp = after_pipeline_creation(nlp)
|
||||
if not isinstance(nlp, cls):
|
||||
|
|
|
@ -29,7 +29,8 @@ cdef class Morphology:
|
|||
FEATURE_SEP = "|"
|
||||
FIELD_SEP = "="
|
||||
VALUE_SEP = ","
|
||||
EMPTY_MORPH = "_" # not an empty string so that the PreshMap key is not 0
|
||||
# not an empty string so that the PreshMap key is not 0
|
||||
EMPTY_MORPH = symbols.NAMES[symbols._]
|
||||
|
||||
def __init__(self, StringStore strings):
|
||||
self.mem = Pool()
|
||||
|
|
|
@ -4,6 +4,7 @@ from enum import Enum
|
|||
from pydantic import BaseModel, Field, ValidationError, validator
|
||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
||||
from pydantic import root_validator
|
||||
from thinc.config import Promise
|
||||
from collections import defaultdict
|
||||
from thinc.api import Optimizer
|
||||
|
||||
|
@ -16,10 +17,12 @@ if TYPE_CHECKING:
|
|||
from .training import Example # noqa: F401
|
||||
|
||||
|
||||
# fmt: off
|
||||
ItemT = TypeVar("ItemT")
|
||||
Batcher = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
|
||||
Reader = Callable[["Language", str], Iterable["Example"]]
|
||||
Logger = Callable[["Language"], Tuple[Callable[[Dict[str, Any]], None], Callable]]
|
||||
Batcher = Union[Callable[[Iterable[ItemT]], Iterable[List[ItemT]]], Promise]
|
||||
Reader = Union[Callable[["Language", str], Iterable["Example"]], Promise]
|
||||
Logger = Union[Callable[["Language"], Tuple[Callable[[Dict[str, Any]], None], Callable]], Promise]
|
||||
# fmt: on
|
||||
|
||||
|
||||
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
|
||||
|
@ -292,6 +295,16 @@ class ConfigSchema(BaseModel):
|
|||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
class TrainingSchema(BaseModel):
|
||||
training: ConfigSchemaTraining
|
||||
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
|
||||
corpora: Dict[str, Reader]
|
||||
|
||||
class Config:
|
||||
extra = "allow"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
|
||||
# Project config Schema
|
||||
|
||||
|
||||
|
|
|
@ -466,3 +466,4 @@ cdef enum symbol_t:
|
|||
ENT_ID
|
||||
|
||||
IDX
|
||||
_
|
||||
|
|
|
@ -465,6 +465,7 @@ IDS = {
|
|||
"acl": acl,
|
||||
"LAW": LAW,
|
||||
"MORPH": MORPH,
|
||||
"_": _,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -24,7 +24,7 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
|
|||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
cfg = {"model": DEFAULT_NER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
ner = EntityRecognizer(en_vocab, model, **config)
|
||||
ner.begin_training(lambda: [_ner_example(ner)])
|
||||
ner(doc)
|
||||
|
@ -46,7 +46,7 @@ def test_ents_reset(en_vocab):
|
|||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
cfg = {"model": DEFAULT_NER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
ner = EntityRecognizer(en_vocab, model, **config)
|
||||
ner.begin_training(lambda: [_ner_example(ner)])
|
||||
ner(doc)
|
||||
|
|
|
@ -23,7 +23,7 @@ def parser(vocab):
|
|||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
parser = DependencyParser(vocab, model, **config)
|
||||
return parser
|
||||
|
||||
|
@ -82,7 +82,7 @@ def test_add_label_deserializes_correctly():
|
|||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
cfg = {"model": DEFAULT_NER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
ner1 = EntityRecognizer(Vocab(), model, **config)
|
||||
ner1.add_label("C")
|
||||
ner1.add_label("B")
|
||||
|
@ -111,7 +111,7 @@ def test_add_label_get_label(pipe_cls, n_moves, model_config):
|
|||
splitting the move names.
|
||||
"""
|
||||
labels = ["A", "B", "C"]
|
||||
model = registry.make_from_config({"model": model_config}, validate=True)["model"]
|
||||
model = registry.resolve({"model": model_config}, validate=True)["model"]
|
||||
config = {
|
||||
"learn_tokens": False,
|
||||
"min_action_freq": 30,
|
||||
|
|
|
@ -127,7 +127,7 @@ def test_get_oracle_actions():
|
|||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
parser = DependencyParser(doc.vocab, model, **config)
|
||||
parser.moves.add_action(0, "")
|
||||
parser.moves.add_action(1, "")
|
||||
|
|
|
@ -25,7 +25,7 @@ def arc_eager(vocab):
|
|||
@pytest.fixture
|
||||
def tok2vec():
|
||||
cfg = {"model": DEFAULT_TOK2VEC_MODEL}
|
||||
tok2vec = registry.make_from_config(cfg, validate=True)["model"]
|
||||
tok2vec = registry.resolve(cfg, validate=True)["model"]
|
||||
tok2vec.initialize()
|
||||
return tok2vec
|
||||
|
||||
|
@ -38,14 +38,14 @@ def parser(vocab, arc_eager):
|
|||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
return Parser(vocab, model, moves=arc_eager, **config)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model(arc_eager, tok2vec, vocab):
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
model.attrs["resize_output"](model, arc_eager.n_moves)
|
||||
model.initialize()
|
||||
return model
|
||||
|
@ -72,7 +72,7 @@ def test_build_model(parser, vocab):
|
|||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model
|
||||
assert parser.model is not None
|
||||
|
||||
|
|
|
@ -28,7 +28,7 @@ def parser(vocab):
|
|||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
parser = DependencyParser(vocab, model, **config)
|
||||
parser.cfg["token_vector_width"] = 4
|
||||
parser.cfg["hidden_width"] = 32
|
||||
|
|
|
@ -139,7 +139,7 @@ TRAIN_DATA = [
|
|||
|
||||
def test_tok2vec_listener():
|
||||
orig_config = Config().from_str(cfg_string)
|
||||
nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||
tagger = nlp.get_pipe("tagger")
|
||||
tok2vec = nlp.get_pipe("tok2vec")
|
||||
|
@ -173,7 +173,7 @@ def test_tok2vec_listener():
|
|||
|
||||
def test_tok2vec_listener_callback():
|
||||
orig_config = Config().from_str(cfg_string)
|
||||
nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||
tagger = nlp.get_pipe("tagger")
|
||||
tok2vec = nlp.get_pipe("tok2vec")
|
||||
|
|
|
@ -195,7 +195,7 @@ def test_issue3345():
|
|||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
cfg = {"model": DEFAULT_NER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
ner = EntityRecognizer(doc.vocab, model, **config)
|
||||
# Add the OUT action. I wouldn't have thought this would be necessary...
|
||||
ner.moves.add_action(5, "")
|
||||
|
|
|
@ -264,9 +264,7 @@ def test_issue3830_no_subtok():
|
|||
"min_action_freq": 30,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)[
|
||||
"model"
|
||||
]
|
||||
model = registry.resolve({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||
parser = DependencyParser(Vocab(), model, **config)
|
||||
parser.add_label("nsubj")
|
||||
assert "subtok" not in parser.labels
|
||||
|
@ -281,9 +279,7 @@ def test_issue3830_with_subtok():
|
|||
"min_action_freq": 30,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)[
|
||||
"model"
|
||||
]
|
||||
model = registry.resolve({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||
parser = DependencyParser(Vocab(), model, **config)
|
||||
parser.add_label("nsubj")
|
||||
assert "subtok" not in parser.labels
|
||||
|
|
|
@ -108,8 +108,8 @@ def my_parser():
|
|||
def test_create_nlp_from_config():
|
||||
config = Config().from_str(nlp_config_string)
|
||||
with pytest.raises(ConfigValidationError):
|
||||
nlp, _ = load_model_from_config(config, auto_fill=False)
|
||||
nlp, resolved = load_model_from_config(config, auto_fill=True)
|
||||
load_model_from_config(config, auto_fill=False)
|
||||
nlp = load_model_from_config(config, auto_fill=True)
|
||||
assert nlp.config["training"]["batcher"]["size"] == 666
|
||||
assert len(nlp.config["training"]) > 1
|
||||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||
|
@ -136,7 +136,7 @@ def test_create_nlp_from_config_multiple_instances():
|
|||
"tagger2": config["components"]["tagger"],
|
||||
}
|
||||
config["nlp"]["pipeline"] = list(config["components"].keys())
|
||||
nlp, _ = load_model_from_config(config, auto_fill=True)
|
||||
nlp = load_model_from_config(config, auto_fill=True)
|
||||
assert nlp.pipe_names == ["t2v", "tagger1", "tagger2"]
|
||||
assert nlp.get_pipe_meta("t2v").factory == "tok2vec"
|
||||
assert nlp.get_pipe_meta("tagger1").factory == "tagger"
|
||||
|
@ -150,7 +150,7 @@ def test_create_nlp_from_config_multiple_instances():
|
|||
def test_serialize_nlp():
|
||||
""" Create a custom nlp pipeline from config and ensure it serializes it correctly """
|
||||
nlp_config = Config().from_str(nlp_config_string)
|
||||
nlp, _ = load_model_from_config(nlp_config, auto_fill=True)
|
||||
nlp = load_model_from_config(nlp_config, auto_fill=True)
|
||||
nlp.get_pipe("tagger").add_label("A")
|
||||
nlp.begin_training()
|
||||
assert "tok2vec" in nlp.pipe_names
|
||||
|
@ -209,7 +209,7 @@ def test_config_nlp_roundtrip():
|
|||
nlp = English()
|
||||
nlp.add_pipe("entity_ruler")
|
||||
nlp.add_pipe("ner")
|
||||
new_nlp, new_config = load_model_from_config(nlp.config, auto_fill=False)
|
||||
new_nlp = load_model_from_config(nlp.config, auto_fill=False)
|
||||
assert new_nlp.config == nlp.config
|
||||
assert new_nlp.pipe_names == nlp.pipe_names
|
||||
assert new_nlp._pipe_configs == nlp._pipe_configs
|
||||
|
@ -280,12 +280,12 @@ def test_config_overrides():
|
|||
overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]}
|
||||
# load_model from config with overrides passed directly to Config
|
||||
config = Config().from_str(nlp_config_string, overrides=overrides_dot)
|
||||
nlp, _ = load_model_from_config(config, auto_fill=True)
|
||||
nlp = load_model_from_config(config, auto_fill=True)
|
||||
assert isinstance(nlp, German)
|
||||
assert nlp.pipe_names == ["tagger"]
|
||||
# Serialized roundtrip with config passed in
|
||||
base_config = Config().from_str(nlp_config_string)
|
||||
base_nlp, _ = load_model_from_config(base_config, auto_fill=True)
|
||||
base_nlp = load_model_from_config(base_config, auto_fill=True)
|
||||
assert isinstance(base_nlp, English)
|
||||
assert base_nlp.pipe_names == ["tok2vec", "tagger"]
|
||||
with make_tempdir() as d:
|
||||
|
@ -328,7 +328,7 @@ def test_config_optional_sections():
|
|||
config = Config().from_str(nlp_config_string)
|
||||
config = DEFAULT_CONFIG.merge(config)
|
||||
assert "pretraining" not in config
|
||||
filled = registry.fill_config(config, schema=ConfigSchema, validate=False)
|
||||
filled = registry.fill(config, schema=ConfigSchema, validate=False)
|
||||
# Make sure that optional "pretraining" block doesn't default to None,
|
||||
# which would (rightly) cause error because it'd result in a top-level
|
||||
# key that's not a section (dict). Note that the following roundtrip is
|
||||
|
@ -341,7 +341,7 @@ def test_config_auto_fill_extra_fields():
|
|||
config = Config({"nlp": {"lang": "en"}, "training": {}})
|
||||
assert load_model_from_config(config, auto_fill=True)
|
||||
config = Config({"nlp": {"lang": "en"}, "training": {"extra": "hello"}})
|
||||
nlp, _ = load_model_from_config(config, auto_fill=True, validate=False)
|
||||
nlp = load_model_from_config(config, auto_fill=True, validate=False)
|
||||
assert "extra" not in nlp.config["training"]
|
||||
# Make sure the config generated is valid
|
||||
load_model_from_config(nlp.config)
|
||||
|
|
|
@ -23,7 +23,7 @@ def parser(en_vocab):
|
|||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
parser = DependencyParser(en_vocab, model, **config)
|
||||
parser.add_label("nsubj")
|
||||
return parser
|
||||
|
@ -37,7 +37,7 @@ def blank_parser(en_vocab):
|
|||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
parser = DependencyParser(en_vocab, model, **config)
|
||||
return parser
|
||||
|
||||
|
@ -45,7 +45,7 @@ def blank_parser(en_vocab):
|
|||
@pytest.fixture
|
||||
def taggers(en_vocab):
|
||||
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
tagger1 = Tagger(en_vocab, model)
|
||||
tagger2 = Tagger(en_vocab, model)
|
||||
return tagger1, tagger2
|
||||
|
@ -59,7 +59,7 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
|
|||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
parser = Parser(en_vocab, model, **config)
|
||||
new_parser = Parser(en_vocab, model, **config)
|
||||
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
|
||||
|
@ -77,7 +77,7 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
|
|||
"update_with_oracle_cut_size": 100,
|
||||
}
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
parser = Parser(en_vocab, model, **config)
|
||||
with make_tempdir() as d:
|
||||
file_path = d / "parser"
|
||||
|
@ -111,7 +111,7 @@ def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers):
|
|||
tagger1 = tagger1.from_bytes(tagger1_b)
|
||||
assert tagger1.to_bytes() == tagger1_b
|
||||
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b)
|
||||
new_tagger1_b = new_tagger1.to_bytes()
|
||||
assert len(new_tagger1_b) == len(tagger1_b)
|
||||
|
@ -126,7 +126,7 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
|
|||
tagger1.to_disk(file_path1)
|
||||
tagger2.to_disk(file_path2)
|
||||
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
tagger1_d = Tagger(en_vocab, model).from_disk(file_path1)
|
||||
tagger2_d = Tagger(en_vocab, model).from_disk(file_path2)
|
||||
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
|
||||
|
@ -135,7 +135,7 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
|
|||
def test_serialize_textcat_empty(en_vocab):
|
||||
# See issue #1105
|
||||
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
textcat = TextCategorizer(
|
||||
en_vocab,
|
||||
model,
|
||||
|
@ -149,7 +149,7 @@ def test_serialize_textcat_empty(en_vocab):
|
|||
@pytest.mark.parametrize("Parser", test_parsers)
|
||||
def test_serialize_pipe_exclude(en_vocab, Parser):
|
||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
config = {
|
||||
"learn_tokens": False,
|
||||
"min_action_freq": 0,
|
||||
|
@ -176,7 +176,7 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
|
|||
|
||||
def test_serialize_sentencerecognizer(en_vocab):
|
||||
cfg = {"model": DEFAULT_SENTER_MODEL}
|
||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
||||
model = registry.resolve(cfg, validate=True)["model"]
|
||||
sr = SentenceRecognizer(en_vocab, model)
|
||||
sr_b = sr.to_bytes()
|
||||
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)
|
||||
|
|
|
@ -7,6 +7,7 @@ from spacy import util
|
|||
from spacy import prefer_gpu, require_gpu
|
||||
from spacy.ml._precomputable_affine import PrecomputableAffine
|
||||
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
|
||||
from thinc.api import Optimizer
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
@ -157,3 +158,16 @@ def test_dot_to_dict(dot_notation, expected):
|
|||
result = util.dot_to_dict(dot_notation)
|
||||
assert result == expected
|
||||
assert util.dict_to_dot(result) == dot_notation
|
||||
|
||||
|
||||
def test_resolve_training_config():
|
||||
config = {
|
||||
"nlp": {"lang": "en", "disabled": []},
|
||||
"training": {"dropout": 0.1, "optimizer": {"@optimizers": "Adam.v1"}},
|
||||
"corpora": {},
|
||||
}
|
||||
resolved = util.resolve_training_config(config)
|
||||
assert resolved["training"]["dropout"] == 0.1
|
||||
assert isinstance(resolved["training"]["optimizer"], Optimizer)
|
||||
assert resolved["corpora"] == {}
|
||||
assert "nlp" not in resolved
|
||||
|
|
|
@ -82,10 +82,10 @@ def test_util_dot_section():
|
|||
no_output_layer = false
|
||||
"""
|
||||
nlp_config = Config().from_str(cfg_string)
|
||||
en_nlp, en_config = util.load_model_from_config(nlp_config, auto_fill=True)
|
||||
en_nlp = util.load_model_from_config(nlp_config, auto_fill=True)
|
||||
default_config = Config().from_disk(DEFAULT_CONFIG_PATH)
|
||||
default_config["nlp"]["lang"] = "nl"
|
||||
nl_nlp, nl_config = util.load_model_from_config(default_config, auto_fill=True)
|
||||
nl_nlp = util.load_model_from_config(default_config, auto_fill=True)
|
||||
# Test that creation went OK
|
||||
assert isinstance(en_nlp, English)
|
||||
assert isinstance(nl_nlp, Dutch)
|
||||
|
@ -94,14 +94,15 @@ def test_util_dot_section():
|
|||
# not exclusive_classes
|
||||
assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] is False
|
||||
# Test that default values got overwritten
|
||||
assert en_config["nlp"]["pipeline"] == ["textcat"]
|
||||
assert nl_config["nlp"]["pipeline"] == [] # default value []
|
||||
assert en_nlp.config["nlp"]["pipeline"] == ["textcat"]
|
||||
assert nl_nlp.config["nlp"]["pipeline"] == [] # default value []
|
||||
# Test proper functioning of 'dot_to_object'
|
||||
with pytest.raises(KeyError):
|
||||
dot_to_object(en_config, "nlp.pipeline.tagger")
|
||||
dot_to_object(en_nlp.config, "nlp.pipeline.tagger")
|
||||
with pytest.raises(KeyError):
|
||||
dot_to_object(en_config, "nlp.unknownattribute")
|
||||
assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer)
|
||||
dot_to_object(en_nlp.config, "nlp.unknownattribute")
|
||||
resolved = util.resolve_training_config(nl_nlp.config)
|
||||
assert isinstance(dot_to_object(resolved, "training.optimizer"), Optimizer)
|
||||
|
||||
|
||||
def test_simple_frozen_list():
|
||||
|
|
|
@ -3,6 +3,7 @@ import pytest
|
|||
from thinc.api import Config
|
||||
from spacy import Language
|
||||
from spacy.util import load_model_from_config, registry, dot_to_object
|
||||
from spacy.util import resolve_training_config
|
||||
from spacy.training import Example
|
||||
|
||||
|
||||
|
@ -37,8 +38,8 @@ def test_readers():
|
|||
return {"train": reader, "dev": reader, "extra": reader, "something": reader}
|
||||
|
||||
config = Config().from_str(config_string)
|
||||
nlp, resolved = load_model_from_config(config, auto_fill=True)
|
||||
|
||||
nlp = load_model_from_config(config, auto_fill=True)
|
||||
resolved = resolve_training_config(nlp.config)
|
||||
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
|
||||
assert isinstance(train_corpus, Callable)
|
||||
optimizer = resolved["training"]["optimizer"]
|
||||
|
@ -87,8 +88,8 @@ def test_cat_readers(reader, additional_config):
|
|||
config = Config().from_str(nlp_config_string)
|
||||
config["corpora"]["@readers"] = reader
|
||||
config["corpora"].update(additional_config)
|
||||
nlp, resolved = load_model_from_config(config, auto_fill=True)
|
||||
|
||||
nlp = load_model_from_config(config, auto_fill=True)
|
||||
resolved = resolve_training_config(nlp.config)
|
||||
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
|
||||
optimizer = resolved["training"]["optimizer"]
|
||||
# simulate a training loop
|
||||
|
|
|
@ -86,7 +86,7 @@ class registry(thinc.registry):
|
|||
# spacy_factories entry point. This registry only exists so we can easily
|
||||
# load them via the entry points. The "true" factories are added via the
|
||||
# Language.factory decorator (in the spaCy code base and user code) and those
|
||||
# are the factories used to initialize components via registry.make_from_config.
|
||||
# are the factories used to initialize components via registry.resolve.
|
||||
_entry_point_factories = catalogue.create("spacy", "factories", entry_points=True)
|
||||
factories = catalogue.create("spacy", "internal_factories")
|
||||
# This is mostly used to get a list of all installed models in the current
|
||||
|
@ -351,9 +351,7 @@ def load_model_from_path(
|
|||
meta = get_model_meta(model_path)
|
||||
config_path = model_path / "config.cfg"
|
||||
config = load_config(config_path, overrides=dict_to_dot(config))
|
||||
nlp, _ = load_model_from_config(
|
||||
config, vocab=vocab, disable=disable, exclude=exclude
|
||||
)
|
||||
nlp = load_model_from_config(config, vocab=vocab, disable=disable, exclude=exclude)
|
||||
return nlp.from_disk(model_path, exclude=exclude)
|
||||
|
||||
|
||||
|
@ -365,7 +363,7 @@ def load_model_from_config(
|
|||
exclude: Iterable[str] = SimpleFrozenList(),
|
||||
auto_fill: bool = False,
|
||||
validate: bool = True,
|
||||
) -> Tuple["Language", Config]:
|
||||
) -> "Language":
|
||||
"""Create an nlp object from a config. Expects the full config file including
|
||||
a section "nlp" containing the settings for the nlp object.
|
||||
|
||||
|
@ -398,7 +396,30 @@ def load_model_from_config(
|
|||
auto_fill=auto_fill,
|
||||
validate=validate,
|
||||
)
|
||||
return nlp, nlp.resolved
|
||||
return nlp
|
||||
|
||||
|
||||
def resolve_training_config(
|
||||
config: Config,
|
||||
exclude: Iterable[str] = ("nlp", "components"),
|
||||
validate: bool = True,
|
||||
) -> Dict[str, Any]:
|
||||
"""Resolve the config sections relevant for trainig and create all objects.
|
||||
Mostly used in the CLI to separate training config (not resolved by default
|
||||
because not runtime-relevant – an nlp object should load fine even if it's
|
||||
[training] block refers to functions that are not available etc.).
|
||||
|
||||
config (Config): The config to resolve.
|
||||
exclude (Iterable[str]): The config blocks to exclude. Those blocks won't
|
||||
be available in the final resolved config.
|
||||
validate (bool): Whether to validate the config.
|
||||
RETURNS (Dict[str, Any]): The resolved config.
|
||||
"""
|
||||
config = config.copy()
|
||||
for key in exclude:
|
||||
if key in config:
|
||||
config.pop(key)
|
||||
return registry.resolve(config, validate=validate)
|
||||
|
||||
|
||||
def load_model_from_init_py(
|
||||
|
|
Loading…
Reference in New Issue
Block a user