From 911dfcccfcdfe1e9effc38b013f0266dbbd79afb Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 15:57:04 +1100 Subject: [PATCH 01/15] Add option to replace listeners for sourced components --- spacy/language.py | 43 +++++++++++++++++++++++++++- spacy/tests/pipeline/test_tok2vec.py | 32 ++++++++++++++++++--- spacy/tests/test_misc.py | 19 ++++++++++++ spacy/training/initialize.py | 28 ++++++++++++------ spacy/util.py | 39 ++++++++++++++++++++++++- 5 files changed, 146 insertions(+), 15 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 6e617e31c..7749ba360 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -8,7 +8,7 @@ from contextlib import contextmanager from copy import deepcopy from pathlib import Path import warnings -from thinc.api import Model, get_current_ops, Config, Optimizer +from thinc.api import get_current_ops, Config, Optimizer import srsly import multiprocessing as mp from itertools import chain, cycle @@ -670,6 +670,47 @@ class Language: self._pipe_configs[name] = filled return resolved[factory_name] + def replace_listeners( + self, + tok2vec_name: str, + pipe_name: str, + listeners: Iterable[str] = SimpleFrozenList(), + ): + if tok2vec_name not in self.pipe_names: + raise ValueError # TODO: + if pipe_name not in self.pipe_names: + raise ValueError # TODO: + tok2vec = self.get_pipe(tok2vec_name) + tok2vec_cfg = self.get_pipe_config(tok2vec_name) + if ( + not hasattr(tok2vec, "model") + or not hasattr(tok2vec, "listener_map") + or "model" not in tok2vec_cfg + ): + raise ValueError # TODO: likely bug in spaCy if this happens + pipe_listeners = tok2vec.listener_map.get(pipe_name, []) + pipe_cfg = self._pipe_configs[pipe_name] + if listeners: + util.logger.debug(f"Replacing listeners of component '{pipe_name}'") + if len(listeners) != len(pipe_listeners): + # The number of listeners defined in the component model doesn't + # match the listeners to replace, so we won't be able to update + # the nodes and generate a matching config + raise ValueError(f"{listeners}, {pipe_listeners}") # TODO: + pipe = self.get_pipe(pipe_name) + # Go over the listener layers and replace them + for listener in pipe_listeners: + util.replace_model_node(pipe.model, listener, tok2vec.model.copy()) + # Update the config accordingly by coping the tok2vec model to all + # sections defined in the listener paths + for listener_path in listeners: + # Check if the path actually exists in the config + try: + util.dot_to_object(pipe_cfg, listener_path) + except KeyError: + raise ValueError # TODO: + util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"]) + def create_pipe_from_source( self, source_name: str, source: "Language", *, name: str ) -> Tuple[Callable[[Doc], Doc], str]: diff --git a/spacy/tests/pipeline/test_tok2vec.py b/spacy/tests/pipeline/test_tok2vec.py index 90052a9c8..56037e4b8 100644 --- a/spacy/tests/pipeline/test_tok2vec.py +++ b/spacy/tests/pipeline/test_tok2vec.py @@ -1,5 +1,4 @@ import pytest - from spacy.ml.models.tok2vec import build_Tok2Vec_model from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder @@ -9,12 +8,11 @@ from spacy.tokens import Doc from spacy.training import Example from spacy import util from spacy.lang.en import English -from ..util import get_batch - from thinc.api import Config - from numpy.testing import assert_equal +from ..util import get_batch + def test_empty_doc(): width = 128 @@ -187,3 +185,29 @@ def test_tok2vec_listener_callback(): Y, get_dX = tagger.model.begin_update(docs) # assure that the backprop call works (and doesn't hit a 'None' callback) assert get_dX(Y) is not None + + +def test_replace_listeners(): + orig_config = Config().from_str(cfg_string) + nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True) + examples = [Example.from_dict(nlp.make_doc("x y"), {"tags": ["V", "Z"]})] + nlp.initialize(lambda: examples) + tok2vec = nlp.get_pipe("tok2vec") + tagger = nlp.get_pipe("tagger") + assert isinstance(tagger.model.layers[0], Tok2VecListener) + assert tok2vec.listener_map["tagger"][0] == tagger.model.layers[0] + assert nlp.config["components"]["tok2vec"]["model"]["@architectures"] == "spacy.Tok2Vec.v2" + assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["@architectures"] == "spacy.Tok2VecListener.v1" + nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"]) + assert not isinstance(tagger.model.layers[0], Tok2VecListener) + t2v_cfg = nlp.config["components"]["tok2vec"]["model"] + assert t2v_cfg["@architectures"] == "spacy.Tok2Vec.v2" + assert nlp.config["components"]["tagger"]["model"]["tok2vec"] == t2v_cfg + with pytest.raises(ValueError): + nlp.replace_listeners("invalid", "tagger", ["model.tok2vec"]) + with pytest.raises(ValueError): + nlp.replace_listeners("tok2vec", "parser", ["model.tok2vec"]) + with pytest.raises(ValueError): + nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"]) + with pytest.raises(ValueError): + nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"]) diff --git a/spacy/tests/test_misc.py b/spacy/tests/test_misc.py index bdb2b9752..e694baa40 100644 --- a/spacy/tests/test_misc.py +++ b/spacy/tests/test_misc.py @@ -205,6 +205,25 @@ def test_dot_to_dict(dot_notation, expected): assert util.dict_to_dot(result) == dot_notation +def test_set_dot_to_object(): + config = {"foo": {"bar": 1, "baz": {"x": "y"}}, "test": {"a": {"b": "c"}}} + with pytest.raises(KeyError): + util.set_dot_to_object(config, "foo.bar.baz", 100) + with pytest.raises(KeyError): + util.set_dot_to_object(config, "hello.world", 100) + with pytest.raises(KeyError): + util.set_dot_to_object(config, "test.a.b.c", 100) + util.set_dot_to_object(config, "foo.bar", 100) + assert config["foo"]["bar"] == 100 + util.set_dot_to_object(config, "foo.baz.x", {"hello": "world"}) + assert config["foo"]["baz"]["x"]["hello"] == "world" + assert config["test"]["a"]["b"] == "c" + util.set_dot_to_object(config, "foo", 123) + assert config["foo"] == 123 + util.set_dot_to_object(config, "test", "hello") + assert dict(config) == {"foo": 123, "test": "hello"} + + @pytest.mark.parametrize( "doc_sizes, expected_batches", [ diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py index 42bab6b4f..4cf8fa354 100644 --- a/spacy/training/initialize.py +++ b/spacy/training/initialize.py @@ -1,4 +1,4 @@ -from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING +from typing import Union, Dict, Optional, Any, IO, TYPE_CHECKING from thinc.api import Config, fix_random_seed, set_gpu_allocator from thinc.api import ConfigValidationError from pathlib import Path @@ -33,7 +33,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": if use_gpu >= 0 and allocator: set_gpu_allocator(allocator) # Use original config here before it's resolved to functions - sourced_components = get_sourced_components(config) + sourced = get_sourced_components(config) nlp = load_model_from_config(raw_config, auto_fill=True) logger.info("Set up nlp object from config") config = nlp.config.interpolate() @@ -57,7 +57,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": # Components that shouldn't be updated during training frozen_components = T["frozen_components"] # Sourced components that require resume_training - resume_components = [p for p in sourced_components if p not in frozen_components] + resume_components = [p for p in sourced if p not in frozen_components] logger.info(f"Pipeline: {nlp.pipe_names}") if resume_components: with nlp.select_pipes(enable=resume_components): @@ -68,10 +68,18 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": logger.info(f"Initialized pipeline components: {nlp.pipe_names}") # Detect components with listeners that are not frozen consistently for name, proc in nlp.pipeline: - if getattr(proc, "listening_components", None): + if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer for listener in proc.listening_components: - if listener in frozen_components and name not in frozen_components: + # If it's a component sourced from another pipeline, we check if + # the tok2vec listeners should be replaced with standalone tok2vec + # models (e.g. so component can be frozen without its performance + # degrading when other components/tok2vec are updated) + paths = sourced.get(listener, {}).get("replace_listeners", []) + if paths: + nlp.replace_listeners(name, listener, paths) + elif listener in frozen_components and name not in frozen_components: logger.warning(Warnings.W087.format(name=name, listener=listener)) + # We always check this regardless, in case user freezes tok2vec if listener not in frozen_components and name in frozen_components: logger.warning(Warnings.W086.format(name=name, listener=listener)) return nlp @@ -173,16 +181,18 @@ def init_tok2vec( return False -def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]: +def get_sourced_components( + config: Union[Dict[str, Any], Config] +) -> Dict[str, Dict[str, Any]]: """RETURNS (List[str]): All sourced components in the original config, e.g. {"source": "en_core_web_sm"}. If the config contains a key "factory", we assume it refers to a component factory. """ - return [ - name + return { + name: cfg for name, cfg in config.get("components", {}).items() if "factory" not in cfg and "source" in cfg - ] + } def convert_vectors( diff --git a/spacy/util.py b/spacy/util.py index 77aa712d1..dbd862687 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -8,7 +8,7 @@ import re from pathlib import Path import thinc from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer -from thinc.api import ConfigValidationError +from thinc.api import ConfigValidationError, Model import functools import itertools import numpy.random @@ -738,6 +738,24 @@ def get_package_path(name: str) -> Path: return Path(pkg.__file__).parent +def replace_model_node(model: Model, target: Model, replacement: Model) -> None: + """Replace a node within a model with a new one, updating refs. + + model (Model): The parent model. + target (Model): The target node. + replacement (Model): The node to replace the target with. + """ + # Place the node into the sublayers + for node in model.walk(): + if target in node.layers: + node.layers[node.layers.index(target)] = replacement + # Now fix any node references + for node in model.walk(): + for ref_name in node.ref_names: + if node.maybe_get_ref(ref_name) is target: + node.set_ref(ref_name, replacement) + + def split_command(command: str) -> List[str]: """Split a string command using shlex. Handles platform compatibility. @@ -1279,6 +1297,25 @@ def dot_to_object(config: Config, section: str): return component +def set_dot_to_object(config: Config, section: str, value: Any) -> None: + """Update a config at a given position from a dot notation. + + config (Config): The config. + section (str): The dot notation of the section in the config. + value (Any): The value to set in the config. + """ + component = config + parts = section.split(".") + for i, item in enumerate(parts): + try: + if i == len(parts) - 1: + component[item] = value + else: + component = component[item] + except (KeyError, TypeError): + raise KeyError(Errors.E952.format(name=section)) from None + + def walk_dict( node: Dict[str, Any], parent: List[str] = [] ) -> Iterator[Tuple[List[str], Any]]: From bbb94b37c61d090e307da16519390839513dfc52 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 16:27:49 +1100 Subject: [PATCH 02/15] Update error handling and docstring --- spacy/errors.py | 13 ++++++++++++- spacy/language.py | 46 ++++++++++++++++++++++++++++++++++++++++------ 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index a50e986ac..4d66fd0ef 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -475,7 +475,18 @@ class Errors: "issue tracker: http://github.com/explosion/spaCy/issues") # TODO: fix numbering after merging develop into master - E890 = ("Can not add the alias '{alias}' to the Knowledge base. " + E886 = ("Can't replace {name} -> {tok2vec} listeners: path '{path}' not " + "found in config for component '{name}'.") + E887 = ("Can't replace {name} -> {tok2vec} listeners: the paths to replace " + "({paths}) don't match the available listeners in the model ({n_listeners}).") + E888 = ("Can't replace listeners for '{name}' ({pipe}): invalid upstream " + "component that doesn't seem to support listeners. Expected Tok2Vec " + "or Transformer component. If you didn't call nlp.replace_listeners " + "manually, this is likely a bug in spaCy.") + E889 = ("Can't replace listeners of component '{name}' because it's not " + "in the pipeline. Available components: {opts}. If you didn't call " + "nlp.replace_listeners manually, this is likely a bug in spaCy.") + E890 = ("Cannot add the alias '{alias}' to the Knowledge base. " "Each alias should be a meaningful string.") E891 = ("Alias '{alias}' could not be added to the Knowledge base. " "This is likely a bug in spaCy.") diff --git a/spacy/language.py b/spacy/language.py index cc079af62..f0d311e5d 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -676,11 +676,36 @@ class Language: tok2vec_name: str, pipe_name: str, listeners: Iterable[str] = SimpleFrozenList(), - ): + ) -> None: + """Find listener layers (connecting to a token-to-vector embedding + component) of a given pipeline component model and replace + them with a standalone copy of the token-to-vector layer. This can be + useful when training a pipeline with components sourced from an existing + pipeline: if multiple components (e.g. tagger, parser, NER) listen to + the same tok2vec component, but some of them are frozen and not updated, + their performance may degrade significally as the tok2vec component is + updated with new data. To prevent this, listeners can be replaced with + a standalone tok2vec layer that is owned by the component and doesn't + change if the component isn't updated. + + tok2vec_name (str): Name of the token-to-vector component, typically + "tok2vec" or "transformer". + pipe_name (str): Name of pipeline component to replace listeners for. + listeners (Iterable[str]): The paths to the listeners, relative to the + component config, e.g. ["model.tok2vec"]. Typically, implementations + will only connect to one tok2vec component, [model.tok2vec], but in + theory, custom models can use multiple listeners. The value here can + either be an empty list to not replace any listeners, or a complete + (!) list of the paths to all listener layers used by the model. + + DOCS: https://nightly.spacy.io/api/language#replace_listeners + """ if tok2vec_name not in self.pipe_names: - raise ValueError # TODO: + err = Errors.E889.format(name=tok2vec_name, opts=", ".join(self.pipe_names)) + raise ValueError(err) if pipe_name not in self.pipe_names: - raise ValueError # TODO: + err = Errors.E889.format(name=pipe_name, opts=", ".join(self.pipe_names)) + raise ValueError(err) tok2vec = self.get_pipe(tok2vec_name) tok2vec_cfg = self.get_pipe_config(tok2vec_name) if ( @@ -688,7 +713,7 @@ class Language: or not hasattr(tok2vec, "listener_map") or "model" not in tok2vec_cfg ): - raise ValueError # TODO: likely bug in spaCy if this happens + raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec))) pipe_listeners = tok2vec.listener_map.get(pipe_name, []) pipe_cfg = self._pipe_configs[pipe_name] if listeners: @@ -697,7 +722,13 @@ class Language: # The number of listeners defined in the component model doesn't # match the listeners to replace, so we won't be able to update # the nodes and generate a matching config - raise ValueError(f"{listeners}, {pipe_listeners}") # TODO: + err = Errors.E887.format( + name=pipe_name, + tok2vec=tok2vec_name, + paths=listeners, + n_listeners=len(pipe_listeners), + ) + raise ValueError(err) pipe = self.get_pipe(pipe_name) # Go over the listener layers and replace them for listener in pipe_listeners: @@ -709,7 +740,10 @@ class Language: try: util.dot_to_object(pipe_cfg, listener_path) except KeyError: - raise ValueError # TODO: + err = Errors.E886.format( + name=pipe_name, tok2vec=tok2vec_name, path=listener_path + ) + raise ValueError(err) util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"]) def create_pipe_from_source( From 8c15d1daeca1bb05255788a48a22e24a414820a3 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 18:24:47 +1100 Subject: [PATCH 03/15] Update and validate config first and exit early if paths don't exist --- spacy/language.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index f0d311e5d..5a2a0cd65 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -730,9 +730,6 @@ class Language: ) raise ValueError(err) pipe = self.get_pipe(pipe_name) - # Go over the listener layers and replace them - for listener in pipe_listeners: - util.replace_model_node(pipe.model, listener, tok2vec.model.copy()) # Update the config accordingly by coping the tok2vec model to all # sections defined in the listener paths for listener_path in listeners: @@ -745,6 +742,9 @@ class Language: ) raise ValueError(err) util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"]) + # Go over the listener layers and replace them + for listener in pipe_listeners: + util.replace_model_node(pipe.model, listener, tok2vec.model.copy()) def create_pipe_from_source( self, source_name: str, source: "Language", *, name: str From 44b5542d144b26eaa3c69b7b7f26162c57060458 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 18:42:41 +1100 Subject: [PATCH 04/15] Change method order --- spacy/language.py | 150 +++++++++++++++++++++++----------------------- 1 file changed, 75 insertions(+), 75 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 5a2a0cd65..12b319fd3 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -671,81 +671,6 @@ class Language: self._pipe_configs[name] = filled return resolved[factory_name] - def replace_listeners( - self, - tok2vec_name: str, - pipe_name: str, - listeners: Iterable[str] = SimpleFrozenList(), - ) -> None: - """Find listener layers (connecting to a token-to-vector embedding - component) of a given pipeline component model and replace - them with a standalone copy of the token-to-vector layer. This can be - useful when training a pipeline with components sourced from an existing - pipeline: if multiple components (e.g. tagger, parser, NER) listen to - the same tok2vec component, but some of them are frozen and not updated, - their performance may degrade significally as the tok2vec component is - updated with new data. To prevent this, listeners can be replaced with - a standalone tok2vec layer that is owned by the component and doesn't - change if the component isn't updated. - - tok2vec_name (str): Name of the token-to-vector component, typically - "tok2vec" or "transformer". - pipe_name (str): Name of pipeline component to replace listeners for. - listeners (Iterable[str]): The paths to the listeners, relative to the - component config, e.g. ["model.tok2vec"]. Typically, implementations - will only connect to one tok2vec component, [model.tok2vec], but in - theory, custom models can use multiple listeners. The value here can - either be an empty list to not replace any listeners, or a complete - (!) list of the paths to all listener layers used by the model. - - DOCS: https://nightly.spacy.io/api/language#replace_listeners - """ - if tok2vec_name not in self.pipe_names: - err = Errors.E889.format(name=tok2vec_name, opts=", ".join(self.pipe_names)) - raise ValueError(err) - if pipe_name not in self.pipe_names: - err = Errors.E889.format(name=pipe_name, opts=", ".join(self.pipe_names)) - raise ValueError(err) - tok2vec = self.get_pipe(tok2vec_name) - tok2vec_cfg = self.get_pipe_config(tok2vec_name) - if ( - not hasattr(tok2vec, "model") - or not hasattr(tok2vec, "listener_map") - or "model" not in tok2vec_cfg - ): - raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec))) - pipe_listeners = tok2vec.listener_map.get(pipe_name, []) - pipe_cfg = self._pipe_configs[pipe_name] - if listeners: - util.logger.debug(f"Replacing listeners of component '{pipe_name}'") - if len(listeners) != len(pipe_listeners): - # The number of listeners defined in the component model doesn't - # match the listeners to replace, so we won't be able to update - # the nodes and generate a matching config - err = Errors.E887.format( - name=pipe_name, - tok2vec=tok2vec_name, - paths=listeners, - n_listeners=len(pipe_listeners), - ) - raise ValueError(err) - pipe = self.get_pipe(pipe_name) - # Update the config accordingly by coping the tok2vec model to all - # sections defined in the listener paths - for listener_path in listeners: - # Check if the path actually exists in the config - try: - util.dot_to_object(pipe_cfg, listener_path) - except KeyError: - err = Errors.E886.format( - name=pipe_name, tok2vec=tok2vec_name, path=listener_path - ) - raise ValueError(err) - util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"]) - # Go over the listener layers and replace them - for listener in pipe_listeners: - util.replace_model_node(pipe.model, listener, tok2vec.model.copy()) - def create_pipe_from_source( self, source_name: str, source: "Language", *, name: str ) -> Tuple[Callable[[Doc], Doc], str]: @@ -1748,6 +1673,81 @@ class Language: ) return nlp + def replace_listeners( + self, + tok2vec_name: str, + pipe_name: str, + listeners: Iterable[str] = SimpleFrozenList(), + ) -> None: + """Find listener layers (connecting to a token-to-vector embedding + component) of a given pipeline component model and replace + them with a standalone copy of the token-to-vector layer. This can be + useful when training a pipeline with components sourced from an existing + pipeline: if multiple components (e.g. tagger, parser, NER) listen to + the same tok2vec component, but some of them are frozen and not updated, + their performance may degrade significally as the tok2vec component is + updated with new data. To prevent this, listeners can be replaced with + a standalone tok2vec layer that is owned by the component and doesn't + change if the component isn't updated. + + tok2vec_name (str): Name of the token-to-vector component, typically + "tok2vec" or "transformer". + pipe_name (str): Name of pipeline component to replace listeners for. + listeners (Iterable[str]): The paths to the listeners, relative to the + component config, e.g. ["model.tok2vec"]. Typically, implementations + will only connect to one tok2vec component, [model.tok2vec], but in + theory, custom models can use multiple listeners. The value here can + either be an empty list to not replace any listeners, or a complete + (!) list of the paths to all listener layers used by the model. + + DOCS: https://nightly.spacy.io/api/language#replace_listeners + """ + if tok2vec_name not in self.pipe_names: + err = Errors.E889.format(name=tok2vec_name, opts=", ".join(self.pipe_names)) + raise ValueError(err) + if pipe_name not in self.pipe_names: + err = Errors.E889.format(name=pipe_name, opts=", ".join(self.pipe_names)) + raise ValueError(err) + tok2vec = self.get_pipe(tok2vec_name) + tok2vec_cfg = self.get_pipe_config(tok2vec_name) + if ( + not hasattr(tok2vec, "model") + or not hasattr(tok2vec, "listener_map") + or "model" not in tok2vec_cfg + ): + raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec))) + pipe_listeners = tok2vec.listener_map.get(pipe_name, []) + pipe_cfg = self._pipe_configs[pipe_name] + if listeners: + util.logger.debug(f"Replacing listeners of component '{pipe_name}'") + if len(listeners) != len(pipe_listeners): + # The number of listeners defined in the component model doesn't + # match the listeners to replace, so we won't be able to update + # the nodes and generate a matching config + err = Errors.E887.format( + name=pipe_name, + tok2vec=tok2vec_name, + paths=listeners, + n_listeners=len(pipe_listeners), + ) + raise ValueError(err) + pipe = self.get_pipe(pipe_name) + # Update the config accordingly by coping the tok2vec model to all + # sections defined in the listener paths + for listener_path in listeners: + # Check if the path actually exists in the config + try: + util.dot_to_object(pipe_cfg, listener_path) + except KeyError: + err = Errors.E886.format( + name=pipe_name, tok2vec=tok2vec_name, path=listener_path + ) + raise ValueError(err) + util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"]) + # Go over the listener layers and replace them + for listener in pipe_listeners: + util.replace_model_node(pipe.model, listener, tok2vec.model.copy()) + def to_disk( self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() ) -> None: From 99842387cbe9a5ba8d38e292d7c3c3ff0df0db51 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 18:45:37 +1100 Subject: [PATCH 05/15] Remove default value --- spacy/language.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 12b319fd3..3f4755d92 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1674,10 +1674,7 @@ class Language: return nlp def replace_listeners( - self, - tok2vec_name: str, - pipe_name: str, - listeners: Iterable[str] = SimpleFrozenList(), + self, tok2vec_name: str, pipe_name: str, listeners: Iterable[str], ) -> None: """Find listener layers (connecting to a token-to-vector embedding component) of a given pipeline component model and replace From 99af9e7125652412bb4b7efcd20b372ad2a6c670 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 18:45:48 +1100 Subject: [PATCH 06/15] Update documentation --- website/docs/api/language.md | 45 ++++++++++++++++++++++++++++++++++ website/docs/usage/training.md | 24 +++++++++++++++--- 2 files changed, 65 insertions(+), 4 deletions(-) diff --git a/website/docs/api/language.md b/website/docs/api/language.md index 280a2011f..15ebb255c 100644 --- a/website/docs/api/language.md +++ b/website/docs/api/language.md @@ -833,6 +833,51 @@ token.ent_iob, token.ent_type | `pretty` | Pretty-print the results as a table. Defaults to `False`. ~~bool~~ | | **RETURNS** | Dictionary containing the pipe analysis, keyed by `"summary"` (component meta by pipe), `"problems"` (attribute names by pipe) and `"attrs"` (pipes that assign and require an attribute, keyed by attribute). ~~Optional[Dict[str, Any]]~~ | +## Language.replace_listeners {#replace_listeners tag="method" new="3"} + +Find [listener layers](/usage/embeddings-transformers#embedding-layers) +(connecting to a shared token-to-vector embedding component) of a given pipeline +component model and replace them with a standalone copy of the token-to-vector +layer. The listener layer allows other components to connect to a shared +token-to-vector embedding component like [`Tok2Vec`](/api/tok2vec) or +[`Transformer`](/api/transformer). Replacing listeners can be useful when +training a pipeline with components sourced from an existing pipeline: if +multiple components (e.g. tagger, parser, NER) listen to the same +token-to-vector component, but some of them are frozen and not updated, their +performance may degrade significally as the token-to-vector component is updated +with new data. To prevent this, listeners can be replaced with a standalone +token-to-vector layer that is owned by the component and doesn't change if the +component isn't updated. + +This method is typically not called directly and only executed under the hood +when loading a config with +[sourced components](/usage/training#config-components) that define +`replace_listeners`. + +> ```python +> ### Example +> nlp = spacy.load("en_core_web_sm") +> nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"]) +> ``` +> +> ```ini +> ### config.cfg (excerpt) +> [training] +> frozen_components = ["tagger"] +> +> [components] +> +> [components.tagger] +> source = "en_core_web_sm" +> replace_listeners = ["model.tok2vec"] +> ``` + +| Name | Description | +| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `tok2vec_name` | Name of the token-to-vector component, typically `"tok2vec"` or `"transformer"`.~~str~~ | +| `pipe_name` | Name of pipeline component to replace listeners for. ~~str~~ | +| `listeners` | The paths to the listeners, relative to the component config, e.g. `["model.tok2vec"]`. Typically, implementations will only connect to one tok2vec component, `model.tok2vec`, but in theory, custom models can use multiple listeners. The value here can either be an empty list to not replace any listeners, or a _complete_ list of the paths to all listener layers used by the model.~~Iterable[str]~~ | + ## Language.meta {#meta tag="property"} Meta data for the `Language` class, including name, version, data sources, diff --git a/website/docs/usage/training.md b/website/docs/usage/training.md index 16b2b0f5a..b8e2c512d 100644 --- a/website/docs/usage/training.md +++ b/website/docs/usage/training.md @@ -419,13 +419,29 @@ pipeline = ["parser", "ner", "textcat", "custom"] frozen_components = ["parser", "custom"] ``` - + When the components in your pipeline [share an embedding layer](/usage/embeddings-transformers#embedding-layers), the -**performance** of your frozen component will be **degraded** if you continue training -other layers with the same underlying `Tok2Vec` instance. As a rule of thumb, -ensure that your frozen components are truly **independent** in the pipeline. +**performance** of your frozen component will be **degraded** if you continue +training other layers with the same underlying `Tok2Vec` instance. As a rule of +thumb, ensure that your frozen components are truly **independent** in the +pipeline. + +To automatically replace a shared token-to-vector listener with an independent +copy of the token-to-vector layer, you can use the `replace_listeners` setting +of a sourced component, pointing to the listener layer(s) in the config. For +more details on how this works under the hood, see +[`Language.replace_listeners`](/api/language#replace_listeners). + +```ini +[training] +frozen_components = ["tagger"] + +[components.tagger] +source = "en_core_web_sm" +replace_listeners = ["model.tok2vec"] +``` From 0f3e3eedc2c68e923ac71e39f4fbbacc4d2b2e8b Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 19:36:38 +1100 Subject: [PATCH 07/15] Add Tok2vec.remove_listener --- spacy/pipeline/tok2vec.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index e6ed84530..eb6679834 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -82,6 +82,15 @@ class Tok2Vec(TrainablePipe): self.listener_map.setdefault(component_name, []) self.listener_map[component_name].append(listener) + def remove_listener(self, listener: "Tok2VecListener", component_name: str) -> None: + """Remove a listener for a downstream component. Usually internals.""" + if component_name in self.listener_map: + if listener in self.listener_map[component_name]: + self.listener_map[component_name].remove(listener) + # If no listeners are left, remove entry + if not self.listener_map[component_name]: + del self.listener_map[component_name] + def find_listeners(self, component) -> None: """Walk over a model of a processing component, looking for layers that are Tok2vecListener subclasses that have an upstream_name that matches From 325f47500d5265ad4240d2903601f178e51efc17 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 19:37:04 +1100 Subject: [PATCH 08/15] Move replacement logic to Language.from_config --- spacy/language.py | 13 +++++++++++++ spacy/training/initialize.py | 26 +++----------------------- spacy/util.py | 15 +++++++++++++++ 3 files changed, 31 insertions(+), 23 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 3f4755d92..372220119 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1629,6 +1629,7 @@ class Language: # Later we replace the component config with the raw config again. interpolated = filled.interpolate() if not filled.is_interpolated else filled pipeline = interpolated.get("components", {}) + sourced = util.get_sourced_components(interpolated) # If components are loaded from a source (existing models), we cache # them here so they're only loaded once source_nlps = {} @@ -1671,6 +1672,17 @@ class Language: raise ValueError( Errors.E942.format(name="pipeline_creation", value=type(nlp)) ) + # Detect components with listeners that are not frozen consistently + for name, proc in nlp.pipeline: + if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer + for listener in proc.listening_components: + # If it's a component sourced from another pipeline, we check if + # the tok2vec listeners should be replaced with standalone tok2vec + # models (e.g. so component can be frozen without its performance + # degrading when other components/tok2vec are updated) + paths = sourced.get(listener, {}).get("replace_listeners", []) + if paths: + nlp.replace_listeners(name, listener, paths) return nlp def replace_listeners( @@ -1744,6 +1756,7 @@ class Language: # Go over the listener layers and replace them for listener in pipe_listeners: util.replace_model_node(pipe.model, listener, tok2vec.model.copy()) + tok2vec.remove_listener(listener, pipe_name) def to_disk( self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py index 4cf8fa354..956e4ec37 100644 --- a/spacy/training/initialize.py +++ b/spacy/training/initialize.py @@ -14,7 +14,8 @@ from ..vectors import Vectors from ..errors import Errors, Warnings from ..schemas import ConfigSchemaTraining from ..util import registry, load_model_from_config, resolve_dot_names, logger -from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB +from ..util import load_model, ensure_path, get_sourced_components +from ..util import OOV_RANK, DEFAULT_OOV_PROB if TYPE_CHECKING: from ..language import Language # noqa: F401 @@ -70,14 +71,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language": for name, proc in nlp.pipeline: if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer for listener in proc.listening_components: - # If it's a component sourced from another pipeline, we check if - # the tok2vec listeners should be replaced with standalone tok2vec - # models (e.g. so component can be frozen without its performance - # degrading when other components/tok2vec are updated) - paths = sourced.get(listener, {}).get("replace_listeners", []) - if paths: - nlp.replace_listeners(name, listener, paths) - elif listener in frozen_components and name not in frozen_components: + if listener in frozen_components and name not in frozen_components: logger.warning(Warnings.W087.format(name=name, listener=listener)) # We always check this regardless, in case user freezes tok2vec if listener not in frozen_components and name in frozen_components: @@ -181,20 +175,6 @@ def init_tok2vec( return False -def get_sourced_components( - config: Union[Dict[str, Any], Config] -) -> Dict[str, Dict[str, Any]]: - """RETURNS (List[str]): All sourced components in the original config, - e.g. {"source": "en_core_web_sm"}. If the config contains a key - "factory", we assume it refers to a component factory. - """ - return { - name: cfg - for name, cfg in config.get("components", {}).items() - if "factory" not in cfg and "source" in cfg - } - - def convert_vectors( nlp: "Language", vectors_loc: Optional[Path], diff --git a/spacy/util.py b/spacy/util.py index 738856563..f1292e327 100644 --- a/spacy/util.py +++ b/spacy/util.py @@ -434,6 +434,20 @@ def load_model_from_config( return nlp +def get_sourced_components( + config: Union[Dict[str, Any], Config] +) -> Dict[str, Dict[str, Any]]: + """RETURNS (List[str]): All sourced components in the original config, + e.g. {"source": "en_core_web_sm"}. If the config contains a key + "factory", we assume it refers to a component factory. + """ + return { + name: cfg + for name, cfg in config.get("components", {}).items() + if "factory" not in cfg and "source" in cfg + } + + def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[Any]: """Resolve one or more "dot notation" names, e.g. corpora.train. The paths could point anywhere into the config, so we don't know which @@ -1480,5 +1494,6 @@ def _pipe(docs, proc, name, default_error_handler, kwargs): def raise_error(proc_name, proc, docs, e): raise e + def ignore_error(proc_name, proc, docs, e): pass From bc089b693c69b43d0c8197b6bf5aa49739edeb33 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 19:38:09 +1100 Subject: [PATCH 09/15] Update tests --- spacy/tests/pipeline/test_tok2vec.py | 110 +++++++++++++++++++++++++-- 1 file changed, 105 insertions(+), 5 deletions(-) diff --git a/spacy/tests/pipeline/test_tok2vec.py b/spacy/tests/pipeline/test_tok2vec.py index 56037e4b8..c6ac42dd2 100644 --- a/spacy/tests/pipeline/test_tok2vec.py +++ b/spacy/tests/pipeline/test_tok2vec.py @@ -6,12 +6,13 @@ from spacy.pipeline.tok2vec import Tok2Vec, Tok2VecListener from spacy.vocab import Vocab from spacy.tokens import Doc from spacy.training import Example +from spacy.training.initialize import init_nlp from spacy import util from spacy.lang.en import English from thinc.api import Config from numpy.testing import assert_equal -from ..util import get_batch +from ..util import get_batch, make_tempdir def test_empty_doc(): @@ -55,17 +56,17 @@ def test_tok2vec_batch_sizes(batch_size, width, embed_size): assert doc_vec.shape == (len(doc), width) -# fmt: off @pytest.mark.parametrize( "width,embed_arch,embed_config,encode_arch,encode_config", + # fmt: off [ (8, MultiHashEmbed, {"rows": [100, 100], "attrs": ["SHAPE", "LOWER"], "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 2}), (8, MultiHashEmbed, {"rows": [100, 20], "attrs": ["ORTH", "PREFIX"], "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 6}), (8, CharacterEmbed, {"rows": 100, "nM": 64, "nC": 8, "include_static_vectors": False}, MaxoutWindowEncoder, {"window_size": 1, "maxout_pieces": 3, "depth": 3}), (8, CharacterEmbed, {"rows": 100, "nM": 16, "nC": 2, "include_static_vectors": False}, MishWindowEncoder, {"window_size": 1, "depth": 3}), ], + # fmt: on ) -# fmt: on def test_tok2vec_configs(width, embed_arch, embed_config, encode_arch, encode_config): embed_config["width"] = width encode_config["width"] = width @@ -196,8 +197,14 @@ def test_replace_listeners(): tagger = nlp.get_pipe("tagger") assert isinstance(tagger.model.layers[0], Tok2VecListener) assert tok2vec.listener_map["tagger"][0] == tagger.model.layers[0] - assert nlp.config["components"]["tok2vec"]["model"]["@architectures"] == "spacy.Tok2Vec.v2" - assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["@architectures"] == "spacy.Tok2VecListener.v1" + assert ( + nlp.config["components"]["tok2vec"]["model"]["@architectures"] + == "spacy.Tok2Vec.v2" + ) + assert ( + nlp.config["components"]["tagger"]["model"]["tok2vec"]["@architectures"] + == "spacy.Tok2VecListener.v1" + ) nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"]) assert not isinstance(tagger.model.layers[0], Tok2VecListener) t2v_cfg = nlp.config["components"]["tok2vec"]["model"] @@ -211,3 +218,96 @@ def test_replace_listeners(): nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"]) with pytest.raises(ValueError): nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"]) + + +cfg_string_multi = """ + [nlp] + lang = "en" + pipeline = ["tok2vec","tagger", "ner"] + + [components] + + [components.tagger] + factory = "tagger" + + [components.tagger.model] + @architectures = "spacy.Tagger.v1" + nO = null + + [components.tagger.model.tok2vec] + @architectures = "spacy.Tok2VecListener.v1" + width = ${components.tok2vec.model.encode.width} + + [components.ner] + factory = "ner" + + [components.ner.model] + @architectures = "spacy.TransitionBasedParser.v2" + + [components.ner.model.tok2vec] + @architectures = "spacy.Tok2VecListener.v1" + width = ${components.tok2vec.model.encode.width} + + [components.tok2vec] + factory = "tok2vec" + + [components.tok2vec.model] + @architectures = "spacy.Tok2Vec.v2" + + [components.tok2vec.model.embed] + @architectures = "spacy.MultiHashEmbed.v1" + width = ${components.tok2vec.model.encode.width} + rows = [2000, 1000, 1000, 1000] + attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"] + include_static_vectors = false + + [components.tok2vec.model.encode] + @architectures = "spacy.MaxoutWindowEncoder.v2" + width = 96 + depth = 4 + window_size = 1 + maxout_pieces = 3 + """ + + +def test_replace_listeners_from_config(): + orig_config = Config().from_str(cfg_string_multi) + nlp = util.load_model_from_config(orig_config, auto_fill=True) + annots = {"tags": ["V", "Z"], "entities": [(0, 1, "A"), (1, 2, "B")]} + examples = [Example.from_dict(nlp.make_doc("x y"), annots)] + nlp.initialize(lambda: examples) + tok2vec = nlp.get_pipe("tok2vec") + tagger = nlp.get_pipe("tagger") + ner = nlp.get_pipe("ner") + assert tok2vec.listening_components == ["tagger", "ner"] + assert any(isinstance(node, Tok2VecListener) for node in ner.model.walk()) + assert any(isinstance(node, Tok2VecListener) for node in tagger.model.walk()) + with make_tempdir() as dir_path: + nlp.to_disk(dir_path) + base_model = str(dir_path) + new_config = { + "nlp": {"lang": "en", "pipeline": ["tok2vec", "tagger", "ner"]}, + "components": { + "tok2vec": {"source": base_model}, + "tagger": { + "source": base_model, + "replace_listeners": ["model.tok2vec"], + }, + "ner": {"source": base_model}, + }, + } + new_nlp = util.load_model_from_config(new_config, auto_fill=True) + new_nlp.initialize(lambda: examples) + tok2vec = new_nlp.get_pipe("tok2vec") + tagger = new_nlp.get_pipe("tagger") + ner = new_nlp.get_pipe("ner") + assert tok2vec.listening_components == ["ner"] + assert any(isinstance(node, Tok2VecListener) for node in ner.model.walk()) + assert not any(isinstance(node, Tok2VecListener) for node in tagger.model.walk()) + t2v_cfg = new_nlp.config["components"]["tok2vec"]["model"] + assert t2v_cfg["@architectures"] == "spacy.Tok2Vec.v2" + assert new_nlp.config["components"]["tagger"]["model"]["tok2vec"] == t2v_cfg + assert ( + new_nlp.config["components"]["ner"]["model"]["tok2vec"]["@architectures"] + == "spacy.Tok2VecListener.v1" + ) From e766e8c56d2fe6ddf9800659cd408d8f4b7141a6 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 21:41:17 +1100 Subject: [PATCH 10/15] Apply suggestions from code review Co-authored-by: Sofie Van Landeghem --- spacy/language.py | 2 +- website/docs/api/language.md | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 372220119..92710024f 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1741,7 +1741,7 @@ class Language: ) raise ValueError(err) pipe = self.get_pipe(pipe_name) - # Update the config accordingly by coping the tok2vec model to all + # Update the config accordingly by copying the tok2vec model to all # sections defined in the listener paths for listener_path in listeners: # Check if the path actually exists in the config diff --git a/website/docs/api/language.md b/website/docs/api/language.md index 15ebb255c..3d60b7750 100644 --- a/website/docs/api/language.md +++ b/website/docs/api/language.md @@ -876,7 +876,7 @@ when loading a config with | -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | `tok2vec_name` | Name of the token-to-vector component, typically `"tok2vec"` or `"transformer"`.~~str~~ | | `pipe_name` | Name of pipeline component to replace listeners for. ~~str~~ | -| `listeners` | The paths to the listeners, relative to the component config, e.g. `["model.tok2vec"]`. Typically, implementations will only connect to one tok2vec component, `model.tok2vec`, but in theory, custom models can use multiple listeners. The value here can either be an empty list to not replace any listeners, or a _complete_ list of the paths to all listener layers used by the model.~~Iterable[str]~~ | +| `listeners` | The paths to the listeners, relative to the component config, e.g. `["model.tok2vec"]`. Typically, implementations will only connect to one tok2vec component, `model.tok2vec`, but in theory, custom models can use multiple listeners. The value here can either be an empty list to not replace any listeners, or a _complete_ list of the paths to all listener layers used by the model that should be replaced.~~Iterable[str]~~ | ## Language.meta {#meta tag="property"} From 2102082478e04ce3e1de96c6fafcc8a935fb4693 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 21:41:38 +1100 Subject: [PATCH 11/15] Make Tok2Vec.remove_listener return bool Whether listener was removed --- spacy/pipeline/tok2vec.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/spacy/pipeline/tok2vec.py b/spacy/pipeline/tok2vec.py index eb6679834..fecf7029b 100644 --- a/spacy/pipeline/tok2vec.py +++ b/spacy/pipeline/tok2vec.py @@ -82,7 +82,7 @@ class Tok2Vec(TrainablePipe): self.listener_map.setdefault(component_name, []) self.listener_map[component_name].append(listener) - def remove_listener(self, listener: "Tok2VecListener", component_name: str) -> None: + def remove_listener(self, listener: "Tok2VecListener", component_name: str) -> bool: """Remove a listener for a downstream component. Usually internals.""" if component_name in self.listener_map: if listener in self.listener_map[component_name]: @@ -90,6 +90,8 @@ class Tok2Vec(TrainablePipe): # If no listeners are left, remove entry if not self.listener_map[component_name]: del self.listener_map[component_name] + return True + return False def find_listeners(self, component) -> None: """Walk over a model of a processing component, looking for layers that From 94232aea088108d3babf27a36d583ffca92f5283 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 23:39:23 +1100 Subject: [PATCH 12/15] Improve E889 --- spacy/errors.py | 7 ++++--- spacy/language.py | 14 ++++++++++++-- 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 4d66fd0ef..17868b605 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -483,9 +483,10 @@ class Errors: "component that doesn't seem to support listeners. Expected Tok2Vec " "or Transformer component. If you didn't call nlp.replace_listeners " "manually, this is likely a bug in spaCy.") - E889 = ("Can't replace listeners of component '{name}' because it's not " - "in the pipeline. Available components: {opts}. If you didn't call " - "nlp.replace_listeners manually, this is likely a bug in spaCy.") + E889 = ("Can't replace '{tok2vec}' listeners of component '{name}' because " + "'{unknown}' is not in the pipeline. Available components: {opts}. " + "If you didn't call nlp.replace_listeners manually, this is likely " + "a bug in spaCy.") E890 = ("Cannot add the alias '{alias}' to the Knowledge base. " "Each alias should be a meaningful string.") E891 = ("Alias '{alias}' could not be added to the Knowledge base. " diff --git a/spacy/language.py b/spacy/language.py index 92710024f..3c7f91dd6 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1712,10 +1712,20 @@ class Language: DOCS: https://nightly.spacy.io/api/language#replace_listeners """ if tok2vec_name not in self.pipe_names: - err = Errors.E889.format(name=tok2vec_name, opts=", ".join(self.pipe_names)) + err = Errors.E889.format( + tok2vec=tok2vec_name, + name=pipe_name, + unknown=tok2vec_name, + opts=", ".join(self.pipe_names), + ) raise ValueError(err) if pipe_name not in self.pipe_names: - err = Errors.E889.format(name=pipe_name, opts=", ".join(self.pipe_names)) + err = Errors.E889.format( + tok2vec=tok2vec_name, + name=pipe_name, + unknown=pipe_name, + opts=", ".join(self.pipe_names), + ) raise ValueError(err) tok2vec = self.get_pipe(tok2vec_name) tok2vec_cfg = self.get_pipe_config(tok2vec_name) From 7694f76dd1a14112f900439dfff6a6d40768a243 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 23:46:01 +1100 Subject: [PATCH 13/15] Update warning and mention replace_listeners --- spacy/errors.py | 20 +++++++++++++++----- website/docs/usage/training.md | 2 +- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/spacy/errors.py b/spacy/errors.py index 17868b605..6874f9a0c 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -80,12 +80,22 @@ class Warnings: # TODO: fix numbering after merging develop into master W086 = ("Component '{listener}' will be (re)trained, but it needs the component " - "'{name}' which is frozen. You should either freeze both, or neither " - "of the two.") + "'{name}' which is frozen. You can either freeze both, or neither " + "of the two. If you're sourcing the component from " + "an existing pipeline, you can use the `replace_listeners` setting in " + "the config block to replace its token-to-vector listener with a copy " + "and make it independent. For example, `replace_listeners = " + "[\"model.tok2vec\"]` See the documentation for details: " + "https://nightly.spacy.io/usage/training#config-components-listeners") W087 = ("Component '{name}' will be (re)trained, but the component '{listener}' " - "depends on it and is frozen. This means that the performance of " - "'{listener}' will be degraded. You should either freeze both, or " - "neither of the two.") + "depends on it via a listener and is frozen. This means that the " + "performance of '{listener}' will be degraded. You can either freeze " + "both, or neither of the two. If you're sourcing the component from " + "an existing pipeline, you can use the `replace_listeners` setting in " + "the config block to replace its token-to-vector listener with a copy " + "and make it independent. For example, `replace_listeners = " + "[\"model.tok2vec\"]` See the documentation for details: " + "https://nightly.spacy.io/usage/training#config-components-listeners") W088 = ("The pipeline component {name} implements a `begin_training` " "method, which won't be called by spaCy. As of v3.0, `begin_training` " "has been renamed to `initialize`, so you likely want to rename the " diff --git a/website/docs/usage/training.md b/website/docs/usage/training.md index b8e2c512d..ef7e5a157 100644 --- a/website/docs/usage/training.md +++ b/website/docs/usage/training.md @@ -419,7 +419,7 @@ pipeline = ["parser", "ner", "textcat", "custom"] frozen_components = ["parser", "custom"] ``` - + When the components in your pipeline [share an embedding layer](/usage/embeddings-transformers#embedding-layers), the From 7886d59c5650f4d8abe1e97d78300c707053f8c8 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 29 Jan 2021 23:47:30 +1100 Subject: [PATCH 14/15] Add check for remove_listener method --- spacy/language.py | 1 + 1 file changed, 1 insertion(+) diff --git a/spacy/language.py b/spacy/language.py index 3c7f91dd6..eca311e8f 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -1732,6 +1732,7 @@ class Language: if ( not hasattr(tok2vec, "model") or not hasattr(tok2vec, "listener_map") + or not hasattr(tok2vec, "remove_listener") or "model" not in tok2vec_cfg ): raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec))) From 7ba29f2d033145638aef30086e34e52fbacfab61 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Sat, 30 Jan 2021 00:06:07 +1100 Subject: [PATCH 15/15] Update spacy-transformers pin --- setup.cfg | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index a3de23a3f..0c0bdbe1a 100644 --- a/setup.cfg +++ b/setup.cfg @@ -68,7 +68,7 @@ console_scripts = lookups = spacy_lookups_data>=1.0.0rc0,<1.1.0 transformers = - spacy_transformers>=1.0.0rc0,<1.1.0 + spacy_transformers>=1.0.0rc4,<1.1.0 ray = spacy_ray>=0.1.0,<1.0.0 cuda =