Typing fixes

This commit is contained in:
Daniël de Kok 2024-01-24 10:28:46 +01:00
parent 82ef6783a8
commit c621e251b8
5 changed files with 27 additions and 13 deletions

View File

@ -31,7 +31,7 @@ pytest-timeout>=1.3.0,<2.0.0
mock>=2.0.0,<3.0.0
flake8>=3.8.0,<6.0.0
hypothesis>=3.27.0,<7.0.0
mypy>=0.990,<1.1.0; platform_machine != "aarch64"
mypy>=1.5.0,<1.6.0; platform_machine != "aarch64" and python_version >= "3.8"
types-mock>=0.1.1
types-setuptools>=57.0.0
types-requests

View File

@ -139,6 +139,8 @@ class Span:
def lemma_(self) -> str: ...
@property
def label_(self) -> str: ...
@label_.setter
def label_(self, label: str): ...
@property
def kb_id_(self) -> str: ...
@property

View File

@ -9,6 +9,10 @@ def annotations_to_doc(
tok_annot: Dict[str, Any],
doc_annot: Dict[str, Any],
) -> Doc: ...
def validate_distillation_examples(
examples: Iterable[Example],
method: str,
) -> None: ...
def validate_examples(
examples: Iterable[Example],
method: str,

View File

@ -58,6 +58,12 @@ def validate_examples(examples, method):
def validate_distillation_examples(examples, method):
"""Check that a batch of examples received during processing is valid
for distillation.
examples (Iterable[Examples]): A batch of examples.
method (str): The method name to show in error messages.
"""
validate_examples(examples, method)
for eg in examples:
if [token.text for token in eg.reference] != [token.text for token in eg.predicted]:

View File

@ -12,7 +12,9 @@ from typing import (
Iterable,
List,
Optional,
Sized,
Tuple,
TypeVar,
Union,
)
@ -22,7 +24,6 @@ from wasabi import Printer
from .. import ty
from ..errors import Errors
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
from ..tokens.doc import Doc
from ..util import (
logger,
registry,
@ -282,7 +283,7 @@ def _distill_loop(
teacher: "Language",
student: "Language",
optimizer: Optimizer,
distill_data: Iterable[List[Example]],
distill_data: Iterable[Tuple[int, List[Example]]],
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
*,
dropout: float,
@ -401,7 +402,7 @@ def _distill_loop(
def train_while_improving(
nlp: "Language",
optimizer: Optimizer,
train_data: Iterable[List[Example]],
train_data: Iterable[Tuple[int, List[Example]]],
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
*,
dropout: float,
@ -520,15 +521,16 @@ def train_while_improving(
break
ItemT = TypeVar("ItemT", bound=Sized)
def subdivide_batch(
batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient: int
):
batch: Iterable[ItemT], accumulate_gradient: int
) -> Iterable[List[ItemT]]:
batch = list(batch)
if len(batch):
if isinstance(batch[0], Example):
batch.sort(key=lambda eg: len(eg.predicted))
else:
batch.sort(key=lambda doc: len(doc))
# Examples are sorted by their predicted length.
batch.sort(key=lambda item: len(item))
sub_len = len(batch) // accumulate_gradient
start = 0
for i in range(accumulate_gradient):
@ -578,7 +580,7 @@ def create_distill_batches(
corpus: Callable[["Language"], Iterable[Example]],
batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
max_epochs: int,
):
) -> Iterable[Tuple[int, List[Example]]]:
"""Create distillation batches. In contrast to training, the corpus
is normally too large to load into memory and shuffle."""
epoch = 0
@ -592,9 +594,9 @@ def create_distill_batches(
def create_train_batches(
nlp: "Language",
corpus: Callable[["Language"], Iterable[Example]],
batcher: Callable[[Iterable[Example]], Iterable[Example]],
batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
max_epochs: int,
):
) -> Iterable[Tuple[int, List[Example]]]:
epoch = 0
if max_epochs >= 0:
examples = list(corpus(nlp)) # type: Iterable[Example]