Add option to replace listeners for sourced components

This commit is contained in:
Ines Montani 2021-01-29 15:57:04 +11:00
parent 78d6ff4dd4
commit 911dfcccfc
5 changed files with 146 additions and 15 deletions

View File

@ -8,7 +8,7 @@ from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from pathlib import Path from pathlib import Path
import warnings import warnings
from thinc.api import Model, get_current_ops, Config, Optimizer from thinc.api import get_current_ops, Config, Optimizer
import srsly import srsly
import multiprocessing as mp import multiprocessing as mp
from itertools import chain, cycle from itertools import chain, cycle
@ -670,6 +670,47 @@ class Language:
self._pipe_configs[name] = filled self._pipe_configs[name] = filled
return resolved[factory_name] return resolved[factory_name]
def replace_listeners(
self,
tok2vec_name: str,
pipe_name: str,
listeners: Iterable[str] = SimpleFrozenList(),
):
if tok2vec_name not in self.pipe_names:
raise ValueError # TODO:
if pipe_name not in self.pipe_names:
raise ValueError # TODO:
tok2vec = self.get_pipe(tok2vec_name)
tok2vec_cfg = self.get_pipe_config(tok2vec_name)
if (
not hasattr(tok2vec, "model")
or not hasattr(tok2vec, "listener_map")
or "model" not in tok2vec_cfg
):
raise ValueError # TODO: likely bug in spaCy if this happens
pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
pipe_cfg = self._pipe_configs[pipe_name]
if listeners:
util.logger.debug(f"Replacing listeners of component '{pipe_name}'")
if len(listeners) != len(pipe_listeners):
# The number of listeners defined in the component model doesn't
# match the listeners to replace, so we won't be able to update
# the nodes and generate a matching config
raise ValueError(f"{listeners}, {pipe_listeners}") # TODO:
pipe = self.get_pipe(pipe_name)
# Go over the listener layers and replace them
for listener in pipe_listeners:
util.replace_model_node(pipe.model, listener, tok2vec.model.copy())
# Update the config accordingly by coping the tok2vec model to all
# sections defined in the listener paths
for listener_path in listeners:
# Check if the path actually exists in the config
try:
util.dot_to_object(pipe_cfg, listener_path)
except KeyError:
raise ValueError # TODO:
util.set_dot_to_object(pipe_cfg, listener_path, tok2vec_cfg["model"])
def create_pipe_from_source( def create_pipe_from_source(
self, source_name: str, source: "Language", *, name: str self, source_name: str, source: "Language", *, name: str
) -> Tuple[Callable[[Doc], Doc], str]: ) -> Tuple[Callable[[Doc], Doc], str]:

View File

@ -1,5 +1,4 @@
import pytest import pytest
from spacy.ml.models.tok2vec import build_Tok2Vec_model from spacy.ml.models.tok2vec import build_Tok2Vec_model
from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed from spacy.ml.models.tok2vec import MultiHashEmbed, CharacterEmbed
from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder from spacy.ml.models.tok2vec import MishWindowEncoder, MaxoutWindowEncoder
@ -9,12 +8,11 @@ from spacy.tokens import Doc
from spacy.training import Example from spacy.training import Example
from spacy import util from spacy import util
from spacy.lang.en import English from spacy.lang.en import English
from ..util import get_batch
from thinc.api import Config from thinc.api import Config
from numpy.testing import assert_equal from numpy.testing import assert_equal
from ..util import get_batch
def test_empty_doc(): def test_empty_doc():
width = 128 width = 128
@ -187,3 +185,29 @@ def test_tok2vec_listener_callback():
Y, get_dX = tagger.model.begin_update(docs) Y, get_dX = tagger.model.begin_update(docs)
# assure that the backprop call works (and doesn't hit a 'None' callback) # assure that the backprop call works (and doesn't hit a 'None' callback)
assert get_dX(Y) is not None assert get_dX(Y) is not None
def test_replace_listeners():
orig_config = Config().from_str(cfg_string)
nlp = util.load_model_from_config(orig_config, auto_fill=True, validate=True)
examples = [Example.from_dict(nlp.make_doc("x y"), {"tags": ["V", "Z"]})]
nlp.initialize(lambda: examples)
tok2vec = nlp.get_pipe("tok2vec")
tagger = nlp.get_pipe("tagger")
assert isinstance(tagger.model.layers[0], Tok2VecListener)
assert tok2vec.listener_map["tagger"][0] == tagger.model.layers[0]
assert nlp.config["components"]["tok2vec"]["model"]["@architectures"] == "spacy.Tok2Vec.v2"
assert nlp.config["components"]["tagger"]["model"]["tok2vec"]["@architectures"] == "spacy.Tok2VecListener.v1"
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec"])
assert not isinstance(tagger.model.layers[0], Tok2VecListener)
t2v_cfg = nlp.config["components"]["tok2vec"]["model"]
assert t2v_cfg["@architectures"] == "spacy.Tok2Vec.v2"
assert nlp.config["components"]["tagger"]["model"]["tok2vec"] == t2v_cfg
with pytest.raises(ValueError):
nlp.replace_listeners("invalid", "tagger", ["model.tok2vec"])
with pytest.raises(ValueError):
nlp.replace_listeners("tok2vec", "parser", ["model.tok2vec"])
with pytest.raises(ValueError):
nlp.replace_listeners("tok2vec", "tagger", ["model.yolo"])
with pytest.raises(ValueError):
nlp.replace_listeners("tok2vec", "tagger", ["model.tok2vec", "model.yolo"])

