Custom component types in spacy.ty (#9469)

* add custom protocols in spacy.ty

* add a test for the new types in spacy.ty

* import Example when type checking

* some type fixes

* put Protocol in compat

* revert update check back to hasattr

* runtime_checkable in compat as well
This commit is contained in:
Sofie Van Landeghem 2021-10-21 15:31:06 +02:00 committed by GitHub
parent d0631e3005
commit 5a38f79f18
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 106 additions and 37 deletions

View File

@ -23,9 +23,9 @@ except ImportError:
cupy = None cupy = None
if sys.version_info[:2] >= (3, 8): # Python 3.8+ if sys.version_info[:2] >= (3, 8): # Python 3.8+
from typing import Literal from typing import Literal, Protocol, runtime_checkable
else: else:
from typing_extensions import Literal # noqa: F401 from typing_extensions import Literal, Protocol, runtime_checkable # noqa: F401
# Important note: The importlib_metadata "backport" includes functionality # Important note: The importlib_metadata "backport" includes functionality
# that's not part of the built-in importlib.metadata. We should treat this # that's not part of the built-in importlib.metadata. We should treat this

View File

@ -17,6 +17,7 @@ from itertools import chain, cycle
from timeit import default_timer as timer from timeit import default_timer as timer
import traceback import traceback
from . import ty
from .tokens.underscore import Underscore from .tokens.underscore import Underscore
from .vocab import Vocab, create_vocab from .vocab import Vocab, create_vocab
from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis from .pipe_analysis import validate_attrs, analyze_pipes, print_pipe_analysis
@ -1135,11 +1136,11 @@ class Language:
if sgd not in (None, False): if sgd not in (None, False):
if ( if (
name not in exclude name not in exclude
and hasattr(proc, "is_trainable") and isinstance(proc, ty.TrainableComponent)
and proc.is_trainable and proc.is_trainable
and proc.model not in (True, False, None) # type: ignore and proc.model not in (True, False, None)
): ):
proc.finish_update(sgd) # type: ignore proc.finish_update(sgd)
if name in annotates: if name in annotates:
for doc, eg in zip( for doc, eg in zip(
_pipe( _pipe(
@ -1278,12 +1279,12 @@ class Language:
) )
self.tokenizer.initialize(get_examples, nlp=self, **tok_settings) # type: ignore[union-attr] self.tokenizer.initialize(get_examples, nlp=self, **tok_settings) # type: ignore[union-attr]
for name, proc in self.pipeline: for name, proc in self.pipeline:
if hasattr(proc, "initialize"): if isinstance(proc, ty.InitializableComponent):
p_settings = I["components"].get(name, {}) p_settings = I["components"].get(name, {})
p_settings = validate_init_settings( p_settings = validate_init_settings(
proc.initialize, p_settings, section="components", name=name proc.initialize, p_settings, section="components", name=name
) )
proc.initialize(get_examples, nlp=self, **p_settings) # type: ignore[call-arg] proc.initialize(get_examples, nlp=self, **p_settings)
pretrain_cfg = config.get("pretraining") pretrain_cfg = config.get("pretraining")
if pretrain_cfg: if pretrain_cfg:
P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain) P = registry.resolve(pretrain_cfg, schema=ConfigSchemaPretrain)
@ -1622,9 +1623,9 @@ class Language:
# components don't receive the pipeline then. So this does have to be # components don't receive the pipeline then. So this does have to be
# here :( # here :(
for i, (name1, proc1) in enumerate(self.pipeline): for i, (name1, proc1) in enumerate(self.pipeline):
if hasattr(proc1, "find_listeners"): if isinstance(proc1, ty.ListenedToComponent):
for name2, proc2 in self.pipeline[i + 1 :]: for name2, proc2 in self.pipeline[i + 1 :]:
proc1.find_listeners(proc2) # type: ignore[attr-defined] proc1.find_listeners(proc2)
@classmethod @classmethod
def from_config( def from_config(
@ -1810,25 +1811,25 @@ class Language:
) )
# Detect components with listeners that are not frozen consistently # Detect components with listeners that are not frozen consistently
for name, proc in nlp.pipeline: for name, proc in nlp.pipeline:
# Remove listeners not in the pipeline if isinstance(proc, ty.ListenedToComponent):
listener_names = getattr(proc, "listening_components", []) # Remove listeners not in the pipeline
unused_listener_names = [ listener_names = proc.listening_components
ll for ll in listener_names if ll not in nlp.pipe_names unused_listener_names = [
] ll for ll in listener_names if ll not in nlp.pipe_names
for listener_name in unused_listener_names: ]
for listener in proc.listener_map.get(listener_name, []): # type: ignore[attr-defined] for listener_name in unused_listener_names:
proc.remove_listener(listener, listener_name) # type: ignore[attr-defined] for listener in proc.listener_map.get(listener_name, []):
proc.remove_listener(listener, listener_name)
for listener in getattr( for listener_name in proc.listening_components:
proc, "listening_components", [] # e.g. tok2vec/transformer
): # e.g. tok2vec/transformer # If it's a component sourced from another pipeline, we check if
# If it's a component sourced from another pipeline, we check if # the tok2vec listeners should be replaced with standalone tok2vec
# the tok2vec listeners should be replaced with standalone tok2vec # models (e.g. so component can be frozen without its performance
# models (e.g. so component can be frozen without its performance # degrading when other components/tok2vec are updated)
# degrading when other components/tok2vec are updated) paths = sourced.get(listener_name, {}).get("replace_listeners", [])
paths = sourced.get(listener, {}).get("replace_listeners", []) if paths:
if paths: nlp.replace_listeners(name, listener_name, paths)
nlp.replace_listeners(name, listener, paths)
return nlp return nlp
def replace_listeners( def replace_listeners(
@ -1878,15 +1879,10 @@ class Language:
raise ValueError(err) raise ValueError(err)
tok2vec = self.get_pipe(tok2vec_name) tok2vec = self.get_pipe(tok2vec_name)
tok2vec_cfg = self.get_pipe_config(tok2vec_name) tok2vec_cfg = self.get_pipe_config(tok2vec_name)
if ( if not isinstance(tok2vec, ty.ListenedToComponent):
not hasattr(tok2vec, "model")
or not hasattr(tok2vec, "listener_map")
or not hasattr(tok2vec, "remove_listener")
or "model" not in tok2vec_cfg
):
raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec))) raise ValueError(Errors.E888.format(name=tok2vec_name, pipe=type(tok2vec)))
tok2vec_model = tok2vec.model # type: ignore[attr-defined] tok2vec_model = tok2vec.model
pipe_listeners = tok2vec.listener_map.get(pipe_name, []) # type: ignore[attr-defined] pipe_listeners = tok2vec.listener_map.get(pipe_name, [])
pipe = self.get_pipe(pipe_name) pipe = self.get_pipe(pipe_name)
pipe_cfg = self._pipe_configs[pipe_name] pipe_cfg = self._pipe_configs[pipe_name]
if listeners: if listeners:
@ -1926,7 +1922,7 @@ class Language:
if "replace_listener" in tok2vec_model.attrs: if "replace_listener" in tok2vec_model.attrs:
new_model = tok2vec_model.attrs["replace_listener"](new_model) new_model = tok2vec_model.attrs["replace_listener"](new_model)
util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined] util.replace_model_node(pipe.model, listener, new_model) # type: ignore[attr-defined]
tok2vec.remove_listener(listener, pipe_name) # type: ignore[attr-defined] tok2vec.remove_listener(listener, pipe_name)
def to_disk( def to_disk(
self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList() self, path: Union[str, Path], *, exclude: Iterable[str] = SimpleFrozenList()

View File

@ -1,10 +1,10 @@
import numpy import numpy
from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast from typing import List, Dict, Callable, Tuple, Optional, Iterable, Any, cast
from typing_extensions import Protocol, runtime_checkable
from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops from thinc.api import Config, Model, get_current_ops, set_dropout_rate, Ops
from thinc.api import Optimizer from thinc.api import Optimizer
from thinc.types import Ragged, Ints2d, Floats2d, Ints1d from thinc.types import Ragged, Ints2d, Floats2d, Ints1d
from ..compat import Protocol, runtime_checkable
from ..scorer import Scorer from ..scorer import Scorer
from ..language import Language from ..language import Language
from .trainable_pipe import TrainablePipe from .trainable_pipe import TrainablePipe

18
spacy/tests/test_ty.py Normal file
View File

@ -0,0 +1,18 @@
import spacy
from spacy import ty
def test_component_types():
nlp = spacy.blank("en")
tok2vec = nlp.create_pipe("tok2vec")
tagger = nlp.create_pipe("tagger")
entity_ruler = nlp.create_pipe("entity_ruler")
assert isinstance(tok2vec, ty.TrainableComponent)
assert isinstance(tagger, ty.TrainableComponent)
assert not isinstance(entity_ruler, ty.TrainableComponent)
assert isinstance(tok2vec, ty.InitializableComponent)
assert isinstance(tagger, ty.InitializableComponent)
assert isinstance(entity_ruler, ty.InitializableComponent)
assert isinstance(tok2vec, ty.ListenedToComponent)
assert not isinstance(tagger, ty.ListenedToComponent)
assert not isinstance(entity_ruler, ty.ListenedToComponent)

55
spacy/ty.py Normal file
View File

@ -0,0 +1,55 @@
from typing import TYPE_CHECKING
from typing import Optional, Any, Iterable, Dict, Callable, Sequence, List
from .compat import Protocol, runtime_checkable
from thinc.api import Optimizer, Model
if TYPE_CHECKING:
from .training import Example
@runtime_checkable
class TrainableComponent(Protocol):
model: Any
is_trainable: bool
def update(
self,
examples: Iterable["Example"],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None
) -> Dict[str, float]:
...
def finish_update(self, sgd: Optimizer) -> None:
...
@runtime_checkable
class InitializableComponent(Protocol):
def initialize(
self,
get_examples: Callable[[], Iterable["Example"]],
nlp: Iterable["Example"],
**kwargs: Any
):
...
@runtime_checkable
class ListenedToComponent(Protocol):
model: Any
listeners: Sequence[Model]
listener_map: Dict[str, Sequence[Model]]
listening_components: List[str]
def add_listener(self, listener: Model, component_name: str) -> None:
...
def remove_listener(self, listener: Model, component_name: str) -> bool:
...
def find_listeners(self, component) -> None:
...