mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Custom component types in spacy.ty (#9469)
* add custom protocols in spacy.ty * add a test for the new types in spacy.ty * import Example when type checking * some type fixes * put Protocol in compat * revert update check back to hasattr * runtime_checkable in compat as well
This commit is contained in:
parent
d0631e3005
commit
5a38f79f18
|
@ -23,9 +23,9 @@ except ImportError:
|
|||
cupy = None
|
||||
|
||||
if sys.version_info[:2] >= (3, 8): # Python 3.8+
|
||||
from typing import Literal
|
||||
from typing import Literal, Protocol, runtime_checkable
|
||||
else:
|
||||
from typing_extensions import Literal # noqa: F401
|
||||
from typing_extensions import Literal, Protocol, runtime_checkable # noqa: F401
|
||||
|
||||
# Important note: The importlib_metadata "backport" includes functionality
|
||||
# that's not part of the built-in importlib.metadata. We should treat this
|
||||
|
|
|
@ -17,6 +17,7 @@ from itertools import chain, cycle
|
|||
from timeit import default_timer as timer
|
||||
import traceback
|
||||
|
||||
from . import ty
|
||||
from .tokens.underscore import Underscore
|
||||
from .vocab import Vocab, create_vocab
|
||||
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
|
||||
|
@ -1135,11 +1136,11 @@ class Language:
|
|||
if sgd not in (None, False):
|
||||
if (
|
||||
name not in exclude
|
||||
and hasattr(proc, "is_trainable")
|
||||
and isinstance(proc, ty.TrainableComponent)
|
||||
and proc.is_trainable
|
||||
and proc.model not in (True, False, None) # type: ignore
|
||||
and proc.model not in (True, False, None)
|
||||
):
|
||||
proc.finish_update(sgd) # type: ignore
|
||||
proc.finish_update(sgd)
|
||||
if name in annotates:
|
||||
for doc, eg in zip(
|
||||
_pipe(
|
||||
|
@ -1278,12 +1279,12 @@ class Language:
|
|||
)
|
||||
self.tokenizer.initialize(get_examples, nlp=self, **tok_settings) # type: ignore[union-attr]
|
||||
for name, proc in self.pipeline:
|
||||
if hasattr(proc, "initialize"):
|
||||
if isinstance(proc, ty.InitializableComponent):
|
||||
p_settings = I["components"].get(name, {})
|
||||
p_settings = validate_init_settings(
|
||||
proc.initialize, p_settings, section="components", name=name
|
||||
)
|
||||
proc.initialize(get_examples, nlp=self, **p_settings) # type: ignore[call-arg]
|
||||
proc.initialize(get_examples, nlp=self, **p_settings)
|
||||
pretrain_cfg = config.get("pretraining")
|
||||
if pretrain_cfg:
|
||||
P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain)
|
||||
|
@ -1622,9 +1623,9 @@ class Language:
|
|||
# components don't receive the pipeline then. So this does have to be
|
||||
# here :(
|
||||
for i, (name1, proc1) in enumerate(self.pipeline):
|
||||
if hasattr(proc1, "find_listeners"):
|
||||
if isinstance(proc1, ty.ListenedToComponent):
|
||||
for name2, proc2 in self.pipeline[i + 1 :]:
|
||||
proc1.find_listeners(proc2) # type: ignore[attr-defined]
|
||||
proc1.find_listeners(proc2)
|
||||
|
||||
@classmethod
|
||||
def from_config(
|
||||
|
@ -1810,25 +1811,25 @@ class Language:
|
|||
)
|
||||
# Detect components with listeners that are not frozen consistently
|
||||
for name, proc in nlp.pipeline:
|
||||
if isinstance(proc, ty.ListenedToComponent):
|
||||
# Remove listeners not in the pipeline
|
||||
listener_names = getattr(proc, "listening_components", [])
|
||||
listener_names = proc.listening_components
|
||||
unused_listener_names = [
|
||||
ll for ll in listener_names if ll not in nlp.pipe_names
|
||||
]
|
||||
for listener_name in unused_listener_names:
|
||||
for listener in proc.listener_map.get(listener_name, []): # type: ignore[attr-defined]
|
||||
proc.remove_listener(listener, listener_name) # type: ignore[attr-defined]
|
||||
for listener in proc.listener_map.get(listener_name, []):
|
||||
proc.remove_listener(listener, listener_name)
|
||||
|
||||
for listener in getattr(
|
||||
proc, "listening_components", []
|
||||
): # e.g. tok2vec/transformer
|
||||
for listener_name in proc.listening_components:
|
||||
# e.g. tok2vec/transformer
|
||||
# If it's a component sourced from another pipeline, we check if
|
||||
# the tok2vec listeners should be replaced with standalone tok2vec
|
||||
# models (e.g. so component can be frozen without its performance
|
||||
# degrading when other components/tok2vec are updated)
|
||||
paths = sourced.get(listener, {}).get("replace_listeners", [])
|
||||
paths = sourced.get(listener_name, {}).get("replace_listeners", [])
|
||||
if paths:
|
||||
nlp.replace_listeners(name, listener, paths)
|
||||
nlp.replace_listeners(name, listener_name, paths)
|
||||
return nlp
|
||||
|
||||
def replace_listeners(
|
||||
|
@ -1878,15 +1879,10 @@ class Language:
|
|||
raise ValueError(err)
|
||||
tok2vec = self.get_pipe(tok2vec_name)
|
||||
tok2vec_cfg = self.get_pipe_config(tok2vec_name)
|
||||
if (
|
||||
not hasattr(tok2vec, "model")
|
||||
or not hasattr(tok2vec, "listener_map")
|
||||
or not hasattr(tok2vec, "remove_listener")
|
||||
or "model" not in tok2vec_cfg
|
||||
):
|
||||
if not isinstance(tok2vec, ty.ListenedToComponent):
|
||||
raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec)))
|
||||
tok2vec_model = tok2vec.model # type: ignore[attr-defined]
|
||||
pipe_listeners = tok2vec.listener_map.get(pipe_name, []) # type: ignore[attr-defined]
|
||||
tok2vec_model = tok2vec.model
|
||||
pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
|
||||
pipe = self.get_pipe(pipe_name)
|
||||
pipe_cfg = self._pipe_configs[pipe_name]
|
||||
if listeners:
|
||||
|
@ -1926,7 +1922,7 @@ class Language:
|
|||
if "replace_listener" in tok2vec_model.attrs:
|
||||
new_model = tok2vec_model.attrs["replace_listener"](new_model)
|
||||
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
|
||||
tok2vec.remove_listener(listener, pipe_name) # type: ignore[attr-defined]
|
||||
tok2vec.remove_listener(listener, pipe_name)
|
||||
|
||||
def to_disk(
|
||||
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()
|
||||
|
|
|
@ -1,10 +1,10 @@
|
|||
import numpy
|
||||
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
|
||||
from typing_extensions import Protocol, runtime_checkable
|
||||
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
|
||||
from thinc.api import Optimizer
|
||||
from thinc.types import Ragged, Ints2d, Floats2d, Ints1d
|
||||
|
||||
from ..compat import Protocol, runtime_checkable
|
||||
from ..scorer import Scorer
|
||||
from ..language import Language
|
||||
from .trainable_pipe import TrainablePipe
|
||||
|
|
18
spacy/tests/test_ty.py
Normal file
18
spacy/tests/test_ty.py
Normal file
|
@ -0,0 +1,18 @@
|
|||
import spacy
|
||||
from spacy import ty
|
||||
|
||||
|
||||
def test_component_types():
|
||||
nlp = spacy.blank("en")
|
||||
tok2vec = nlp.create_pipe("tok2vec")
|
||||
tagger = nlp.create_pipe("tagger")
|
||||
entity_ruler = nlp.create_pipe("entity_ruler")
|
||||
assert isinstance(tok2vec, ty.TrainableComponent)
|
||||
assert isinstance(tagger, ty.TrainableComponent)
|
||||
assert not isinstance(entity_ruler, ty.TrainableComponent)
|
||||
assert isinstance(tok2vec, ty.InitializableComponent)
|
||||
assert isinstance(tagger, ty.InitializableComponent)
|
||||
assert isinstance(entity_ruler, ty.InitializableComponent)
|
||||
assert isinstance(tok2vec, ty.ListenedToComponent)
|
||||
assert not isinstance(tagger, ty.ListenedToComponent)
|
||||
assert not isinstance(entity_ruler, ty.ListenedToComponent)
|
55
spacy/ty.py
Normal file
55
spacy/ty.py
Normal file
|
@ -0,0 +1,55 @@
|
|||
from typing import TYPE_CHECKING
|
||||
from typing import Optional, Any, Iterable, Dict, Callable, Sequence, List
|
||||
from .compat import Protocol, runtime_checkable
|
||||
|
||||
from thinc.api import Optimizer, Model
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .training import Example
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class TrainableComponent(Protocol):
|
||||
model: Any
|
||||
is_trainable: bool
|
||||
|
||||
def update(
|
||||
self,
|
||||
examples: Iterable["Example"],
|
||||
*,
|
||||
drop: float = 0.0,
|
||||
sgd: Optional[Optimizer] = None,
|
||||
losses: Optional[Dict[str, float]] = None
|
||||
) -> Dict[str, float]:
|
||||
...
|
||||
|
||||
def finish_update(self, sgd: Optimizer) -> None:
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class InitializableComponent(Protocol):
|
||||
def initialize(
|
||||
self,
|
||||
get_examples: Callable[[], Iterable["Example"]],
|
||||
nlp: Iterable["Example"],
|
||||
**kwargs: Any
|
||||
):
|
||||
...
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ListenedToComponent(Protocol):
|
||||
model: Any
|
||||
listeners: Sequence[Model]
|
||||
listener_map: Dict[str, Sequence[Model]]
|
||||
listening_components: List[str]
|
||||
|
||||
def add_listener(self, listener: Model, component_name: str) -> None:
|
||||
...
|
||||
|
||||
def remove_listener(self, listener: Model, component_name: str) -> bool:
|
||||
...
|
||||
|
||||
def find_listeners(self, component) -> None:
|
||||
...
|
Loading…
Reference in New Issue
Block a user