Merge pull request #6852 from explosion/feature/replace-listeners

This commit is contained in:
Ines Montani 2021-01-30 00:58:08 +11:00 committed by GitHub
commit 95e958a229
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 410 additions and 35 deletions

View File

@ -68,7 +68,7 @@ console_scripts =
lookups =
spacy_lookups_data>=1.0.0,<1.1.0
transformers =
spacy_transformers>=1.0.0rc0,<1.1.0
spacy_transformers>=1.0.0rc4,<1.1.0
ray =
spacy_ray>=0.1.0,<1.0.0
cuda =

View File

@ -80,12 +80,22 @@ class Warnings:
# TODO: fix numbering after merging develop into master
W086 = ("Component '{listener}' will be (re)trained, but it needs the component "
"'{name}' which is frozen. You should either freeze both, or neither "
"of the two.")
"'{name}' which is frozen. You can either freeze both, or neither "
"of the two. If you're sourcing the component from "
"an existing pipeline, you can use the `replace_listeners` setting in "
"the config block to replace its token-to-vector listener with a copy "
"and make it independent. For example, `replace_listeners = "
"[\"model.tok2vec\"]` See the documentation for details: "
"https://nightly.spacy.io/usage/training#config-components-listeners")
W087 = ("Component '{name}' will be (re)trained, but the component '{listener}' "
"depends on it and is frozen. This means that the performance of "
"'{listener}' will be degraded. You should either freeze both, or "
"neither of the two.")
"depends on it via a listener and is frozen. This means that the "
"performance of '{listener}' will be degraded. You can either freeze "
"both, or neither of the two. If you're sourcing the component from "
"an existing pipeline, you can use the `replace_listeners` setting in "
"the config block to replace its token-to-vector listener with a copy "
"and make it independent. For example, `replace_listeners = "
"[\"model.tok2vec\"]` See the documentation for details: "
"https://nightly.spacy.io/usage/training#config-components-listeners")
W088 = ("The pipeline component {name} implements a `begin_training` "
"method, which won't be called by spaCy. As of v3.0, `begin_training` "
"has been renamed to `initialize`, so you likely want to rename the "
@ -475,7 +485,19 @@ class Errors:
"issue tracker: http://github.com/explosion/spaCy/issues")
# TODO: fix numbering after merging develop into master
E890 = ("Can not add the alias '{alias}' to the Knowledge base. "
E886 = ("Can't replace {name} -> {tok2vec} listeners: path '{path}' not "
"found in config for component '{name}'.")
E887 = ("Can't replace {name} -> {tok2vec} listeners: the paths to replace "
"({paths}) don't match the available listeners in the model ({n_listeners}).")
E888 = ("Can't replace listeners for '{name}' ({pipe}): invalid upstream "
"component that doesn't seem to support listeners. Expected Tok2Vec "
"or Transformer component. If you didn't call nlp.replace_listeners "
"manually, this is likely a bug in spaCy.")
E889 = ("Can't replace '{tok2vec}' listeners of component '{name}' because "
"'{unknown}' is not in the pipeline. Available components: {opts}. "
"If you didn't call nlp.replace_listeners manually, this is likely "
"a bug in spaCy.")
E890 = ("Cannot add the alias '{alias}' to the Knowledge base. "
"Each alias should be a meaningful string.")
E891 = ("Alias '{alias}' could not be added to the Knowledge base. "
"This is likely a bug in spaCy.")

View File

@ -1629,6 +1629,7 @@ class Language:
# Later we replace the component config with the raw config again.
interpolated = filled.interpolate() if not filled.is_interpolated else filled
pipeline = interpolated.get("components", {})
sourced = util.get_sourced_components(interpolated)
# If components are loaded from a source (existing models), we cache
# them here so they're only loaded once
source_nlps = {}
@ -1671,8 +1672,103 @@ class Language:
raise ValueError(
Errors.E942.format(name="pipeline_creation", value=type(nlp))
)
# Detect components with listeners that are not frozen consistently
for name, proc in nlp.pipeline:
if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer
for listener in proc.listening_components:
# If it's a component sourced from another pipeline, we check if
# the tok2vec listeners should be replaced with standalone tok2vec
# models (e.g. so component can be frozen without its performance
# degrading when other components/tok2vec are updated)
paths = sourced.get(listener, {}).get("replace_listeners", [])
if paths:
nlp.replace_listeners(name, listener, paths)
return nlp
def replace_listeners(
self, tok2vec_name: str, pipe_name: str, listeners: Iterable[str],
) -> None:
"""Find listener layers (connecting to a token-to-vector embedding
component) of a given pipeline component model and replace
them with a standalone copy of the token-to-vector layer. This can be
useful when training a pipeline with components sourced from an existing
pipeline: if multiple components (e.g. tagger, parser, NER) listen to
the same tok2vec component, but some of them are frozen and not updated,
their performance may degrade significally as the tok2vec component is
updated with new data. To prevent this, listeners can be replaced with
a standalone tok2vec layer that is owned by the component and doesn't
change if the component isn't updated.
tok2vec_name (str): Name of the token-to-vector component, typically
"tok2vec" or "transformer".
pipe_name (str): Name of pipeline component to replace listeners for.
listeners (Iterable[str]): The paths to the listeners, relative to the
component config, e.g. ["model.tok2vec"]. Typically, implementations
will only connect to one tok2vec component, [model.tok2vec], but in
theory, custom models can use multiple listeners. The value here can
either be an empty list to not replace any listeners, or a complete
(!) list of the paths to all listener layers used by the model.
DOCS: https://nightly.spacy.io/api/language#replace_listeners
"""
if tok2vec_name not in self.pipe_names:
err = Errors.E889.format(
tok2vec=tok2vec_name,
name=pipe_name,
unknown=tok2vec_name,
opts=", ".join(self.pipe_names),
)
raise ValueError(err)
if pipe_name not in self.pipe_names:
err = Errors.E889.format(
tok2vec=tok2vec_name,
name=pipe_name,
unknown=pipe_name,
opts=", ".join(self.pipe_names),
)
raise ValueError(err)
tok2vec = self.get_pipe(tok2vec_name)
tok2vec_cfg = self.get_pipe_config(tok2vec_name)
if (
not hasattr(tok2vec, "model")
or not hasattr(tok2vec, "listener_map")
or not hasattr(tok2vec, "remove_listener")
or "model" not in tok2vec_cfg
):
raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec)))
pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
pipe_cfg = self._pipe_configs[pipe_name]
if listeners:
util.logger.debug(f"Replacing listeners of component '{pipe_name}'")
if len(listeners) != len(pipe_listeners):
# The number of listeners defined in the component model doesn't
# match the listeners to replace, so we won't be able to update
# the nodes and generate a matching config
err = Errors.E887.format(
name=pipe_name,
tok2vec=tok2vec_name,
paths=listeners,
n_listeners=len(pipe_listeners),
)
raise ValueError(err)
pipe = self.get_pipe(pipe_name)
# Update the config accordingly by copying the tok2vec model to all
# sections defined in the listener paths
for listener_path in listeners:
# Check if the path actually exists in the config
try:
util.dot_to_object(pipe_cfg, listener_path)
except KeyError:
err = Errors.E886.format(
name=pipe_name, tok2vec=tok2vec_name, path=listener_path
)
raise ValueError(err)
util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"])
# Go over the listener layers and replace them
for listener in pipe_listeners:
util.replace_model_node(pipe.model, listener, tok2vec.model.copy())
tok2vec.remove_listener(listener, pipe_name)
def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
) -> None:

