mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Merge pull request #6852 from explosion/feature/replace-listeners
This commit is contained in:
commit
95e958a229
|
@ -68,7 +68,7 @@ console_scripts =
|
||||||
lookups =
|
lookups =
|
||||||
spacy_lookups_data>=1.0.0,<1.1.0
|
spacy_lookups_data>=1.0.0,<1.1.0
|
||||||
transformers =
|
transformers =
|
||||||
spacy_transformers>=1.0.0rc0,<1.1.0
|
spacy_transformers>=1.0.0rc4,<1.1.0
|
||||||
ray =
|
ray =
|
||||||
spacy_ray>=0.1.0,<1.0.0
|
spacy_ray>=0.1.0,<1.0.0
|
||||||
cuda =
|
cuda =
|
||||||
|
|
|
@ -80,12 +80,22 @@ class Warnings:
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
W086 = ("Component '{listener}' will be (re)trained, but it needs the component "
|
W086 = ("Component '{listener}' will be (re)trained, but it needs the component "
|
||||||
"'{name}' which is frozen. You should either freeze both, or neither "
|
"'{name}' which is frozen. You can either freeze both, or neither "
|
||||||
"of the two.")
|
"of the two. If you're sourcing the component from "
|
||||||
|
"an existing pipeline, you can use the `replace_listeners` setting in "
|
||||||
|
"the config block to replace its token-to-vector listener with a copy "
|
||||||
|
"and make it independent. For example, `replace_listeners = "
|
||||||
|
"[\"model.tok2vec\"]` See the documentation for details: "
|
||||||
|
"https://nightly.spacy.io/usage/training#config-components-listeners")
|
||||||
W087 = ("Component '{name}' will be (re)trained, but the component '{listener}' "
|
W087 = ("Component '{name}' will be (re)trained, but the component '{listener}' "
|
||||||
"depends on it and is frozen. This means that the performance of "
|
"depends on it via a listener and is frozen. This means that the "
|
||||||
"'{listener}' will be degraded. You should either freeze both, or "
|
"performance of '{listener}' will be degraded. You can either freeze "
|
||||||
"neither of the two.")
|
"both, or neither of the two. If you're sourcing the component from "
|
||||||
|
"an existing pipeline, you can use the `replace_listeners` setting in "
|
||||||
|
"the config block to replace its token-to-vector listener with a copy "
|
||||||
|
"and make it independent. For example, `replace_listeners = "
|
||||||
|
"[\"model.tok2vec\"]` See the documentation for details: "
|
||||||
|
"https://nightly.spacy.io/usage/training#config-components-listeners")
|
||||||
W088 = ("The pipeline component {name} implements a `begin_training` "
|
W088 = ("The pipeline component {name} implements a `begin_training` "
|
||||||
"method, which won't be called by spaCy. As of v3.0, `begin_training` "
|
"method, which won't be called by spaCy. As of v3.0, `begin_training` "
|
||||||
"has been renamed to `initialize`, so you likely want to rename the "
|
"has been renamed to `initialize`, so you likely want to rename the "
|
||||||
|
@ -475,7 +485,19 @@ class Errors:
|
||||||
"issue tracker: http://github.com/explosion/spaCy/issues")
|
"issue tracker: http://github.com/explosion/spaCy/issues")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
E890 = ("Can not add the alias '{alias}' to the Knowledge base. "
|
E886 = ("Can't replace {name} -> {tok2vec} listeners: path '{path}' not "
|
||||||
|
"found in config for component '{name}'.")
|
||||||
|
E887 = ("Can't replace {name} -> {tok2vec} listeners: the paths to replace "
|
||||||
|
"({paths}) don't match the available listeners in the model ({n_listeners}).")
|
||||||
|
E888 = ("Can't replace listeners for '{name}' ({pipe}): invalid upstream "
|
||||||
|
"component that doesn't seem to support listeners. Expected Tok2Vec "
|
||||||
|
"or Transformer component. If you didn't call nlp.replace_listeners "
|
||||||
|
"manually, this is likely a bug in spaCy.")
|
||||||
|
E889 = ("Can't replace '{tok2vec}' listeners of component '{name}' because "
|
||||||
|
"'{unknown}' is not in the pipeline. Available components: {opts}. "
|
||||||
|
"If you didn't call nlp.replace_listeners manually, this is likely "
|
||||||
|
"a bug in spaCy.")
|
||||||
|
E890 = ("Cannot add the alias '{alias}' to the Knowledge base. "
|
||||||
"Each alias should be a meaningful string.")
|
"Each alias should be a meaningful string.")
|
||||||
E891 = ("Alias '{alias}' could not be added to the Knowledge base. "
|
E891 = ("Alias '{alias}' could not be added to the Knowledge base. "
|
||||||
"This is likely a bug in spaCy.")
|
"This is likely a bug in spaCy.")
|
||||||
|
|
|
@ -1629,6 +1629,7 @@ class Language:
|
||||||
# Later we replace the component config with the raw config again.
|
# Later we replace the component config with the raw config again.
|
||||||
interpolated = filled.interpolate() if not filled.is_interpolated else filled
|
interpolated = filled.interpolate() if not filled.is_interpolated else filled
|
||||||
pipeline = interpolated.get("components", {})
|
pipeline = interpolated.get("components", {})
|
||||||
|
sourced = util.get_sourced_components(interpolated)
|
||||||
# If components are loaded from a source (existing models), we cache
|
# If components are loaded from a source (existing models), we cache
|
||||||
# them here so they're only loaded once
|
# them here so they're only loaded once
|
||||||
source_nlps = {}
|
source_nlps = {}
|
||||||
|
@ -1671,8 +1672,103 @@ class Language:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
Errors.E942.format(name="pipeline_creation", value=type(nlp))
|
Errors.E942.format(name="pipeline_creation", value=type(nlp))
|
||||||
)
|
)
|
||||||
|
# Detect components with listeners that are not frozen consistently
|
||||||
|
for name, proc in nlp.pipeline:
|
||||||
|
if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer
|
||||||
|
for listener in proc.listening_components:
|
||||||
|
# If it's a component sourced from another pipeline, we check if
|
||||||
|
# the tok2vec listeners should be replaced with standalone tok2vec
|
||||||
|
# models (e.g. so component can be frozen without its performance
|
||||||
|
# degrading when other components/tok2vec are updated)
|
||||||
|
paths = sourced.get(listener, {}).get("replace_listeners", [])
|
||||||
|
if paths:
|
||||||
|
nlp.replace_listeners(name, listener, paths)
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
def replace_listeners(
|
||||||
|
self, tok2vec_name: str, pipe_name: str, listeners: Iterable[str],
|
||||||
|
) -> None:
|
||||||
|
"""Find listener layers (connecting to a token-to-vector embedding
|
||||||
|
component) of a given pipeline component model and replace
|
||||||
|
them with a standalone copy of the token-to-vector layer. This can be
|
||||||
|
useful when training a pipeline with components sourced from an existing
|
||||||
|
pipeline: if multiple components (e.g. tagger, parser, NER) listen to
|
||||||
|
the same tok2vec component, but some of them are frozen and not updated,
|
||||||
|
their performance may degrade significally as the tok2vec component is
|
||||||
|
updated with new data. To prevent this, listeners can be replaced with
|
||||||
|
a standalone tok2vec layer that is owned by the component and doesn't
|
||||||
|
change if the component isn't updated.
|
||||||
|
|
||||||
|
tok2vec_name (str): Name of the token-to-vector component, typically
|
||||||
|
"tok2vec" or "transformer".
|
||||||
|
pipe_name (str): Name of pipeline component to replace listeners for.
|
||||||
|
listeners (Iterable[str]): The paths to the listeners, relative to the
|
||||||
|
component config, e.g. ["model.tok2vec"]. Typically, implementations
|
||||||
|
will only connect to one tok2vec component, [model.tok2vec], but in
|
||||||
|
theory, custom models can use multiple listeners. The value here can
|
||||||
|
either be an empty list to not replace any listeners, or a complete
|
||||||
|
(!) list of the paths to all listener layers used by the model.
|
||||||
|
|
||||||
|
DOCS: https://nightly.spacy.io/api/language#replace_listeners
|
||||||
|
"""
|
||||||
|
if tok2vec_name not in self.pipe_names:
|
||||||
|
err = Errors.E889.format(
|
||||||
|
tok2vec=tok2vec_name,
|
||||||
|
name=pipe_name,
|
||||||
|
unknown=tok2vec_name,
|
||||||
|
opts=", ".join(self.pipe_names),
|
||||||
|
)
|
||||||
|
raise ValueError(err)
|
||||||
|
if pipe_name not in self.pipe_names:
|
||||||
|
err = Errors.E889.format(
|
||||||
|
tok2vec=tok2vec_name,
|
||||||
|
name=pipe_name,
|
||||||
|
unknown=pipe_name,
|
||||||
|
opts=", ".join(self.pipe_names),
|
||||||
|
)
|
||||||
|
raise ValueError(err)
|
||||||
|
tok2vec = self.get_pipe(tok2vec_name)
|
||||||
|
tok2vec_cfg = self.get_pipe_config(tok2vec_name)
|
||||||
|
if (
|
||||||
|
not hasattr(tok2vec, "model")
|
||||||
|
or not hasattr(tok2vec, "listener_map")
|
||||||
|
or not hasattr(tok2vec, "remove_listener")
|
||||||
|
or "model" not in tok2vec_cfg
|
||||||
|
):
|
||||||
|
raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec)))
|
||||||
|
pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
|
||||||
|
pipe_cfg = self._pipe_configs[pipe_name]
|
||||||
|
if listeners:
|
||||||
|
util.logger.debug(f"Replacing listeners of component '{pipe_name}'")
|
||||||
|
if len(listeners) != len(pipe_listeners):
|
||||||
|
# The number of listeners defined in the component model doesn't
|
||||||
|
# match the listeners to replace, so we won't be able to update
|
||||||
|
# the nodes and generate a matching config
|
||||||
|
err = Errors.E887.format(
|
||||||
|
name=pipe_name,
|
||||||
|
tok2vec=tok2vec_name,
|
||||||
|
paths=listeners,
|
||||||
|
n_listeners=len(pipe_listeners),
|
||||||
|
)
|
||||||
|
raise ValueError(err)
|
||||||
|
pipe = self.get_pipe(pipe_name)
|
||||||
|
# Update the config accordingly by copying the tok2vec model to all
|
||||||
|
# sections defined in the listener paths
|
||||||
|
for listener_path in listeners:
|
||||||
|
# Check if the path actually exists in the config
|
||||||
|
try:
|
||||||
|
util.dot_to_object(pipe_cfg, listener_path)
|
||||||
|
except KeyError:
|
||||||
|
err = Errors.E886.format(
|
||||||
|
name=pipe_name, tok2vec=tok2vec_name, path=listener_path
|
||||||
|
)
|
||||||
|
raise ValueError(err)
|
||||||
|
util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"])
|
||||||
|
# Go over the listener layers and replace them
|
||||||
|
for listener in pipe_listeners:
|
||||||
|
util.replace_model_node(pipe.model, listener, tok2vec.model.copy())
|
||||||
|
tok2vec.remove_listener(listener, pipe_name)
|
||||||
|
|
||||||
def to_disk(
|
def to_disk(
|
||||||
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||||
) -> None:
|
) -> None:
|
||||||
|
|
|
@ -82,6 +82,17 @@ class Tok2Vec(TrainablePipe):
|
||||||
self.listener_map.setdefault(component_name, [])
|
self.listener_map.setdefault(component_name, [])
|
||||||
self.listener_map[component_name].append(listener)
|
self.listener_map[component_name].append(listener)
|
||||||
|
|
||||||
|
def remove_listener(self, listener: "Tok2VecListener", component_name: str) -> bool:
|
||||||
|
"""Remove a listener for a downstream component. Usually internals."""
|
||||||
|
if component_name in self.listener_map:
|
||||||
|
if listener in self.listener_map[component_name]:
|
||||||
|
self.listener_map[component_name].remove(listener)
|
||||||
|
# If no listeners are left, remove entry
|
||||||
|
if not self.listener_map[component_name]:
|
||||||
|
del self.listener_map[component_name]
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
def find_listeners(self, component) -> None:
|
def find_listeners(self, component) -> None:
|
||||||
"""Walk over a model of a processing component, looking for layers that
|
"""Walk over a model of a processing component, looking for layers that
|
||||||
are Tok2vecListener subclasses that have an upstream_name that matches
|
are Tok2vecListener subclasses that have an upstream_name that matches
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from spacy.ml.models.tok2vec import build_Tok2Vec_model
|
from spacy.ml.models.tok2vec import build_Tok2Vec_model
|
||||||
from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed
|
from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed
|
||||||
from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder
|
from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder
|
||||||
|
@ -7,14 +6,14 @@ from spacy.pipeline.tok2vec import Tok2Vec, Tok2VecListener
|
||||||
from spacy.vocab import Vocab
|
from spacy.vocab import Vocab
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
|
from spacy.training.initialize import init_nlp
|
||||||
from spacy import util
|
from spacy import util
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from ..util import get_batch
|
|
||||||
|
|
||||||
from thinc.api import Config
|
from thinc.api import Config
|
||||||
|
|
||||||
from numpy.testing import assert_equal
|
from numpy.testing import assert_equal
|
||||||
|
|
||||||
|
from ..util import get_batch, make_tempdir
|
||||||
|
|
||||||
|
|
||||||
def test_empty_doc():
|
def test_empty_doc():
|
||||||
width = 128
|
width = 128
|
||||||
|
@ -57,17 +56,17 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size):
|
||||||
assert doc_vec.shape == (len(doc), width)
|
assert doc_vec.shape == (len(doc), width)
|
||||||
|
|
||||||
|
|
||||||
# fmt: off
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"width,embed_arch,embed_config,encode_arch,encode_config",
|
"width,embed_arch,embed_config,encode_arch,encode_config",
|
||||||
|
# fmt: off
|
||||||
[
|
[
|
||||||
(8, MultiHashEmbed, {"rows": [100, 100], "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
|
(8, MultiHashEmbed, {"rows": [100, 100], "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}),
|
||||||
(8, MultiHashEmbed, {"rows": [100, 20], "attrs": ["ORTH", "PREFIX"], "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
|
(8, MultiHashEmbed, {"rows": [100, 20], "attrs": ["ORTH", "PREFIX"], "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}),
|
||||||
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8, "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
|
(8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8, "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}),
|
||||||
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2, "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
|
(8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2, "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 3}),
|
||||||
],
|
],
|
||||||
|
# fmt: on
|
||||||
)
|
)
|
||||||
# fmt: on
|
|
||||||
def test_tok2vec_configs(width, embed_arch, embed_config, encode_arch, encode_config):
|
def test_tok2vec_configs(width, embed_arch, embed_config, encode_arch, encode_config):
|
||||||
embed_config["width"] = width
|
embed_config["width"] = width
|
||||||
encode_config["width"] = width
|
encode_config["width"] = width
|
||||||
|
@ -187,3 +186,128 @@ def test_tok2vec_listener_callback():
|
||||||
Y, get_dX = tagger.model.begin_update(docs)
|
Y, get_dX = tagger.model.begin_update(docs)
|
||||||
# assure that the backprop call works (and doesn't hit a 'None' callback)
|
# assure that the backprop call works (and doesn't hit a 'None' callback)
|
||||||
assert get_dX(Y) is not None
|
assert get_dX(Y) is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_listeners():
|
||||||
|
orig_config = Config().from_str(cfg_string)
|
||||||
|
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
|
||||||
|
examples = [Example.from_dict(nlp.make_doc("x y"), {"tags": ["V", "Z"]})]
|
||||||
|
nlp.initialize(lambda: examples)
|
||||||
|
tok2vec = nlp.get_pipe("tok2vec")
|
||||||
|
tagger = nlp.get_pipe("tagger")
|
||||||
|
assert isinstance(tagger.model.layers[0], Tok2VecListener)
|
||||||
|
assert tok2vec.listener_map["tagger"][0] == tagger.model.layers[0]
|
||||||
|
assert (
|
||||||
|
nlp.config["components"]["tok2vec"]["model"]["@architectures"]
|
||||||
|
== "spacy.Tok2Vec.v2"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
nlp.config["components"]["tagger"]["model"]["tok2vec"]["@architectures"]
|
||||||
|
== "spacy.Tok2VecListener.v1"
|
||||||
|
)
|
||||||
|
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"])
|
||||||
|
assert not isinstance(tagger.model.layers[0], Tok2VecListener)
|
||||||
|
t2v_cfg = nlp.config["components"]["tok2vec"]["model"]
|
||||||
|
assert t2v_cfg["@architectures"] == "spacy.Tok2Vec.v2"
|
||||||
|
assert nlp.config["components"]["tagger"]["model"]["tok2vec"] == t2v_cfg
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.replace_listeners("invalid", "tagger", ["model.tok2vec"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.replace_listeners("tok2vec", "parser", ["model.tok2vec"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"])
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"])
|
||||||
|
|
||||||
|
|
||||||
|
cfg_string_multi = """
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["tok2vec","tagger", "ner"]
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.tagger]
|
||||||
|
factory = "tagger"
|
||||||
|
|
||||||
|
[components.tagger.model]
|
||||||
|
@architectures = "spacy.Tagger.v1"
|
||||||
|
nO = null
|
||||||
|
|
||||||
|
[components.tagger.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
|
||||||
|
[components.ner]
|
||||||
|
factory = "ner"
|
||||||
|
|
||||||
|
[components.ner.model]
|
||||||
|
@architectures = "spacy.TransitionBasedParser.v2"
|
||||||
|
|
||||||
|
[components.ner.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
|
||||||
|
[components.tok2vec]
|
||||||
|
factory = "tok2vec"
|
||||||
|
|
||||||
|
[components.tok2vec.model]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
|
||||||
|
[components.tok2vec.model.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v1"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
rows = [2000, 1000, 1000, 1000]
|
||||||
|
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||||
|
include_static_vectors = false
|
||||||
|
|
||||||
|
[components.tok2vec.model.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
|
width = 96
|
||||||
|
depth = 4
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def test_replace_listeners_from_config():
|
||||||
|
orig_config = Config().from_str(cfg_string_multi)
|
||||||
|
nlp = util.load_model_from_config(orig_config, auto_fill=True)
|
||||||
|
annots = {"tags": ["V", "Z"], "entities": [(0, 1, "A"), (1, 2, "B")]}
|
||||||
|
examples = [Example.from_dict(nlp.make_doc("x y"), annots)]
|
||||||
|
nlp.initialize(lambda: examples)
|
||||||
|
tok2vec = nlp.get_pipe("tok2vec")
|
||||||
|
tagger = nlp.get_pipe("tagger")
|
||||||
|
ner = nlp.get_pipe("ner")
|
||||||
|
assert tok2vec.listening_components == ["tagger", "ner"]
|
||||||
|
assert any(isinstance(node, Tok2VecListener) for node in ner.model.walk())
|
||||||
|
assert any(isinstance(node, Tok2VecListener) for node in tagger.model.walk())
|
||||||
|
with make_tempdir() as dir_path:
|
||||||
|
nlp.to_disk(dir_path)
|
||||||
|
base_model = str(dir_path)
|
||||||
|
new_config = {
|
||||||
|
"nlp": {"lang": "en", "pipeline": ["tok2vec", "tagger", "ner"]},
|
||||||
|
"components": {
|
||||||
|
"tok2vec": {"source": base_model},
|
||||||
|
"tagger": {
|
||||||
|
"source": base_model,
|
||||||
|
"replace_listeners": ["model.tok2vec"],
|
||||||
|
},
|
||||||
|
"ner": {"source": base_model},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
new_nlp = util.load_model_from_config(new_config, auto_fill=True)
|
||||||
|
new_nlp.initialize(lambda: examples)
|
||||||
|
tok2vec = new_nlp.get_pipe("tok2vec")
|
||||||
|
tagger = new_nlp.get_pipe("tagger")
|
||||||
|
ner = new_nlp.get_pipe("ner")
|
||||||
|
assert tok2vec.listening_components == ["ner"]
|
||||||
|
assert any(isinstance(node, Tok2VecListener) for node in ner.model.walk())
|
||||||
|
assert not any(isinstance(node, Tok2VecListener) for node in tagger.model.walk())
|
||||||
|
t2v_cfg = new_nlp.config["components"]["tok2vec"]["model"]
|
||||||
|
assert t2v_cfg["@architectures"] == "spacy.Tok2Vec.v2"
|
||||||
|
assert new_nlp.config["components"]["tagger"]["model"]["tok2vec"] == t2v_cfg
|
||||||
|
assert (
|
||||||
|
new_nlp.config["components"]["ner"]["model"]["tok2vec"]["@architectures"]
|
||||||
|
== "spacy.Tok2VecListener.v1"
|
||||||
|
)
|
||||||
|
|
|
@ -205,6 +205,25 @@ def test_dot_to_dict(dot_notation, expected):
|
||||||
assert util.dict_to_dot(result) == dot_notation
|
assert util.dict_to_dot(result) == dot_notation
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_dot_to_object():
|
||||||
|
config = {"foo": {"bar": 1, "baz": {"x": "y"}}, "test": {"a": {"b": "c"}}}
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
util.set_dot_to_object(config, "foo.bar.baz", 100)
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
util.set_dot_to_object(config, "hello.world", 100)
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
util.set_dot_to_object(config, "test.a.b.c", 100)
|
||||||
|
util.set_dot_to_object(config, "foo.bar", 100)
|
||||||
|
assert config["foo"]["bar"] == 100
|
||||||
|
util.set_dot_to_object(config, "foo.baz.x", {"hello": "world"})
|
||||||
|
assert config["foo"]["baz"]["x"]["hello"] == "world"
|
||||||
|
assert config["test"]["a"]["b"] == "c"
|
||||||
|
util.set_dot_to_object(config, "foo", 123)
|
||||||
|
assert config["foo"] == 123
|
||||||
|
util.set_dot_to_object(config, "test", "hello")
|
||||||
|
assert dict(config) == {"foo": 123, "test": "hello"}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"doc_sizes, expected_batches",
|
"doc_sizes, expected_batches",
|
||||||
[
|
[
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING
|
from typing import Union, Dict, Optional, Any, IO, TYPE_CHECKING
|
||||||
from thinc.api import Config, fix_random_seed, set_gpu_allocator
|
from thinc.api import Config, fix_random_seed, set_gpu_allocator
|
||||||
from thinc.api import ConfigValidationError
|
from thinc.api import ConfigValidationError
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -14,7 +14,8 @@ from ..vectors import Vectors
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
from ..schemas import ConfigSchemaTraining
|
from ..schemas import ConfigSchemaTraining
|
||||||
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
||||||
from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB
|
from ..util import load_model, ensure_path, get_sourced_components
|
||||||
|
from ..util import OOV_RANK, DEFAULT_OOV_PROB
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..language import Language # noqa: F401
|
from ..language import Language # noqa: F401
|
||||||
|
@ -33,7 +34,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
if use_gpu >= 0 and allocator:
|
if use_gpu >= 0 and allocator:
|
||||||
set_gpu_allocator(allocator)
|
set_gpu_allocator(allocator)
|
||||||
# Use original config here before it's resolved to functions
|
# Use original config here before it's resolved to functions
|
||||||
sourced_components = get_sourced_components(config)
|
sourced = get_sourced_components(config)
|
||||||
nlp = load_model_from_config(raw_config, auto_fill=True)
|
nlp = load_model_from_config(raw_config, auto_fill=True)
|
||||||
logger.info("Set up nlp object from config")
|
logger.info("Set up nlp object from config")
|
||||||
config = nlp.config.interpolate()
|
config = nlp.config.interpolate()
|
||||||
|
@ -57,7 +58,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
# Components that shouldn't be updated during training
|
# Components that shouldn't be updated during training
|
||||||
frozen_components = T["frozen_components"]
|
frozen_components = T["frozen_components"]
|
||||||
# Sourced components that require resume_training
|
# Sourced components that require resume_training
|
||||||
resume_components = [p for p in sourced_components if p not in frozen_components]
|
resume_components = [p for p in sourced if p not in frozen_components]
|
||||||
logger.info(f"Pipeline: {nlp.pipe_names}")
|
logger.info(f"Pipeline: {nlp.pipe_names}")
|
||||||
if resume_components:
|
if resume_components:
|
||||||
with nlp.select_pipes(enable=resume_components):
|
with nlp.select_pipes(enable=resume_components):
|
||||||
|
@ -68,10 +69,11 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
|
logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
|
||||||
# Detect components with listeners that are not frozen consistently
|
# Detect components with listeners that are not frozen consistently
|
||||||
for name, proc in nlp.pipeline:
|
for name, proc in nlp.pipeline:
|
||||||
if getattr(proc, "listening_components", None):
|
if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer
|
||||||
for listener in proc.listening_components:
|
for listener in proc.listening_components:
|
||||||
if listener in frozen_components and name not in frozen_components:
|
if listener in frozen_components and name not in frozen_components:
|
||||||
logger.warning(Warnings.W087.format(name=name, listener=listener))
|
logger.warning(Warnings.W087.format(name=name, listener=listener))
|
||||||
|
# We always check this regardless, in case user freezes tok2vec
|
||||||
if listener not in frozen_components and name in frozen_components:
|
if listener not in frozen_components and name in frozen_components:
|
||||||
logger.warning(Warnings.W086.format(name=name, listener=listener))
|
logger.warning(Warnings.W086.format(name=name, listener=listener))
|
||||||
return nlp
|
return nlp
|
||||||
|
@ -173,18 +175,6 @@ def init_tok2vec(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]:
|
|
||||||
"""RETURNS (List[str]): All sourced components in the original config,
|
|
||||||
e.g. {"source": "en_core_web_sm"}. If the config contains a key
|
|
||||||
"factory", we assume it refers to a component factory.
|
|
||||||
"""
|
|
||||||
return [
|
|
||||||
name
|
|
||||||
for name, cfg in config.get("components", {}).items()
|
|
||||||
if "factory" not in cfg and "source" in cfg
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def convert_vectors(
|
def convert_vectors(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
vectors_loc: Optional[Path],
|
vectors_loc: Optional[Path],
|
||||||
|
|
|
@ -8,7 +8,7 @@ import re
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import thinc
|
import thinc
|
||||||
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
|
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
|
||||||
from thinc.api import ConfigValidationError
|
from thinc.api import ConfigValidationError, Model
|
||||||
import functools
|
import functools
|
||||||
import itertools
|
import itertools
|
||||||
import numpy.random
|
import numpy.random
|
||||||
|
@ -434,6 +434,20 @@ def load_model_from_config(
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
|
def get_sourced_components(
|
||||||
|
config: Union[Dict[str, Any], Config]
|
||||||
|
) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""RETURNS (List[str]): All sourced components in the original config,
|
||||||
|
e.g. {"source": "en_core_web_sm"}. If the config contains a key
|
||||||
|
"factory", we assume it refers to a component factory.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
name: cfg
|
||||||
|
for name, cfg in config.get("components", {}).items()
|
||||||
|
if "factory" not in cfg and "source" in cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[Any]:
|
def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[Any]:
|
||||||
"""Resolve one or more "dot notation" names, e.g. corpora.train.
|
"""Resolve one or more "dot notation" names, e.g. corpora.train.
|
||||||
The paths could point anywhere into the config, so we don't know which
|
The paths could point anywhere into the config, so we don't know which
|
||||||
|
@ -738,6 +752,24 @@ def get_package_path(name: str) -> Path:
|
||||||
return Path(pkg.__file__).parent
|
return Path(pkg.__file__).parent
|
||||||
|
|
||||||
|
|
||||||
|
def replace_model_node(model: Model, target: Model, replacement: Model) -> None:
|
||||||
|
"""Replace a node within a model with a new one, updating refs.
|
||||||
|
|
||||||
|
model (Model): The parent model.
|
||||||
|
target (Model): The target node.
|
||||||
|
replacement (Model): The node to replace the target with.
|
||||||
|
"""
|
||||||
|
# Place the node into the sublayers
|
||||||
|
for node in model.walk():
|
||||||
|
if target in node.layers:
|
||||||
|
node.layers[node.layers.index(target)] = replacement
|
||||||
|
# Now fix any node references
|
||||||
|
for node in model.walk():
|
||||||
|
for ref_name in node.ref_names:
|
||||||
|
if node.maybe_get_ref(ref_name) is target:
|
||||||
|
node.set_ref(ref_name, replacement)
|
||||||
|
|
||||||
|
|
||||||
def split_command(command: str) -> List[str]:
|
def split_command(command: str) -> List[str]:
|
||||||
"""Split a string command using shlex. Handles platform compatibility.
|
"""Split a string command using shlex. Handles platform compatibility.
|
||||||
|
|
||||||
|
@ -1279,6 +1311,25 @@ def dot_to_object(config: Config, section: str):
|
||||||
return component
|
return component
|
||||||
|
|
||||||
|
|
||||||
|
def set_dot_to_object(config: Config, section: str, value: Any) -> None:
|
||||||
|
"""Update a config at a given position from a dot notation.
|
||||||
|
|
||||||
|
config (Config): The config.
|
||||||
|
section (str): The dot notation of the section in the config.
|
||||||
|
value (Any): The value to set in the config.
|
||||||
|
"""
|
||||||
|
component = config
|
||||||
|
parts = section.split(".")
|
||||||
|
for i, item in enumerate(parts):
|
||||||
|
try:
|
||||||
|
if i == len(parts) - 1:
|
||||||
|
component[item] = value
|
||||||
|
else:
|
||||||
|
component = component[item]
|
||||||
|
except (KeyError, TypeError):
|
||||||
|
raise KeyError(Errors.E952.format(name=section)) from None
|
||||||
|
|
||||||
|
|
||||||
def walk_dict(
|
def walk_dict(
|
||||||
node: Dict[str, Any], parent: List[str] = []
|
node: Dict[str, Any], parent: List[str] = []
|
||||||
) -> Iterator[Tuple[List[str], Any]]:
|
) -> Iterator[Tuple[List[str], Any]]:
|
||||||
|
@ -1443,5 +1494,6 @@ def _pipe(docs, proc, name, default_error_handler, kwargs):
|
||||||
def raise_error(proc_name, proc, docs, e):
|
def raise_error(proc_name, proc, docs, e):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def ignore_error(proc_name, proc, docs, e):
|
def ignore_error(proc_name, proc, docs, e):
|
||||||
pass
|
pass
|
||||||
|
|
|
@ -833,6 +833,51 @@ token.ent_iob, token.ent_type
|
||||||
| `pretty` | Pretty-print the results as a table. Defaults to `False`. ~~bool~~ |
|
| `pretty` | Pretty-print the results as a table. Defaults to `False`. ~~bool~~ |
|
||||||
| **RETURNS** | Dictionary containing the pipe analysis, keyed by `"summary"` (component meta by pipe), `"problems"` (attribute names by pipe) and `"attrs"` (pipes that assign and require an attribute, keyed by attribute). ~~Optional[Dict[str, Any]]~~ |
|
| **RETURNS** | Dictionary containing the pipe analysis, keyed by `"summary"` (component meta by pipe), `"problems"` (attribute names by pipe) and `"attrs"` (pipes that assign and require an attribute, keyed by attribute). ~~Optional[Dict[str, Any]]~~ |
|
||||||
|
|
||||||
|
## Language.replace_listeners {#replace_listeners tag="method" new="3"}
|
||||||
|
|
||||||
|
Find [listener layers](/usage/embeddings-transformers#embedding-layers)
|
||||||
|
(connecting to a shared token-to-vector embedding component) of a given pipeline
|
||||||
|
component model and replace them with a standalone copy of the token-to-vector
|
||||||
|
layer. The listener layer allows other components to connect to a shared
|
||||||
|
token-to-vector embedding component like [`Tok2Vec`](/api/tok2vec) or
|
||||||
|
[`Transformer`](/api/transformer). Replacing listeners can be useful when
|
||||||
|
training a pipeline with components sourced from an existing pipeline: if
|
||||||
|
multiple components (e.g. tagger, parser, NER) listen to the same
|
||||||
|
token-to-vector component, but some of them are frozen and not updated, their
|
||||||
|
performance may degrade significally as the token-to-vector component is updated
|
||||||
|
with new data. To prevent this, listeners can be replaced with a standalone
|
||||||
|
token-to-vector layer that is owned by the component and doesn't change if the
|
||||||
|
component isn't updated.
|
||||||
|
|
||||||
|
This method is typically not called directly and only executed under the hood
|
||||||
|
when loading a config with
|
||||||
|
[sourced components](/usage/training#config-components) that define
|
||||||
|
`replace_listeners`.
|
||||||
|
|
||||||
|
> ```python
|
||||||
|
> ### Example
|
||||||
|
> nlp = spacy.load("en_core_web_sm")
|
||||||
|
> nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"])
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> ```ini
|
||||||
|
> ### config.cfg (excerpt)
|
||||||
|
> [training]
|
||||||
|
> frozen_components = ["tagger"]
|
||||||
|
>
|
||||||
|
> [components]
|
||||||
|
>
|
||||||
|
> [components.tagger]
|
||||||
|
> source = "en_core_web_sm"
|
||||||
|
> replace_listeners = ["model.tok2vec"]
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `tok2vec_name` | Name of the token-to-vector component, typically `"tok2vec"` or `"transformer"`.~~str~~ |
|
||||||
|
| `pipe_name` | Name of pipeline component to replace listeners for. ~~str~~ |
|
||||||
|
| `listeners` | The paths to the listeners, relative to the component config, e.g. `["model.tok2vec"]`. Typically, implementations will only connect to one tok2vec component, `model.tok2vec`, but in theory, custom models can use multiple listeners. The value here can either be an empty list to not replace any listeners, or a _complete_ list of the paths to all listener layers used by the model that should be replaced.~~Iterable[str]~~ |
|
||||||
|
|
||||||
## Language.meta {#meta tag="property"}
|
## Language.meta {#meta tag="property"}
|
||||||
|
|
||||||
Meta data for the `Language` class, including name, version, data sources,
|
Meta data for the `Language` class, including name, version, data sources,
|
||||||
|
|
|
@ -419,13 +419,29 @@ pipeline = ["parser", "ner", "textcat", "custom"]
|
||||||
frozen_components = ["parser", "custom"]
|
frozen_components = ["parser", "custom"]
|
||||||
```
|
```
|
||||||
|
|
||||||
<Infobox variant="warning" title="Shared Tok2Vec layer">
|
<Infobox variant="warning" title="Shared Tok2Vec listener layer" id="config-components-listeners">
|
||||||
|
|
||||||
When the components in your pipeline
|
When the components in your pipeline
|
||||||
[share an embedding layer](/usage/embeddings-transformers#embedding-layers), the
|
[share an embedding layer](/usage/embeddings-transformers#embedding-layers), the
|
||||||
**performance** of your frozen component will be **degraded** if you continue training
|
**performance** of your frozen component will be **degraded** if you continue
|
||||||
other layers with the same underlying `Tok2Vec` instance. As a rule of thumb,
|
training other layers with the same underlying `Tok2Vec` instance. As a rule of
|
||||||
ensure that your frozen components are truly **independent** in the pipeline.
|
thumb, ensure that your frozen components are truly **independent** in the
|
||||||
|
pipeline.
|
||||||
|
|
||||||
|
To automatically replace a shared token-to-vector listener with an independent
|
||||||
|
copy of the token-to-vector layer, you can use the `replace_listeners` setting
|
||||||
|
of a sourced component, pointing to the listener layer(s) in the config. For
|
||||||
|
more details on how this works under the hood, see
|
||||||
|
[`Language.replace_listeners`](/api/language#replace_listeners).
|
||||||
|
|
||||||
|
```ini
|
||||||
|
[training]
|
||||||
|
frozen_components = ["tagger"]
|
||||||
|
|
||||||
|
[components.tagger]
|
||||||
|
source = "en_core_web_sm"
|
||||||
|
replace_listeners = ["model.tok2vec"]
|
||||||
|
```
|
||||||
|
|
||||||
</Infobox>
|
</Infobox>
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user