mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Typing fixes
This commit is contained in:
parent
82ef6783a8
commit
c621e251b8
|
@ -31,7 +31,7 @@ pytest-timeout>=1.3.0,<2.0.0
|
|||
mock>=2.0.0,<3.0.0
|
||||
flake8>=3.8.0,<6.0.0
|
||||
hypothesis>=3.27.0,<7.0.0
|
||||
mypy>=0.990,<1.1.0; platform_machine != "aarch64"
|
||||
mypy>=1.5.0,<1.6.0; platform_machine != "aarch64" and python_version >= "3.8"
|
||||
types-mock>=0.1.1
|
||||
types-setuptools>=57.0.0
|
||||
types-requests
|
||||
|
|
|
@ -139,6 +139,8 @@ class Span:
|
|||
def lemma_(self) -> str: ...
|
||||
@property
|
||||
def label_(self) -> str: ...
|
||||
@label_.setter
|
||||
def label_(self, label: str): ...
|
||||
@property
|
||||
def kb_id_(self) -> str: ...
|
||||
@property
|
||||
|
|
|
@ -9,6 +9,10 @@ def annotations_to_doc(
|
|||
tok_annot: Dict[str, Any],
|
||||
doc_annot: Dict[str, Any],
|
||||
) -> Doc: ...
|
||||
def validate_distillation_examples(
|
||||
examples: Iterable[Example],
|
||||
method: str,
|
||||
) -> None: ...
|
||||
def validate_examples(
|
||||
examples: Iterable[Example],
|
||||
method: str,
|
||||
|
|
|
@ -58,6 +58,12 @@ def validate_examples(examples, method):
|
|||
|
||||
|
||||
def validate_distillation_examples(examples, method):
|
||||
"""Check that a batch of examples received during processing is valid
|
||||
for distillation.
|
||||
|
||||
examples (Iterable[Examples]): A batch of examples.
|
||||
method (str): The method name to show in error messages.
|
||||
"""
|
||||
validate_examples(examples, method)
|
||||
for eg in examples:
|
||||
if [token.text for token in eg.reference] != [token.text for token in eg.predicted]:
|
||||
|
|
|
@ -12,7 +12,9 @@ from typing import (
|
|||
Iterable,
|
||||
List,
|
||||
Optional,
|
||||
Sized,
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
|
@ -22,7 +24,6 @@ from wasabi import Printer
|
|||
from .. import ty
|
||||
from ..errors import Errors
|
||||
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
|
||||
from ..tokens.doc import Doc
|
||||
from ..util import (
|
||||
logger,
|
||||
registry,
|
||||
|
@ -282,7 +283,7 @@ def _distill_loop(
|
|||
teacher: "Language",
|
||||
student: "Language",
|
||||
optimizer: Optimizer,
|
||||
distill_data: Iterable[List[Example]],
|
||||
distill_data: Iterable[Tuple[int, List[Example]]],
|
||||
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
|
||||
*,
|
||||
dropout: float,
|
||||
|
@ -401,7 +402,7 @@ def _distill_loop(
|
|||
def train_while_improving(
|
||||
nlp: "Language",
|
||||
optimizer: Optimizer,
|
||||
train_data: Iterable[List[Example]],
|
||||
train_data: Iterable[Tuple[int, List[Example]]],
|
||||
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
|
||||
*,
|
||||
dropout: float,
|
||||
|
@ -520,15 +521,16 @@ def train_while_improving(
|
|||
break
|
||||
|
||||
|
||||
ItemT = TypeVar("ItemT", bound=Sized)
|
||||
|
||||
|
||||
def subdivide_batch(
|
||||
batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient: int
|
||||
):
|
||||
batch: Iterable[ItemT], accumulate_gradient: int
|
||||
) -> Iterable[List[ItemT]]:
|
||||
batch = list(batch)
|
||||
if len(batch):
|
||||
if isinstance(batch[0], Example):
|
||||
batch.sort(key=lambda eg: len(eg.predicted))
|
||||
else:
|
||||
batch.sort(key=lambda doc: len(doc))
|
||||
# Examples are sorted by their predicted length.
|
||||
batch.sort(key=lambda item: len(item))
|
||||
sub_len = len(batch) // accumulate_gradient
|
||||
start = 0
|
||||
for i in range(accumulate_gradient):
|
||||
|
@ -578,7 +580,7 @@ def create_distill_batches(
|
|||
corpus: Callable[["Language"], Iterable[Example]],
|
||||
batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
|
||||
max_epochs: int,
|
||||
):
|
||||
) -> Iterable[Tuple[int, List[Example]]]:
|
||||
"""Create distillation batches. In contrast to training, the corpus
|
||||
is normally too large to load into memory and shuffle."""
|
||||
epoch = 0
|
||||
|
@ -592,9 +594,9 @@ def create_distill_batches(
|
|||
def create_train_batches(
|
||||
nlp: "Language",
|
||||
corpus: Callable[["Language"], Iterable[Example]],
|
||||
batcher: Callable[[Iterable[Example]], Iterable[Example]],
|
||||
batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
|
||||
max_epochs: int,
|
||||
):
|
||||
) -> Iterable[Tuple[int, List[Example]]]:
|
||||
epoch = 0
|
||||
if max_epochs >= 0:
|
||||
examples = list(corpus(nlp)) # type: Iterable[Example]
|
||||
|
|
Loading…
Reference in New Issue
Block a user