Typing fixes

This commit is contained in:
Daniël de Kok 2024-01-24 10:28:46 +01:00
parent 82ef6783a8
commit c621e251b8
5 changed files with 27 additions and 13 deletions

View File

@ -31,7 +31,7 @@ pytest-timeout>=1.3.0,<2.0.0
mock>=2.0.0,<3.0.0 mock>=2.0.0,<3.0.0
flake8>=3.8.0,<6.0.0 flake8>=3.8.0,<6.0.0
hypothesis>=3.27.0,<7.0.0 hypothesis>=3.27.0,<7.0.0
mypy>=0.990,<1.1.0; platform_machine != "aarch64" mypy>=1.5.0,<1.6.0; platform_machine != "aarch64" and python_version >= "3.8"
types-mock>=0.1.1 types-mock>=0.1.1
types-setuptools>=57.0.0 types-setuptools>=57.0.0
types-requests types-requests

View File

@ -139,6 +139,8 @@ class Span:
def lemma_(self) -> str: ... def lemma_(self) -> str: ...
@property @property
def label_(self) -> str: ... def label_(self) -> str: ...
@label_.setter
def label_(self, label: str): ...
@property @property
def kb_id_(self) -> str: ... def kb_id_(self) -> str: ...
@property @property

View File

@ -9,6 +9,10 @@ def annotations_to_doc(
tok_annot: Dict[str, Any], tok_annot: Dict[str, Any],
doc_annot: Dict[str, Any], doc_annot: Dict[str, Any],
) -> Doc: ... ) -> Doc: ...
def validate_distillation_examples(
examples: Iterable[Example],
method: str,
) -> None: ...
def validate_examples( def validate_examples(
examples: Iterable[Example], examples: Iterable[Example],
method: str, method: str,

View File

@ -58,6 +58,12 @@ def validate_examples(examples, method):
def validate_distillation_examples(examples, method): def validate_distillation_examples(examples, method):
"""Check that a batch of examples received during processing is valid
for distillation.
examples (Iterable[Examples]): A batch of examples.
method (str): The method name to show in error messages.
"""
validate_examples(examples, method) validate_examples(examples, method)
for eg in examples: for eg in examples:
if [token.text for token in eg.reference] != [token.text for token in eg.predicted]: if [token.text for token in eg.reference] != [token.text for token in eg.predicted]:

View File

@ -12,7 +12,9 @@ from typing import (
Iterable, Iterable,
List, List,
Optional, Optional,
Sized,
Tuple, Tuple,
TypeVar,
Union, Union,
) )
@ -22,7 +24,6 @@ from wasabi import Printer
from .. import ty from .. import ty
from ..errors import Errors from ..errors import Errors
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
from ..tokens.doc import Doc
from ..util import ( from ..util import (
logger, logger,
registry, registry,
@ -282,7 +283,7 @@ def _distill_loop(
teacher: "Language", teacher: "Language",
student: "Language", student: "Language",
optimizer: Optimizer, optimizer: Optimizer,
distill_data: Iterable[List[Example]], distill_data: Iterable[Tuple[int, List[Example]]],
evaluate: Callable[[], Tuple[float, Dict[str, float]]], evaluate: Callable[[], Tuple[float, Dict[str, float]]],
*, *,
dropout: float, dropout: float,
@ -401,7 +402,7 @@ def _distill_loop(
def train_while_improving( def train_while_improving(
nlp: "Language", nlp: "Language",
optimizer: Optimizer, optimizer: Optimizer,
train_data: Iterable[List[Example]], train_data: Iterable[Tuple[int, List[Example]]],
evaluate: Callable[[], Tuple[float, Dict[str, float]]], evaluate: Callable[[], Tuple[float, Dict[str, float]]],
*, *,
dropout: float, dropout: float,
@ -520,15 +521,16 @@ def train_while_improving(
break break
ItemT = TypeVar("ItemT", bound=Sized)
def subdivide_batch( def subdivide_batch(
batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient: int batch: Iterable[ItemT], accumulate_gradient: int
): ) -> Iterable[List[ItemT]]:
batch = list(batch) batch = list(batch)
if len(batch): if len(batch):
if isinstance(batch[0], Example): # Examples are sorted by their predicted length.
batch.sort(key=lambda eg: len(eg.predicted)) batch.sort(key=lambda item: len(item))
else:
batch.sort(key=lambda doc: len(doc))
sub_len = len(batch) // accumulate_gradient sub_len = len(batch) // accumulate_gradient
start = 0 start = 0
for i in range(accumulate_gradient): for i in range(accumulate_gradient):
@ -578,7 +580,7 @@ def create_distill_batches(
corpus: Callable[["Language"], Iterable[Example]], corpus: Callable[["Language"], Iterable[Example]],
batcher: Callable[[Iterable[Example]], Iterable[List[Example]]], batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
max_epochs: int, max_epochs: int,
): ) -> Iterable[Tuple[int, List[Example]]]:
"""Create distillation batches. In contrast to training, the corpus """Create distillation batches. In contrast to training, the corpus
is normally too large to load into memory and shuffle.""" is normally too large to load into memory and shuffle."""
epoch = 0 epoch = 0
@ -592,9 +594,9 @@ def create_distill_batches(
def create_train_batches( def create_train_batches(
nlp: "Language", nlp: "Language",
corpus: Callable[["Language"], Iterable[Example]], corpus: Callable[["Language"], Iterable[Example]],
batcher: Callable[[Iterable[Example]], Iterable[Example]], batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
max_epochs: int, max_epochs: int,
): ) -> Iterable[Tuple[int, List[Example]]]:
epoch = 0 epoch = 0
if max_epochs >= 0: if max_epochs >= 0:
examples = list(corpus(nlp)) # type: Iterable[Example] examples = list(corpus(nlp)) # type: Iterable[Example]