mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 10:26:35 +03:00
6b07be2110
* Add `Language.distill` This method is the distillation counterpart of `Language.update`. It takes a teacher `Language` instance and distills the student pipes on the teacher pipes. * Apply suggestions from code review Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com> * Clarify that how Example is used in distillation * Update transition parser distill docstring for examples argument * Pass optimizer to `TrainablePipe.distill` * Annotate pipe before update As discussed internally, we want to let a pipe annotate before doing an update with gold/silver data. Otherwise, the output may be (too) informed by the gold/silver data. * Rename `component_map` to `student_to_teacher` * Better synopsis in `Language.distill` docstring * `name` -> `student_name` * Fix labels type in docstring * Mark distill test as slow * Fix `student_to_teacher` type in docs --------- Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
74 lines
1.7 KiB
Python
74 lines
1.7 KiB
Python
from typing import TYPE_CHECKING, Protocol, runtime_checkable
|
|
from typing import Optional, Any, Iterable, Dict, Callable, Sequence, List
|
|
|
|
from thinc.api import Optimizer, Model
|
|
|
|
if TYPE_CHECKING:
|
|
from .training import Example
|
|
|
|
|
|
@runtime_checkable
|
|
class TrainableComponent(Protocol):
|
|
model: Any
|
|
is_trainable: bool
|
|
|
|
def update(
|
|
self,
|
|
examples: Iterable["Example"],
|
|
*,
|
|
drop: float = 0.0,
|
|
sgd: Optional[Optimizer] = None,
|
|
losses: Optional[Dict[str, float]] = None
|
|
) -> Dict[str, float]:
|
|
...
|
|
|
|
def finish_update(self, sgd: Optimizer) -> None:
|
|
...
|
|
|
|
|
|
@runtime_checkable
|
|
class DistillableComponent(Protocol):
|
|
is_distillable: bool
|
|
|
|
def distill(
|
|
self,
|
|
teacher_pipe: Optional[TrainableComponent],
|
|
examples: Iterable["Example"],
|
|
*,
|
|
drop: float = 0.0,
|
|
sgd: Optional[Optimizer] = None,
|
|
losses: Optional[Dict[str, float]] = None
|
|
) -> Dict[str, float]:
|
|
...
|
|
|
|
def finish_update(self, sgd: Optimizer) -> None:
|
|
...
|
|
|
|
|
|
@runtime_checkable
|
|
class InitializableComponent(Protocol):
|
|
def initialize(
|
|
self,
|
|
get_examples: Callable[[], Iterable["Example"]],
|
|
nlp: Iterable["Example"],
|
|
**kwargs: Any
|
|
):
|
|
...
|
|
|
|
|
|
@runtime_checkable
|
|
class ListenedToComponent(Protocol):
|
|
model: Any
|
|
listeners: Sequence[Model]
|
|
listener_map: Dict[str, Sequence[Model]]
|
|
listening_components: List[str]
|
|
|
|
def add_listener(self, listener: Model, component_name: str) -> None:
|
|
...
|
|
|
|
def remove_listener(self, listener: Model, component_name: str) -> bool:
|
|
...
|
|
|
|
def find_listeners(self, component) -> None:
|
|
...
|