mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Allow adding pipeline components from source model (#5857)
* Allow adding pipeline components from source model * Config: name -> component * Improve error messages * Fix error and test * Add frozen components and exclude logic * Remove exclude from Language.evaluate * Init sourced components with current vocab * Fix error codes
This commit is contained in:
parent
34873c4911
commit
b795f02fbd
|
@ -19,6 +19,7 @@ max_epochs = 0
|
||||||
patience = 10000
|
patience = 10000
|
||||||
eval_frequency = 200
|
eval_frequency = 200
|
||||||
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
|
score_weights = {"dep_las": 0.4, "ents_f": 0.4, "tag_acc": 0.2}
|
||||||
|
frozen_components = []
|
||||||
|
|
||||||
[training.train_corpus]
|
[training.train_corpus]
|
||||||
@readers = "spacy.Corpus.v1"
|
@readers = "spacy.Corpus.v1"
|
||||||
|
|
|
@ -6,7 +6,7 @@ import hashlib
|
||||||
import typer
|
import typer
|
||||||
from typer.main import get_command
|
from typer.main import get_command
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from thinc.config import ConfigValidationError
|
from thinc.config import Config, ConfigValidationError
|
||||||
from configparser import InterpolationError
|
from configparser import InterpolationError
|
||||||
import sys
|
import sys
|
||||||
|
|
||||||
|
@ -217,3 +217,15 @@ def import_code(code_path: Optional[Union[Path, str]]) -> None:
|
||||||
import_file("python_code", code_path)
|
import_file("python_code", code_path)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
msg.fail(f"Couldn't load Python code: {code_path}", e, exits=1)
|
||||||
|
|
||||||
|
|
||||||
|
def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
|
||||||
|
"""RETURNS (List[str]): All sourced components in the original config,
|
||||||
|
e.g. {"source": "en_core_web_sm"}. If the config contains a key
|
||||||
|
"factory", we assume it refers to a component factory.
|
||||||
|
"""
|
||||||
|
return [
|
||||||
|
name
|
||||||
|
for name, cfg in config.get("components", {}).items()
|
||||||
|
if "factory" not in cfg and "source" in cfg
|
||||||
|
]
|
||||||
|
|
|
@ -8,7 +8,7 @@ import typer
|
||||||
from thinc.api import Config
|
from thinc.api import Config
|
||||||
|
|
||||||
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
from ._util import app, Arg, Opt, show_validation_error, parse_config_overrides
|
||||||
from ._util import import_code, debug_cli
|
from ._util import import_code, debug_cli, get_sourced_components
|
||||||
from ..gold import Corpus, Example
|
from ..gold import Corpus, Example
|
||||||
from ..pipeline._parser_internals import nonproj
|
from ..pipeline._parser_internals import nonproj
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
|
@ -138,9 +138,10 @@ def debug_data(
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
cfg = Config().from_disk(config_path)
|
cfg = Config().from_disk(config_path)
|
||||||
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides)
|
nlp, config = util.load_model_from_config(cfg, overrides=config_overrides)
|
||||||
# TODO: handle base model
|
# Use original config here, not resolved version
|
||||||
lang = config["nlp"]["lang"]
|
sourced_components = get_sourced_components(cfg)
|
||||||
base_model = config["training"]["base_model"]
|
frozen_components = config["training"]["frozen_components"]
|
||||||
|
resume_components = [p for p in sourced_components if p not in frozen_components]
|
||||||
pipeline = nlp.pipe_names
|
pipeline = nlp.pipe_names
|
||||||
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
|
factory_names = [nlp.get_pipe_meta(pipe).factory for pipe in nlp.pipe_names]
|
||||||
tag_map_path = util.ensure_path(config["training"]["tag_map"])
|
tag_map_path = util.ensure_path(config["training"]["tag_map"])
|
||||||
|
@ -187,13 +188,15 @@ def debug_data(
|
||||||
|
|
||||||
train_texts = gold_train_data["texts"]
|
train_texts = gold_train_data["texts"]
|
||||||
dev_texts = gold_dev_data["texts"]
|
dev_texts = gold_dev_data["texts"]
|
||||||
|
frozen_components = config["training"]["frozen_components"]
|
||||||
|
|
||||||
msg.divider("Training stats")
|
msg.divider("Training stats")
|
||||||
|
msg.text(f"Language: {config['nlp']['lang']}")
|
||||||
msg.text(f"Training pipeline: {', '.join(pipeline)}")
|
msg.text(f"Training pipeline: {', '.join(pipeline)}")
|
||||||
if base_model:
|
if resume_components:
|
||||||
msg.text(f"Starting with base model '{base_model}'")
|
msg.text(f"Components from other models: {', '.join(resume_components)}")
|
||||||
else:
|
if frozen_components:
|
||||||
msg.text(f"Starting with blank model '{lang}'")
|
msg.text(f"Frozen components: {', '.join(frozen_components)}")
|
||||||
msg.text(f"{len(train_dataset)} training docs")
|
msg.text(f"{len(train_dataset)} training docs")
|
||||||
msg.text(f"{len(dev_dataset)} evaluation docs")
|
msg.text(f"{len(dev_dataset)} evaluation docs")
|
||||||
|
|
||||||
|
@ -204,7 +207,9 @@ def debug_data(
|
||||||
msg.warn(f"{overlap} training examples also in evaluation data")
|
msg.warn(f"{overlap} training examples also in evaluation data")
|
||||||
else:
|
else:
|
||||||
msg.good("No overlap between training and evaluation data")
|
msg.good("No overlap between training and evaluation data")
|
||||||
if not base_model and len(train_dataset) < BLANK_MODEL_THRESHOLD:
|
# TODO: make this feedback more fine-grained and report on updated
|
||||||
|
# components vs. blank components
|
||||||
|
if not resume_components and len(train_dataset) < BLANK_MODEL_THRESHOLD:
|
||||||
text = (
|
text = (
|
||||||
f"Low number of examples to train from a blank model ({len(train_dataset)})"
|
f"Low number of examples to train from a blank model ({len(train_dataset)})"
|
||||||
)
|
)
|
||||||
|
|
|
@ -11,7 +11,7 @@ import random
|
||||||
import typer
|
import typer
|
||||||
|
|
||||||
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
from ._util import app, Arg, Opt, parse_config_overrides, show_validation_error
|
||||||
from ._util import import_code
|
from ._util import import_code, get_sourced_components
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
from .. import util
|
from .. import util
|
||||||
from ..gold.example import Example
|
from ..gold.example import Example
|
||||||
|
@ -78,6 +78,8 @@ def train(
|
||||||
config = Config().from_disk(config_path)
|
config = Config().from_disk(config_path)
|
||||||
if config.get("training", {}).get("seed") is not None:
|
if config.get("training", {}).get("seed") is not None:
|
||||||
fix_random_seed(config["training"]["seed"])
|
fix_random_seed(config["training"]["seed"])
|
||||||
|
# Use original config here before it's resolved to functions
|
||||||
|
sourced_components = get_sourced_components(config)
|
||||||
with show_validation_error(config_path):
|
with show_validation_error(config_path):
|
||||||
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
|
nlp, config = util.load_model_from_config(config, overrides=config_overrides)
|
||||||
if config["training"]["vectors"] is not None:
|
if config["training"]["vectors"] is not None:
|
||||||
|
@ -92,11 +94,16 @@ def train(
|
||||||
train_corpus = T_cfg["train_corpus"]
|
train_corpus = T_cfg["train_corpus"]
|
||||||
dev_corpus = T_cfg["dev_corpus"]
|
dev_corpus = T_cfg["dev_corpus"]
|
||||||
batcher = T_cfg["batcher"]
|
batcher = T_cfg["batcher"]
|
||||||
if resume_training:
|
# Components that shouldn't be updated during training
|
||||||
msg.info("Resuming training")
|
frozen_components = T_cfg["frozen_components"]
|
||||||
nlp.resume_training()
|
# Sourced components that require resume_training
|
||||||
else:
|
resume_components = [p for p in sourced_components if p not in frozen_components]
|
||||||
msg.info(f"Initializing the nlp pipeline: {nlp.pipe_names}")
|
msg.info(f"Pipeline: {nlp.pipe_names}")
|
||||||
|
if resume_components:
|
||||||
|
with nlp.select_pipes(enable=resume_components):
|
||||||
|
msg.info(f"Resuming training for: {resume_components}")
|
||||||
|
nlp.resume_training()
|
||||||
|
with nlp.select_pipes(disable=[*frozen_components, *resume_components]):
|
||||||
nlp.begin_training(lambda: train_corpus(nlp))
|
nlp.begin_training(lambda: train_corpus(nlp))
|
||||||
|
|
||||||
if tag_map:
|
if tag_map:
|
||||||
|
@ -136,7 +143,7 @@ def train(
|
||||||
patience=T_cfg["patience"],
|
patience=T_cfg["patience"],
|
||||||
max_steps=T_cfg["max_steps"],
|
max_steps=T_cfg["max_steps"],
|
||||||
eval_frequency=T_cfg["eval_frequency"],
|
eval_frequency=T_cfg["eval_frequency"],
|
||||||
raw_text=None
|
raw_text=None,
|
||||||
)
|
)
|
||||||
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
msg.info(f"Training. Initial learn rate: {optimizer.learn_rate}")
|
||||||
print_row = setup_printer(T_cfg, nlp)
|
print_row = setup_printer(T_cfg, nlp)
|
||||||
|
@ -192,9 +199,7 @@ def create_train_batches(iterator, batcher, max_epochs: int):
|
||||||
|
|
||||||
|
|
||||||
def create_evaluation_callback(
|
def create_evaluation_callback(
|
||||||
nlp: Language,
|
nlp: Language, dev_corpus: Callable, weights: Dict[str, float],
|
||||||
dev_corpus: Callable,
|
|
||||||
weights: Dict[str, float],
|
|
||||||
) -> Callable[[], Tuple[float, Dict[str, float]]]:
|
) -> Callable[[], Tuple[float, Dict[str, float]]]:
|
||||||
def evaluate() -> Tuple[float, Dict[str, float]]:
|
def evaluate() -> Tuple[float, Dict[str, float]]:
|
||||||
dev_examples = list(dev_corpus(nlp))
|
dev_examples = list(dev_corpus(nlp))
|
||||||
|
@ -223,6 +228,7 @@ def train_while_improving(
|
||||||
patience: int,
|
patience: int,
|
||||||
max_steps: int,
|
max_steps: int,
|
||||||
raw_text: List[Dict[str, str]],
|
raw_text: List[Dict[str, str]],
|
||||||
|
exclude: List[str],
|
||||||
):
|
):
|
||||||
"""Train until an evaluation stops improving. Works as a generator,
|
"""Train until an evaluation stops improving. Works as a generator,
|
||||||
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
with each iteration yielding a tuple `(batch, info, is_best_checkpoint)`,
|
||||||
|
@ -268,8 +274,6 @@ def train_while_improving(
|
||||||
dropouts = dropout
|
dropouts = dropout
|
||||||
results = []
|
results = []
|
||||||
losses = {}
|
losses = {}
|
||||||
to_enable = [name for name, proc in nlp.pipeline if hasattr(proc, "model")]
|
|
||||||
|
|
||||||
if raw_text:
|
if raw_text:
|
||||||
random.shuffle(raw_text)
|
random.shuffle(raw_text)
|
||||||
raw_examples = [
|
raw_examples = [
|
||||||
|
@ -279,17 +283,19 @@ def train_while_improving(
|
||||||
|
|
||||||
for step, (epoch, batch) in enumerate(train_data):
|
for step, (epoch, batch) in enumerate(train_data):
|
||||||
dropout = next(dropouts)
|
dropout = next(dropouts)
|
||||||
with nlp.select_pipes(enable=to_enable):
|
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
||||||
for subbatch in subdivide_batch(batch, accumulate_gradient):
|
nlp.update(
|
||||||
nlp.update(subbatch, drop=dropout, losses=losses, sgd=False)
|
subbatch, drop=dropout, losses=losses, sgd=False, exclude=exclude
|
||||||
if raw_text:
|
)
|
||||||
# If raw text is available, perform 'rehearsal' updates,
|
if raw_text:
|
||||||
# which use unlabelled data to reduce overfitting.
|
# If raw text is available, perform 'rehearsal' updates,
|
||||||
raw_batch = list(next(raw_batches))
|
# which use unlabelled data to reduce overfitting.
|
||||||
nlp.rehearse(raw_batch, sgd=optimizer, losses=losses)
|
raw_batch = list(next(raw_batches))
|
||||||
for name, proc in nlp.pipeline:
|
nlp.rehearse(raw_batch, sgd=optimizer, losses=losses, exclude=exclude)
|
||||||
if hasattr(proc, "model"):
|
# TODO: refactor this so we don't have to run it separately in here
|
||||||
proc.model.finish_update(optimizer)
|
for name, proc in nlp.pipeline:
|
||||||
|
if name not in exclude and hasattr(proc, "model"):
|
||||||
|
proc.model.finish_update(optimizer)
|
||||||
optimizer.step_schedules()
|
optimizer.step_schedules()
|
||||||
if not (step % eval_frequency):
|
if not (step % eval_frequency):
|
||||||
if optimizer.averages:
|
if optimizer.averages:
|
||||||
|
@ -418,10 +424,7 @@ def load_from_paths(
|
||||||
return raw_text, tag_map, morph_rules, weights_data
|
return raw_text, tag_map, morph_rules, weights_data
|
||||||
|
|
||||||
|
|
||||||
def verify_cli_args(
|
def verify_cli_args(config_path: Path, output_path: Optional[Path] = None,) -> None:
|
||||||
config_path: Path,
|
|
||||||
output_path: Optional[Path] = None,
|
|
||||||
) -> None:
|
|
||||||
# Make sure all files and paths exists if they are needed
|
# Make sure all files and paths exists if they are needed
|
||||||
if not config_path or not config_path.exists():
|
if not config_path or not config_path.exists():
|
||||||
msg.fail("Config file not found", config_path, exits=1)
|
msg.fail("Config file not found", config_path, exits=1)
|
||||||
|
|
|
@ -37,6 +37,8 @@ max_steps = 20000
|
||||||
eval_frequency = 200
|
eval_frequency = 200
|
||||||
# Control how scores are printed and checkpoints are evaluated.
|
# Control how scores are printed and checkpoints are evaluated.
|
||||||
score_weights = {}
|
score_weights = {}
|
||||||
|
# Names of pipeline components that shouldn't be updated during training
|
||||||
|
frozen_components = []
|
||||||
|
|
||||||
[training.train_corpus]
|
[training.train_corpus]
|
||||||
@readers = "spacy.Corpus.v1"
|
@readers = "spacy.Corpus.v1"
|
||||||
|
|
|
@ -482,6 +482,10 @@ class Errors:
|
||||||
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
E199 = ("Unable to merge 0-length span at doc[{start}:{end}].")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E944 = ("Can't copy pipeline component '{name}' from source model '{model}': "
|
||||||
|
"not found in pipeline. Available components: {opts}")
|
||||||
|
E945 = ("Can't copy pipeline component '{name}' from source. Expected loaded "
|
||||||
|
"nlp object, but got: {source}")
|
||||||
E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to "
|
E946 = ("The Vocab for the knowledge base is not initialized. Did you forget to "
|
||||||
"call kb.initialize()?")
|
"call kb.initialize()?")
|
||||||
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
|
E947 = ("Matcher.add received invalid 'greedy' argument: expected "
|
||||||
|
@ -571,11 +575,13 @@ class Errors:
|
||||||
"into {values}, but found {value}.")
|
"into {values}, but found {value}.")
|
||||||
E983 = ("Invalid key for '{dict}': {key}. Available keys: "
|
E983 = ("Invalid key for '{dict}': {key}. Available keys: "
|
||||||
"{keys}")
|
"{keys}")
|
||||||
E984 = ("Invalid component config for '{name}': no 'factory' key "
|
E984 = ("Invalid component config for '{name}': component block needs either "
|
||||||
"specifying the registered function used to initialize the "
|
"a key 'factory' specifying the registered function used to "
|
||||||
"component. For example, factory = \"ner\" will use the 'ner' "
|
"initialize the component, or a key 'source' key specifying a "
|
||||||
"factory and all other settings in the block will be passed "
|
"spaCy model to copy the component from. For example, factory = "
|
||||||
"to it as arguments.\n\n{config}")
|
"\"ner\" will use the 'ner' factory and all other settings in the "
|
||||||
|
"block will be passed to it as arguments. Alternatively, source = "
|
||||||
|
"\"en_core_web_sm\" will copy the component from that model.\n\n{config}")
|
||||||
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")
|
E985 = ("Can't load model from config file: no 'nlp' section found.\n\n{config}")
|
||||||
E986 = ("Could not create any training batches: check your input. "
|
E986 = ("Could not create any training batches: check your input. "
|
||||||
"Perhaps discard_oversize should be set to False ?")
|
"Perhaps discard_oversize should be set to False ?")
|
||||||
|
|
|
@ -620,6 +620,32 @@ class Language:
|
||||||
self._pipe_configs[name] = filled
|
self._pipe_configs[name] = filled
|
||||||
return resolved[factory_name]
|
return resolved[factory_name]
|
||||||
|
|
||||||
|
def create_pipe_from_source(
|
||||||
|
self, source_name: str, source: "Language", *, name: str,
|
||||||
|
) -> Tuple[Callable[[Doc], Doc], str]:
|
||||||
|
"""Create a pipeline component by copying it from an existing model.
|
||||||
|
|
||||||
|
source_name (str): Name of the component in the source pipeline.
|
||||||
|
source (Language): The source nlp object to copy from.
|
||||||
|
name (str): Optional alternative name to use in current pipeline.
|
||||||
|
RETURNS (Tuple[Callable, str]): The component and its factory name.
|
||||||
|
"""
|
||||||
|
# TODO: handle errors and mismatches (vectors etc.)
|
||||||
|
if not isinstance(source, self.__class__):
|
||||||
|
raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
|
||||||
|
if not source.has_pipe(source_name):
|
||||||
|
raise KeyError(
|
||||||
|
Errors.E944.format(
|
||||||
|
name=source_name,
|
||||||
|
model=f"{source.meta['lang']}_{source.meta['name']}",
|
||||||
|
opts=", ".join(source.pipe_names),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
pipe = source.get_pipe(source_name)
|
||||||
|
pipe_config = util.copy_config(source.config["components"][source_name])
|
||||||
|
self._pipe_configs[name] = pipe_config
|
||||||
|
return pipe, pipe_config["factory"]
|
||||||
|
|
||||||
def add_pipe(
|
def add_pipe(
|
||||||
self,
|
self,
|
||||||
factory_name: str,
|
factory_name: str,
|
||||||
|
@ -629,6 +655,7 @@ class Language:
|
||||||
after: Optional[Union[str, int]] = None,
|
after: Optional[Union[str, int]] = None,
|
||||||
first: Optional[bool] = None,
|
first: Optional[bool] = None,
|
||||||
last: Optional[bool] = None,
|
last: Optional[bool] = None,
|
||||||
|
source: Optional["Language"] = None,
|
||||||
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
config: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
overrides: Optional[Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
validate: bool = True,
|
validate: bool = True,
|
||||||
|
@ -648,6 +675,8 @@ class Language:
|
||||||
component directly after.
|
component directly after.
|
||||||
first (bool): If True, insert component first in the pipeline.
|
first (bool): If True, insert component first in the pipeline.
|
||||||
last (bool): If True, insert component last in the pipeline.
|
last (bool): If True, insert component last in the pipeline.
|
||||||
|
source (Language): Optional loaded nlp object to copy the pipeline
|
||||||
|
component from.
|
||||||
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
config (Optional[Dict[str, Any]]): Config parameters to use for this
|
||||||
component. Will be merged with default config, if available.
|
component. Will be merged with default config, if available.
|
||||||
overrides (Optional[Dict[str, Any]]): Config overrides, typically
|
overrides (Optional[Dict[str, Any]]): Config overrides, typically
|
||||||
|
@ -662,24 +691,31 @@ class Language:
|
||||||
bad_val = repr(factory_name)
|
bad_val = repr(factory_name)
|
||||||
err = Errors.E966.format(component=bad_val, name=name)
|
err = Errors.E966.format(component=bad_val, name=name)
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
if not self.has_factory(factory_name):
|
|
||||||
err = Errors.E002.format(
|
|
||||||
name=factory_name,
|
|
||||||
opts=", ".join(self.factory_names),
|
|
||||||
method="add_pipe",
|
|
||||||
lang=util.get_object_name(self),
|
|
||||||
lang_code=self.lang,
|
|
||||||
)
|
|
||||||
name = name if name is not None else factory_name
|
name = name if name is not None else factory_name
|
||||||
if name in self.pipe_names:
|
if name in self.pipe_names:
|
||||||
raise ValueError(Errors.E007.format(name=name, opts=self.pipe_names))
|
raise ValueError(Errors.E007.format(name=name, opts=self.pipe_names))
|
||||||
pipe_component = self.create_pipe(
|
if source is not None:
|
||||||
factory_name,
|
# We're loading the component from a model. After loading the
|
||||||
name=name,
|
# component, we know its real factory name
|
||||||
config=config,
|
pipe_component, factory_name = self.create_pipe_from_source(
|
||||||
overrides=overrides,
|
factory_name, source, name=name
|
||||||
validate=validate,
|
)
|
||||||
)
|
else:
|
||||||
|
if not self.has_factory(factory_name):
|
||||||
|
err = Errors.E002.format(
|
||||||
|
name=factory_name,
|
||||||
|
opts=", ".join(self.factory_names),
|
||||||
|
method="add_pipe",
|
||||||
|
lang=util.get_object_name(self),
|
||||||
|
lang_code=self.lang,
|
||||||
|
)
|
||||||
|
pipe_component = self.create_pipe(
|
||||||
|
factory_name,
|
||||||
|
name=name,
|
||||||
|
config=config,
|
||||||
|
overrides=overrides,
|
||||||
|
validate=validate,
|
||||||
|
)
|
||||||
pipe_index = self._get_pipe_index(before, after, first, last)
|
pipe_index = self._get_pipe_index(before, after, first, last)
|
||||||
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
self._pipe_meta[name] = self.get_factory_meta(factory_name)
|
||||||
self.pipeline.insert(pipe_index, (name, pipe_component))
|
self.pipeline.insert(pipe_index, (name, pipe_component))
|
||||||
|
@ -911,6 +947,7 @@ class Language:
|
||||||
sgd: Optional[Optimizer] = None,
|
sgd: Optional[Optimizer] = None,
|
||||||
losses: Optional[Dict[str, float]] = None,
|
losses: Optional[Dict[str, float]] = None,
|
||||||
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||||
|
exclude: Iterable[str] = tuple(),
|
||||||
):
|
):
|
||||||
"""Update the models in the pipeline.
|
"""Update the models in the pipeline.
|
||||||
|
|
||||||
|
@ -921,6 +958,7 @@ class Language:
|
||||||
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
|
losses (Dict[str, float]): Dictionary to update with the loss, keyed by component.
|
||||||
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
|
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
|
||||||
components, keyed by component name.
|
components, keyed by component name.
|
||||||
|
exclude (Iterable[str]): Names of components that shouldn't be updated.
|
||||||
RETURNS (Dict[str, float]): The updated losses dictionary
|
RETURNS (Dict[str, float]): The updated losses dictionary
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#update
|
DOCS: https://spacy.io/api/language#update
|
||||||
|
@ -953,12 +991,12 @@ class Language:
|
||||||
component_cfg[name].setdefault("drop", drop)
|
component_cfg[name].setdefault("drop", drop)
|
||||||
component_cfg[name].setdefault("set_annotations", False)
|
component_cfg[name].setdefault("set_annotations", False)
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if not hasattr(proc, "update"):
|
if name in exclude or not hasattr(proc, "update"):
|
||||||
continue
|
continue
|
||||||
proc.update(examples, sgd=None, losses=losses, **component_cfg[name])
|
proc.update(examples, sgd=None, losses=losses, **component_cfg[name])
|
||||||
if sgd not in (None, False):
|
if sgd not in (None, False):
|
||||||
for name, proc in self.pipeline:
|
for name, proc in self.pipeline:
|
||||||
if hasattr(proc, "model"):
|
if name not in exclude and hasattr(proc, "model"):
|
||||||
proc.model.finish_update(sgd)
|
proc.model.finish_update(sgd)
|
||||||
return losses
|
return losses
|
||||||
|
|
||||||
|
@ -969,6 +1007,7 @@ class Language:
|
||||||
sgd: Optional[Optimizer] = None,
|
sgd: Optional[Optimizer] = None,
|
||||||
losses: Optional[Dict[str, float]] = None,
|
losses: Optional[Dict[str, float]] = None,
|
||||||
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
component_cfg: Optional[Dict[str, Dict[str, Any]]] = None,
|
||||||
|
exclude: Iterable[str] = tuple(),
|
||||||
) -> Dict[str, float]:
|
) -> Dict[str, float]:
|
||||||
"""Make a "rehearsal" update to the models in the pipeline, to prevent
|
"""Make a "rehearsal" update to the models in the pipeline, to prevent
|
||||||
forgetting. Rehearsal updates run an initial copy of the model over some
|
forgetting. Rehearsal updates run an initial copy of the model over some
|
||||||
|
@ -980,6 +1019,7 @@ class Language:
|
||||||
sgd (Optional[Optimizer]): An optimizer.
|
sgd (Optional[Optimizer]): An optimizer.
|
||||||
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
|
component_cfg (Dict[str, Dict]): Config parameters for specific pipeline
|
||||||
components, keyed by component name.
|
components, keyed by component name.
|
||||||
|
exclude (Iterable[str]): Names of components that shouldn't be updated.
|
||||||
RETURNS (dict): Results from the update.
|
RETURNS (dict): Results from the update.
|
||||||
|
|
||||||
EXAMPLE:
|
EXAMPLE:
|
||||||
|
@ -1023,7 +1063,7 @@ class Language:
|
||||||
get_grads.b1 = sgd.b1
|
get_grads.b1 = sgd.b1
|
||||||
get_grads.b2 = sgd.b2
|
get_grads.b2 = sgd.b2
|
||||||
for name, proc in pipes:
|
for name, proc in pipes:
|
||||||
if not hasattr(proc, "rehearse"):
|
if name in exclude or not hasattr(proc, "rehearse"):
|
||||||
continue
|
continue
|
||||||
grads = {}
|
grads = {}
|
||||||
proc.rehearse(
|
proc.rehearse(
|
||||||
|
@ -1074,7 +1114,7 @@ class Language:
|
||||||
return self._optimizer
|
return self._optimizer
|
||||||
|
|
||||||
def resume_training(
|
def resume_training(
|
||||||
self, *, sgd: Optional[Optimizer] = None, device: int = -1
|
self, *, sgd: Optional[Optimizer] = None, device: int = -1,
|
||||||
) -> Optimizer:
|
) -> Optimizer:
|
||||||
"""Continue training a pretrained model.
|
"""Continue training a pretrained model.
|
||||||
|
|
||||||
|
@ -1373,6 +1413,7 @@ class Language:
|
||||||
cls,
|
cls,
|
||||||
config: Union[Dict[str, Any], Config] = {},
|
config: Union[Dict[str, Any], Config] = {},
|
||||||
*,
|
*,
|
||||||
|
vocab: Union[Vocab, bool] = True,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
overrides: Dict[str, Any] = {},
|
overrides: Dict[str, Any] = {},
|
||||||
auto_fill: bool = True,
|
auto_fill: bool = True,
|
||||||
|
@ -1383,6 +1424,7 @@ class Language:
|
||||||
the default config of the given language is used.
|
the default config of the given language is used.
|
||||||
|
|
||||||
config (Dict[str, Any] / Config): The loaded config.
|
config (Dict[str, Any] / Config): The loaded config.
|
||||||
|
vocab (Vocab): A Vocab object. If True, a vocab is created.
|
||||||
disable (Iterable[str]): List of pipeline component names to disable.
|
disable (Iterable[str]): List of pipeline component names to disable.
|
||||||
auto_fill (bool): Automatically fill in missing values in config based
|
auto_fill (bool): Automatically fill in missing values in config based
|
||||||
on defaults and function argument annotations.
|
on defaults and function argument annotations.
|
||||||
|
@ -1422,32 +1464,48 @@ class Language:
|
||||||
create_tokenizer = resolved["nlp"]["tokenizer"]
|
create_tokenizer = resolved["nlp"]["tokenizer"]
|
||||||
create_lemmatizer = resolved["nlp"]["lemmatizer"]
|
create_lemmatizer = resolved["nlp"]["lemmatizer"]
|
||||||
nlp = cls(
|
nlp = cls(
|
||||||
create_tokenizer=create_tokenizer, create_lemmatizer=create_lemmatizer,
|
vocab=vocab,
|
||||||
|
create_tokenizer=create_tokenizer,
|
||||||
|
create_lemmatizer=create_lemmatizer,
|
||||||
)
|
)
|
||||||
# Note that we don't load vectors here, instead they get loaded explicitly
|
# Note that we don't load vectors here, instead they get loaded explicitly
|
||||||
# inside stuff like the spacy train function. If we loaded them here,
|
# inside stuff like the spacy train function. If we loaded them here,
|
||||||
# then we would load them twice at runtime: once when we make from config,
|
# then we would load them twice at runtime: once when we make from config,
|
||||||
# and then again when we load from disk.
|
# and then again when we load from disk.
|
||||||
pipeline = config.get("components", {})
|
pipeline = config.get("components", {})
|
||||||
|
# If components are loaded from a source (existing models), we cache
|
||||||
|
# them here so they're only loaded once
|
||||||
|
source_nlps = {}
|
||||||
for pipe_name in config["nlp"]["pipeline"]:
|
for pipe_name in config["nlp"]["pipeline"]:
|
||||||
if pipe_name not in pipeline:
|
if pipe_name not in pipeline:
|
||||||
opts = ", ".join(pipeline.keys())
|
opts = ", ".join(pipeline.keys())
|
||||||
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
|
raise ValueError(Errors.E956.format(name=pipe_name, opts=opts))
|
||||||
pipe_cfg = util.copy_config(pipeline[pipe_name])
|
pipe_cfg = util.copy_config(pipeline[pipe_name])
|
||||||
if pipe_name not in disable:
|
if pipe_name not in disable:
|
||||||
if "factory" not in pipe_cfg:
|
if "factory" not in pipe_cfg and "source" not in pipe_cfg:
|
||||||
err = Errors.E984.format(name=pipe_name, config=pipe_cfg)
|
err = Errors.E984.format(name=pipe_name, config=pipe_cfg)
|
||||||
raise ValueError(err)
|
raise ValueError(err)
|
||||||
factory = pipe_cfg.pop("factory")
|
if "factory" in pipe_cfg:
|
||||||
# The pipe name (key in the config) here is the unique name of the
|
factory = pipe_cfg.pop("factory")
|
||||||
# component, not necessarily the factory
|
# The pipe name (key in the config) here is the unique name
|
||||||
nlp.add_pipe(
|
# of the component, not necessarily the factory
|
||||||
factory,
|
nlp.add_pipe(
|
||||||
name=pipe_name,
|
factory,
|
||||||
config=pipe_cfg,
|
name=pipe_name,
|
||||||
overrides=pipe_overrides,
|
config=pipe_cfg,
|
||||||
validate=validate,
|
overrides=pipe_overrides,
|
||||||
)
|
validate=validate,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
model = pipe_cfg["source"]
|
||||||
|
if model not in source_nlps:
|
||||||
|
# We only need the components here and we need to init
|
||||||
|
# model with the same vocab as the current nlp object
|
||||||
|
source_nlps[model] = util.load_model(
|
||||||
|
model, vocab=nlp.vocab, disable=["vocab", "tokenizer"]
|
||||||
|
)
|
||||||
|
source_name = pipe_cfg.get("component", pipe_name)
|
||||||
|
nlp.add_pipe(source_name, source=source_nlps[model], name=pipe_name)
|
||||||
nlp.config = filled if auto_fill else config
|
nlp.config = filled if auto_fill else config
|
||||||
nlp.resolved = resolved
|
nlp.resolved = resolved
|
||||||
return nlp
|
return nlp
|
||||||
|
|
|
@ -202,6 +202,7 @@ class ConfigSchemaTraining(BaseModel):
|
||||||
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
|
init_tok2vec: Optional[StrictStr] = Field(..., title="Path to pretrained tok2vec weights")
|
||||||
raw_text: Optional[StrictStr] = Field(default=None, title="Raw text")
|
raw_text: Optional[StrictStr] = Field(default=None, title="Raw text")
|
||||||
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
optimizer: Optimizer = Field(..., title="The optimizer to use")
|
||||||
|
frozen_components: List[str] = Field(..., title="Pipeline components that shouldn't be updated during training")
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
|
|
|
@ -8,6 +8,8 @@ from thinc.api import Model, Linear
|
||||||
from thinc.config import ConfigValidationError
|
from thinc.config import ConfigValidationError
|
||||||
from pydantic import StrictInt, StrictStr
|
from pydantic import StrictInt, StrictStr
|
||||||
|
|
||||||
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
|
||||||
def test_pipe_function_component():
|
def test_pipe_function_component():
|
||||||
name = "test_component"
|
name = "test_component"
|
||||||
|
@ -374,3 +376,65 @@ def test_language_factories_scores():
|
||||||
cfg = nlp.config["training"]
|
cfg = nlp.config["training"]
|
||||||
expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05}
|
expected_weights = {"a1": 0.25, "a2": 0.25, "b1": 0.1, "b2": 0.35, "b3": 0.05}
|
||||||
assert cfg["score_weights"] == expected_weights
|
assert cfg["score_weights"] == expected_weights
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipe_factories_from_source():
|
||||||
|
"""Test adding components from a source model."""
|
||||||
|
source_nlp = English()
|
||||||
|
source_nlp.add_pipe("tagger", name="my_tagger")
|
||||||
|
nlp = English()
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.add_pipe("my_tagger", source="en_core_web_sm")
|
||||||
|
nlp.add_pipe("my_tagger", source=source_nlp)
|
||||||
|
assert "my_tagger" in nlp.pipe_names
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
nlp.add_pipe("custom", source=source_nlp)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipe_factories_from_source_custom():
|
||||||
|
"""Test adding components from a source model with custom components."""
|
||||||
|
name = "test_pipe_factories_from_source_custom"
|
||||||
|
|
||||||
|
@Language.factory(name, default_config={"arg": "hello"})
|
||||||
|
def test_factory(nlp, name, arg: str):
|
||||||
|
return lambda doc: doc
|
||||||
|
|
||||||
|
source_nlp = English()
|
||||||
|
source_nlp.add_pipe("tagger")
|
||||||
|
source_nlp.add_pipe(name, config={"arg": "world"})
|
||||||
|
nlp = English()
|
||||||
|
nlp.add_pipe(name, source=source_nlp)
|
||||||
|
assert name in nlp.pipe_names
|
||||||
|
assert nlp.get_pipe_meta(name).default_config["arg"] == "hello"
|
||||||
|
config = nlp.config["components"][name]
|
||||||
|
assert config["factory"] == name
|
||||||
|
assert config["arg"] == "world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_pipe_factories_from_source_config():
|
||||||
|
name = "test_pipe_factories_from_source_config"
|
||||||
|
|
||||||
|
@Language.factory(name, default_config={"arg": "hello"})
|
||||||
|
def test_factory(nlp, name, arg: str):
|
||||||
|
return lambda doc: doc
|
||||||
|
|
||||||
|
source_nlp = English()
|
||||||
|
source_nlp.add_pipe("tagger")
|
||||||
|
source_nlp.add_pipe(name, name="yolo", config={"arg": "world"})
|
||||||
|
dest_nlp_cfg = {"lang": "en", "pipeline": ["parser", "custom"]}
|
||||||
|
with make_tempdir() as tempdir:
|
||||||
|
source_nlp.to_disk(tempdir)
|
||||||
|
dest_components_cfg = {
|
||||||
|
"parser": {"factory": "parser"},
|
||||||
|
"custom": {"source": str(tempdir), "component": "yolo"},
|
||||||
|
}
|
||||||
|
dest_config = {"nlp": dest_nlp_cfg, "components": dest_components_cfg}
|
||||||
|
nlp = English.from_config(dest_config)
|
||||||
|
assert nlp.pipe_names == ["parser", "custom"]
|
||||||
|
assert nlp.pipe_factories == {"parser": "parser", "custom": name}
|
||||||
|
meta = nlp.get_pipe_meta("custom")
|
||||||
|
assert meta.factory == name
|
||||||
|
assert meta.default_config["arg"] == "hello"
|
||||||
|
config = nlp.config["components"]["custom"]
|
||||||
|
assert config["factory"] == name
|
||||||
|
assert config["arg"] == "world"
|
||||||
|
|
|
@ -205,43 +205,51 @@ def load_vectors_into_model(
|
||||||
|
|
||||||
def load_model(
|
def load_model(
|
||||||
name: Union[str, Path],
|
name: Union[str, Path],
|
||||||
|
*,
|
||||||
|
vocab: Union["Vocab", bool] = True,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
"""Load a model from a package or data path.
|
"""Load a model from a package or data path.
|
||||||
|
|
||||||
name (str): Package name or model path.
|
name (str): Package name or model path.
|
||||||
|
vocab (Vocab / True): Optional vocab to pass in on initialization. If True,
|
||||||
|
a new Vocab object will be created.
|
||||||
disable (Iterable[str]): Names of pipeline components to disable.
|
disable (Iterable[str]): Names of pipeline components to disable.
|
||||||
component_cfg (Dict[str, dict]): Config overrides for pipeline components,
|
component_cfg (Dict[str, dict]): Config overrides for pipeline components,
|
||||||
keyed by component names.
|
keyed by component names.
|
||||||
RETURNS (Language): The loaded nlp object.
|
RETURNS (Language): The loaded nlp object.
|
||||||
"""
|
"""
|
||||||
cfg = component_cfg
|
kwargs = {"vocab": vocab, "disable": disable, "component_cfg": component_cfg}
|
||||||
if isinstance(name, str): # name or string path
|
if isinstance(name, str): # name or string path
|
||||||
if name.startswith("blank:"): # shortcut for blank model
|
if name.startswith("blank:"): # shortcut for blank model
|
||||||
return get_lang_class(name.replace("blank:", ""))()
|
return get_lang_class(name.replace("blank:", ""))()
|
||||||
if is_package(name): # installed as package
|
if is_package(name): # installed as package
|
||||||
return load_model_from_package(name, disable=disable, component_cfg=cfg)
|
return load_model_from_package(name, **kwargs)
|
||||||
if Path(name).exists(): # path to model data directory
|
if Path(name).exists(): # path to model data directory
|
||||||
return load_model_from_path(Path(name), disable=disable, component_cfg=cfg)
|
return load_model_from_path(Path(name), **kwargs)
|
||||||
elif hasattr(name, "exists"): # Path or Path-like to model data
|
elif hasattr(name, "exists"): # Path or Path-like to model data
|
||||||
return load_model_from_path(name, disable=disable, component_cfg=cfg)
|
return load_model_from_path(name, **kwargs)
|
||||||
raise IOError(Errors.E050.format(name=name))
|
raise IOError(Errors.E050.format(name=name))
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_package(
|
def load_model_from_package(
|
||||||
name: str,
|
name: str,
|
||||||
|
*,
|
||||||
|
vocab: Union["Vocab", bool] = True,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
"""Load a model from an installed package."""
|
"""Load a model from an installed package."""
|
||||||
cls = importlib.import_module(name)
|
cls = importlib.import_module(name)
|
||||||
return cls.load(disable=disable, component_cfg=component_cfg)
|
return cls.load(vocab=vocab, disable=disable, component_cfg=component_cfg)
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_path(
|
def load_model_from_path(
|
||||||
model_path: Union[str, Path],
|
model_path: Union[str, Path],
|
||||||
|
*,
|
||||||
meta: Optional[Dict[str, Any]] = None,
|
meta: Optional[Dict[str, Any]] = None,
|
||||||
|
vocab: Union["Vocab", bool] = True,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
|
@ -257,12 +265,16 @@ def load_model_from_path(
|
||||||
config = Config().from_disk(config_path)
|
config = Config().from_disk(config_path)
|
||||||
override_cfg = {"components": {p: dict_to_dot(c) for p, c in component_cfg.items()}}
|
override_cfg = {"components": {p: dict_to_dot(c) for p, c in component_cfg.items()}}
|
||||||
overrides = dict_to_dot(override_cfg)
|
overrides = dict_to_dot(override_cfg)
|
||||||
nlp, _ = load_model_from_config(config, disable=disable, overrides=overrides)
|
nlp, _ = load_model_from_config(
|
||||||
|
config, vocab=vocab, disable=disable, overrides=overrides
|
||||||
|
)
|
||||||
return nlp.from_disk(model_path, exclude=disable)
|
return nlp.from_disk(model_path, exclude=disable)
|
||||||
|
|
||||||
|
|
||||||
def load_model_from_config(
|
def load_model_from_config(
|
||||||
config: Union[Dict[str, Any], Config],
|
config: Union[Dict[str, Any], Config],
|
||||||
|
*,
|
||||||
|
vocab: Union["Vocab", bool] = True,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
overrides: Dict[str, Any] = {},
|
overrides: Dict[str, Any] = {},
|
||||||
auto_fill: bool = False,
|
auto_fill: bool = False,
|
||||||
|
@ -281,6 +293,7 @@ def load_model_from_config(
|
||||||
lang_cls = get_lang_class(nlp_config["lang"])
|
lang_cls = get_lang_class(nlp_config["lang"])
|
||||||
nlp = lang_cls.from_config(
|
nlp = lang_cls.from_config(
|
||||||
config,
|
config,
|
||||||
|
vocab=vocab,
|
||||||
disable=disable,
|
disable=disable,
|
||||||
overrides=overrides,
|
overrides=overrides,
|
||||||
auto_fill=auto_fill,
|
auto_fill=auto_fill,
|
||||||
|
@ -291,6 +304,8 @@ def load_model_from_config(
|
||||||
|
|
||||||
def load_model_from_init_py(
|
def load_model_from_init_py(
|
||||||
init_file: Union[Path, str],
|
init_file: Union[Path, str],
|
||||||
|
*,
|
||||||
|
vocab: Union["Vocab", bool] = True,
|
||||||
disable: Iterable[str] = tuple(),
|
disable: Iterable[str] = tuple(),
|
||||||
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
component_cfg: Dict[str, Dict[str, Any]] = SimpleFrozenDict(),
|
||||||
) -> "Language":
|
) -> "Language":
|
||||||
|
@ -308,7 +323,7 @@ def load_model_from_init_py(
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
raise IOError(Errors.E052.format(path=data_path))
|
raise IOError(Errors.E052.format(path=data_path))
|
||||||
return load_model_from_path(
|
return load_model_from_path(
|
||||||
data_path, meta, disable=disable, component_cfg=component_cfg
|
data_path, vocab=vocab, meta=meta, disable=disable, component_cfg=component_cfg
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user