mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
Allow adding pipeline components from source model (#5857)
* Allow adding pipeline components from source model * Config: name -> component * Improve error messages * Fix error and test * Add frozen components and exclude logic * Remove exclude from Language.evaluate * Init sourced components with current vocab * Fix error codes
This commit is contained in:
parent
34873c4911
commit
b795f02fbd
|
@ -19,6 +19,7 @@ max_epochs = 0
|
|||
patience = 10000
|
||||
eval_frequency = 200
|
||||
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
|
||||
frozen_components = []
|
||||
|
||||
[training.train_corpus]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
|
|
|
@ -6,7 +6,7 @@ import hashlib
|
|||
import typer
|
||||
from typer.main import get_command
|
||||
from contextlib import contextmanager
|
||||
from thinc.config import ConfigValidationError
|
||||
from thinc.config import Config, ConfigValidationError
|
||||
from configparser import InterpolationError
|
||||
import sys
|
||||
|
||||
|
@ -217,3 +217,15 @@ def import_code(code_path: Optional[Union[Path, str]]) -> None:
|
|||
import_file("python_code", code_path)
|
||||
except Exception as e:
|
||||
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
||||
|
||||
|
||||
def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
|
||||
"""RETURNS (List[str]): All sourced components in the original config,
|
||||
e.g. {"source": "en_core_web_sm"}. If the config contains a key
|
||||
"factory", we assume it refers to a component factory.
|
||||
"""
|
||||
return [
|
||||
name
|
||||
for name, cfg in config.get("components", {}).items()
|
||||
if "factory" not in cfg and "source" in cfg
|
||||
]
|
||||
|
|
|
@ -8,7 +8,7 @@ import typer
|
|||
from thinc.api import Config
|
||||
|
||||
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
||||
from ._util import import_code, debug_cli
|
||||
from ._util import import_code, debug_cli, get_sourced_components
|
||||
from ..gold import Corpus, Example
|
||||
from ..pipeline._parser_internals import nonproj
|
||||
from ..language import Language
|
||||
|
@ -138,9 +138,10 @@ def debug_data(
|
|||
with show_validation_error(config_path):
|
||||
cfg = Config().from_disk(config_path)
|
||||
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides)
|
||||
# TODO: handle base model
|
||||
lang = config["nlp"]["lang"]
|
||||
base_model = config["training"]["base_model"]
|
||||
# Use original config here, not resolved version
|
||||
sourced_components = get_sourced_components(cfg)
|
||||
frozen_components = config["training"]["frozen_components"]
|
||||
resume_components = [p for p in sourced_components if p not in frozen_components]
|
||||
pipeline = nlp.pipe_names
|
||||
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
|
||||
tag_map_path = util.ensure_path(config["training"]["tag_map"])
|
||||
|
@ -187,13 +188,15 @@ def debug_data(
|
|||
|
||||
train_texts = gold_train_data["texts"]
|
||||
dev_texts = gold_dev_data["texts"]
|
||||
frozen_components = config["training"]["frozen_components"]
|
||||
|
||||
msg.divider("Training stats")
|
||||
msg.text(f"Language: {config['nlp']['lang']}")
|
||||
msg.text(f"Training pipeline: {', '.join(pipeline)}")
|
||||
if base_model:
|
||||
msg.text(f"Starting with base model '{base_model}'")
|
||||
else:
|
||||
msg.text(f"Starting with blank model '{lang}'")
|
||||
if resume_components:
|
||||
msg.text(f"Components from other models: {', '.join(resume_components)}")
|
||||
if frozen_components:
|
||||
msg.text(f"Frozen components: {', '.join(frozen_components)}")
|
||||
msg.text(f"{len(train_dataset)} training docs")
|
||||
msg.text(f"{len(dev_dataset)} evaluation docs")
|
||||
|
||||
|
@ -204,7 +207,9 @@ def debug_data(
|
|||
msg.warn(f"{overlap} training examples also in evaluation data")
|
||||
else:
|
||||
msg.good("No overlap between training and evaluation data")
|
||||
if not base_model and len(train_dataset) < BLANK_MODEL_THRESHOLD:
|
||||
# TODO: make this feedback more fine-grained and report on updated
|
||||
# components vs. blank components
|
||||
if not resume_components and len(train_dataset) < BLANK_MODEL_THRESHOLD:
|
||||
text = (
|
||||
f"Low number of examples to train from a blank model ({len(train_dataset)})"
|
||||
)
|
||||
|
|
|
@ -11,7 +11,7 @@ import random
|
|||
import typer
|
||||
|
||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||
from ._util import import_code
|
||||
from ._util import import_code, get_sourced_components
|
||||
from ..language import Language
|
||||
from .. import util
|
||||
from ..gold.example import Example
|
||||
|
@ -78,6 +78,8 @@ def train(
|
|||
config = Config().from_disk(config_path)
|
||||
if config.get("training", {}).get("seed") is not None:
|
||||
fix_random_seed(config["training"]["seed"])
|
||||
# Use original config here before it's resolved to functions
|
||||
sourced_components = get_sourced_components(config)
|
||||
with show_validation_error(config_path):
|
||||
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
|
||||
if config["training"]["vectors"] is not None:
|
||||
|
@ -92,11 +94,16 @@ def train(
|
|||
train_corpus = T_cfg["train_corpus"]
|
||||
dev_corpus = T_cfg["dev_corpus"]
|
||||
batcher = T_cfg["batcher"]
|
||||
if resume_training:
|
||||
msg.info("Resuming training")
|
||||
# Components that shouldn't be updated during training
|
||||
frozen_components = T_cfg["frozen_components"]
|
||||
# Sourced components that require resume_training
|
||||
resume_components = [p for p in sourced_components if p not in frozen_components]
|
||||
msg.info(f"Pipeline: {nlp.pipe_names}")
|
||||
if resume_components:
|
||||
with nlp.select_pipes(enable=resume_components):
|
||||
msg.info(f"Resuming training for: {resume_components}")
|
||||
nlp.resume_training()
|
||||
else:
|
||||
msg.info(f"Initializing the nlp pipeline: {nlp.pipe_names}")
|
||||
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||
nlp.begin_training(lambda: train_corpus(nlp))
|
||||
|
||||
if tag_map:
|
||||
|
@ -136,7 +143,7 @@ def train(
|
|||
patience=T_cfg["patience"],
|
||||
max_steps=T_cfg["max_steps"],
|
||||
eval_frequency=T_cfg["eval_frequency"],
|
||||
raw_text=None
|
||||
raw_text=None,
|
||||
)
|
||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||
print_row = setup_printer(T_cfg, nlp)
|
||||
|
@ -192,9 +199,7 @@ def create_train_batches(iterator, batcher, max_epochs: int):
|
|||
|
||||
|
||||
def create_evaluation_callback(
|
||||
nlp: Language,
|
||||
dev_corpus: Callable,
|
||||
weights: Dict[str, float],
|
||||
nlp: Language, dev_corpus: Callable, weights: Dict[str, float],
|
||||
) -> Callable[[], Tuple[float, Dict[str, float]]]:
|
||||
def evaluate() -> Tuple[float, Dict[str, float]]:
|
||||
dev_examples = list(dev_corpus(nlp))
|
||||
|
@ -223,6 +228,7 @@ def train_while_improving(
|
|||
patience: int,
|
||||
max_steps: int,
|
||||
raw_text: List[Dict[str, str]],
|
||||
exclude: List[str],
|
||||
):
|
||||
"""Train until an evaluation stops improving. Works as a generator,
|
||||
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
||||
|
@ -268,8 +274,6 @@ def train_while_improving(
|
|||
dropouts = dropout
|
||||
results = []
|
||||
losses = {}
|
||||
to_enable = [name for name, proc in nlp.pipeline if hasattr(proc, "model")]
|
||||
|
||||
if raw_text:
|
||||
random.shuffle(raw_text)
|
||||
raw_examples = [
|
||||
|
@ -279,16 +283,18 @@ def train_while_improving(
|
|||
|
||||
for step, (epoch, batch) in enumerate(train_data):
|
||||
dropout = next(dropouts)
|
||||
with nlp.select_pipes(enable=to_enable):
|
||||
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||
nlp.update(subbatch, drop=dropout, losses=losses, sgd=False)
|
||||
nlp.update(
|
||||
subbatch, drop=dropout, losses=losses, sgd=False, exclude=exclude
|
||||
)
|
||||
if raw_text:
|
||||
# If raw text is available, perform 'rehearsal' updates,
|
||||
# which use unlabelled data to reduce overfitting.
|
||||
raw_batch = list(next(raw_batches))
|
||||
nlp.rehearse(raw_batch, sgd=optimizer, losses=losses)
|
||||
nlp.rehearse(raw_batch, sgd=optimizer, losses=losses, exclude=exclude)
|
||||
# TODO: refactor this so we don't have to run it separately in here
|
||||
for name, proc in nlp.pipeline:
|
||||
if hasattr(proc, "model"):
|
||||
if name not in exclude and hasattr(proc, "model"):
|
||||
proc.model.finish_update(optimizer)
|
||||
optimizer.step_schedules()
|
||||
if not (step % eval_frequency):
|
||||
|
@ -418,10 +424,7 @@ def load_from_paths(
|
|||
return raw_text, tag_map, morph_rules, weights_data
|
||||
|
||||
|
||||
def verify_cli_args(
|
||||
config_path: Path,
|
||||
output_path: Optional[Path] = None,
|
||||
) -> None:
|
||||
def verify_cli_args(config_path: Path, output_path: Optional[Path] = None,) -> None:
|
||||
# Make sure all files and paths exists if they are needed
|
||||
if not config_path or not config_path.exists():
|
||||
msg.fail("Config file not found", config_path, exits=1)
|
||||
|
|
|
@ -37,6 +37,8 @@ max_steps = 20000
|
|||
eval_frequency = 200
|
||||
# Control how scores are printed and checkpoints are evaluated.
|
||||
score_weights = {}
|
||||
# Names of pipeline components that shouldn't be updated during training
|
||||
frozen_components = []
|
||||
|
||||
[training.train_corpus]
|
||||
@readers = "spacy.Corpus.v1"
|
||||
|
|
|
@ -482,6 +482,10 @@ class Errors:
|
|||
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
||||
|
||||
# TODO: fix numbering after merging develop into master
|
||||
E944 = ("Can't copy pipeline component '{name}' from source model '{model}': "
|
||||
"not found in pipeline. Available components: {opts}")
|
||||
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
|
||||
"nlp object, but got: {source}")
|
||||
E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to "
|
||||
"call kb.initialize()?")
|
||||
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
|
||||
|
@ -571,11 +575,13 @@ class Errors:
|
|||
"into {values}, but found {value}.")
|
||||
E983 = ("Invalid key for '{dict}': {key}. Available keys: "
|
||||
"{keys}")
|
||||
E984 = ("Invalid component config for '{name}': no 'factory' key "
|
||||
"specifying the registered function used to initialize the "
|
||||
"component. For example, factory = \"ner\" will use the 'ner' "
|
||||
"factory and all other settings in the block will be passed "
|
||||
"to it as arguments.\n\n{config}")
|
||||
E984 = ("Invalid component config for '{name}': component block needs either "
|
||||
"a key 'factory' specifying the registered function used to "
|
||||
"initialize the component, or a key 'source' key specifying a "
|
||||
"spaCy model to copy the component from. For example, factory = "
|
||||
"\"ner\" will use the 'ner' factory and all other settings in the "
|
||||
"block will be passed to it as arguments. Alternatively, source = "
|
||||
"\"en_core_web_sm\" will copy the component from that model.\n\n{config}")
|
||||
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")
|
||||
E986 = ("Could not create any training batches: check your input. "
|
||||
"Perhaps discard_oversize should be set to False ?")
|
||||
|
|
|
@ -620,6 +620,32 @@ class Language:
|
|||
self._pipe_configs[name] = filled
|
||||
return resolved[factory_name]
|
||||
|
||||
def create_pipe_from_source(
|
||||
self, source_name: str, source: "Language", *, name: str,
|
||||
) -> Tuple[Callable[[Doc], Doc], str]:
|
||||
"""Create a pipeline component by copying it from an existing model.
|
||||
|
||||
source_name (str): Name of the component in the source pipeline.
|
||||
source (Language): The source nlp object to copy from.
|
||||
name (str): Optional alternative name to use in current pipeline.
|
||||
RETURNS (Tuple[Callable, str]): The component and its factory name.
|
||||
"""
|
||||
# TODO: handle errors and mismatches (vectors etc.)
|
||||
if not isinstance(source, self.__class__):
|
||||
raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
|
||||
if not source.has_pipe(source_name):
|
||||
raise KeyError(
|
||||
Errors.E944.format(
|
||||
name=source_name,
|
||||
model=f"{source.meta['lang']}_{source.meta['name']}",
|
||||
opts=", ".join(source.pipe_names),
|
||||
)
|
||||
)
|
||||
pipe = source.get_pipe(source_name)
|
||||
pipe_config = util.copy_config(source.config["components"][source_name])
|
||||
self._pipe_configs[name] = pipe_config
|
||||
return pipe, pipe_config["factory"]
|
||||
|
||||
def add_pipe(
|
||||
self,
|
||||
factory_name: str,
|
||||
|
@ -629,6 +655,7 @@ class Language:
|
|||
after: Optional[Union[str, int]] = None,
|
||||
first: Optional[bool] = None,
|
||||
last: Optional[bool] = None,
|
||||
source: Optional["Language"] = None,
|
||||
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||
validate: bool = True,
|
||||
|
@ -648,6 +675,8 @@ class Language:
|
|||
component directly after.
|
||||
first (bool): If True, insert component first in the pipeline.
|
||||
last (bool): If True, insert component last in the pipeline.
|
||||
source (Language): Optional loaded nlp object to copy the pipeline
|
||||
component from.
|
||||
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
||||
component. Will be merged with default config, if available.
|
||||
overrides (Optional[Dict[str, Any]]): Config overrides, typically
|
||||
|
@ -662,6 +691,16 @@ class Language:
|
|||
bad_val = repr(factory_name)
|
||||
err = Errors.E966.format(component=bad_val, name=name)
|
||||
raise ValueError(err)
|
||||
name = name if name is not None else factory_name
|
||||
if name in self.pipe_names:
|
||||
raise ValueError(Errors.E007.format(name=name, opts=self.pipe_names))
|
||||
if source is not None:
|
||||
# We're loading the component from a model. After loading the
|
||||
# component, we know its real factory name
|
||||
pipe_component, factory_name = self.create_pipe_from_source(
|
||||
factory_name, source, name=name
|
||||
)
|
||||
else:
|
||||
if not self.has_factory(factory_name):
|
||||
err = Errors.E002.format(
|
||||
name=factory_name,
|
||||
|
@ -670,9 +709,6 @@ class Language:
|
|||
lang=util.get_object_name(self),
|
||||
lang_code=self.lang,
|
||||
)
|
||||
name = name if name is not None else factory_name
|
||||
if name in self.pipe_names:
|
||||
raise ValueError(Errors.E007.format(name=name, opts=self.pipe_names))
|
||||
pipe_component = self.create_pipe(
|
||||
factory_name,
|
||||
name=name,
|
||||
|
@ -911,6 +947,7 @@ class Language:
|
|||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
exclude: Iterable[str] = tuple(),
|
||||
):
|
||||
"""Update the models in the pipeline.
|
||||
|
||||
|
@ -921,6 +958,7 @@ class Language:
|
|||
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
|
||||
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
|
||||
components, keyed by component name.
|
||||
exclude (Iterable[str]): Names of components that shouldn't be updated.
|
||||
RETURNS (Dict[str, float]): The updated losses dictionary
|
||||
|
||||
DOCS: https://spacy.io/api/language#update
|
||||
|
@ -953,12 +991,12 @@ class Language:
|
|||
component_cfg[name].setdefault("drop", drop)
|
||||
component_cfg[name].setdefault("set_annotations", False)
|
||||
for name, proc in self.pipeline:
|
||||
if not hasattr(proc, "update"):
|
||||
if name in exclude or not hasattr(proc, "update"):
|
||||
continue
|
||||
proc.update(examples, sgd=None, losses=losses, **component_cfg[name])
|
||||
if sgd not in (None, False):
|
||||
for name, proc in self.pipeline:
|
||||
if hasattr(proc, "model"):
|
||||
if name not in exclude and hasattr(proc, "model"):
|
||||
proc.model.finish_update(sgd)
|
||||
return losses
|
||||
|
||||
|
@ -969,6 +1007,7 @@ class Language:
|
|||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None,
|
||||
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||
exclude: Iterable[str] = tuple(),
|
||||
) -> Dict[str, float]:
|
||||
"""Make a "rehearsal" update to the models in the pipeline, to prevent
|
||||
forgetting. Rehearsal updates run an initial copy of the model over some
|
||||
|
@ -980,6 +1019,7 @@ class Language:
|
|||
sgd (Optional[Optimizer]): An optimizer.
|
||||
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
|
||||
components, keyed by component name.
|
||||
exclude (Iterable[str]): Names of components that shouldn't be updated.
|
||||
RETURNS (dict): Results from the update.
|
||||
|
||||
EXAMPLE:
|
||||
|
@ -1023,7 +1063,7 @@ class Language:
|
|||
get_grads.b1 = sgd.b1
|
||||
get_grads.b2 = sgd.b2
|
||||
for name, proc in pipes:
|
||||
if not hasattr(proc, "rehearse"):
|
||||
if name in exclude or not hasattr(proc, "rehearse"):
|
||||
continue
|
||||
grads = {}
|
||||
proc.rehearse(
|
||||
|
@ -1074,7 +1114,7 @@ class Language:
|
|||
return self._optimizer
|
||||
|
||||
def resume_training(
|
||||
self, *, sgd: Optional[Optimizer] = None, device: int = -1
|
||||
self, *, sgd: Optional[Optimizer] = None, device: int = -1,
|
||||
) -> Optimizer:
|
||||
"""Continue training a pretrained model.
|
||||
|
||||
|
@ -1373,6 +1413,7 @@ class Language:
|
|||
cls,
|
||||
config: Union[Dict[str, Any], Config] = {},
|
||||
*,
|
||||
vocab: Union[Vocab, bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
overrides: Dict[str, Any] = {},
|
||||
auto_fill: bool = True,
|
||||
|
@ -1383,6 +1424,7 @@ class Language:
|
|||
the default config of the given language is used.
|
||||
|
||||
config (Dict[str, Any] / Config): The loaded config.
|
||||
vocab (Vocab): A Vocab object. If True, a vocab is created.
|
||||
disable (Iterable[str]): List of pipeline component names to disable.
|
||||
auto_fill (bool): Automatically fill in missing values in config based
|
||||
on defaults and function argument annotations.
|
||||
|
@ -1422,25 +1464,31 @@ class Language:
|
|||
create_tokenizer = resolved["nlp"]["tokenizer"]
|
||||
create_lemmatizer = resolved["nlp"]["lemmatizer"]
|
||||
nlp = cls(
|
||||
create_tokenizer=create_tokenizer, create_lemmatizer=create_lemmatizer,
|
||||
vocab=vocab,
|
||||
create_tokenizer=create_tokenizer,
|
||||
create_lemmatizer=create_lemmatizer,
|
||||
)
|
||||
# Note that we don't load vectors here, instead they get loaded explicitly
|
||||
# inside stuff like the spacy train function. If we loaded them here,
|
||||
# then we would load them twice at runtime: once when we make from config,
|
||||
# and then again when we load from disk.
|
||||
pipeline = config.get("components", {})
|
||||
# If components are loaded from a source (existing models), we cache
|
||||
# them here so they're only loaded once
|
||||
source_nlps = {}
|
||||
for pipe_name in config["nlp"]["pipeline"]:
|
||||
if pipe_name not in pipeline:
|
||||
opts = ", ".join(pipeline.keys())
|
||||
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
|
||||
pipe_cfg = util.copy_config(pipeline[pipe_name])
|
||||
if pipe_name not in disable:
|
||||
if "factory" not in pipe_cfg:
|
||||
if "factory" not in pipe_cfg and "source" not in pipe_cfg:
|
||||
err = Errors.E984.format(name=pipe_name, config=pipe_cfg)
|
||||
raise ValueError(err)
|
||||
if "factory" in pipe_cfg:
|
||||
factory = pipe_cfg.pop("factory")
|
||||
# The pipe name (key in the config) here is the unique name of the
|
||||
# component, not necessarily the factory
|
||||
# The pipe name (key in the config) here is the unique name
|
||||
# of the component, not necessarily the factory
|
||||
nlp.add_pipe(
|
||||
factory,
|
||||
name=pipe_name,
|
||||
|
@ -1448,6 +1496,16 @@ class Language:
|
|||
overrides=pipe_overrides,
|
||||
validate=validate,
|
||||
)
|
||||
else:
|
||||
model = pipe_cfg["source"]
|
||||
if model not in source_nlps:
|
||||
# We only need the components here and we need to init
|
||||
# model with the same vocab as the current nlp object
|
||||
source_nlps[model] = util.load_model(
|
||||
model, vocab=nlp.vocab, disable=["vocab", "tokenizer"]
|
||||
)
|
||||
source_name = pipe_cfg.get("component", pipe_name)
|
||||
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
|
||||
nlp.config = filled if auto_fill else config
|
||||
nlp.resolved = resolved
|
||||
return nlp
|
||||
|
|
|
@ -202,6 +202,7 @@ class ConfigSchemaTraining(BaseModel):
|
|||
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
|
||||
raw_text: Optional[StrictStr] = Field(default=None, title="Raw text")
|
||||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
|
|
|
@ -8,6 +8,8 @@ from thinc.api import Model, Linear
|
|||
from thinc.config import ConfigValidationError
|
||||
from pydantic import StrictInt, StrictStr
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
||||
def test_pipe_function_component():
|
||||
name = "test_component"
|
||||
|
@ -374,3 +376,65 @@ def test_language_factories_scores():
|
|||
cfg = nlp.config["training"]
|
||||
expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05}
|
||||
assert cfg["score_weights"] == expected_weights
|
||||
|
||||
|
||||
def test_pipe_factories_from_source():
|
||||
"""Test adding components from a source model."""
|
||||
source_nlp = English()
|
||||
source_nlp.add_pipe("tagger", name="my_tagger")
|
||||
nlp = English()
|
||||
with pytest.raises(ValueError):
|
||||
nlp.add_pipe("my_tagger", source="en_core_web_sm")
|
||||
nlp.add_pipe("my_tagger", source=source_nlp)
|
||||
assert "my_tagger" in nlp.pipe_names
|
||||
with pytest.raises(KeyError):
|
||||
nlp.add_pipe("custom", source=source_nlp)
|
||||
|
||||
|
||||
def test_pipe_factories_from_source_custom():
|
||||
"""Test adding components from a source model with custom components."""
|
||||
name = "test_pipe_factories_from_source_custom"
|
||||
|
||||
@Language.factory(name, default_config={"arg": "hello"})
|
||||
def test_factory(nlp, name, arg: str):
|
||||
return lambda doc: doc
|
||||
|
||||
source_nlp = English()
|
||||
source_nlp.add_pipe("tagger")
|
||||
source_nlp.add_pipe(name, config={"arg": "world"})
|
||||
nlp = English()
|
||||
nlp.add_pipe(name, source=source_nlp)
|
||||
assert name in nlp.pipe_names
|
||||
assert nlp.get_pipe_meta(name).default_config["arg"] == "hello"
|
||||
config = nlp.config["components"][name]
|
||||
assert config["factory"] == name
|
||||
assert config["arg"] == "world"
|
||||
|
||||
|
||||
def test_pipe_factories_from_source_config():
|
||||
name = "test_pipe_factories_from_source_config"
|
||||
|
||||
@Language.factory(name, default_config={"arg": "hello"})
|
||||
def test_factory(nlp, name, arg: str):
|
||||
return lambda doc: doc
|
||||
|
||||
source_nlp = English()
|
||||
source_nlp.add_pipe("tagger")
|
||||
source_nlp.add_pipe(name, name="yolo", config={"arg": "world"})
|
||||
dest_nlp_cfg = {"lang": "en", "pipeline": ["parser", "custom"]}
|
||||
with make_tempdir() as tempdir:
|
||||
source_nlp.to_disk(tempdir)
|
||||
dest_components_cfg = {
|
||||
"parser": {"factory": "parser"},
|
||||
"custom": {"source": str(tempdir), "component": "yolo"},
|
||||
}
|
||||
dest_config = {"nlp": dest_nlp_cfg, "components": dest_components_cfg}
|
||||
nlp = English.from_config(dest_config)
|
||||
assert nlp.pipe_names == ["parser", "custom"]
|
||||
assert nlp.pipe_factories == {"parser": "parser", "custom": name}
|
||||
meta = nlp.get_pipe_meta("custom")
|
||||
assert meta.factory == name
|
||||
assert meta.default_config["arg"] == "hello"
|
||||
config = nlp.config["components"]["custom"]
|
||||
assert config["factory"] == name
|
||||
assert config["arg"] == "world"
|
||||
|
|
|
@ -205,43 +205,51 @@ def load_vectors_into_model(
|
|||
|
||||
def load_model(
|
||||
name: Union[str, Path],
|
||||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
"""Load a model from a package or data path.
|
||||
|
||||
name (str): Package name or model path.
|
||||
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
|
||||
a new Vocab object will be created.
|
||||
disable (Iterable[str]): Names of pipeline components to disable.
|
||||
component_cfg (Dict[str, dict]): Config overrides for pipeline components,
|
||||
keyed by component names.
|
||||
RETURNS (Language): The loaded nlp object.
|
||||
"""
|
||||
cfg = component_cfg
|
||||
kwargs = {"vocab": vocab, "disable": disable, "component_cfg": component_cfg}
|
||||
if isinstance(name, str): # name or string path
|
||||
if name.startswith("blank:"): # shortcut for blank model
|
||||
return get_lang_class(name.replace("blank:", ""))()
|
||||
if is_package(name): # installed as package
|
||||
return load_model_from_package(name, disable=disable, component_cfg=cfg)
|
||||
return load_model_from_package(name, **kwargs)
|
||||
if Path(name).exists(): # path to model data directory
|
||||
return load_model_from_path(Path(name), disable=disable, component_cfg=cfg)
|
||||
return load_model_from_path(Path(name), **kwargs)
|
||||
elif hasattr(name, "exists"): # Path or Path-like to model data
|
||||
return load_model_from_path(name, disable=disable, component_cfg=cfg)
|
||||
return load_model_from_path(name, **kwargs)
|
||||
raise IOError(Errors.E050.format(name=name))
|
||||
|
||||
|
||||
def load_model_from_package(
|
||||
name: str,
|
||||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
"""Load a model from an installed package."""
|
||||
cls = importlib.import_module(name)
|
||||
return cls.load(disable=disable, component_cfg=component_cfg)
|
||||
return cls.load(vocab=vocab, disable=disable, component_cfg=component_cfg)
|
||||
|
||||
|
||||
def load_model_from_path(
|
||||
model_path: Union[str, Path],
|
||||
*,
|
||||
meta: Optional[Dict[str, Any]] = None,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
|
@ -257,12 +265,16 @@ def load_model_from_path(
|
|||
config = Config().from_disk(config_path)
|
||||
override_cfg = {"components": {p: dict_to_dot(c) for p, c in component_cfg.items()}}
|
||||
overrides = dict_to_dot(override_cfg)
|
||||
nlp, _ = load_model_from_config(config, disable=disable, overrides=overrides)
|
||||
nlp, _ = load_model_from_config(
|
||||
config, vocab=vocab, disable=disable, overrides=overrides
|
||||
)
|
||||
return nlp.from_disk(model_path, exclude=disable)
|
||||
|
||||
|
||||
def load_model_from_config(
|
||||
config: Union[Dict[str, Any], Config],
|
||||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
overrides: Dict[str, Any] = {},
|
||||
auto_fill: bool = False,
|
||||
|
@ -281,6 +293,7 @@ def load_model_from_config(
|
|||
lang_cls = get_lang_class(nlp_config["lang"])
|
||||
nlp = lang_cls.from_config(
|
||||
config,
|
||||
vocab=vocab,
|
||||
disable=disable,
|
||||
overrides=overrides,
|
||||
auto_fill=auto_fill,
|
||||
|
@ -291,6 +304,8 @@ def load_model_from_config(
|
|||
|
||||
def load_model_from_init_py(
|
||||
init_file: Union[Path, str],
|
||||
*,
|
||||
vocab: Union["Vocab", bool] = True,
|
||||
disable: Iterable[str] = tuple(),
|
||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||
) -> "Language":
|
||||
|
@ -308,7 +323,7 @@ def load_model_from_init_py(
|
|||
if not model_path.exists():
|
||||
raise IOError(Errors.E052.format(path=data_path))
|
||||
return load_model_from_path(
|
||||
data_path, meta, disable=disable, component_cfg=component_cfg
|
||||
data_path, vocab=vocab, meta=meta, disable=disable, component_cfg=component_cfg
|
||||
)
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user