diff --git a/requirements.txt b/requirements.txt index 1a2459498..1a3cdd8f2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -31,7 +31,7 @@ pytest-timeout>=1.3.0,<2.0.0 mock>=2.0.0,<3.0.0 flake8>=3.8.0,<6.0.0 hypothesis>=3.27.0,<7.0.0 -mypy>=0.990,<1.1.0; platform_machine != "aarch64" +mypy>=1.5.0,<1.6.0; platform_machine != "aarch64" and python_version >= "3.8" types-mock>=0.1.1 types-setuptools>=57.0.0 types-requests diff --git a/spacy/tokens/span.pyi b/spacy/tokens/span.pyi index 2a529593e..f1030278c 100644 --- a/spacy/tokens/span.pyi +++ b/spacy/tokens/span.pyi @@ -139,6 +139,8 @@ class Span: def lemma_(self) -> str: ... @property def label_(self) -> str: ... + @label_.setter + def label_(self, label: str): ... @property def kb_id_(self) -> str: ... @property diff --git a/spacy/training/example.pyi b/spacy/training/example.pyi index 06639d70c..33cf07b09 100644 --- a/spacy/training/example.pyi +++ b/spacy/training/example.pyi @@ -9,6 +9,10 @@ def annotations_to_doc( tok_annot: Dict[str, Any], doc_annot: Dict[str, Any], ) -> Doc: ... +def validate_distillation_examples( + examples: Iterable[Example], + method: str, +) -> None: ... def validate_examples( examples: Iterable[Example], method: str, diff --git a/spacy/training/example.pyx b/spacy/training/example.pyx index 914e877f5..9cc9e04ab 100644 --- a/spacy/training/example.pyx +++ b/spacy/training/example.pyx @@ -58,6 +58,12 @@ def validate_examples(examples, method): def validate_distillation_examples(examples, method): + """Check that a batch of examples received during processing is valid + for distillation. + + examples (Iterable[Examples]): A batch of examples. + method (str): The method name to show in error messages. + """ validate_examples(examples, method) for eg in examples: if [token.text for token in eg.reference] != [token.text for token in eg.predicted]: diff --git a/spacy/training/loop.py b/spacy/training/loop.py index 63715ec2c..575a583b7 100644 --- a/spacy/training/loop.py +++ b/spacy/training/loop.py @@ -12,7 +12,9 @@ from typing import ( Iterable, List, Optional, + Sized, Tuple, + TypeVar, Union, ) @@ -22,7 +24,6 @@ from wasabi import Printer from .. import ty from ..errors import Errors from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining -from ..tokens.doc import Doc from ..util import ( logger, registry, @@ -282,7 +283,7 @@ def _distill_loop( teacher: "Language", student: "Language", optimizer: Optimizer, - distill_data: Iterable[List[Example]], + distill_data: Iterable[Tuple[int, List[Example]]], evaluate: Callable[[], Tuple[float, Dict[str, float]]], *, dropout: float, @@ -401,7 +402,7 @@ def _distill_loop( def train_while_improving( nlp: "Language", optimizer: Optimizer, - train_data: Iterable[List[Example]], + train_data: Iterable[Tuple[int, List[Example]]], evaluate: Callable[[], Tuple[float, Dict[str, float]]], *, dropout: float, @@ -520,15 +521,16 @@ def train_while_improving( break +ItemT = TypeVar("ItemT", bound=Sized) + + def subdivide_batch( - batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient: int -): + batch: Iterable[ItemT], accumulate_gradient: int +) -> Iterable[List[ItemT]]: batch = list(batch) if len(batch): - if isinstance(batch[0], Example): - batch.sort(key=lambda eg: len(eg.predicted)) - else: - batch.sort(key=lambda doc: len(doc)) + # Examples are sorted by their predicted length. + batch.sort(key=lambda item: len(item)) sub_len = len(batch) // accumulate_gradient start = 0 for i in range(accumulate_gradient): @@ -578,7 +580,7 @@ def create_distill_batches( corpus: Callable[["Language"], Iterable[Example]], batcher: Callable[[Iterable[Example]], Iterable[List[Example]]], max_epochs: int, -): +) -> Iterable[Tuple[int, List[Example]]]: """Create distillation batches. In contrast to training, the corpus is normally too large to load into memory and shuffle.""" epoch = 0 @@ -592,9 +594,9 @@ def create_distill_batches( def create_train_batches( nlp: "Language", corpus: Callable[["Language"], Iterable[Example]], - batcher: Callable[[Iterable[Example]], Iterable[Example]], + batcher: Callable[[Iterable[Example]], Iterable[List[Example]]], max_epochs: int, -): +) -> Iterable[Tuple[int, List[Example]]]: epoch = 0 if max_epochs >= 0: examples = list(corpus(nlp)) # type: Iterable[Example]