Custom component types in spacy.ty (#9469)

* add custom protocols in spacy.ty

* add a test for the new types in spacy.ty

* import Example when type checking

* some type fixes

* put Protocol in compat

* revert update check back to hasattr

* runtime_checkable in compat as well
This commit is contained in:
Sofie Van Landeghem 2021-10-21 15:31:06 +02:00 committed by GitHub
parent d0631e3005
commit 5a38f79f18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 106 additions and 37 deletions

View File

@ -23,9 +23,9 @@ except ImportError:
cupy = None
if sys.version_info[:2] >= (3, 8): # Python 3.8+
from typing import Literal
from typing import Literal, Protocol, runtime_checkable
else:
from typing_extensions import Literal # noqa: F401
from typing_extensions import Literal, Protocol, runtime_checkable # noqa: F401
# Important note: The importlib_metadata "backport" includes functionality
# that's not part of the built-in importlib.metadata. We should treat this

View File

@ -17,6 +17,7 @@ from itertools import chain, cycle
from timeit import default_timer as timer
import traceback
from . import ty
from .tokens.underscore import Underscore
from .vocab import Vocab, create_vocab
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
@ -1135,11 +1136,11 @@ class Language:
if sgd not in (None, False):
if (
name not in exclude
and hasattr(proc, "is_trainable")
and isinstance(proc, ty.TrainableComponent)
and proc.is_trainable
and proc.model not in (True, False, None) # type: ignore
and proc.model not in (True, False, None)
):
proc.finish_update(sgd) # type: ignore
proc.finish_update(sgd)
if name in annotates:
for doc, eg in zip(
_pipe(
@ -1278,12 +1279,12 @@ class Language:
)
self.tokenizer.initialize(get_examples, nlp=self, **tok_settings) # type: ignore[union-attr]
for name, proc in self.pipeline:
if hasattr(proc, "initialize"):
if isinstance(proc, ty.InitializableComponent):
p_settings = I["components"].get(name, {})
p_settings = validate_init_settings(
proc.initialize, p_settings, section="components", name=name
)
proc.initialize(get_examples, nlp=self, **p_settings) # type: ignore[call-arg]
proc.initialize(get_examples, nlp=self, **p_settings)
pretrain_cfg = config.get("pretraining")
if pretrain_cfg:
P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain)
@ -1622,9 +1623,9 @@ class Language:
# components don't receive the pipeline then. So this does have to be
# here :(
for i, (name1, proc1) in enumerate(self.pipeline):
if hasattr(proc1, "find_listeners"):
if isinstance(proc1, ty.ListenedToComponent):
for name2, proc2 in self.pipeline[i + 1 :]:
proc1.find_listeners(proc2) # type: ignore[attr-defined]
proc1.find_listeners(proc2)
@classmethod
def from_config(
@ -1810,25 +1811,25 @@ class Language:
)
# Detect components with listeners that are not frozen consistently
for name, proc in nlp.pipeline:
# Remove listeners not in the pipeline
listener_names = getattr(proc, "listening_components", [])
unused_listener_names = [
ll for ll in listener_names if ll not in nlp.pipe_names
]
for listener_name in unused_listener_names:
for listener in proc.listener_map.get(listener_name, []): # type: ignore[attr-defined]
proc.remove_listener(listener, listener_name) # type: ignore[attr-defined]
if isinstance(proc, ty.ListenedToComponent):
# Remove listeners not in the pipeline
listener_names = proc.listening_components
unused_listener_names = [
ll for ll in listener_names if ll not in nlp.pipe_names
]
for listener_name in unused_listener_names:
for listener in proc.listener_map.get(listener_name, []):
proc.remove_listener(listener, listener_name)
for listener in getattr(
proc, "listening_components", []
): # e.g. tok2vec/transformer
# If it's a component sourced from another pipeline, we check if
# the tok2vec listeners should be replaced with standalone tok2vec
# models (e.g. so component can be frozen without its performance
# degrading when other components/tok2vec are updated)
paths = sourced.get(listener, {}).get("replace_listeners", [])
if paths:
nlp.replace_listeners(name, listener, paths)
for listener_name in proc.listening_components:
# e.g. tok2vec/transformer
# If it's a component sourced from another pipeline, we check if
# the tok2vec listeners should be replaced with standalone tok2vec
# models (e.g. so component can be frozen without its performance
# degrading when other components/tok2vec are updated)
paths = sourced.get(listener_name, {}).get("replace_listeners", [])
if paths:
nlp.replace_listeners(name, listener_name, paths)
return nlp
def replace_listeners(
@ -1878,15 +1879,10 @@ class Language:
raise ValueError(err)
tok2vec = self.get_pipe(tok2vec_name)
tok2vec_cfg = self.get_pipe_config(tok2vec_name)
if (
not hasattr(tok2vec, "model")
or not hasattr(tok2vec, "listener_map")
or not hasattr(tok2vec, "remove_listener")
or "model" not in tok2vec_cfg
):
if not isinstance(tok2vec, ty.ListenedToComponent):
raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec)))
tok2vec_model = tok2vec.model # type: ignore[attr-defined]
pipe_listeners = tok2vec.listener_map.get(pipe_name, []) # type: ignore[attr-defined]
tok2vec_model = tok2vec.model
pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
pipe = self.get_pipe(pipe_name)
pipe_cfg = self._pipe_configs[pipe_name]
if listeners:
@ -1926,7 +1922,7 @@ class Language:
if "replace_listener" in tok2vec_model.attrs:
new_model = tok2vec_model.attrs["replace_listener"](new_model)
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
tok2vec.remove_listener(listener, pipe_name) # type: ignore[attr-defined]
tok2vec.remove_listener(listener, pipe_name)
def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()

