mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Move replacement logic to Language.from_config
This commit is contained in:
parent
0f3e3eedc2
commit
325f47500d
|
@ -1629,6 +1629,7 @@ class Language:
|
||||||
# Later we replace the component config with the raw config again.
|
# Later we replace the component config with the raw config again.
|
||||||
interpolated = filled.interpolate() if not filled.is_interpolated else filled
|
interpolated = filled.interpolate() if not filled.is_interpolated else filled
|
||||||
pipeline = interpolated.get("components", {})
|
pipeline = interpolated.get("components", {})
|
||||||
|
sourced = util.get_sourced_components(interpolated)
|
||||||
# If components are loaded from a source (existing models), we cache
|
# If components are loaded from a source (existing models), we cache
|
||||||
# them here so they're only loaded once
|
# them here so they're only loaded once
|
||||||
source_nlps = {}
|
source_nlps = {}
|
||||||
|
@ -1671,6 +1672,17 @@ class Language:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
Errors.E942.format(name="pipeline_creation", value=type(nlp))
|
Errors.E942.format(name="pipeline_creation", value=type(nlp))
|
||||||
)
|
)
|
||||||
|
# Detect components with listeners that are not frozen consistently
|
||||||
|
for name, proc in nlp.pipeline:
|
||||||
|
if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer
|
||||||
|
for listener in proc.listening_components:
|
||||||
|
# If it's a component sourced from another pipeline, we check if
|
||||||
|
# the tok2vec listeners should be replaced with standalone tok2vec
|
||||||
|
# models (e.g. so component can be frozen without its performance
|
||||||
|
# degrading when other components/tok2vec are updated)
|
||||||
|
paths = sourced.get(listener, {}).get("replace_listeners", [])
|
||||||
|
if paths:
|
||||||
|
nlp.replace_listeners(name, listener, paths)
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
def replace_listeners(
|
def replace_listeners(
|
||||||
|
@ -1744,6 +1756,7 @@ class Language:
|
||||||
# Go over the listener layers and replace them
|
# Go over the listener layers and replace them
|
||||||
for listener in pipe_listeners:
|
for listener in pipe_listeners:
|
||||||
util.replace_model_node(pipe.model, listener, tok2vec.model.copy())
|
util.replace_model_node(pipe.model, listener, tok2vec.model.copy())
|
||||||
|
tok2vec.remove_listener(listener, pipe_name)
|
||||||
|
|
||||||
def to_disk(
|
def to_disk(
|
||||||
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||||
|
|
|
@ -14,7 +14,8 @@ from ..vectors import Vectors
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
from ..schemas import ConfigSchemaTraining
|
from ..schemas import ConfigSchemaTraining
|
||||||
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
||||||
from ..util import load_model, ensure_path, OOV_RANK, DEFAULT_OOV_PROB
|
from ..util import load_model, ensure_path, get_sourced_components
|
||||||
|
from ..util import OOV_RANK, DEFAULT_OOV_PROB
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..language import Language # noqa: F401
|
from ..language import Language # noqa: F401
|
||||||
|
@ -70,14 +71,7 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
|
||||||
for name, proc in nlp.pipeline:
|
for name, proc in nlp.pipeline:
|
||||||
if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer
|
if getattr(proc, "listening_components", None): # e.g. tok2vec/transformer
|
||||||
for listener in proc.listening_components:
|
for listener in proc.listening_components:
|
||||||
# If it's a component sourced from another pipeline, we check if
|
if listener in frozen_components and name not in frozen_components:
|
||||||
# the tok2vec listeners should be replaced with standalone tok2vec
|
|
||||||
# models (e.g. so component can be frozen without its performance
|
|
||||||
# degrading when other components/tok2vec are updated)
|
|
||||||
paths = sourced.get(listener, {}).get("replace_listeners", [])
|
|
||||||
if paths:
|
|
||||||
nlp.replace_listeners(name, listener, paths)
|
|
||||||
elif listener in frozen_components and name not in frozen_components:
|
|
||||||
logger.warning(Warnings.W087.format(name=name, listener=listener))
|
logger.warning(Warnings.W087.format(name=name, listener=listener))
|
||||||
# We always check this regardless, in case user freezes tok2vec
|
# We always check this regardless, in case user freezes tok2vec
|
||||||
if listener not in frozen_components and name in frozen_components:
|
if listener not in frozen_components and name in frozen_components:
|
||||||
|
@ -181,20 +175,6 @@ def init_tok2vec(
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_sourced_components(
|
|
||||||
config: Union[Dict[str, Any], Config]
|
|
||||||
) -> Dict[str, Dict[str, Any]]:
|
|
||||||
"""RETURNS (List[str]): All sourced components in the original config,
|
|
||||||
e.g. {"source": "en_core_web_sm"}. If the config contains a key
|
|
||||||
"factory", we assume it refers to a component factory.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
name: cfg
|
|
||||||
for name, cfg in config.get("components", {}).items()
|
|
||||||
if "factory" not in cfg and "source" in cfg
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def convert_vectors(
|
def convert_vectors(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
vectors_loc: Optional[Path],
|
vectors_loc: Optional[Path],
|
||||||
|
|
|
@ -434,6 +434,20 @@ def load_model_from_config(
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
|
def get_sourced_components(
|
||||||
|
config: Union[Dict[str, Any], Config]
|
||||||
|
) -> Dict[str, Dict[str, Any]]:
|
||||||
|
"""RETURNS (List[str]): All sourced components in the original config,
|
||||||
|
e.g. {"source": "en_core_web_sm"}. If the config contains a key
|
||||||
|
"factory", we assume it refers to a component factory.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
name: cfg
|
||||||
|
for name, cfg in config.get("components", {}).items()
|
||||||
|
if "factory" not in cfg and "source" in cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[Any]:
|
def resolve_dot_names(config: Config, dot_names: List[Optional[str]]) -> Tuple[Any]:
|
||||||
"""Resolve one or more "dot notation" names, e.g. corpora.train.
|
"""Resolve one or more "dot notation" names, e.g. corpora.train.
|
||||||
The paths could point anywhere into the config, so we don't know which
|
The paths could point anywhere into the config, so we don't know which
|
||||||
|
@ -1480,5 +1494,6 @@ def _pipe(docs, proc, name, default_error_handler, kwargs):
|
||||||
def raise_error(proc_name, proc, docs, e):
|
def raise_error(proc_name, proc, docs, e):
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
|
||||||
def ignore_error(proc_name, proc, docs, e):
|
def ignore_error(proc_name, proc, docs, e):
|
||||||
pass
|
pass
|
||||||
|
|
Loading…
Reference in New Issue
Block a user