Allow adding pipeline components from source model (#5857)

* Allow adding pipeline components from source model

* Config: name -> component

* Improve error messages

* Fix error and test

* Add frozen components and exclude logic

* Remove exclude from Language.evaluate

* Init sourced components with current vocab

* Fix error codes
This commit is contained in:
Ines Montani 2020-08-04 23:39:19 +02:00 committed by GitHub
parent 34873c4911
commit b795f02fbd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 247 additions and 80 deletions

View File

@ -19,6 +19,7 @@ max_epochs = 0
patience = 10000 patience = 10000
eval_frequency = 200 eval_frequency = 200
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2} score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
frozen_components = []
[training.train_corpus] [training.train_corpus]
@readers = "spacy.Corpus.v1" @readers = "spacy.Corpus.v1"

View File

@ -6,7 +6,7 @@ import hashlib
import typer import typer
from typer.main import get_command from typer.main import get_command
from contextlib import contextmanager from contextlib import contextmanager
from thinc.config import ConfigValidationError from thinc.config import Config, ConfigValidationError
from configparser import InterpolationError from configparser import InterpolationError
import sys import sys
@ -217,3 +217,15 @@ def import_code(code_path: Optional[Union[Path, str]]) -> None:
import_file("python_code", code_path) import_file("python_code", code_path)
except Exception as e: except Exception as e:
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1) msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
"""RETURNS (List[str]): All sourced components in the original config,
e.g. {"source": "en_core_web_sm"}. If the config contains a key
"factory", we assume it refers to a component factory.
"""
return [
name
for name, cfg in config.get("components", {}).items()
if "factory" not in cfg and "source" in cfg
]

View File