View File

@ -205,6 +205,25 @@ def test_dot_to_dict(dot_notation, expected):
assert util.dict_to_dot(result) == dot_notation assert util.dict_to_dot(result) == dot_notation
def test_set_dot_to_object():
config = {"foo": {"bar": 1, "baz": {"x": "y"}}, "test": {"a": {"b": "c"}}}
with pytest.raises(KeyError):
util.set_dot_to_object(config, "foo.bar.baz", 100)
with pytest.raises(KeyError):
util.set_dot_to_object(config, "hello.world", 100)
with pytest.raises(KeyError):
util.set_dot_to_object(config, "test.a.b.c", 100)
util.set_dot_to_object(config, "foo.bar", 100)
assert config["foo"]["bar"] == 100
util.set_dot_to_object(config, "foo.baz.x", {"hello": "world"})
assert config["foo"]["baz"]["x"]["hello"] == "world"
assert config["test"]["a"]["b"] == "c"
util.set_dot_to_object(config, "foo", 123)
assert config["foo"] == 123
util.set_dot_to_object(config, "test", "hello")
assert dict(config) == {"foo": 123, "test": "hello"}
@pytest.mark.parametrize( @pytest.mark.parametrize(
"doc_sizes, expected_batches", "doc_sizes, expected_batches",
[ [

View File

@ -1,4 +1,4 @@
from typing import Union, Dict, Optional, Any, List, IO, TYPE_CHECKING from typing import Union, Dict, Optional, Any, IO, TYPE_CHECKING
from thinc.api import Config, fix_random_seed, set_gpu_allocator from thinc.api import Config, fix_random_seed, set_gpu_allocator
from thinc.api import ConfigValidationError from thinc.api import ConfigValidationError
from pathlib import Path from pathlib import Path
@ -33,7 +33,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
if use_gpu >= 0 and allocator: if use_gpu >= 0 and allocator:
set_gpu_allocator(allocator) set_gpu_allocator(allocator)
# Use original config here before it's resolved to functions # Use original config here before it's resolved to functions
sourced_components = get_sourced_components(config) sourced = get_sourced_components(config)
nlp = load_model_from_config(raw_config, auto_fill=True) nlp = load_model_from_config(raw_config, auto_fill=True)
logger.info("Set up nlp object from config") logger.info("Set up nlp object from config")
config = nlp.config.interpolate() config = nlp.config.interpolate()
@ -57,7 +57,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
# Components that shouldn't be updated during training # Components that shouldn't be updated during training
frozen_components = T["frozen_components"] frozen_components = T["frozen_components"]
# Sourced components that require resume_training # Sourced components that require resume_training
resume_components = [p for p in sourced_components if p not in frozen_components] resume_components = [p for p in sourced if p not in frozen_components]
logger.info(f"Pipeline: {nlp.pipe_names}") logger.info(f"Pipeline: {nlp.pipe_names}")
if resume_components: if resume_components:
with nlp.select_pipes(enable=resume_components): with nlp.select_pipes(enable=resume_components):
@ -68,10 +68,18 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
logger.info(f"Initialized pipeline components: {nlp.pipe_names}") logger.info(f"Initialized pipeline components: {nlp.pipe_names}")
# Detect components with listeners that are not frozen consistently # Detect components with listeners that are not frozen consistently
for name, proc in nlp.pipeline: for name, proc in nlp.pipeline:
if getattr(proc, "listening_components", None): if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer
for listener in proc.listening_components: for listener in proc.listening_components:
if listener in frozen_components and name not in frozen_components: # If it's a component sourced from another pipeline, we check if
# the tok2vec listeners should be replaced with standalone tok2vec
# models (e.g. so component can be frozen without its performance
# degrading when other components/tok2vec are updated)
paths = sourced.get(listener, {}).get("replace_listeners", [])
if paths:
nlp.replace_listeners(name, listener, paths)
elif listener in frozen_components and name not in frozen_components:
logger.warning(Warnings.W087.format(name=name, listener=listener)) logger.warning(Warnings.W087.format(name=name, listener=listener))
# We always check this regardless, in case user freezes tok2vec
if listener not in frozen_components and name in frozen_components: if listener not in frozen_components and name in frozen_components:
logger.warning(Warnings.W086.format(name=name, listener=listener)) logger.warning(Warnings.W086.format(name=name, listener=listener))
return nlp return nlp
@ -173,16 +181,18 @@ def init_tok2vec(
return False return False
def get_sourced_components(config: Union[Dict[str, Any], Config]) -> List[str]: def get_sourced_components(
config: Union[Dict[str, Any], Config]
) -> Dict[str, Dict[str, Any]]:
"""RETURNS (List[str]): All sourced components in the original config, """RETURNS (List[str]): All sourced components in the original config,
e.g. {"source": "en_core_web_sm"}. If the config contains a key e.g. {"source": "en_core_web_sm"}. If the config contains a key
"factory", we assume it refers to a component factory. "factory", we assume it refers to a component factory.
""" """
return [ return {
name name: cfg
for name, cfg in config.get("components", {}).items() for name, cfg in config.get("components", {}).items()
if "factory" not in cfg and "source" in cfg if "factory" not in cfg and "source" in cfg
] }
def convert_vectors( def convert_vectors(

View File

@ -8,7 +8,7 @@ import re
from pathlib import Path from pathlib import Path
import thinc import thinc
from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer from thinc.api import NumpyOps, get_current_ops, Adam, Config, Optimizer
from thinc.api import ConfigValidationError from thinc.api import ConfigValidationError, Model
import functools import functools
import itertools import itertools
import numpy.random import numpy.random
@ -738,6 +738,24 @@ def get_package_path(name: str) -> Path:
return Path(pkg.__file__).parent return Path(pkg.__file__).parent
def replace_model_node(model: Model, target: Model, replacement: Model) -> None:
"""Replace a node within a model with a new one, updating refs.
model (Model): The parent model.
target (Model): The target node.
replacement (Model): The node to replace the target with.
"""
# Place the node into the sublayers
for node in model.walk():
if target in node.layers:
node.layers[node.layers.index(target)] = replacement
# Now fix any node references
for node in model.walk():
for ref_name in node.ref_names:
if node.maybe_get_ref(ref_name) is target:
node.set_ref(ref_name, replacement)
def split_command(command: str) -> List[str]: def split_command(command: str) -> List[str]:
"""Split a string command using shlex. Handles platform compatibility. """Split a string command using shlex. Handles platform compatibility.
@ -1279,6 +1297,25 @@ def dot_to_object(config: Config, section: str):
return component return component
def set_dot_to_object(config: Config, section: str, value: Any) -> None:
"""Update a config at a given position from a dot notation.
config (Config): The config.
section (str): The dot notation of the section in the config.
value (Any): The value to set in the config.
"""
component = config
parts = section.split(".")
for i, item in enumerate(parts):
try:
if i == len(parts) - 1:
component[item] = value
else:
component = component[item]
except (KeyError, TypeError):
raise KeyError(Errors.E952.format(name=section)) from None
def walk_dict( def walk_dict(
node: Dict[str, Any], parent: List[str] = [] node: Dict[str, Any], parent: List[str] = []
) -> Iterator[Tuple[List[str], Any]]: ) -> Iterator[Tuple[List[str], Any]]: