mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-16 06:37:04 +03:00
5a38f79f18
* add custom protocols in spacy.ty * add a test for the new types in spacy.ty * import Example when type checking * some type fixes * put Protocol in compat * revert update check back to hasattr * runtime_checkable in compat as well
56 lines
1.3 KiB
Python
56 lines
1.3 KiB
Python
from typing import TYPE_CHECKING
|
|
from typing import Optional, Any, Iterable, Dict, Callable, Sequence, List
|
|
from .compat import Protocol, runtime_checkable
|
|
|
|
from thinc.api import Optimizer, Model
|
|
|
|
if TYPE_CHECKING:
|
|
from .training import Example
|
|
|
|
|
|
@runtime_checkable
|
|
class TrainableComponent(Protocol):
|
|
model: Any
|
|
is_trainable: bool
|
|
|
|
def update(
|
|
self,
|
|
examples: Iterable["Example"],
|
|
*,
|
|
drop: float = 0.0,
|
|
sgd: Optional[Optimizer] = None,
|
|
losses: Optional[Dict[str, float]] = None
|
|
) -> Dict[str, float]:
|
|
...
|
|
|
|
def finish_update(self, sgd: Optimizer) -> None:
|
|
...
|
|
|
|
|
|
@runtime_checkable
|
|
class InitializableComponent(Protocol):
|
|
def initialize(
|
|
self,
|
|
get_examples: Callable[[], Iterable["Example"]],
|
|
nlp: Iterable["Example"],
|
|
**kwargs: Any
|
|
):
|
|
...
|
|
|
|
|
|
@runtime_checkable
|
|
class ListenedToComponent(Protocol):
|
|
model: Any
|
|
listeners: Sequence[Model]
|
|
listener_map: Dict[str, Sequence[Model]]
|
|
listening_components: List[str]
|
|
|
|
def add_listener(self, listener: Model, component_name: str) -> None:
|
|
...
|
|
|
|
def remove_listener(self, listener: Model, component_name: str) -> bool:
|
|
...
|
|
|
|
def find_listeners(self, component) -> None:
|
|
...
|