@ -8,7 +8,7 @@ import typer
from thinc.api import Config from thinc.api import Config
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
from ._util import import_code, debug_cli from ._util import import_code, debug_cli, get_sourced_components
from ..gold import Corpus, Example from ..gold import Corpus, Example
from ..pipeline._parser_internals import nonproj from ..pipeline._parser_internals import nonproj
from ..language import Language from ..language import Language
@ -138,9 +138,10 @@ def debug_data(
with show_validation_error(config_path): with show_validation_error(config_path):
cfg = Config().from_disk(config_path) cfg = Config().from_disk(config_path)
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides) nlp, config = util.load_model_from_config(cfg, overrides=config_overrides)
# TODO: handle base model # Use original config here, not resolved version
lang = config["nlp"]["lang"] sourced_components = get_sourced_components(cfg)
base_model = config["training"]["base_model"] frozen_components = config["training"]["frozen_components"]
resume_components = [p for p in sourced_components if p not in frozen_components]
pipeline = nlp.pipe_names pipeline = nlp.pipe_names
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names] factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
tag_map_path = util.ensure_path(config["training"]["tag_map"]) tag_map_path = util.ensure_path(config["training"]["tag_map"])
@ -187,13 +188,15 @@ def debug_data(
train_texts = gold_train_data["texts"] train_texts = gold_train_data["texts"]
dev_texts = gold_dev_data["texts"] dev_texts = gold_dev_data["texts"]
frozen_components = config["training"]["frozen_components"]
msg.divider("Training stats") msg.divider("Training stats")
msg.text(f"Language: {config['nlp']['lang']}")
msg.text(f"Training pipeline: {', '.join(pipeline)}") msg.text(f"Training pipeline: {', '.join(pipeline)}")
if base_model: if resume_components:
msg.text(f"Starting with base model '{base_model}'") msg.text(f"Components from other models: {', '.join(resume_components)}")
else: if frozen_components:
msg.text(f"Starting with blank model '{lang}'") msg.text(f"Frozen components: {', '.join(frozen_components)}")
msg.text(f"{len(train_dataset)} training docs") msg.text(f"{len(train_dataset)} training docs")
msg.text(f"{len(dev_dataset)} evaluation docs") msg.text(f"{len(dev_dataset)} evaluation docs")
@ -204,7 +207,9 @@ def debug_data(
msg.warn(f"{overlap} training examples also in evaluation data") msg.warn(f"{overlap} training examples also in evaluation data")
else: else:
msg.good("No overlap between training and evaluation data") msg.good("No overlap between training and evaluation data")
if not base_model and len(train_dataset) < BLANK_MODEL_THRESHOLD: # TODO: make this feedback more fine-grained and report on updated
# components vs. blank components
if not resume_components and len(train_dataset) < BLANK_MODEL_THRESHOLD:
text = ( text = (
f"Low number of examples to train from a blank model ({len(train_dataset)})" f"Low number of examples to train from a blank model ({len(train_dataset)})"
) )

View File

@ -11,7 +11,7 @@ import random
import typer import typer
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
from ._util import import_code from ._util import import_code, get_sourced_components
from ..language import Language from ..language import Language
from .. import util from .. import util
from ..gold.example import Example from ..gold.example import Example
@ -78,6 +78,8 @@ def train(
config = Config().from_disk(config_path) config = Config().from_disk(config_path)
if config.get("training", {}).get("seed") is not None: if config.get("training", {}).get("seed") is not None:
fix_random_seed(config["training"]["seed"]) fix_random_seed(config["training"]["seed"])
# Use original config here before it's resolved to functions
sourced_components = get_sourced_components(config)
with show_validation_error(config_path): with show_validation_error(config_path):
nlp, config = util.load_model_from_config(config, overrides=config_overrides) nlp, config = util.load_model_from_config(config, overrides=config_overrides)
if config["training"]["vectors"] is not None: if config["training"]["vectors"] is not None:
@ -92,11 +94,16 @@ def train(
train_corpus = T_cfg["train_corpus"] train_corpus = T_cfg["train_corpus"]
dev_corpus = T_cfg["dev_corpus"] dev_corpus = T_cfg["dev_corpus"]
batcher = T_cfg["batcher"] batcher = T_cfg["batcher"]
if resume_training: # Components that shouldn't be updated during training
msg.info("Resuming training") frozen_components = T_cfg["frozen_components"]
nlp.resume_training() # Sourced components that require resume_training
else: resume_components = [p for p in sourced_components if p not in frozen_components]
msg.info(f"Initializing the nlp pipeline: {nlp.pipe_names}") msg.info(f"Pipeline: {nlp.pipe_names}")
if resume_components:
with nlp.select_pipes(enable=resume_components):
msg.info(f"Resuming training for: {resume_components}")
nlp.resume_training()
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
nlp.begin_training(lambda: train_corpus(nlp)) nlp.begin_training(lambda: train_corpus(nlp))
if tag_map: if tag_map:
@ -136,7 +143,7 @@ def train(
patience=T_cfg["patience"], patience=T_cfg["patience"],
max_steps=T_cfg["max_steps"], max_steps=T_cfg["max_steps"],
eval_frequency=T_cfg["eval_frequency"], eval_frequency=T_cfg["eval_frequency"],
raw_text=None raw_text=None,
) )
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}") msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
print_row = setup_printer(T_cfg, nlp) print_row = setup_printer(T_cfg, nlp)
@ -192,9 +199,7 @@ def create_train_batches(iterator, batcher, max_epochs: int):
def create_evaluation_callback( def create_evaluation_callback(
nlp: Language, nlp: Language, dev_corpus: Callable, weights: Dict[str, float],
dev_corpus: Callable,
weights: Dict[str, float],
) -> Callable[[], Tuple[float, Dict[str, float]]]: ) -> Callable[[], Tuple[float, Dict[str, float]]]:
def evaluate() -> Tuple[float, Dict[str, float]]: def evaluate() -> Tuple[float, Dict[str, float]]:
dev_examples = list(dev_corpus(nlp)) dev_examples = list(dev_corpus(nlp))
@ -223,6 +228,7 @@ def train_while_improving(
patience: int, patience: int,
max_steps: int, max_steps: int,
raw_text: List[Dict[str, str]], raw_text: List[Dict[str, str]],
exclude: List[str],
): ):
"""Train until an evaluation stops improving. Works as a generator, """Train until an evaluation stops improving. Works as a generator,
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`, with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
@ -268,8 +274,6 @@ def train_while_improving(
dropouts = dropout dropouts = dropout
results = [] results = []
losses = {} losses = {}
to_enable = [name for name, proc in nlp.pipeline if hasattr(proc, "model")]
if raw_text: if raw_text:
random.shuffle(raw_text) random.shuffle(raw_text)
raw_examples = [ raw_examples = [
@ -279,17 +283,19 @@ def train_while_improving(
for step, (epoch, batch) in enumerate(train_data): for step, (epoch, batch) in enumerate(train_data):
dropout = next(dropouts) dropout = next(dropouts)
with nlp.select_pipes(enable=to_enable): for subbatch in subdivide_batch(batch, accumulate_gradient):
for subbatch in subdivide_batch(batch, accumulate_gradient): nlp.update(
nlp.update(subbatch, drop=dropout, losses=losses, sgd=False) subbatch, drop=dropout, losses=losses, sgd=False, exclude=exclude
if raw_text: )
# If raw text is available, perform 'rehearsal' updates, if raw_text:
# which use unlabelled data to reduce overfitting. # If raw text is available, perform 'rehearsal' updates,
raw_batch = list(next(raw_batches)) # which use unlabelled data to reduce overfitting.
nlp.rehearse(raw_batch, sgd=optimizer, losses=losses) raw_batch = list(next(raw_batches))
for name, proc in nlp.pipeline: nlp.rehearse(raw_batch, sgd=optimizer, losses=losses, exclude=exclude)
if hasattr(proc, "model"): # TODO: refactor this so we don't have to run it separately in here
proc.model.finish_update(optimizer) for name, proc in nlp.pipeline:
if name not in exclude and hasattr(proc, "model"):
proc.model.finish_update(optimizer)
optimizer.step_schedules() optimizer.step_schedules()
if not (step % eval_frequency): if not (step % eval_frequency):
if optimizer.averages: if optimizer.averages:
@ -418,10 +424,7 @@ def load_from_paths(
return raw_text, tag_map, morph_rules, weights_data return raw_text, tag_map, morph_rules, weights_data
def verify_cli_args( def verify_cli_args(config_path: Path, output_path: Optional[Path] = None,) -> None:
config_path: Path,
output_path: Optional[Path] = None,
) -> None:
# Make sure all files and paths exists if they are needed # Make sure all files and paths exists if they are needed
if not config_path or not config_path.exists(): if not config_path or not config_path.exists():
msg.fail("Config file not found", config_path, exits=1) msg.fail("Config file not found", config_path, exits=1)

View File

@ -37,6 +37,8 @@ max_steps = 20000
eval_frequency = 200 eval_frequency = 200
# Control how scores are printed and checkpoints are evaluated. # Control how scores are printed and checkpoints are evaluated.
score_weights = {} score_weights = {}
# Names of pipeline components that shouldn't be updated during training
frozen_components = []
[training.train_corpus] [training.train_corpus]
@readers = "spacy.Corpus.v1" @readers = "spacy.Corpus.v1"

View File

@ -482,6 +482,10 @@ class Errors:
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].") E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
# TODO: fix numbering after merging develop into master # TODO: fix numbering after merging develop into master
E944 = ("Can't copy pipeline component '{name}' from source model '{model}': "
"not found in pipeline. Available components: {opts}")
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
"nlp object, but got: {source}")
E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to " E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to "
"call kb.initialize()?") "call kb.initialize()?")
E947 = ("Matcher.add received invalid 'greedy' argument: expected " E947 = ("Matcher.add received invalid 'greedy' argument: expected "
@ -571,11 +575,13 @@ class Errors:
"into {values}, but found {value}.") "into {values}, but found {value}.")
E983 = ("Invalid key for '{dict}': {key}. Available keys: " E983 = ("Invalid key for '{dict}': {key}. Available keys: "
"{keys}") "{keys}")
E984 = ("Invalid component config for '{name}': no 'factory' key " E984 = ("Invalid component config for '{name}': component block needs either "
"specifying the registered function used to initialize the " "a key 'factory' specifying the registered function used to "
"component. For example, factory = \"ner\" will use the 'ner' " "initialize the component, or a key 'source' key specifying a "
"factory and all other settings in the block will be passed " "spaCy model to copy the component from. For example, factory = "
"to it as arguments.\n\n{config}") "\"ner\" will use the 'ner' factory and all other settings in the "
"block will be passed to it as arguments. Alternatively, source = "
"\"en_core_web_sm\" will copy the component from that model.\n\n{config}")
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}") E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")
E986 = ("Could not create any training batches: check your input. " E986 = ("Could not create any training batches: check your input. "
"Perhaps discard_oversize should be set to False ?") "Perhaps discard_oversize should be set to False ?")

View File

@ -620,6 +620,32 @@ class Language:
self._pipe_configs[name] = filled self._pipe_configs[name] = filled
return resolved[factory_name] return resolved[factory_name]
def create_pipe_from_source(
self, source_name: str, source: "Language", *, name: str,
) -> Tuple[Callable[[Doc], Doc], str]:
"""Create a pipeline component by copying it from an existing model.
source_name (str): Name of the component in the source pipeline.
source (Language): The source nlp object to copy from.
name (str): Optional alternative name to use in current pipeline.
RETURNS (Tuple[Callable, str]): The component and its factory name.
"""
# TODO: handle errors and mismatches (vectors etc.)
if not isinstance(source, self.__class__):
raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
if not source.has_pipe(source_name):
raise KeyError(
Errors.E944.format(
name=source_name,
model=f"{source.meta['lang']}_{source.meta['name']}",
opts=", ".join(source.pipe_names),
)
)
pipe = source.get_pipe(source_name)
pipe_config = util.copy_config(source.config["components"][source_name])
self._pipe_configs[name] = pipe_config
return pipe, pipe_config["factory"]
def add_pipe( def add_pipe(
self, self,
factory_name: str, factory_name: str,
@ -629,6 +655,7 @@ class Language:
after: Optional[Union[str, int]] = None, after: Optional[Union[str, int]] = None,
first: Optional[bool] = None, first: Optional[bool] = None,
last: Optional[bool] = None, last: Optional[bool] = None,
source: Optional["Language"] = None,
config: Optional[Dict[str, Any]] = SimpleFrozenDict(), config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(), overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
validate: bool = True, validate: bool = True,
@ -648,6 +675,8 @@ class Language:
component directly after. component directly after.
first (bool): If True, insert component first in the pipeline. first (bool): If True, insert component first in the pipeline.
last (bool): If True, insert component last in the pipeline. last (bool): If True, insert component last in the pipeline.
source (Language): Optional loaded nlp object to copy the pipeline
component from.
config (Optional[Dict[str, Any]]): Config parameters to use for this config (Optional[Dict[str, Any]]): Config parameters to use for this
component. Will be merged with default config, if available. component. Will be merged with default config, if available.
overrides (Optional[Dict[str, Any]]): Config overrides, typically overrides (Optional[Dict[str, Any]]): Config overrides, typically
@ -662,24 +691,31 @@ class Language:
bad_val = repr(factory_name) bad_val = repr(factory_name)
err = Errors.E966.format(component=bad_val, name=name) err = Errors.E966.format(component=bad_val, name=name)
raise ValueError(err) raise ValueError(err)
if not self.has_factory(factory_name):
err = Errors.E002.format(
name=factory_name,
opts=", ".join(self.factory_names),
method="add_pipe",
lang=util.get_object_name(self),
lang_code=self.lang,
)
name = name if name is not None else factory_name name = name if name is not None else factory_name
if name in self.pipe_names: if name in self.pipe_names:
raise ValueError(Errors.E007.format(name=name, opts=self.pipe_names)) raise ValueError(Errors.E007.format(name=name, opts=self.pipe_names))
pipe_component = self.create_pipe( if source is not None:
factory_name, # We're loading the component from a model. After loading the
name=name, # component, we know its real factory name
config=config, pipe_component, factory_name = self.create_pipe_from_source(
overrides=overrides, factory_name, source, name=name
validate=validate, )
) else:
if not self.has_factory(factory_name):
err = Errors.E002.format(
name=factory_name,
opts=", ".join(self.factory_names),
method="add_pipe",
lang=util.get_object_name(self),
lang_code=self.lang,
)
pipe_component = self.create_pipe(
factory_name,
name=name,
config=config,
overrides=overrides,
validate=validate,
)
pipe_index = self._get_pipe_index(before, after, first, last) pipe_index = self._get_pipe_index(before, after, first, last)
self._pipe_meta[name] = self.get_factory_meta(factory_name) self._pipe_meta[name] = self.get_factory_meta(factory_name)
self.pipeline.insert(pipe_index, (name, pipe_component)) self.pipeline.insert(pipe_index, (name, pipe_component))
@ -911,6 +947,7 @@ class Language:
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None, losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = tuple(),
): ):
"""Update the models in the pipeline. """Update the models in the pipeline.
@ -921,6 +958,7 @@ class Language:
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component. losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
components, keyed by component name. components, keyed by component name.
exclude (Iterable[str]): Names of components that shouldn't be updated.
RETURNS (Dict[str, float]): The updated losses dictionary RETURNS (Dict[str, float]): The updated losses dictionary
DOCS: https://spacy.io/api/language#update DOCS: https://spacy.io/api/language#update
@ -953,12 +991,12 @@ class Language:
component_cfg[name].setdefault("drop", drop) component_cfg[name].setdefault("drop", drop)
component_cfg[name].setdefault("set_annotations", False) component_cfg[name].setdefault("set_annotations", False)
for name, proc in self.pipeline: for name, proc in self.pipeline:
if not hasattr(proc, "update"): if name in exclude or not hasattr(proc, "update"):
continue continue
proc.update(examples, sgd=None, losses=losses, **component_cfg[name]) proc.update(examples, sgd=None, losses=losses, **component_cfg[name])
if sgd not in (None, False): if sgd not in (None, False):
for name, proc in self.pipeline: for name, proc in self.pipeline:
if hasattr(proc, "model"): if name not in exclude and hasattr(proc, "model"):
proc.model.finish_update(sgd) proc.model.finish_update(sgd)
return losses return losses
@ -969,6 +1007,7 @@ class Language:
sgd: Optional[Optimizer] = None, sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None, losses: Optional[Dict[str, float]] = None,
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None, component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
exclude: Iterable[str] = tuple(),
) -> Dict[str, float]: ) -> Dict[str, float]:
"""Make a "rehearsal" update to the models in the pipeline, to prevent """Make a "rehearsal" update to the models in the pipeline, to prevent
forgetting. Rehearsal updates run an initial copy of the model over some forgetting. Rehearsal updates run an initial copy of the model over some
@ -980,6 +1019,7 @@ class Language:
sgd (Optional[Optimizer]): An optimizer. sgd (Optional[Optimizer]): An optimizer.
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
components, keyed by component name. components, keyed by component name.
exclude (Iterable[str]): Names of components that shouldn't be updated.
RETURNS (dict): Results from the update. RETURNS (dict): Results from the update.
EXAMPLE: EXAMPLE:
@ -1023,7 +1063,7 @@ class Language:
get_grads.b1 = sgd.b1 get_grads.b1 = sgd.b1
get_grads.b2 = sgd.b2 get_grads.b2 = sgd.b2
for name, proc in pipes: for name, proc in pipes:
if not hasattr(proc, "rehearse"): if name in exclude or not hasattr(proc, "rehearse"):
continue continue
grads = {} grads = {}
proc.rehearse( proc.rehearse(
@ -1074,7 +1114,7 @@ class Language:
return self._optimizer return self._optimizer
def resume_training( def resume_training(
self, *, sgd: Optional[Optimizer] = None, device: int = -1 self, *, sgd: Optional[Optimizer] = None, device: int = -1,
) -> Optimizer: ) -> Optimizer:
"""Continue training a pretrained model. """Continue training a pretrained model.
@ -1373,6 +1413,7 @@ class Language:
cls, cls,
config: Union[Dict[str, Any], Config] = {}, config: Union[Dict[str, Any], Config] = {},
*, *,
vocab: Union[Vocab, bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
overrides: Dict[str, Any] = {}, overrides: Dict[str, Any] = {},
auto_fill: bool = True, auto_fill: bool = True,
@ -1383,6 +1424,7 @@ class Language:
the default config of the given language is used. the default config of the given language is used.
config (Dict[str, Any] / Config): The loaded config. config (Dict[str, Any] / Config): The loaded config.
vocab (Vocab): A Vocab object. If True, a vocab is created.
disable (Iterable[str]): List of pipeline component names to disable. disable (Iterable[str]): List of pipeline component names to disable.
auto_fill (bool): Automatically fill in missing values in config based auto_fill (bool): Automatically fill in missing values in config based
on defaults and function argument annotations. on defaults and function argument annotations.
@ -1422,32 +1464,48 @@ class Language:
create_tokenizer = resolved["nlp"]["tokenizer"] create_tokenizer = resolved["nlp"]["tokenizer"]
create_lemmatizer = resolved["nlp"]["lemmatizer"] create_lemmatizer = resolved["nlp"]["lemmatizer"]
nlp = cls( nlp = cls(
create_tokenizer=create_tokenizer, create_lemmatizer=create_lemmatizer, vocab=vocab,
create_tokenizer=create_tokenizer,
create_lemmatizer=create_lemmatizer,
) )
# Note that we don't load vectors here, instead they get loaded explicitly # Note that we don't load vectors here, instead they get loaded explicitly
# inside stuff like the spacy train function. If we loaded them here, # inside stuff like the spacy train function. If we loaded them here,
# then we would load them twice at runtime: once when we make from config, # then we would load them twice at runtime: once when we make from config,
# and then again when we load from disk. # and then again when we load from disk.
pipeline = config.get("components", {}) pipeline = config.get("components", {})
# If components are loaded from a source (existing models), we cache
# them here so they're only loaded once
source_nlps = {}
for pipe_name in config["nlp"]["pipeline"]: for pipe_name in config["nlp"]["pipeline"]:
if pipe_name not in pipeline: if pipe_name not in pipeline:
opts = ", ".join(pipeline.keys()) opts = ", ".join(pipeline.keys())
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts)) raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
pipe_cfg = util.copy_config(pipeline[pipe_name]) pipe_cfg = util.copy_config(pipeline[pipe_name])
if pipe_name not in disable: if pipe_name not in disable:
if "factory" not in pipe_cfg: if "factory" not in pipe_cfg and "source" not in pipe_cfg:
err = Errors.E984.format(name=pipe_name, config=pipe_cfg) err = Errors.E984.format(name=pipe_name, config=pipe_cfg)
raise ValueError(err) raise ValueError(err)
factory = pipe_cfg.pop("factory") if "factory" in pipe_cfg:
# The pipe name (key in the config) here is the unique name of the factory = pipe_cfg.pop("factory")
# component, not necessarily the factory # The pipe name (key in the config) here is the unique name
nlp.add_pipe( # of the component, not necessarily the factory
factory, nlp.add_pipe(
name=pipe_name, factory,
config=pipe_cfg, name=pipe_name,
overrides=pipe_overrides, config=pipe_cfg,
validate=validate, overrides=pipe_overrides,
) validate=validate,
)
else:
model = pipe_cfg["source"]
if model not in source_nlps:
# We only need the components here and we need to init
# model with the same vocab as the current nlp object
source_nlps[model] = util.load_model(
model, vocab=nlp.vocab, disable=["vocab", "tokenizer"]
)
source_name = pipe_cfg.get("component", pipe_name)
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
nlp.config = filled if auto_fill else config nlp.config = filled if auto_fill else config
nlp.resolved = resolved nlp.resolved = resolved
return nlp return nlp

View File

@ -202,6 +202,7 @@ class ConfigSchemaTraining(BaseModel):
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights") init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
raw_text: Optional[StrictStr] = Field(default=None, title="Raw text") raw_text: Optional[StrictStr] = Field(default=None, title="Raw text")
optimizer: Optimizer = Field(..., title="The optimizer to use") optimizer: Optimizer = Field(..., title="The optimizer to use")
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
# fmt: on # fmt: on
class Config: class Config:

View File

@ -8,6 +8,8 @@ from thinc.api import Model, Linear
from thinc.config import ConfigValidationError from thinc.config import ConfigValidationError
from pydantic import StrictInt, StrictStr from pydantic import StrictInt, StrictStr
from ..util import make_tempdir
def test_pipe_function_component(): def test_pipe_function_component():
name = "test_component" name = "test_component"
@ -374,3 +376,65 @@ def test_language_factories_scores():
cfg = nlp.config["training"] cfg = nlp.config["training"]
expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05} expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05}
assert cfg["score_weights"] == expected_weights assert cfg["score_weights"] == expected_weights
def test_pipe_factories_from_source():
"""Test adding components from a source model."""
source_nlp = English()
source_nlp.add_pipe("tagger", name="my_tagger")
nlp = English()
with pytest.raises(ValueError):
nlp.add_pipe("my_tagger", source="en_core_web_sm")
nlp.add_pipe("my_tagger", source=source_nlp)
assert "my_tagger" in nlp.pipe_names
with pytest.raises(KeyError):
nlp.add_pipe("custom", source=source_nlp)
def test_pipe_factories_from_source_custom():
"""Test adding components from a source model with custom components."""
name = "test_pipe_factories_from_source_custom"
@Language.factory(name, default_config={"arg": "hello"})
def test_factory(nlp, name, arg: str):
return lambda doc: doc
source_nlp = English()
source_nlp.add_pipe("tagger")
source_nlp.add_pipe(name, config={"arg": "world"})
nlp = English()
nlp.add_pipe(name, source=source_nlp)
assert name in nlp.pipe_names
assert nlp.get_pipe_meta(name).default_config["arg"] == "hello"
config = nlp.config["components"][name]
assert config["factory"] == name
assert config["arg"] == "world"
def test_pipe_factories_from_source_config():
name = "test_pipe_factories_from_source_config"
@Language.factory(name, default_config={"arg": "hello"})
def test_factory(nlp, name, arg: str):
return lambda doc: doc
source_nlp = English()
source_nlp.add_pipe("tagger")
source_nlp.add_pipe(name, name="yolo", config={"arg": "world"})
dest_nlp_cfg = {"lang": "en", "pipeline": ["parser", "custom"]}
with make_tempdir() as tempdir:
source_nlp.to_disk(tempdir)
dest_components_cfg = {
"parser": {"factory": "parser"},
"custom": {"source": str(tempdir), "component": "yolo"},
}
dest_config = {"nlp": dest_nlp_cfg, "components": dest_components_cfg}
nlp = English.from_config(dest_config)
assert nlp.pipe_names == ["parser", "custom"]
assert nlp.pipe_factories == {"parser": "parser", "custom": name}
meta = nlp.get_pipe_meta("custom")
assert meta.factory == name
assert meta.default_config["arg"] == "hello"
config = nlp.config["components"]["custom"]
assert config["factory"] == name
assert config["arg"] == "world"

View File

@ -205,43 +205,51 @@ def load_vectors_into_model(
def load_model( def load_model(
name: Union[str, Path], name: Union[str, Path],
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Load a model from a package or data path. """Load a model from a package or data path.
name (str): Package name or model path. name (str): Package name or model path.
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
a new Vocab object will be created.
disable (Iterable[str]): Names of pipeline components to disable. disable (Iterable[str]): Names of pipeline components to disable.
component_cfg (Dict[str, dict]): Config overrides for pipeline components, component_cfg (Dict[str, dict]): Config overrides for pipeline components,
keyed by component names. keyed by component names.
RETURNS (Language): The loaded nlp object. RETURNS (Language): The loaded nlp object.
""" """
cfg = component_cfg kwargs = {"vocab": vocab, "disable": disable, "component_cfg": component_cfg}
if isinstance(name, str): # name or string path if isinstance(name, str): # name or string path
if name.startswith("blank:"): # shortcut for blank model if name.startswith("blank:"): # shortcut for blank model
return get_lang_class(name.replace("blank:", ""))() return get_lang_class(name.replace("blank:", ""))()
if is_package(name): # installed as package if is_package(name): # installed as package
return load_model_from_package(name, disable=disable, component_cfg=cfg) return load_model_from_package(name, **kwargs)
if Path(name).exists(): # path to model data directory if Path(name).exists(): # path to model data directory
return load_model_from_path(Path(name), disable=disable, component_cfg=cfg) return load_model_from_path(Path(name), **kwargs)
elif hasattr(name, "exists"): # Path or Path-like to model data elif hasattr(name, "exists"): # Path or Path-like to model data
return load_model_from_path(name, disable=disable, component_cfg=cfg) return load_model_from_path(name, **kwargs)
raise IOError(Errors.E050.format(name=name)) raise IOError(Errors.E050.format(name=name))
def load_model_from_package( def load_model_from_package(
name: str, name: str,
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
"""Load a model from an installed package.""" """Load a model from an installed package."""
cls = importlib.import_module(name) cls = importlib.import_module(name)
return cls.load(disable=disable, component_cfg=component_cfg) return cls.load(vocab=vocab, disable=disable, component_cfg=component_cfg)
def load_model_from_path( def load_model_from_path(
model_path: Union[str, Path], model_path: Union[str, Path],
*,
meta: Optional[Dict[str, Any]] = None, meta: Optional[Dict[str, Any]] = None,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
@ -257,12 +265,16 @@ def load_model_from_path(
config = Config().from_disk(config_path) config = Config().from_disk(config_path)
override_cfg = {"components": {p: dict_to_dot(c) for p, c in component_cfg.items()}} override_cfg = {"components": {p: dict_to_dot(c) for p, c in component_cfg.items()}}
overrides = dict_to_dot(override_cfg) overrides = dict_to_dot(override_cfg)
nlp, _ = load_model_from_config(config, disable=disable, overrides=overrides) nlp, _ = load_model_from_config(
config, vocab=vocab, disable=disable, overrides=overrides
)
return nlp.from_disk(model_path, exclude=disable) return nlp.from_disk(model_path, exclude=disable)
def load_model_from_config( def load_model_from_config(
config: Union[Dict[str, Any], Config], config: Union[Dict[str, Any], Config],
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
overrides: Dict[str, Any] = {}, overrides: Dict[str, Any] = {},
auto_fill: bool = False, auto_fill: bool = False,
@ -281,6 +293,7 @@ def load_model_from_config(
lang_cls = get_lang_class(nlp_config["lang"]) lang_cls = get_lang_class(nlp_config["lang"])
nlp = lang_cls.from_config( nlp = lang_cls.from_config(
config, config,
vocab=vocab,
disable=disable, disable=disable,
overrides=overrides, overrides=overrides,
auto_fill=auto_fill, auto_fill=auto_fill,
@ -291,6 +304,8 @@ def load_model_from_config(
def load_model_from_init_py( def load_model_from_init_py(
init_file: Union[Path, str], init_file: Union[Path, str],
*,
vocab: Union["Vocab", bool] = True,
disable: Iterable[str] = tuple(), disable: Iterable[str] = tuple(),
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(), component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
) -> "Language": ) -> "Language":
@ -308,7 +323,7 @@ def load_model_from_init_py(
if not model_path.exists(): if not model_path.exists():
raise IOError(Errors.E052.format(path=data_path)) raise IOError(Errors.E052.format(path=data_path))
return load_model_from_path( return load_model_from_path(
data_path, meta, disable=disable, component_cfg=component_cfg data_path, vocab=vocab, meta=meta, disable=disable, component_cfg=component_cfg
) )