View File

@ -1,10 +1,10 @@
import numpy
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
from typing_extensions import Protocol, runtime_checkable
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer
from thinc.types import Ragged, Ints2d, Floats2d, Ints1d
from ..compat import Protocol, runtime_checkable
from ..scorer import Scorer
from ..language import Language
from .trainable_pipe import TrainablePipe

18
spacy/tests/test_ty.py Normal file
View File

@ -0,0 +1,18 @@
import spacy
from spacy import ty
def test_component_types():
nlp = spacy.blank("en")
tok2vec = nlp.create_pipe("tok2vec")
tagger = nlp.create_pipe("tagger")
entity_ruler = nlp.create_pipe("entity_ruler")
assert isinstance(tok2vec, ty.TrainableComponent)
assert isinstance(tagger, ty.TrainableComponent)
assert not isinstance(entity_ruler, ty.TrainableComponent)
assert isinstance(tok2vec, ty.InitializableComponent)
assert isinstance(tagger, ty.InitializableComponent)
assert isinstance(entity_ruler, ty.InitializableComponent)
assert isinstance(tok2vec, ty.ListenedToComponent)
assert not isinstance(tagger, ty.ListenedToComponent)
assert not isinstance(entity_ruler, ty.ListenedToComponent)

55
spacy/ty.py Normal file
View File

@ -0,0 +1,55 @@
from typing import TYPE_CHECKING
from typing import Optional, Any, Iterable, Dict, Callable, Sequence, List
from .compat import Protocol, runtime_checkable
from thinc.api import Optimizer, Model
if TYPE_CHECKING:
from .training import Example
@runtime_checkable
class TrainableComponent(Protocol):
model: Any
is_trainable: bool
def update(
self,
examples: Iterable["Example"],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None
) -> Dict[str, float]:
...
def finish_update(self, sgd: Optimizer) -> None:
...
@runtime_checkable
class InitializableComponent(Protocol):
def initialize(
self,
get_examples: Callable[[], Iterable["Example"]],
nlp: Iterable["Example"],
**kwargs: Any
):
...
@runtime_checkable
class ListenedToComponent(Protocol):
model: Any
listeners: Sequence[Model]
listener_map: Dict[str, Sequence[Model]]
listening_components: List[str]
def add_listener(self, listener: Model, component_name: str) -> None:
...
def remove_listener(self, listener: Model, component_name: str) -> bool:
...
def find_listeners(self, component) -> None:
...