2023-01-27 17:48:20 +03:00
|
|
|
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
2021-10-21 16:31:06 +03:00
|
|
|
from typing import Optional, Any, Iterable, Dict, Callable, Sequence, List
|
|
|
|
|
|
|
|
from thinc.api import Optimizer, Model
|
|
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
|
from .training import Example
|
|
|
|
|
|
|
|
|
|
|
|
@runtime_checkable
|
|
|
|
class TrainableComponent(Protocol):
|
|
|
|
model: Any
|
|
|
|
is_trainable: bool
|
|
|
|
|
|
|
|
def update(
|
|
|
|
self,
|
|
|
|
examples: Iterable["Example"],
|
|
|
|
*,
|
|
|
|
drop: float = 0.0,
|
|
|
|
sgd: Optional[Optimizer] = None,
|
|
|
|
losses: Optional[Dict[str, float]] = None
|
|
|
|
) -> Dict[str, float]:
|
|
|
|
...
|
|
|
|
|
|
|
|
def finish_update(self, sgd: Optimizer) -> None:
|
|
|
|
...
|
|
|
|
|
|
|
|
|
2023-01-30 14:44:11 +03:00
|
|
|
@runtime_checkable
|
|
|
|
class DistillableComponent(Protocol):
|
|
|
|
is_distillable: bool
|
|
|
|
|
|
|
|
def distill(
|
|
|
|
self,
|
|
|
|
teacher_pipe: Optional[TrainableComponent],
|
|
|
|
examples: Iterable["Example"],
|
|
|
|
*,
|
|
|
|
drop: float = 0.0,
|
|
|
|
sgd: Optional[Optimizer] = None,
|
|
|
|
losses: Optional[Dict[str, float]] = None
|
|
|
|
) -> Dict[str, float]:
|
|
|
|
...
|
|
|
|
|
|
|
|
def finish_update(self, sgd: Optimizer) -> None:
|
|
|
|
...
|
|
|
|
|
|
|
|
|
2021-10-21 16:31:06 +03:00
|
|
|
@runtime_checkable
|
|
|
|
class InitializableComponent(Protocol):
|
|
|
|
def initialize(
|
|
|
|
self,
|
|
|
|
get_examples: Callable[[], Iterable["Example"]],
|
|
|
|
nlp: Iterable["Example"],
|
|
|
|
**kwargs: Any
|
|
|
|
):
|
|
|
|
...
|
|
|
|
|
|
|
|
|
|
|
|
@runtime_checkable
|
|
|
|
class ListenedToComponent(Protocol):
|
|
|
|
model: Any
|
|
|
|
listeners: Sequence[Model]
|
|
|
|
listener_map: Dict[str, Sequence[Model]]
|
|
|
|
listening_components: List[str]
|
|
|
|
|
|
|
|
def add_listener(self, listener: Model, component_name: str) -> None:
|
|
|
|
...
|
|
|
|
|
|
|
|
def remove_listener(self, listener: Model, component_name: str) -> bool:
|
|
|
|
...
|
|
|
|
|
|
|
|
def find_listeners(self, component) -> None:
|
|
|
|
...
|