Move replacement logic to Language.from_config

This commit is contained in:
Ines Montani 2021-01-29 19:37:04 +11:00
parent 0f3e3eedc2
commit 325f47500d
3 changed files with 31 additions and 23 deletions

View File

@ -1629,6 +1629,7 @@ class Language:
# Later we replace the component config with the raw config again. # Later we replace the component config with the raw config again.
interpolated = filled.interpolate() if not filled.is_interpolated else filled interpolated = filled.interpolate() if not filled.is_interpolated else filled
pipeline = interpolated.get("components", {}) pipeline = interpolated.get("components", {})
sourced = util.get_sourced_components(interpolated)
# If components are loaded from a source (existing models), we cache # If components are loaded from a source (existing models), we cache
# them here so they're only loaded once # them here so they're only loaded once
source_nlps = {} source_nlps = {}
@ -1671,6 +1672,17 @@ class Language:
raise ValueError( raise ValueError(
Errors.E942.format(name="pipeline_creation", value=type(nlp)) Errors.E942.format(name="pipeline_creation", value=type(nlp))
) )
# Detect components with listeners that are not frozen consistently
for name, proc in nlp.pipeline:
if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer
for listener in proc.listening_components:
# If it's a component sourced from another pipeline, we check if
# the tok2vec listeners should be replaced with standalone tok2vec
# models (e.g. so component can be frozen without its performance
# degrading when other components/tok2vec are updated)
paths = sourced.get(listener, {}).get("replace_listeners", [])
if paths:
nlp.replace_listeners(name, listener, paths)
return nlp return nlp
def replace_listeners( def replace_listeners(
@ -1744,6 +1756,7 @@ class Language:
# Go over the listener layers and replace them # Go over the listener layers and replace them
for listener in pipe_listeners: for listener in pipe_listeners:
util.replace_model_node(pipe.model, listener, tok2vec.model.copy()) util.replace_model_node(pipe.model, listener, tok2vec.model.copy())
tok2vec.remove_listener(listener, pipe_name)
def to_disk( def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()

View File

@ -14,7 +14,8 @@ from ..vectors import Vectors
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..schemas import ConfigSchemaTraining from ..schemas import ConfigSchemaTraining
from ..util import registry, load_model_from_config, resolve_dot_names, logger from ..util import registry, load_model_from_config, resolve_dot_names, logger
from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB from ..util import load_model, ensure_path, get_sourced_components
from ..util import OOV_RANK, DEFAULT_OOV_PROB
if TYPE_CHECKING: if TYPE_CHECKING:
from ..language import Language # noqa: F401 from ..language import Language # noqa: F401
@ -70,14 +71,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
for name, proc in nlp.pipeline: for name, proc in nlp.pipeline:
if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer
for listener in proc.listening_components: for listener in proc.listening_components:
# If it's a component sourced from another pipeline, we check if if listener in frozen_components and name not in frozen_components:
# the tok2vec listeners should be replaced with standalone tok2vec
# models (e.g. so component can be frozen without its performance
# degrading when other components/tok2vec are updated)
paths = sourced.get(listener, {}).get("replace_listeners", [])
if paths:
nlp.replace_listeners(name, listener, paths)
elif listener in frozen_components and name not in frozen_components:
logger.warning(Warnings.W087.format(name=name, listener=listener)) logger.warning(Warnings.W087.format(name=name, listener=listener))
# We always check this regardless, in case user freezes tok2vec # We always check this regardless, in case user freezes tok2vec
if listener not in frozen_components and name in frozen_components: if listener not in frozen_components and name in frozen_components:
@ -181,20 +175,6 @@ def init_tok2vec(
return False return False
def get_sourced_components(
config: Union[Dict[str, Any], Config]
) -> Dict[str, Dict[str, Any]]:
"""RETURNS (List[str]): All sourced components in the original config,
e.g. {"source": "en_core_web_sm"}. If the config contains a key
"factory", we assume it refers to a component factory.
"""
return {
name: cfg
for name, cfg in config.get("components", {}).items()
if "factory" not in cfg and "source" in cfg
}
def convert_vectors( def convert_vectors(
nlp: "Language", nlp: "Language",
vectors_loc: Optional[Path], vectors_loc: Optional[Path],

View File

@ -434,6 +434,20 @@ def load_model_from_config(
return nlp return nlp
def get_sourced_components(
config: Union[Dict[str, Any], Config]
) -> Dict[str, Dict[str, Any]]:
"""RETURNS (List[str]): All sourced components in the original config,
e.g. {"source": "en_core_web_sm"}. If the config contains a key
"factory", we assume it refers to a component factory.
"""
return {
name: cfg
for name, cfg in config.get("components", {}).items()
if "factory" not in cfg and "source" in cfg
}
def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[Any]: def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[Any]:
"""Resolve one or more "dot notation" names, e.g. corpora.train. """Resolve one or more "dot notation" names, e.g. corpora.train.
The paths could point anywhere into the config, so we don't know which The paths could point anywhere into the config, so we don't know which
@ -1480,5 +1494,6 @@ def _pipe(docs, proc, name, default_error_handler, kwargs):
def raise_error(proc_name, proc, docs, e): def raise_error(proc_name, proc, docs, e):
raise e raise e
def ignore_error(proc_name, proc, docs, e): def ignore_error(proc_name, proc, docs, e):
pass pass