mirror of
https://github.com/explosion/spaCy.git
synced 2025-10-24 04:31:17 +03:00
Upd train
This commit is contained in:
commit
a3e1791c9c
|
@ -6,7 +6,7 @@ requires = [
|
||||||
"cymem>=2.0.2,<2.1.0",
|
"cymem>=2.0.2,<2.1.0",
|
||||||
"preshed>=3.0.2,<3.1.0",
|
"preshed>=3.0.2,<3.1.0",
|
||||||
"murmurhash>=0.28.0,<1.1.0",
|
"murmurhash>=0.28.0,<1.1.0",
|
||||||
"thinc>=8.0.0a36,<8.0.0a40",
|
"thinc>=8.0.0a41,<8.0.0a50",
|
||||||
"blis>=0.4.0,<0.5.0",
|
"blis>=0.4.0,<0.5.0",
|
||||||
"pytokenizations",
|
"pytokenizations",
|
||||||
"pathy"
|
"pathy"
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# Our libraries
|
# Our libraries
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
thinc>=8.0.0a36,<8.0.0a40
|
thinc>=8.0.0a41,<8.0.0a50
|
||||||
blis>=0.4.0,<0.5.0
|
blis>=0.4.0,<0.5.0
|
||||||
ml_datasets==0.2.0a0
|
ml_datasets==0.2.0a0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
|
|
|
@ -34,13 +34,13 @@ setup_requires =
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
thinc>=8.0.0a36,<8.0.0a40
|
thinc>=8.0.0a41,<8.0.0a50
|
||||||
install_requires =
|
install_requires =
|
||||||
# Our libraries
|
# Our libraries
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
thinc>=8.0.0a36,<8.0.0a40
|
thinc>=8.0.0a41,<8.0.0a50
|
||||||
blis>=0.4.0,<0.5.0
|
blis>=0.4.0,<0.5.0
|
||||||
wasabi>=0.8.0,<1.1.0
|
wasabi>=0.8.0,<1.1.0
|
||||||
srsly>=2.1.0,<3.0.0
|
srsly>=2.1.0,<3.0.0
|
||||||
|
|
|
@ -243,6 +243,8 @@ def show_validation_error(
|
||||||
yield
|
yield
|
||||||
except ConfigValidationError as e:
|
except ConfigValidationError as e:
|
||||||
title = title if title is not None else e.title
|
title = title if title is not None else e.title
|
||||||
|
if e.desc:
|
||||||
|
desc = f"{e.desc}" if not desc else f"{e.desc}\n\n{desc}"
|
||||||
# Re-generate a new error object with overrides
|
# Re-generate a new error object with overrides
|
||||||
err = e.from_error(e, title="", desc=desc, show_config=show_config)
|
err = e.from_error(e, title="", desc=desc, show_config=show_config)
|
||||||
msg.fail(title)
|
msg.fail(title)
|
||||||
|
|
|
@ -51,9 +51,10 @@ def debug_config(
|
||||||
msg.divider("Config validation")
|
msg.divider("Config validation")
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
config = util.load_config(config_path, overrides=overrides)
|
config = util.load_config(config_path, overrides=overrides)
|
||||||
nlp, resolved = util.load_model_from_config(config)
|
nlp = util.load_model_from_config(config)
|
||||||
# Use the resolved config here in case user has one function returning
|
# Use the resolved config here in case user has one function returning
|
||||||
# a dict of corpora etc.
|
# a dict of corpora etc.
|
||||||
|
resolved = util.resolve_training_config(nlp.config)
|
||||||
check_section_refs(resolved, ["training.dev_corpus", "training.train_corpus"])
|
check_section_refs(resolved, ["training.dev_corpus", "training.train_corpus"])
|
||||||
msg.good("Config is valid")
|
msg.good("Config is valid")
|
||||||
if show_vars:
|
if show_vars:
|
||||||
|
|
|
@ -93,18 +93,19 @@ def debug_data(
|
||||||
msg.fail("Config file not found", config_path, exists=1)
|
msg.fail("Config file not found", config_path, exists=1)
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
cfg = util.load_config(config_path, overrides=config_overrides)
|
cfg = util.load_config(config_path, overrides=config_overrides)
|
||||||
nlp, config = util.load_model_from_config(cfg)
|
nlp = util.load_model_from_config(cfg)
|
||||||
|
C = util.resolve_training_config(nlp.config)
|
||||||
# Use original config here, not resolved version
|
# Use original config here, not resolved version
|
||||||
sourced_components = get_sourced_components(cfg)
|
sourced_components = get_sourced_components(cfg)
|
||||||
frozen_components = config["training"]["frozen_components"]
|
frozen_components = C["training"]["frozen_components"]
|
||||||
resume_components = [p for p in sourced_components if p not in frozen_components]
|
resume_components = [p for p in sourced_components if p not in frozen_components]
|
||||||
pipeline = nlp.pipe_names
|
pipeline = nlp.pipe_names
|
||||||
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
|
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
|
||||||
tag_map_path = util.ensure_path(config["training"]["tag_map"])
|
tag_map_path = util.ensure_path(C["training"]["tag_map"])
|
||||||
tag_map = {}
|
tag_map = {}
|
||||||
if tag_map_path is not None:
|
if tag_map_path is not None:
|
||||||
tag_map = srsly.read_json(tag_map_path)
|
tag_map = srsly.read_json(tag_map_path)
|
||||||
morph_rules_path = util.ensure_path(config["training"]["morph_rules"])
|
morph_rules_path = util.ensure_path(C["training"]["morph_rules"])
|
||||||
morph_rules = {}
|
morph_rules = {}
|
||||||
if morph_rules_path is not None:
|
if morph_rules_path is not None:
|
||||||
morph_rules = srsly.read_json(morph_rules_path)
|
morph_rules = srsly.read_json(morph_rules_path)
|
||||||
|
@ -144,10 +145,10 @@ def debug_data(
|
||||||
|
|
||||||
train_texts = gold_train_data["texts"]
|
train_texts = gold_train_data["texts"]
|
||||||
dev_texts = gold_dev_data["texts"]
|
dev_texts = gold_dev_data["texts"]
|
||||||
frozen_components = config["training"]["frozen_components"]
|
frozen_components = C["training"]["frozen_components"]
|
||||||
|
|
||||||
msg.divider("Training stats")
|
msg.divider("Training stats")
|
||||||
msg.text(f"Language: {config['nlp']['lang']}")
|
msg.text(f"Language: {C['nlp']['lang']}")
|
||||||
msg.text(f"Training pipeline: {', '.join(pipeline)}")
|
msg.text(f"Training pipeline: {', '.join(pipeline)}")
|
||||||
if resume_components:
|
if resume_components:
|
||||||
msg.text(f"Components from other pipelines: {', '.join(resume_components)}")
|
msg.text(f"Components from other pipelines: {', '.join(resume_components)}")
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
import warnings
|
|
||||||
from typing import Dict, Any, Optional, Iterable
|
from typing import Dict, Any, Optional, Iterable
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
@ -57,14 +56,17 @@ def debug_model_cli(
|
||||||
}
|
}
|
||||||
config_overrides = parse_config_overrides(ctx.args)
|
config_overrides = parse_config_overrides(ctx.args)
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
config = util.load_config(
|
raw_config = util.load_config(
|
||||||
config_path, overrides=config_overrides, interpolate=True
|
config_path, overrides=config_overrides, interpolate=False
|
||||||
)
|
)
|
||||||
allocator = config["training"]["gpu_allocator"]
|
config = raw_config.iterpolate()
|
||||||
if use_gpu >= 0 and allocator:
|
allocator = config["training"]["gpu_allocator"]
|
||||||
set_gpu_allocator(allocator)
|
if use_gpu >= 0 and allocator:
|
||||||
nlp, config = util.load_model_from_config(config)
|
set_gpu_allocator(allocator)
|
||||||
seed = config["training"]["seed"]
|
with show_validation_error(config_path):
|
||||||
|
nlp = util.load_model_from_config(raw_config)
|
||||||
|
C = util.resolve_training_config(nlp.config)
|
||||||
|
seed = C["training"]["seed"]
|
||||||
if seed is not None:
|
if seed is not None:
|
||||||
msg.info(f"Fixing random seed: {seed}")
|
msg.info(f"Fixing random seed: {seed}")
|
||||||
fix_random_seed(seed)
|
fix_random_seed(seed)
|
||||||
|
@ -75,7 +77,7 @@ def debug_model_cli(
|
||||||
exits=1,
|
exits=1,
|
||||||
)
|
)
|
||||||
model = pipe.model
|
model = pipe.model
|
||||||
debug_model(config, nlp, model, print_settings=print_settings)
|
debug_model(C, nlp, model, print_settings=print_settings)
|
||||||
|
|
||||||
|
|
||||||
def debug_model(
|
def debug_model(
|
||||||
|
@ -108,7 +110,7 @@ def debug_model(
|
||||||
_set_output_dim(nO=7, model=model)
|
_set_output_dim(nO=7, model=model)
|
||||||
nlp.begin_training(lambda: [Example.from_dict(x, {}) for x in X])
|
nlp.begin_training(lambda: [Example.from_dict(x, {}) for x in X])
|
||||||
msg.info("Initialized the model with dummy data.")
|
msg.info("Initialized the model with dummy data.")
|
||||||
except:
|
except Exception:
|
||||||
msg.fail(
|
msg.fail(
|
||||||
"Could not initialize the model: you'll have to provide a valid train_corpus argument in the config file.",
|
"Could not initialize the model: you'll have to provide a valid train_corpus argument in the config file.",
|
||||||
exits=1,
|
exits=1,
|
||||||
|
|
|
@ -88,10 +88,10 @@ def fill_config(
|
||||||
msg = Printer(no_print=no_print)
|
msg = Printer(no_print=no_print)
|
||||||
with show_validation_error(hint_fill=False):
|
with show_validation_error(hint_fill=False):
|
||||||
config = util.load_config(base_path)
|
config = util.load_config(base_path)
|
||||||
nlp, _ = util.load_model_from_config(config, auto_fill=True, validate=False)
|
nlp = util.load_model_from_config(config, auto_fill=True, validate=False)
|
||||||
# Load a second time with validation to be extra sure that the produced
|
# Load a second time with validation to be extra sure that the produced
|
||||||
# config result is a valid config
|
# config result is a valid config
|
||||||
nlp, _ = util.load_model_from_config(nlp.config)
|
nlp = util.load_model_from_config(nlp.config)
|
||||||
filled = nlp.config
|
filled = nlp.config
|
||||||
if pretraining:
|
if pretraining:
|
||||||
validate_config_for_pretrain(filled, msg)
|
validate_config_for_pretrain(filled, msg)
|
||||||
|
@ -169,7 +169,7 @@ def init_config(
|
||||||
msg.text(f"- {label}: {value}")
|
msg.text(f"- {label}: {value}")
|
||||||
with show_validation_error(hint_fill=False):
|
with show_validation_error(hint_fill=False):
|
||||||
config = util.load_config_from_str(base_template)
|
config = util.load_config_from_str(base_template)
|
||||||
nlp, _ = util.load_model_from_config(config, auto_fill=True)
|
nlp = util.load_model_from_config(config, auto_fill=True)
|
||||||
config = nlp.config
|
config = nlp.config
|
||||||
if pretraining:
|
if pretraining:
|
||||||
validate_config_for_pretrain(config, msg)
|
validate_config_for_pretrain(config, msg)
|
||||||
|
|
|
@ -69,17 +69,18 @@ def pretrain_cli(
|
||||||
msg.info(f"Loading config from: {config_path}")
|
msg.info(f"Loading config from: {config_path}")
|
||||||
|
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
config = util.load_config(
|
raw_config = util.load_config(
|
||||||
config_path, overrides=config_overrides, interpolate=True
|
config_path, overrides=config_overrides, interpolate=False
|
||||||
)
|
)
|
||||||
|
config = raw_config.interpolate()
|
||||||
if not config.get("pretraining"):
|
if not config.get("pretraining"):
|
||||||
# TODO: What's the solution here? How do we handle optional blocks?
|
# TODO: What's the solution here? How do we handle optional blocks?
|
||||||
msg.fail("The [pretraining] block in your config is empty", exits=1)
|
msg.fail("The [pretraining] block in your config is empty", exits=1)
|
||||||
if not output_dir.exists():
|
if not output_dir.exists():
|
||||||
output_dir.mkdir()
|
output_dir.mkdir()
|
||||||
msg.good(f"Created output directory: {output_dir}")
|
msg.good(f"Created output directory: {output_dir}")
|
||||||
|
# Save non-interpolated config
|
||||||
config.to_disk(output_dir / "config.cfg")
|
raw_config.to_disk(output_dir / "config.cfg")
|
||||||
msg.good("Saved config file in the output directory")
|
msg.good("Saved config file in the output directory")
|
||||||
|
|
||||||
pretrain(
|
pretrain(
|
||||||
|
@ -103,14 +104,13 @@ def pretrain(
|
||||||
allocator = config["training"]["gpu_allocator"]
|
allocator = config["training"]["gpu_allocator"]
|
||||||
if use_gpu >= 0 and allocator:
|
if use_gpu >= 0 and allocator:
|
||||||
set_gpu_allocator(allocator)
|
set_gpu_allocator(allocator)
|
||||||
|
nlp = util.load_model_from_config(config)
|
||||||
nlp, config = util.load_model_from_config(config)
|
C = util.resolve_training_config(nlp.config)
|
||||||
P_cfg = config["pretraining"]
|
P_cfg = C["pretraining"]
|
||||||
corpus = dot_to_object(config, P_cfg["corpus"])
|
corpus = dot_to_object(C, P_cfg["corpus"])
|
||||||
batcher = P_cfg["batcher"]
|
batcher = P_cfg["batcher"]
|
||||||
model = create_pretraining_model(nlp, config["pretraining"])
|
model = create_pretraining_model(nlp, C["pretraining"])
|
||||||
optimizer = config["pretraining"]["optimizer"]
|
optimizer = C["pretraining"]["optimizer"]
|
||||||
|
|
||||||
# Load in pretrained weights to resume from
|
# Load in pretrained weights to resume from
|
||||||
if resume_path is not None:
|
if resume_path is not None:
|
||||||
_resume_model(model, resume_path, epoch_resume)
|
_resume_model(model, resume_path, epoch_resume)
|
||||||
|
|
|
@ -58,7 +58,7 @@ def train_cli(
|
||||||
else:
|
else:
|
||||||
msg.info("Using CPU")
|
msg.info("Using CPU")
|
||||||
config = util.load_config(
|
config = util.load_config(
|
||||||
config_path, overrides=config_overrides, interpolate=True
|
config_path, overrides=config_overrides, interpolate=False
|
||||||
)
|
)
|
||||||
if output_path is None:
|
if output_path is None:
|
||||||
nlp = init_pipeline(config)
|
nlp = init_pipeline(config)
|
||||||
|
@ -75,24 +75,32 @@ def train_cli(
|
||||||
|
|
||||||
def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
||||||
# Create iterator, which yields out info after each optimization step.
|
# Create iterator, which yields out info after each optimization step.
|
||||||
config = nlp.config
|
config = nlp.config.interpolate()
|
||||||
T_cfg = config["training"]
|
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
|
||||||
score_weights = T_cfg["score_weights"]
|
optimizer T["optimizer"]
|
||||||
optimizer = T_cfg["optimizer"]
|
score_weights = T["score_weights"]
|
||||||
train_corpus = dot_to_object(config, T_cfg["train_corpus"])
|
# TODO: This might not be called corpora
|
||||||
dev_corpus = dot_to_object(config, T_cfg["dev_corpus"])
|
corpora = registry.resolve(config["corpora"], schema=ConfigSchemaCorpora)
|
||||||
batcher = T_cfg["batcher"]
|
train_corpus = dot_to_object({"corpora": corpora}, T["train_corpus"])
|
||||||
|
dev_corpus = dot_to_object({"corpora": corpora}, T["dev_corpus"])
|
||||||
|
batcher = T["batcher"]
|
||||||
|
train_logger = T["logger"]
|
||||||
|
before_to_disk = create_before_to_disk_callback(T["before_to_disk"])
|
||||||
|
# Components that shouldn't be updated during training
|
||||||
|
frozen_components = T["frozen_components"]
|
||||||
|
|
||||||
|
# Create iterator, which yields out info after each optimization step.
|
||||||
|
msg.info("Start training")
|
||||||
training_step_iterator = train_while_improving(
|
training_step_iterator = train_while_improving(
|
||||||
nlp,
|
nlp,
|
||||||
optimizer,
|
optimizer,
|
||||||
create_train_batches(train_corpus(nlp), batcher, T_cfg["max_epochs"]),
|
create_train_batches(train_corpus(nlp), batcher, T["max_epochs"]),
|
||||||
create_evaluation_callback(nlp, dev_corpus, score_weights),
|
create_evaluation_callback(nlp, dev_corpus, score_weights),
|
||||||
dropout=T_cfg["dropout"],
|
dropout=T["dropout"],
|
||||||
accumulate_gradient=T_cfg["accumulate_gradient"],
|
accumulate_gradient=T["accumulate_gradient"],
|
||||||
patience=T_cfg["patience"],
|
patience=T["patience"],
|
||||||
max_steps=T_cfg["max_steps"],
|
max_steps=T["max_steps"],
|
||||||
eval_frequency=T_cfg["eval_frequency"],
|
eval_frequency=T["eval_frequency"],
|
||||||
raw_text=None,
|
raw_text=None,
|
||||||
exclude=frozen_components,
|
exclude=frozen_components,
|
||||||
)
|
)
|
||||||
|
@ -101,7 +109,7 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
||||||
print_row, finalize_logger = train_logger(nlp)
|
print_row, finalize_logger = train_logger(nlp)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
progress = tqdm.tqdm(total=T["eval_frequency"], leave=False)
|
||||||
progress.set_description(f"Epoch 1")
|
progress.set_description(f"Epoch 1")
|
||||||
for batch, info, is_best_checkpoint in training_step_iterator:
|
for batch, info, is_best_checkpoint in training_step_iterator:
|
||||||
progress.update(1)
|
progress.update(1)
|
||||||
|
@ -110,11 +118,11 @@ def train(nlp: Language, output_path: Optional[Path]=None) -> None:
|
||||||
print_row(info)
|
print_row(info)
|
||||||
if is_best_checkpoint and output_path is not None:
|
if is_best_checkpoint and output_path is not None:
|
||||||
with nlp.select_pipes(disable=frozen_components):
|
with nlp.select_pipes(disable=frozen_components):
|
||||||
update_meta(T_cfg, nlp, info)
|
update_meta(T, nlp, info)
|
||||||
with nlp.use_params(optimizer.averages):
|
with nlp.use_params(optimizer.averages):
|
||||||
nlp = before_to_disk(nlp)
|
nlp = before_to_disk(nlp)
|
||||||
nlp.to_disk(output_path / "model-best")
|
nlp.to_disk(output_path / "model-best")
|
||||||
progress = tqdm.tqdm(total=T_cfg["eval_frequency"], leave=False)
|
progress = tqdm.tqdm(total=T["eval_frequency"], leave=False)
|
||||||
progress.set_description(f"Epoch {info['epoch']}")
|
progress.set_description(f"Epoch {info['epoch']}")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
finalize_logger()
|
finalize_logger()
|
||||||
|
|
|
@ -12,8 +12,10 @@ from .tag_bigram_map import TAG_BIGRAM_MAP
|
||||||
from ...compat import copy_reg
|
from ...compat import copy_reg
|
||||||
from ...errors import Errors
|
from ...errors import Errors
|
||||||
from ...language import Language
|
from ...language import Language
|
||||||
|
from ...scorer import Scorer
|
||||||
from ...symbols import POS
|
from ...symbols import POS
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
|
from ...training import validate_examples
|
||||||
from ...util import DummyTokenizer, registry
|
from ...util import DummyTokenizer, registry
|
||||||
from ... import util
|
from ... import util
|
||||||
|
|
||||||
|
@ -130,6 +132,10 @@ class JapaneseTokenizer(DummyTokenizer):
|
||||||
)
|
)
|
||||||
return sub_tokens_list
|
return sub_tokens_list
|
||||||
|
|
||||||
|
def score(self, examples):
|
||||||
|
validate_examples(examples, "JapaneseTokenizer.score")
|
||||||
|
return Scorer.score_tokenization(examples)
|
||||||
|
|
||||||
def _get_config(self) -> Dict[str, Any]:
|
def _get_config(self) -> Dict[str, Any]:
|
||||||
return {"split_mode": self.split_mode}
|
return {"split_mode": self.split_mode}
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,9 @@ from .lex_attrs import LEX_ATTRS
|
||||||
from ...language import Language
|
from ...language import Language
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
from ...compat import copy_reg
|
from ...compat import copy_reg
|
||||||
|
from ...scorer import Scorer
|
||||||
from ...symbols import POS
|
from ...symbols import POS
|
||||||
|
from ...training import validate_examples
|
||||||
from ...util import DummyTokenizer, registry
|
from ...util import DummyTokenizer, registry
|
||||||
|
|
||||||
|
|
||||||
|
@ -62,6 +64,10 @@ class KoreanTokenizer(DummyTokenizer):
|
||||||
lemma = surface
|
lemma = surface
|
||||||
yield {"surface": surface, "lemma": lemma, "tag": tag}
|
yield {"surface": surface, "lemma": lemma, "tag": tag}
|
||||||
|
|
||||||
|
def score(self, examples):
|
||||||
|
validate_examples(examples, "KoreanTokenizer.score")
|
||||||
|
return Scorer.score_tokenization(examples)
|
||||||
|
|
||||||
|
|
||||||
class KoreanDefaults(Language.Defaults):
|
class KoreanDefaults(Language.Defaults):
|
||||||
config = Config().from_str(DEFAULT_CONFIG)
|
config = Config().from_str(DEFAULT_CONFIG)
|
||||||
|
|
|
@ -8,7 +8,9 @@ from thinc.api import Config
|
||||||
|
|
||||||
from ...errors import Warnings, Errors
|
from ...errors import Warnings, Errors
|
||||||
from ...language import Language
|
from ...language import Language
|
||||||
|
from ...scorer import Scorer
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
|
from ...training import validate_examples
|
||||||
from ...util import DummyTokenizer, registry
|
from ...util import DummyTokenizer, registry
|
||||||
from .lex_attrs import LEX_ATTRS
|
from .lex_attrs import LEX_ATTRS
|
||||||
from .stop_words import STOP_WORDS
|
from .stop_words import STOP_WORDS
|
||||||
|
@ -136,6 +138,10 @@ class ChineseTokenizer(DummyTokenizer):
|
||||||
warn_msg = Warnings.W104.format(target="pkuseg", current=self.segmenter)
|
warn_msg = Warnings.W104.format(target="pkuseg", current=self.segmenter)
|
||||||
warnings.warn(warn_msg)
|
warnings.warn(warn_msg)
|
||||||
|
|
||||||
|
def score(self, examples):
|
||||||
|
validate_examples(examples, "ChineseTokenizer.score")
|
||||||
|
return Scorer.score_tokenization(examples)
|
||||||
|
|
||||||
def _get_config(self) -> Dict[str, Any]:
|
def _get_config(self) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"segmenter": self.segmenter,
|
"segmenter": self.segmenter,
|
||||||
|
|
|
@ -27,7 +27,7 @@ from .lang.punctuation import TOKENIZER_INFIXES
|
||||||
from .tokens import Doc
|
from .tokens import Doc
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .errors import Errors, Warnings
|
from .errors import Errors, Warnings
|
||||||
from .schemas import ConfigSchema
|
from .schemas import ConfigSchema, ConfigSchemaNlp
|
||||||
from .git_info import GIT_VERSION
|
from .git_info import GIT_VERSION
|
||||||
from . import util
|
from . import util
|
||||||
from . import about
|
from . import about
|
||||||
|
@ -166,11 +166,10 @@ class Language:
|
||||||
self._components = []
|
self._components = []
|
||||||
self._disabled = set()
|
self._disabled = set()
|
||||||
self.max_length = max_length
|
self.max_length = max_length
|
||||||
self.resolved = {}
|
|
||||||
# Create the default tokenizer from the default config
|
# Create the default tokenizer from the default config
|
||||||
if not create_tokenizer:
|
if not create_tokenizer:
|
||||||
tokenizer_cfg = {"tokenizer": self._config["nlp"]["tokenizer"]}
|
tokenizer_cfg = {"tokenizer": self._config["nlp"]["tokenizer"]}
|
||||||
create_tokenizer = registry.make_from_config(tokenizer_cfg)["tokenizer"]
|
create_tokenizer = registry.resolve(tokenizer_cfg)["tokenizer"]
|
||||||
self.tokenizer = create_tokenizer(self)
|
self.tokenizer = create_tokenizer(self)
|
||||||
|
|
||||||
def __init_subclass__(cls, **kwargs):
|
def __init_subclass__(cls, **kwargs):
|
||||||
|
@ -467,7 +466,7 @@ class Language:
|
||||||
if "nlp" not in arg_names or "name" not in arg_names:
|
if "nlp" not in arg_names or "name" not in arg_names:
|
||||||
raise ValueError(Errors.E964.format(name=name))
|
raise ValueError(Errors.E964.format(name=name))
|
||||||
# Officially register the factory so we can later call
|
# Officially register the factory so we can later call
|
||||||
# registry.make_from_config and refer to it in the config as
|
# registry.resolve and refer to it in the config as
|
||||||
# @factories = "spacy.Language.xyz". We use the class name here so
|
# @factories = "spacy.Language.xyz". We use the class name here so
|
||||||
# different classes can have different factories.
|
# different classes can have different factories.
|
||||||
registry.factories.register(internal_name, func=factory_func)
|
registry.factories.register(internal_name, func=factory_func)
|
||||||
|
@ -650,8 +649,9 @@ class Language:
|
||||||
cfg = {factory_name: config}
|
cfg = {factory_name: config}
|
||||||
# We're calling the internal _fill here to avoid constructing the
|
# We're calling the internal _fill here to avoid constructing the
|
||||||
# registered functions twice
|
# registered functions twice
|
||||||
resolved, filled = registry.resolve(cfg, validate=validate)
|
resolved = registry.resolve(cfg, validate=validate)
|
||||||
filled = Config(filled[factory_name])
|
filled = registry.fill({"cfg": cfg[factory_name]}, validate=validate)["cfg"]
|
||||||
|
filled = Config(filled)
|
||||||
filled["factory"] = factory_name
|
filled["factory"] = factory_name
|
||||||
filled.pop("@factories", None)
|
filled.pop("@factories", None)
|
||||||
# Remove the extra values we added because we don't want to keep passing
|
# Remove the extra values we added because we don't want to keep passing
|
||||||
|
@ -1518,15 +1518,19 @@ class Language:
|
||||||
config = util.copy_config(config)
|
config = util.copy_config(config)
|
||||||
orig_pipeline = config.pop("components", {})
|
orig_pipeline = config.pop("components", {})
|
||||||
config["components"] = {}
|
config["components"] = {}
|
||||||
resolved, filled = registry.resolve(
|
if auto_fill:
|
||||||
config, validate=validate, schema=ConfigSchema
|
filled = registry.fill(config, validate=validate, schema=ConfigSchema)
|
||||||
)
|
else:
|
||||||
|
filled = config
|
||||||
filled["components"] = orig_pipeline
|
filled["components"] = orig_pipeline
|
||||||
config["components"] = orig_pipeline
|
config["components"] = orig_pipeline
|
||||||
create_tokenizer = resolved["nlp"]["tokenizer"]
|
resolved_nlp = registry.resolve(
|
||||||
before_creation = resolved["nlp"]["before_creation"]
|
filled["nlp"], validate=validate, schema=ConfigSchemaNlp
|
||||||
after_creation = resolved["nlp"]["after_creation"]
|
)
|
||||||
after_pipeline_creation = resolved["nlp"]["after_pipeline_creation"]
|
create_tokenizer = resolved_nlp["tokenizer"]
|
||||||
|
before_creation = resolved_nlp["before_creation"]
|
||||||
|
after_creation = resolved_nlp["after_creation"]
|
||||||
|
after_pipeline_creation = resolved_nlp["after_pipeline_creation"]
|
||||||
lang_cls = cls
|
lang_cls = cls
|
||||||
if before_creation is not None:
|
if before_creation is not None:
|
||||||
lang_cls = before_creation(cls)
|
lang_cls = before_creation(cls)
|
||||||
|
@ -1587,7 +1591,6 @@ class Language:
|
||||||
disabled_pipes = [*config["nlp"]["disabled"], *disable]
|
disabled_pipes = [*config["nlp"]["disabled"], *disable]
|
||||||
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
|
nlp._disabled = set(p for p in disabled_pipes if p not in exclude)
|
||||||
nlp.config = filled if auto_fill else config
|
nlp.config = filled if auto_fill else config
|
||||||
nlp.resolved = resolved
|
|
||||||
if after_pipeline_creation is not None:
|
if after_pipeline_creation is not None:
|
||||||
nlp = after_pipeline_creation(nlp)
|
nlp = after_pipeline_creation(nlp)
|
||||||
if not isinstance(nlp, cls):
|
if not isinstance(nlp, cls):
|
||||||
|
|
|
@ -29,7 +29,8 @@ cdef class Morphology:
|
||||||
FEATURE_SEP = "|"
|
FEATURE_SEP = "|"
|
||||||
FIELD_SEP = "="
|
FIELD_SEP = "="
|
||||||
VALUE_SEP = ","
|
VALUE_SEP = ","
|
||||||
EMPTY_MORPH = "_" # not an empty string so that the PreshMap key is not 0
|
# not an empty string so that the PreshMap key is not 0
|
||||||
|
EMPTY_MORPH = symbols.NAMES[symbols._]
|
||||||
|
|
||||||
def __init__(self, StringStore strings):
|
def __init__(self, StringStore strings):
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
|
|
|
@ -4,6 +4,7 @@ from enum import Enum
|
||||||
from pydantic import BaseModel, Field, ValidationError, validator
|
from pydantic import BaseModel, Field, ValidationError, validator
|
||||||
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
from pydantic import StrictStr, StrictInt, StrictFloat, StrictBool
|
||||||
from pydantic import root_validator
|
from pydantic import root_validator
|
||||||
|
from thinc.config import Promise
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from thinc.api import Optimizer
|
from thinc.api import Optimizer
|
||||||
|
|
||||||
|
@ -16,10 +17,12 @@ if TYPE_CHECKING:
|
||||||
from .training import Example # noqa: F401
|
from .training import Example # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
ItemT = TypeVar("ItemT")
|
ItemT = TypeVar("ItemT")
|
||||||
Batcher = Callable[[Iterable[ItemT]], Iterable[List[ItemT]]]
|
Batcher = Union[Callable[[Iterable[ItemT]], Iterable[List[ItemT]]], Promise]
|
||||||
Reader = Callable[["Language", str], Iterable["Example"]]
|
Reader = Union[Callable[["Language", str], Iterable["Example"]], Promise]
|
||||||
Logger = Callable[["Language"], Tuple[Callable[[Dict[str, Any]], None], Callable]]
|
Logger = Union[Callable[["Language"], Tuple[Callable[[Dict[str, Any]], None], Callable]], Promise]
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
|
def validate(schema: Type[BaseModel], obj: Dict[str, Any]) -> List[str]:
|
||||||
|
@ -292,6 +295,16 @@ class ConfigSchema(BaseModel):
|
||||||
arbitrary_types_allowed = True
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
|
class TrainingSchema(BaseModel):
|
||||||
|
training: ConfigSchemaTraining
|
||||||
|
pretraining: Union[ConfigSchemaPretrain, ConfigSchemaPretrainEmpty] = {}
|
||||||
|
corpora: Dict[str, Reader]
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra = "allow"
|
||||||
|
arbitrary_types_allowed = True
|
||||||
|
|
||||||
|
|
||||||
# Project config Schema
|
# Project config Schema
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -466,3 +466,4 @@ cdef enum symbol_t:
|
||||||
ENT_ID
|
ENT_ID
|
||||||
|
|
||||||
IDX
|
IDX
|
||||||
|
_
|
||||||
|
|
|
@ -465,6 +465,7 @@ IDS = {
|
||||||
"acl": acl,
|
"acl": acl,
|
||||||
"LAW": LAW,
|
"LAW": LAW,
|
||||||
"MORPH": MORPH,
|
"MORPH": MORPH,
|
||||||
|
"_": _,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -24,7 +24,7 @@ def test_doc_add_entities_set_ents_iob(en_vocab):
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
cfg = {"model": DEFAULT_NER_MODEL}
|
cfg = {"model": DEFAULT_NER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
ner = EntityRecognizer(en_vocab, model, **config)
|
ner = EntityRecognizer(en_vocab, model, **config)
|
||||||
ner.begin_training(lambda: [_ner_example(ner)])
|
ner.begin_training(lambda: [_ner_example(ner)])
|
||||||
ner(doc)
|
ner(doc)
|
||||||
|
@ -46,7 +46,7 @@ def test_ents_reset(en_vocab):
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
cfg = {"model": DEFAULT_NER_MODEL}
|
cfg = {"model": DEFAULT_NER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
ner = EntityRecognizer(en_vocab, model, **config)
|
ner = EntityRecognizer(en_vocab, model, **config)
|
||||||
ner.begin_training(lambda: [_ner_example(ner)])
|
ner.begin_training(lambda: [_ner_example(ner)])
|
||||||
ner(doc)
|
ner(doc)
|
||||||
|
|
|
@ -23,7 +23,7 @@ def parser(vocab):
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
parser = DependencyParser(vocab, model, **config)
|
parser = DependencyParser(vocab, model, **config)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -82,7 +82,7 @@ def test_add_label_deserializes_correctly():
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
cfg = {"model": DEFAULT_NER_MODEL}
|
cfg = {"model": DEFAULT_NER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
ner1 = EntityRecognizer(Vocab(), model, **config)
|
ner1 = EntityRecognizer(Vocab(), model, **config)
|
||||||
ner1.add_label("C")
|
ner1.add_label("C")
|
||||||
ner1.add_label("B")
|
ner1.add_label("B")
|
||||||
|
@ -111,7 +111,7 @@ def test_add_label_get_label(pipe_cls, n_moves, model_config):
|
||||||
splitting the move names.
|
splitting the move names.
|
||||||
"""
|
"""
|
||||||
labels = ["A", "B", "C"]
|
labels = ["A", "B", "C"]
|
||||||
model = registry.make_from_config({"model": model_config}, validate=True)["model"]
|
model = registry.resolve({"model": model_config}, validate=True)["model"]
|
||||||
config = {
|
config = {
|
||||||
"learn_tokens": False,
|
"learn_tokens": False,
|
||||||
"min_action_freq": 30,
|
"min_action_freq": 30,
|
||||||
|
|
|
@ -127,7 +127,7 @@ def test_get_oracle_actions():
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
parser = DependencyParser(doc.vocab, model, **config)
|
parser = DependencyParser(doc.vocab, model, **config)
|
||||||
parser.moves.add_action(0, "")
|
parser.moves.add_action(0, "")
|
||||||
parser.moves.add_action(1, "")
|
parser.moves.add_action(1, "")
|
||||||
|
|
|
@ -25,7 +25,7 @@ def arc_eager(vocab):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def tok2vec():
|
def tok2vec():
|
||||||
cfg = {"model": DEFAULT_TOK2VEC_MODEL}
|
cfg = {"model": DEFAULT_TOK2VEC_MODEL}
|
||||||
tok2vec = registry.make_from_config(cfg, validate=True)["model"]
|
tok2vec = registry.resolve(cfg, validate=True)["model"]
|
||||||
tok2vec.initialize()
|
tok2vec.initialize()
|
||||||
return tok2vec
|
return tok2vec
|
||||||
|
|
||||||
|
@ -38,14 +38,14 @@ def parser(vocab, arc_eager):
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
return Parser(vocab, model, moves=arc_eager, **config)
|
return Parser(vocab, model, moves=arc_eager, **config)
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def model(arc_eager, tok2vec, vocab):
|
def model(arc_eager, tok2vec, vocab):
|
||||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
model.attrs["resize_output"](model, arc_eager.n_moves)
|
model.attrs["resize_output"](model, arc_eager.n_moves)
|
||||||
model.initialize()
|
model.initialize()
|
||||||
return model
|
return model
|
||||||
|
@ -72,7 +72,7 @@ def test_build_model(parser, vocab):
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model
|
parser.model = Parser(vocab, model=model, moves=parser.moves, **config).model
|
||||||
assert parser.model is not None
|
assert parser.model is not None
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ def parser(vocab):
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
parser = DependencyParser(vocab, model, **config)
|
parser = DependencyParser(vocab, model, **config)
|
||||||
parser.cfg["token_vector_width"] = 4
|
parser.cfg["token_vector_width"] = 4
|
||||||
parser.cfg["hidden_width"] = 32
|
parser.cfg["hidden_width"] = 32
|
||||||
|
|
|
@ -139,7 +139,7 @@ TRAIN_DATA = [
|
||||||
|
|
||||||
def test_tok2vec_listener():
|
def test_tok2vec_listener():
|
||||||
orig_config = Config().from_str(cfg_string)
|
orig_config = Config().from_str(cfg_string)
|
||||||
nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||||
tagger = nlp.get_pipe("tagger")
|
tagger = nlp.get_pipe("tagger")
|
||||||
tok2vec = nlp.get_pipe("tok2vec")
|
tok2vec = nlp.get_pipe("tok2vec")
|
||||||
|
@ -173,7 +173,7 @@ def test_tok2vec_listener():
|
||||||
|
|
||||||
def test_tok2vec_listener_callback():
|
def test_tok2vec_listener_callback():
|
||||||
orig_config = Config().from_str(cfg_string)
|
orig_config = Config().from_str(cfg_string)
|
||||||
nlp, config = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||||
tagger = nlp.get_pipe("tagger")
|
tagger = nlp.get_pipe("tagger")
|
||||||
tok2vec = nlp.get_pipe("tok2vec")
|
tok2vec = nlp.get_pipe("tok2vec")
|
||||||
|
|
|
@ -195,7 +195,7 @@ def test_issue3345():
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
cfg = {"model": DEFAULT_NER_MODEL}
|
cfg = {"model": DEFAULT_NER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
ner = EntityRecognizer(doc.vocab, model, **config)
|
ner = EntityRecognizer(doc.vocab, model, **config)
|
||||||
# Add the OUT action. I wouldn't have thought this would be necessary...
|
# Add the OUT action. I wouldn't have thought this would be necessary...
|
||||||
ner.moves.add_action(5, "")
|
ner.moves.add_action(5, "")
|
||||||
|
|
|
@ -264,9 +264,7 @@ def test_issue3830_no_subtok():
|
||||||
"min_action_freq": 30,
|
"min_action_freq": 30,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)[
|
model = registry.resolve({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||||
"model"
|
|
||||||
]
|
|
||||||
parser = DependencyParser(Vocab(), model, **config)
|
parser = DependencyParser(Vocab(), model, **config)
|
||||||
parser.add_label("nsubj")
|
parser.add_label("nsubj")
|
||||||
assert "subtok" not in parser.labels
|
assert "subtok" not in parser.labels
|
||||||
|
@ -281,9 +279,7 @@ def test_issue3830_with_subtok():
|
||||||
"min_action_freq": 30,
|
"min_action_freq": 30,
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
model = registry.make_from_config({"model": DEFAULT_PARSER_MODEL}, validate=True)[
|
model = registry.resolve({"model": DEFAULT_PARSER_MODEL}, validate=True)["model"]
|
||||||
"model"
|
|
||||||
]
|
|
||||||
parser = DependencyParser(Vocab(), model, **config)
|
parser = DependencyParser(Vocab(), model, **config)
|
||||||
parser.add_label("nsubj")
|
parser.add_label("nsubj")
|
||||||
assert "subtok" not in parser.labels
|
assert "subtok" not in parser.labels
|
||||||
|
|
|
@ -108,8 +108,8 @@ def my_parser():
|
||||||
def test_create_nlp_from_config():
|
def test_create_nlp_from_config():
|
||||||
config = Config().from_str(nlp_config_string)
|
config = Config().from_str(nlp_config_string)
|
||||||
with pytest.raises(ConfigValidationError):
|
with pytest.raises(ConfigValidationError):
|
||||||
nlp, _ = load_model_from_config(config, auto_fill=False)
|
load_model_from_config(config, auto_fill=False)
|
||||||
nlp, resolved = load_model_from_config(config, auto_fill=True)
|
nlp = load_model_from_config(config, auto_fill=True)
|
||||||
assert nlp.config["training"]["batcher"]["size"] == 666
|
assert nlp.config["training"]["batcher"]["size"] == 666
|
||||||
assert len(nlp.config["training"]) > 1
|
assert len(nlp.config["training"]) > 1
|
||||||
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
assert nlp.pipe_names == ["tok2vec", "tagger"]
|
||||||
|
@ -136,7 +136,7 @@ def test_create_nlp_from_config_multiple_instances():
|
||||||
"tagger2": config["components"]["tagger"],
|
"tagger2": config["components"]["tagger"],
|
||||||
}
|
}
|
||||||
config["nlp"]["pipeline"] = list(config["components"].keys())
|
config["nlp"]["pipeline"] = list(config["components"].keys())
|
||||||
nlp, _ = load_model_from_config(config, auto_fill=True)
|
nlp = load_model_from_config(config, auto_fill=True)
|
||||||
assert nlp.pipe_names == ["t2v", "tagger1", "tagger2"]
|
assert nlp.pipe_names == ["t2v", "tagger1", "tagger2"]
|
||||||
assert nlp.get_pipe_meta("t2v").factory == "tok2vec"
|
assert nlp.get_pipe_meta("t2v").factory == "tok2vec"
|
||||||
assert nlp.get_pipe_meta("tagger1").factory == "tagger"
|
assert nlp.get_pipe_meta("tagger1").factory == "tagger"
|
||||||
|
@ -150,7 +150,7 @@ def test_create_nlp_from_config_multiple_instances():
|
||||||
def test_serialize_nlp():
|
def test_serialize_nlp():
|
||||||
""" Create a custom nlp pipeline from config and ensure it serializes it correctly """
|
""" Create a custom nlp pipeline from config and ensure it serializes it correctly """
|
||||||
nlp_config = Config().from_str(nlp_config_string)
|
nlp_config = Config().from_str(nlp_config_string)
|
||||||
nlp, _ = load_model_from_config(nlp_config, auto_fill=True)
|
nlp = load_model_from_config(nlp_config, auto_fill=True)
|
||||||
nlp.get_pipe("tagger").add_label("A")
|
nlp.get_pipe("tagger").add_label("A")
|
||||||
nlp.begin_training()
|
nlp.begin_training()
|
||||||
assert "tok2vec" in nlp.pipe_names
|
assert "tok2vec" in nlp.pipe_names
|
||||||
|
@ -209,7 +209,7 @@ def test_config_nlp_roundtrip():
|
||||||
nlp = English()
|
nlp = English()
|
||||||
nlp.add_pipe("entity_ruler")
|
nlp.add_pipe("entity_ruler")
|
||||||
nlp.add_pipe("ner")
|
nlp.add_pipe("ner")
|
||||||
new_nlp, new_config = load_model_from_config(nlp.config, auto_fill=False)
|
new_nlp = load_model_from_config(nlp.config, auto_fill=False)
|
||||||
assert new_nlp.config == nlp.config
|
assert new_nlp.config == nlp.config
|
||||||
assert new_nlp.pipe_names == nlp.pipe_names
|
assert new_nlp.pipe_names == nlp.pipe_names
|
||||||
assert new_nlp._pipe_configs == nlp._pipe_configs
|
assert new_nlp._pipe_configs == nlp._pipe_configs
|
||||||
|
@ -280,12 +280,12 @@ def test_config_overrides():
|
||||||
overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]}
|
overrides_dot = {"nlp.lang": "de", "nlp.pipeline": ["tagger"]}
|
||||||
# load_model from config with overrides passed directly to Config
|
# load_model from config with overrides passed directly to Config
|
||||||
config = Config().from_str(nlp_config_string, overrides=overrides_dot)
|
config = Config().from_str(nlp_config_string, overrides=overrides_dot)
|
||||||
nlp, _ = load_model_from_config(config, auto_fill=True)
|
nlp = load_model_from_config(config, auto_fill=True)
|
||||||
assert isinstance(nlp, German)
|
assert isinstance(nlp, German)
|
||||||
assert nlp.pipe_names == ["tagger"]
|
assert nlp.pipe_names == ["tagger"]
|
||||||
# Serialized roundtrip with config passed in
|
# Serialized roundtrip with config passed in
|
||||||
base_config = Config().from_str(nlp_config_string)
|
base_config = Config().from_str(nlp_config_string)
|
||||||
base_nlp, _ = load_model_from_config(base_config, auto_fill=True)
|
base_nlp = load_model_from_config(base_config, auto_fill=True)
|
||||||
assert isinstance(base_nlp, English)
|
assert isinstance(base_nlp, English)
|
||||||
assert base_nlp.pipe_names == ["tok2vec", "tagger"]
|
assert base_nlp.pipe_names == ["tok2vec", "tagger"]
|
||||||
with make_tempdir() as d:
|
with make_tempdir() as d:
|
||||||
|
@ -328,7 +328,7 @@ def test_config_optional_sections():
|
||||||
config = Config().from_str(nlp_config_string)
|
config = Config().from_str(nlp_config_string)
|
||||||
config = DEFAULT_CONFIG.merge(config)
|
config = DEFAULT_CONFIG.merge(config)
|
||||||
assert "pretraining" not in config
|
assert "pretraining" not in config
|
||||||
filled = registry.fill_config(config, schema=ConfigSchema, validate=False)
|
filled = registry.fill(config, schema=ConfigSchema, validate=False)
|
||||||
# Make sure that optional "pretraining" block doesn't default to None,
|
# Make sure that optional "pretraining" block doesn't default to None,
|
||||||
# which would (rightly) cause error because it'd result in a top-level
|
# which would (rightly) cause error because it'd result in a top-level
|
||||||
# key that's not a section (dict). Note that the following roundtrip is
|
# key that's not a section (dict). Note that the following roundtrip is
|
||||||
|
@ -341,7 +341,7 @@ def test_config_auto_fill_extra_fields():
|
||||||
config = Config({"nlp": {"lang": "en"}, "training": {}})
|
config = Config({"nlp": {"lang": "en"}, "training": {}})
|
||||||
assert load_model_from_config(config, auto_fill=True)
|
assert load_model_from_config(config, auto_fill=True)
|
||||||
config = Config({"nlp": {"lang": "en"}, "training": {"extra": "hello"}})
|
config = Config({"nlp": {"lang": "en"}, "training": {"extra": "hello"}})
|
||||||
nlp, _ = load_model_from_config(config, auto_fill=True, validate=False)
|
nlp = load_model_from_config(config, auto_fill=True, validate=False)
|
||||||
assert "extra" not in nlp.config["training"]
|
assert "extra" not in nlp.config["training"]
|
||||||
# Make sure the config generated is valid
|
# Make sure the config generated is valid
|
||||||
load_model_from_config(nlp.config)
|
load_model_from_config(nlp.config)
|
||||||
|
|
|
@ -23,7 +23,7 @@ def parser(en_vocab):
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
parser = DependencyParser(en_vocab, model, **config)
|
parser = DependencyParser(en_vocab, model, **config)
|
||||||
parser.add_label("nsubj")
|
parser.add_label("nsubj")
|
||||||
return parser
|
return parser
|
||||||
|
@ -37,7 +37,7 @@ def blank_parser(en_vocab):
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
parser = DependencyParser(en_vocab, model, **config)
|
parser = DependencyParser(en_vocab, model, **config)
|
||||||
return parser
|
return parser
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ def blank_parser(en_vocab):
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def taggers(en_vocab):
|
def taggers(en_vocab):
|
||||||
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
tagger1 = Tagger(en_vocab, model)
|
tagger1 = Tagger(en_vocab, model)
|
||||||
tagger2 = Tagger(en_vocab, model)
|
tagger2 = Tagger(en_vocab, model)
|
||||||
return tagger1, tagger2
|
return tagger1, tagger2
|
||||||
|
@ -59,7 +59,7 @@ def test_serialize_parser_roundtrip_bytes(en_vocab, Parser):
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
parser = Parser(en_vocab, model, **config)
|
parser = Parser(en_vocab, model, **config)
|
||||||
new_parser = Parser(en_vocab, model, **config)
|
new_parser = Parser(en_vocab, model, **config)
|
||||||
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
|
new_parser = new_parser.from_bytes(parser.to_bytes(exclude=["vocab"]))
|
||||||
|
@ -77,7 +77,7 @@ def test_serialize_parser_roundtrip_disk(en_vocab, Parser):
|
||||||
"update_with_oracle_cut_size": 100,
|
"update_with_oracle_cut_size": 100,
|
||||||
}
|
}
|
||||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
parser = Parser(en_vocab, model, **config)
|
parser = Parser(en_vocab, model, **config)
|
||||||
with make_tempdir() as d:
|
with make_tempdir() as d:
|
||||||
file_path = d / "parser"
|
file_path = d / "parser"
|
||||||
|
@ -111,7 +111,7 @@ def test_serialize_tagger_roundtrip_bytes(en_vocab, taggers):
|
||||||
tagger1 = tagger1.from_bytes(tagger1_b)
|
tagger1 = tagger1.from_bytes(tagger1_b)
|
||||||
assert tagger1.to_bytes() == tagger1_b
|
assert tagger1.to_bytes() == tagger1_b
|
||||||
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b)
|
new_tagger1 = Tagger(en_vocab, model).from_bytes(tagger1_b)
|
||||||
new_tagger1_b = new_tagger1.to_bytes()
|
new_tagger1_b = new_tagger1.to_bytes()
|
||||||
assert len(new_tagger1_b) == len(tagger1_b)
|
assert len(new_tagger1_b) == len(tagger1_b)
|
||||||
|
@ -126,7 +126,7 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
|
||||||
tagger1.to_disk(file_path1)
|
tagger1.to_disk(file_path1)
|
||||||
tagger2.to_disk(file_path2)
|
tagger2.to_disk(file_path2)
|
||||||
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
cfg = {"model": DEFAULT_TAGGER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
tagger1_d = Tagger(en_vocab, model).from_disk(file_path1)
|
tagger1_d = Tagger(en_vocab, model).from_disk(file_path1)
|
||||||
tagger2_d = Tagger(en_vocab, model).from_disk(file_path2)
|
tagger2_d = Tagger(en_vocab, model).from_disk(file_path2)
|
||||||
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
|
assert tagger1_d.to_bytes() == tagger2_d.to_bytes()
|
||||||
|
@ -135,7 +135,7 @@ def test_serialize_tagger_roundtrip_disk(en_vocab, taggers):
|
||||||
def test_serialize_textcat_empty(en_vocab):
|
def test_serialize_textcat_empty(en_vocab):
|
||||||
# See issue #1105
|
# See issue #1105
|
||||||
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
|
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
textcat = TextCategorizer(
|
textcat = TextCategorizer(
|
||||||
en_vocab,
|
en_vocab,
|
||||||
model,
|
model,
|
||||||
|
@ -149,7 +149,7 @@ def test_serialize_textcat_empty(en_vocab):
|
||||||
@pytest.mark.parametrize("Parser", test_parsers)
|
@pytest.mark.parametrize("Parser", test_parsers)
|
||||||
def test_serialize_pipe_exclude(en_vocab, Parser):
|
def test_serialize_pipe_exclude(en_vocab, Parser):
|
||||||
cfg = {"model": DEFAULT_PARSER_MODEL}
|
cfg = {"model": DEFAULT_PARSER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
config = {
|
config = {
|
||||||
"learn_tokens": False,
|
"learn_tokens": False,
|
||||||
"min_action_freq": 0,
|
"min_action_freq": 0,
|
||||||
|
@ -176,7 +176,7 @@ def test_serialize_pipe_exclude(en_vocab, Parser):
|
||||||
|
|
||||||
def test_serialize_sentencerecognizer(en_vocab):
|
def test_serialize_sentencerecognizer(en_vocab):
|
||||||
cfg = {"model": DEFAULT_SENTER_MODEL}
|
cfg = {"model": DEFAULT_SENTER_MODEL}
|
||||||
model = registry.make_from_config(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
sr = SentenceRecognizer(en_vocab, model)
|
sr = SentenceRecognizer(en_vocab, model)
|
||||||
sr_b = sr.to_bytes()
|
sr_b = sr.to_bytes()
|
||||||
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)
|
sr_d = SentenceRecognizer(en_vocab, model).from_bytes(sr_b)
|
||||||
|
|
|
@ -7,6 +7,7 @@ from spacy import util
|
||||||
from spacy import prefer_gpu, require_gpu
|
from spacy import prefer_gpu, require_gpu
|
||||||
from spacy.ml._precomputable_affine import PrecomputableAffine
|
from spacy.ml._precomputable_affine import PrecomputableAffine
|
||||||
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
|
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
|
||||||
|
from thinc.api import Optimizer
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
@ -157,3 +158,16 @@ def test_dot_to_dict(dot_notation, expected):
|
||||||
result = util.dot_to_dict(dot_notation)
|
result = util.dot_to_dict(dot_notation)
|
||||||
assert result == expected
|
assert result == expected
|
||||||
assert util.dict_to_dot(result) == dot_notation
|
assert util.dict_to_dot(result) == dot_notation
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_training_config():
|
||||||
|
config = {
|
||||||
|
"nlp": {"lang": "en", "disabled": []},
|
||||||
|
"training": {"dropout": 0.1, "optimizer": {"@optimizers": "Adam.v1"}},
|
||||||
|
"corpora": {},
|
||||||
|
}
|
||||||
|
resolved = util.resolve_training_config(config)
|
||||||
|
assert resolved["training"]["dropout"] == 0.1
|
||||||
|
assert isinstance(resolved["training"]["optimizer"], Optimizer)
|
||||||
|
assert resolved["corpora"] == {}
|
||||||
|
assert "nlp" not in resolved
|
||||||
|
|
|
@ -82,10 +82,10 @@ def test_util_dot_section():
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
"""
|
"""
|
||||||
nlp_config = Config().from_str(cfg_string)
|
nlp_config = Config().from_str(cfg_string)
|
||||||
en_nlp, en_config = util.load_model_from_config(nlp_config, auto_fill=True)
|
en_nlp = util.load_model_from_config(nlp_config, auto_fill=True)
|
||||||
default_config = Config().from_disk(DEFAULT_CONFIG_PATH)
|
default_config = Config().from_disk(DEFAULT_CONFIG_PATH)
|
||||||
default_config["nlp"]["lang"] = "nl"
|
default_config["nlp"]["lang"] = "nl"
|
||||||
nl_nlp, nl_config = util.load_model_from_config(default_config, auto_fill=True)
|
nl_nlp = util.load_model_from_config(default_config, auto_fill=True)
|
||||||
# Test that creation went OK
|
# Test that creation went OK
|
||||||
assert isinstance(en_nlp, English)
|
assert isinstance(en_nlp, English)
|
||||||
assert isinstance(nl_nlp, Dutch)
|
assert isinstance(nl_nlp, Dutch)
|
||||||
|
@ -94,14 +94,15 @@ def test_util_dot_section():
|
||||||
# not exclusive_classes
|
# not exclusive_classes
|
||||||
assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] is False
|
assert en_nlp.get_pipe("textcat").model.attrs["multi_label"] is False
|
||||||
# Test that default values got overwritten
|
# Test that default values got overwritten
|
||||||
assert en_config["nlp"]["pipeline"] == ["textcat"]
|
assert en_nlp.config["nlp"]["pipeline"] == ["textcat"]
|
||||||
assert nl_config["nlp"]["pipeline"] == [] # default value []
|
assert nl_nlp.config["nlp"]["pipeline"] == [] # default value []
|
||||||
# Test proper functioning of 'dot_to_object'
|
# Test proper functioning of 'dot_to_object'
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
dot_to_object(en_config, "nlp.pipeline.tagger")
|
dot_to_object(en_nlp.config, "nlp.pipeline.tagger")
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
dot_to_object(en_config, "nlp.unknownattribute")
|
dot_to_object(en_nlp.config, "nlp.unknownattribute")
|
||||||
assert isinstance(dot_to_object(nl_config, "training.optimizer"), Optimizer)
|
resolved = util.resolve_training_config(nl_nlp.config)
|
||||||
|
assert isinstance(dot_to_object(resolved, "training.optimizer"), Optimizer)
|
||||||
|
|
||||||
|
|
||||||
def test_simple_frozen_list():
|
def test_simple_frozen_list():
|
||||||
|
|
|
@ -3,6 +3,7 @@ import pytest
|
||||||
from thinc.api import Config
|
from thinc.api import Config
|
||||||
from spacy import Language
|
from spacy import Language
|
||||||
from spacy.util import load_model_from_config, registry, dot_to_object
|
from spacy.util import load_model_from_config, registry, dot_to_object
|
||||||
|
from spacy.util import resolve_training_config
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
|
|
||||||
|
|
||||||
|
@ -37,8 +38,8 @@ def test_readers():
|
||||||
return {"train": reader, "dev": reader, "extra": reader, "something": reader}
|
return {"train": reader, "dev": reader, "extra": reader, "something": reader}
|
||||||
|
|
||||||
config = Config().from_str(config_string)
|
config = Config().from_str(config_string)
|
||||||
nlp, resolved = load_model_from_config(config, auto_fill=True)
|
nlp = load_model_from_config(config, auto_fill=True)
|
||||||
|
resolved = resolve_training_config(nlp.config)
|
||||||
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
|
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
|
||||||
assert isinstance(train_corpus, Callable)
|
assert isinstance(train_corpus, Callable)
|
||||||
optimizer = resolved["training"]["optimizer"]
|
optimizer = resolved["training"]["optimizer"]
|
||||||
|
@ -87,8 +88,8 @@ def test_cat_readers(reader, additional_config):
|
||||||
config = Config().from_str(nlp_config_string)
|
config = Config().from_str(nlp_config_string)
|
||||||
config["corpora"]["@readers"] = reader
|
config["corpora"]["@readers"] = reader
|
||||||
config["corpora"].update(additional_config)
|
config["corpora"].update(additional_config)
|
||||||
nlp, resolved = load_model_from_config(config, auto_fill=True)
|
nlp = load_model_from_config(config, auto_fill=True)
|
||||||
|
resolved = resolve_training_config(nlp.config)
|
||||||
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
|
train_corpus = dot_to_object(resolved, resolved["training"]["train_corpus"])
|
||||||
optimizer = resolved["training"]["optimizer"]
|
optimizer = resolved["training"]["optimizer"]
|
||||||
# simulate a training loop
|
# simulate a training loop
|
||||||
|
|
|
@ -86,7 +86,7 @@ class registry(thinc.registry):
|
||||||
# spacy_factories entry point. This registry only exists so we can easily
|
# spacy_factories entry point. This registry only exists so we can easily
|
||||||
# load them via the entry points. The "true" factories are added via the
|
# load them via the entry points. The "true" factories are added via the
|
||||||
# Language.factory decorator (in the spaCy code base and user code) and those
|
# Language.factory decorator (in the spaCy code base and user code) and those
|
||||||
# are the factories used to initialize components via registry.make_from_config.
|
# are the factories used to initialize components via registry.resolve.
|
||||||
_entry_point_factories = catalogue.create("spacy", "factories", entry_points=True)
|
_entry_point_factories = catalogue.create("spacy", "factories", entry_points=True)
|
||||||
factories = catalogue.create("spacy", "internal_factories")
|
factories = catalogue.create("spacy", "internal_factories")
|
||||||
# This is mostly used to get a list of all installed models in the current
|
# This is mostly used to get a list of all installed models in the current
|
||||||
|
@ -351,9 +351,7 @@ def load_model_from_path(
|
||||||
meta = get_model_meta(model_path)
|
meta = get_model_meta(model_path)
|
||||||
config_path = model_path / "config.cfg"
|
config_path = model_path / "config.cfg"
|
||||||
config = load_config(config_path, overrides=dict_to_dot(config))
|
config = load_config(config_path, overrides=dict_to_dot(config))
|
||||||
nlp, _ = load_model_from_config(
|
nlp = load_model_from_config(config, vocab=vocab, disable=disable, exclude=exclude)
|
||||||
config, vocab=vocab, disable=disable, exclude=exclude
|
|
||||||
)
|
|
||||||
return nlp.from_disk(model_path, exclude=exclude)
|
return nlp.from_disk(model_path, exclude=exclude)
|
||||||
|
|
||||||
|
|
||||||
|
@ -365,7 +363,7 @@ def load_model_from_config(
|
||||||
exclude: Iterable[str] = SimpleFrozenList(),
|
exclude: Iterable[str] = SimpleFrozenList(),
|
||||||
auto_fill: bool = False,
|
auto_fill: bool = False,
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
) -> Tuple["Language", Config]:
|
) -> "Language":
|
||||||
"""Create an nlp object from a config. Expects the full config file including
|
"""Create an nlp object from a config. Expects the full config file including
|
||||||
a section "nlp" containing the settings for the nlp object.
|
a section "nlp" containing the settings for the nlp object.
|
||||||
|
|
||||||
|
@ -398,7 +396,30 @@ def load_model_from_config(
|
||||||
auto_fill=auto_fill,
|
auto_fill=auto_fill,
|
||||||
validate=validate,
|
validate=validate,
|
||||||
)
|
)
|
||||||
return nlp, nlp.resolved
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_training_config(
|
||||||
|
config: Config,
|
||||||
|
exclude: Iterable[str] = ("nlp", "components"),
|
||||||
|
validate: bool = True,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""Resolve the config sections relevant for trainig and create all objects.
|
||||||
|
Mostly used in the CLI to separate training config (not resolved by default
|
||||||
|
because not runtime-relevant – an nlp object should load fine even if it's
|
||||||
|
[training] block refers to functions that are not available etc.).
|
||||||
|
|
||||||
|
config (Config): The config to resolve.
|
||||||
|
exclude (Iterable[str]): The config blocks to exclude. Those blocks won't
|
||||||
|
be available in the final resolved config.
|
||||||
|
validate (bool): Whether to validate the config.
|
||||||
|
RETURNS (Dict[str, Any]): The resolved config.
|
||||||
|
"""
|
||||||
|
config = config.copy()
|
||||||
|
for key in exclude:
|
||||||
|
if key in config:
|
||||||
|
config.pop(key)
|
||||||
|
return registry.resolve(config, validate=validate)
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_init_py(
|
def load_model_from_init_py(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user