2023-06-14 18:48:41 +03:00
|
|
|
from typing import (
|
|
|
|
TYPE_CHECKING,
|
|
|
|
Any,
|
|
|
|
Callable,
|
|
|
|
Dict,
|
|
|
|
Iterable,
|
|
|
|
List,
|
|
|
|
Optional,
|
|
|
|
Sequence,
|
|
|
|
)
|
2023-06-02 15:29:52 +03:00
|
|
|
|
2023-06-14 18:48:41 +03:00
|
|
|
from thinc.api import Model, Optimizer
|
2021-10-21 16:31:06 +03:00
|
|
|
|
2023-06-14 18:48:41 +03:00
|
|
|
from .compat import Protocol, runtime_checkable
|
2021-10-21 16:31:06 +03:00
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
2023-06-02 15:29:52 +03:00
|
|
|
from .language import Language
|
2023-06-14 18:48:41 +03:00
|
|
|
from .training import Example
|
2021-10-21 16:31:06 +03:00
|
|
|
|
|
|
|
|
|
|
|
@runtime_checkable
|
|
|
|
class TrainableComponent(Protocol):
|
|
|
|
model: Any
|
|
|
|
is_trainable: bool
|
|
|
|
|
|
|
|
def update(
|
|
|
|
self,
|
|
|
|
examples: Iterable["Example"],
|
|
|
|
*,
|
|
|
|
drop: float = 0.0,
|
|
|
|
sgd: Optional[Optimizer] = None,
|
|
|
|
losses: Optional[Dict[str, float]] = None
|
|
|
|
) -> Dict[str, float]:
|
|
|
|
...
|
|
|
|
|
|
|
|
def finish_update(self, sgd: Optimizer) -> None:
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
@runtime_checkable
|
|
|
|
class InitializableComponent(Protocol):
|
|
|
|
def initialize(
|
|
|
|
self,
|
|
|
|
get_examples: Callable[[], Iterable["Example"]],
|
2023-06-02 15:29:52 +03:00
|
|
|
nlp: "Language",
|
2021-10-21 16:31:06 +03:00
|
|
|
**kwargs: Any
|
|
|
|
):
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
@runtime_checkable
|
|
|
|
class ListenedToComponent(Protocol):
|
|
|
|
model: Any
|
|
|
|
listeners: Sequence[Model]
|
|
|
|
listener_map: Dict[str, Sequence[Model]]
|
|
|
|
listening_components: List[str]
|
|
|
|
|
|
|
|
def add_listener(self, listener: Model, component_name: str) -> None:
|
|
|
|
...
|
|
|
|
|
|
|
|
def remove_listener(self, listener: Model, component_name: str) -> bool:
|
|
|
|
...
|
|
|
|
|
|
|
|
def find_listeners(self, component) -> None:
|
|
|
|
...
|