mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-01 04:46:38 +03:00
85 lines
1.7 KiB
Python
85 lines
1.7 KiB
Python
from typing import (
|
|
TYPE_CHECKING,
|
|
Any,
|
|
Callable,
|
|
Dict,
|
|
Iterable,
|
|
List,
|
|
Optional,
|
|
Protocol,
|
|
Sequence,
|
|
runtime_checkable,
|
|
)
|
|
|
|
from thinc.api import Model, Optimizer
|
|
|
|
if TYPE_CHECKING:
|
|
from .language import Language
|
|
from .training import Example
|
|
|
|
|
|
@runtime_checkable
|
|
class TrainableComponent(Protocol):
|
|
model: Any
|
|
is_trainable: bool
|
|
|
|
def update(
|
|
self,
|
|
examples: Iterable["Example"],
|
|
*,
|
|
drop: float = 0.0,
|
|
sgd: Optional[Optimizer] = None,
|
|
losses: Optional[Dict[str, float]] = None
|
|
) -> Dict[str, float]:
|
|
...
|
|
|
|
def finish_update(self, sgd: Optimizer) -> None:
|
|
...
|
|
|
|
|
|
@runtime_checkable
|
|
class DistillableComponent(Protocol):
|
|
is_distillable: bool
|
|
|
|
def distill(
|
|
self,
|
|
teacher_pipe: Optional[TrainableComponent],
|
|
examples: Iterable["Example"],
|
|
*,
|
|
drop: float = 0.0,
|
|
sgd: Optional[Optimizer] = None,
|
|
losses: Optional[Dict[str, float]] = None
|
|
) -> Dict[str, float]:
|
|
...
|
|
|
|
def finish_update(self, sgd: Optimizer) -> None:
|
|
...
|
|
|
|
|
|
@runtime_checkable
|
|
class InitializableComponent(Protocol):
|
|
def initialize(
|
|
self,
|
|
get_examples: Callable[[], Iterable["Example"]],
|
|
nlp: "Language",
|
|
**kwargs: Any
|
|
):
|
|
...
|
|
|
|
|
|
@runtime_checkable
|
|
class ListenedToComponent(Protocol):
|
|
model: Any
|
|
listeners: Sequence[Model]
|
|
listener_map: Dict[str, Sequence[Model]]
|
|
listening_components: List[str]
|
|
|
|
def add_listener(self, listener: Model, component_name: str) -> None:
|
|
...
|
|
|
|
def remove_listener(self, listener: Model, component_name: str) -> bool:
|
|
...
|
|
|
|
def find_listeners(self, component) -> None:
|
|
...
|