spaCy/spacy/ty.py
Daniël de Kok 6b07be2110
Add Language.distill (#12116)
* Add `Language.distill`

This method is the distillation counterpart of `Language.update`.  It
takes a teacher `Language` instance and distills the student pipes on
the teacher pipes.

* Apply suggestions from code review

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>

* Clarify that how Example is used in distillation

* Update transition parser distill docstring for examples argument

* Pass optimizer to `TrainablePipe.distill`

* Annotate pipe before update

As discussed internally, we want to let a pipe annotate before doing an
update with gold/silver data. Otherwise, the output may be (too)
informed by the gold/silver data.

* Rename `component_map` to `student_to_teacher`

* Better synopsis in `Language.distill` docstring

* `name` -> `student_name`

* Fix labels type in docstring

* Mark distill test as slow

* Fix `student_to_teacher` type in docs

---------

Co-authored-by: Madeesh Kannan <shadeMe@users.noreply.github.com>
2023-01-30 12:44:11 +01:00

74 lines
1.7 KiB
Python

from typing import TYPE_CHECKING, Protocol, runtime_checkable
from typing import Optional, Any, Iterable, Dict, Callable, Sequence, List
from thinc.api import Optimizer, Model
if TYPE_CHECKING:
from .training import Example
@runtime_checkable
class TrainableComponent(Protocol):
model: Any
is_trainable: bool
def update(
self,
examples: Iterable["Example"],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None
) -> Dict[str, float]:
...
def finish_update(self, sgd: Optimizer) -> None:
...
@runtime_checkable
class DistillableComponent(Protocol):
is_distillable: bool
def distill(
self,
teacher_pipe: Optional[TrainableComponent],
examples: Iterable["Example"],
*,
drop: float = 0.0,
sgd: Optional[Optimizer] = None,
losses: Optional[Dict[str, float]] = None
) -> Dict[str, float]:
...
def finish_update(self, sgd: Optimizer) -> None:
...
@runtime_checkable
class InitializableComponent(Protocol):
def initialize(
self,
get_examples: Callable[[], Iterable["Example"]],
nlp: Iterable["Example"],
**kwargs: Any
):
...
@runtime_checkable
class ListenedToComponent(Protocol):
model: Any
listeners: Sequence[Model]
listener_map: Dict[str, Sequence[Model]]
listening_components: List[str]
def add_listener(self, listener: Model, component_name: str) -> None:
...
def remove_listener(self, listener: Model, component_name: str) -> bool:
...
def find_listeners(self, component) -> None:
...