View File

@ -82,6 +82,17 @@ class Tok2Vec(TrainablePipe):
self.listener_map.setdefault(component_name, [])
self.listener_map[component_name].append(listener)
def remove_listener(self, listener: "Tok2VecListener", component_name: str) -> bool:
"""Remove a listener for a downstream component. Usually internals."""
if component_name in self.listener_map:
if listener in self.listener_map[component_name]:
self.listener_map[component_name].remove(listener)
# If no listeners are left, remove entry
if not self.listener_map[component_name]:
del self.listener_map[component_name]
return True
return False
def find_listeners(self, component) -> None:
"""Walk over a model of a processing component, looking for layers that
are Tok2vecListener subclasses that have an upstream_name that matches

View File

@ -1,5 +1,4 @@
import pytest
from spacy.ml.models.tok2vec import build_Tok2Vec_model
from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed
from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder
@ -7,14 +6,14 @@ from spacy.pipeline.tok2vec import Tok2Vec, Tok2VecListener
from spacy.vocab import Vocab
from spacy.tokens import Doc
from spacy.training import Example
from spacy.training.initialize import init_nlp
from spacy import util
from spacy.lang.en import English
from ..util import get_batch
from thinc.api import Config
from numpy.testing import assert_equal
from ..util import get_batch, make_tempdir
def test_empty_doc():
width = 128
@ -57,17 +56,17 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
assert doc_vec.shape == (len(doc), width)
# fmt: off
@pytest.mark.parametrize(
"width,embed_arch,embed_config,encode_arch,encode_config",
# fmt: off
[
(8, MultiHashEmbed, {"rows": [100, 100], "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
(8, MultiHashEmbed, {"rows": [100, 20], "attrs": ["ORTH", "PREFIX"], "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8, "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2, "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
],
# fmt: on
)
# fmt: on
def test_tok2vec_configs(width, embed_arch, embed_config, encode_arch, encode_config):
embed_config["width"] = width
encode_config["width"] = width
@ -187,3 +186,128 @@ def test_tok2vec_listener_callback():
Y, get_dX = tagger.model.begin_update(docs)
# assure that the backprop call works (and doesn't hit a 'None' callback)
assert get_dX(Y) is not None
def test_replace_listeners():
orig_config = Config().from_str(cfg_string)
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
examples = [Example.from_dict(nlp.make_doc("x y"), {"tags": ["V", "Z"]})]
nlp.initialize(lambda: examples)
tok2vec = nlp.get_pipe("tok2vec")
tagger = nlp.get_pipe("tagger")
assert isinstance(tagger.model.layers[0], Tok2VecListener)
assert tok2vec.listener_map["tagger"][0] == tagger.model.layers[0]
assert (
nlp.config["components"]["tok2vec"]["model"]["@architectures"]
== "spacy.Tok2Vec.v2"
)
assert (
nlp.config["components"]["tagger"]["model"]["tok2vec"]["@architectures"]
== "spacy.Tok2VecListener.v1"
)
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"])
assert not isinstance(tagger.model.layers[0], Tok2VecListener)
t2v_cfg = nlp.config["components"]["tok2vec"]["model"]
assert t2v_cfg["@architectures"] == "spacy.Tok2Vec.v2"
assert nlp.config["components"]["tagger"]["model"]["tok2vec"] == t2v_cfg
with pytest.raises(ValueError):
nlp.replace_listeners("invalid", "tagger", ["model.tok2vec"])
with pytest.raises(ValueError):
nlp.replace_listeners("tok2vec", "parser", ["model.tok2vec"])
with pytest.raises(ValueError):
nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"])
with pytest.raises(ValueError):
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"])
cfg_string_multi = """
[nlp]
lang = "en"
pipeline = ["tok2vec","tagger", "ner"]
[components]
[components.tagger]
factory = "tagger"
[components.tagger.model]
@architectures = "spacy.Tagger.v1"
nO = null
[components.tagger.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width}
[components.ner]
factory = "ner"
[components.ner.model]
@architectures = "spacy.TransitionBasedParser.v2"
[components.ner.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width}
[components.tok2vec]
factory = "tok2vec"
[components.tok2vec.model]
@architectures = "spacy.Tok2Vec.v2"
[components.tok2vec.model.embed]
@architectures = "spacy.MultiHashEmbed.v1"
width = ${components.tok2vec.model.encode.width}
rows = [2000, 1000, 1000, 1000]
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
include_static_vectors = false
[components.tok2vec.model.encode]
@architectures = "spacy.MaxoutWindowEncoder.v2"
width = 96
depth = 4
window_size = 1
maxout_pieces = 3
"""
def test_replace_listeners_from_config():
orig_config = Config().from_str(cfg_string_multi)
nlp = util.load_model_from_config(orig_config, auto_fill=True)
annots = {"tags": ["V", "Z"], "entities": [(0, 1, "A"), (1, 2, "B")]}
examples = [Example.from_dict(nlp.make_doc("x y"), annots)]
nlp.initialize(lambda: examples)
tok2vec = nlp.get_pipe("tok2vec")
tagger = nlp.get_pipe("tagger")
ner = nlp.get_pipe("ner")
assert tok2vec.listening_components == ["tagger", "ner"]
assert any(isinstance(node, Tok2VecListener) for node in ner.model.walk())
assert any(isinstance(node, Tok2VecListener) for node in tagger.model.walk())
with make_tempdir() as dir_path:
nlp.to_disk(dir_path)
base_model = str(dir_path)
new_config = {
"nlp": {"lang": "en", "pipeline": ["tok2vec", "tagger", "ner"]},
"components": {
"tok2vec": {"source": base_model},
"tagger": {
"source": base_model,
"replace_listeners": ["model.tok2vec"],
},
"ner": {"source": base_model},
},
}
new_nlp = util.load_model_from_config(new_config, auto_fill=True)
new_nlp.initialize(lambda: examples)
tok2vec = new_nlp.get_pipe("tok2vec")
tagger = new_nlp.get_pipe("tagger")
ner = new_nlp.get_pipe("ner")
assert tok2vec.listening_components == ["ner"]
assert any(isinstance(node, Tok2VecListener) for node in ner.model.walk())
assert not any(isinstance(node, Tok2VecListener) for node in tagger.model.walk())
t2v_cfg = new_nlp.config["components"]["tok2vec"]["model"]
assert t2v_cfg["@architectures"] == "spacy.Tok2Vec.v2"
assert new_nlp.config["components"]["tagger"]["model"]["tok2vec"] == t2v_cfg
assert (
new_nlp.config["components"]["ner"]["model"]["tok2vec"]["@architectures"]
== "spacy.Tok2VecListener.v1"
)

View File

@ -205,6 +205,25 @@ def test_dot_to_dict(dot_notation, expected):
assert util.dict_to_dot(result) == dot_notation
def test_set_dot_to_object():
config = {"foo": {"bar": 1, "baz": {"x": "y"}}, "test": {"a": {"b": "c"}}}
with pytest.raises(KeyError):
util.set_dot_to_object(config, "foo.bar.baz", 100)
with pytest.raises(KeyError):
util.set_dot_to_object(config, "hello.world", 100)
with pytest.raises(KeyError):
util.set_dot_to_object(config, "test.a.b.c", 100)
util.set_dot_to_object(config, "foo.bar", 100)
assert config["foo"]["bar"] == 100
util.set_dot_to_object(config, "foo.baz.x", {"hello": "world"})
assert config["foo"]["baz"]["x"]["hello"] == "world"
assert config["test"]["a"]["b"] == "c"
util.set_dot_to_object(config, "foo", 123)
assert config["foo"] == 123
util.set_dot_to_object(config, "test", "hello")
assert dict(config) == {"foo": 123, "test": "hello"}
@pytest.mark.parametrize(
"doc_sizes, expected_batches",
[

View File

@ -1,4 +1,4 @@
from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING
from typing import Union, Dict, Optional, Any, IO, TYPE_CHECKING
from thinc.api import Config, fix_random_seed, set_gpu_allocator
from thinc.api import ConfigValidationError
from pathlib import Path
@ -14,7 +14,8 @@ from ..vectors import Vectors
from ..errors import Errors, Warnings
from ..schemas import ConfigSchemaTraining
from ..util import registry, load_model_from_config, resolve_dot_names, logger
from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB
from ..util import load_model, ensure_path, get_sourced_components
from ..util import OOV_RANK, DEFAULT_OOV_PROB
if TYPE_CHECKING:
from ..language import Language # noqa: F401
@ -33,7 +34,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator)
# Use original config here before it's resolved to functions
sourced_components = get_sourced_components(config)
sourced = get_sourced_components(config)
nlp = load_model_from_config(raw_config, auto_fill=True)
logger.info("Set up nlp object from config")
config = nlp.config.interpolate()
@ -57,7 +58,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
# Components that shouldn't be updated during training
frozen_components = T["frozen_components"]
# Sourced components that require resume_training
resume_components = [p for p in sourced_components if p not in frozen_components]
resume_components = [p for p in sourced if p not in frozen_components]
logger.info(f"Pipeline: {nlp.pipe_names}")
if resume_components:
with nlp.select_pipes(enable=resume_components):
@ -68,10 +69,11 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
# Detect components with listeners that are not frozen consistently
for name, proc in nlp.pipeline:
if getattr(proc, "listening_components", None):
if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer
for listener in proc.listening_components:
if listener in frozen_components and name not in frozen_components:
logger.warning(Warnings.W087.format(name=name, listener=listener))
# We always check this regardless, in case user freezes tok2vec
if listener not in frozen_components and name in frozen_components:
logger.warning(Warnings.W086.format(name=name, listener=listener))
return nlp
@ -173,18 +175,6 @@ def init_tok2vec(
return False
def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
"""RETURNS (List[str]): All sourced components in the original config,
e.g. {"source": "en_core_web_sm"}. If the config contains a key
"factory", we assume it refers to a component factory.
"""
return [
name
for name, cfg in config.get("components", {}).items()
if "factory" not in cfg and "source" in cfg
]
def convert_vectors(
nlp: "Language",
vectors_loc: Optional[Path],

View File

@ -8,7 +8,7 @@ import re
from pathlib import Path
import thinc
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
from thinc.api import ConfigValidationError
from thinc.api import ConfigValidationError, Model
import functools
import itertools
import numpy.random
@ -434,6 +434,20 @@ def load_model_from_config(
return nlp
def get_sourced_components(
config: Union[Dict[str, Any], Config]
) -> Dict[str, Dict[str, Any]]:
"""RETURNS (List[str]): All sourced components in the original config,
e.g. {"source": "en_core_web_sm"}. If the config contains a key
"factory", we assume it refers to a component factory.
"""
return {
name: cfg
for name, cfg in config.get("components", {}).items()
if "factory" not in cfg and "source" in cfg
}
def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[Any]:
"""Resolve one or more "dot notation" names, e.g. corpora.train.
The paths could point anywhere into the config, so we don't know which
@ -738,6 +752,24 @@ def get_package_path(name: str) -> Path:
return Path(pkg.__file__).parent
def replace_model_node(model: Model, target: Model, replacement: Model) -> None:
"""Replace a node within a model with a new one, updating refs.
model (Model): The parent model.
target (Model): The target node.
replacement (Model): The node to replace the target with.
"""
# Place the node into the sublayers
for node in model.walk():
if target in node.layers:
node.layers[node.layers.index(target)] = replacement
# Now fix any node references
for node in model.walk():
for ref_name in node.ref_names:
if node.maybe_get_ref(ref_name) is target:
node.set_ref(ref_name, replacement)
def split_command(command: str) -> List[str]:
"""Split a string command using shlex. Handles platform compatibility.
@ -1279,6 +1311,25 @@ def dot_to_object(config: Config, section: str):
return component
def set_dot_to_object(config: Config, section: str, value: Any) -> None:
"""Update a config at a given position from a dot notation.
config (Config): The config.
section (str): The dot notation of the section in the config.
value (Any): The value to set in the config.
"""
component = config
parts = section.split(".")
for i, item in enumerate(parts):
try:
if i == len(parts) - 1:
component[item] = value
else:
component = component[item]
except (KeyError, TypeError):
raise KeyError(Errors.E952.format(name=section)) from None
def walk_dict(
node: Dict[str, Any], parent: List[str] = []
) -> Iterator[Tuple[List[str], Any]]:
@ -1443,5 +1494,6 @@ def _pipe(docs, proc, name, default_error_handler, kwargs):
def raise_error(proc_name, proc, docs, e):
raise e
def ignore_error(proc_name, proc, docs, e):
pass

View File

@ -833,6 +833,51 @@ token.ent_iob, token.ent_type
| `pretty` | Pretty-print the results as a table. Defaults to `False`. ~~bool~~ |
| **RETURNS** | Dictionary containing the pipe analysis, keyed by `"summary"` (component meta by pipe), `"problems"` (attribute names by pipe) and `"attrs"` (pipes that assign and require an attribute, keyed by attribute). ~~Optional[Dict[str, Any]]~~ |
## Language.replace_listeners {#replace_listeners tag="method" new="3"}
Find [listener layers](/usage/embeddings-transformers#embedding-layers)
(connecting to a shared token-to-vector embedding component) of a given pipeline
component model and replace them with a standalone copy of the token-to-vector
layer. The listener layer allows other components to connect to a shared
token-to-vector embedding component like [`Tok2Vec`](/api/tok2vec) or
[`Transformer`](/api/transformer). Replacing listeners can be useful when
training a pipeline with components sourced from an existing pipeline: if
multiple components (e.g. tagger, parser, NER) listen to the same
token-to-vector component, but some of them are frozen and not updated, their
performance may degrade significally as the token-to-vector component is updated
with new data. To prevent this, listeners can be replaced with a standalone
token-to-vector layer that is owned by the component and doesn't change if the
component isn't updated.
This method is typically not called directly and only executed under the hood
when loading a config with
[sourced components](/usage/training#config-components) that define
`replace_listeners`.
> ```python
> ### Example
> nlp = spacy.load("en_core_web_sm")
> nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"])
> ```
>
> ```ini
> ### config.cfg (excerpt)
> [training]
> frozen_components = ["tagger"]
>
> [components]
>
> [components.tagger]
> source = "en_core_web_sm"
> replace_listeners = ["model.tok2vec"]
> ```
| Name | Description |
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `tok2vec_name` | Name of the token-to-vector component, typically `"tok2vec"` or `"transformer"`.~~str~~ |
| `pipe_name` | Name of pipeline component to replace listeners for. ~~str~~ |
| `listeners` | The paths to the listeners, relative to the component config, e.g. `["model.tok2vec"]`. Typically, implementations will only connect to one tok2vec component, `model.tok2vec`, but in theory, custom models can use multiple listeners. The value here can either be an empty list to not replace any listeners, or a _complete_ list of the paths to all listener layers used by the model that should be replaced.~~Iterable[str]~~ |
## Language.meta {#meta tag="property"}
Meta data for the `Language` class, including name, version, data sources,

View File

@ -419,13 +419,29 @@ pipeline = ["parser", "ner", "textcat", "custom"]
frozen_components = ["parser", "custom"]
```
<Infobox variant="warning" title="Shared Tok2Vec layer">
<Infobox variant="warning" title="Shared Tok2Vec listener layer" id="config-components-listeners">
When the components in your pipeline
[share an embedding layer](/usage/embeddings-transformers#embedding-layers), the
**performance** of your frozen component will be **degraded** if you continue training
other layers with the same underlying `Tok2Vec` instance. As a rule of thumb,
ensure that your frozen components are truly **independent** in the pipeline.
**performance** of your frozen component will be **degraded** if you continue
training other layers with the same underlying `Tok2Vec` instance. As a rule of
thumb, ensure that your frozen components are truly **independent** in the
pipeline.
To automatically replace a shared token-to-vector listener with an independent
copy of the token-to-vector layer, you can use the `replace_listeners` setting
of a sourced component, pointing to the listener layer(s) in the config. For
more details on how this works under the hood, see
[`Language.replace_listeners`](/api/language#replace_listeners).
```ini
[training]
frozen_components = ["tagger"]
[components.tagger]
source = "en_core_web_sm"
replace_listeners = ["model.tok2vec"]
```
</Infobox>