mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-04 05:03:06 +03:00
Typing fixes
This commit is contained in:
parent
82ef6783a8
commit
c621e251b8
|
@ -31,7 +31,7 @@ pytest-timeout>=1.3.0,<2.0.0
|
||||||
mock>=2.0.0,<3.0.0
|
mock>=2.0.0,<3.0.0
|
||||||
flake8>=3.8.0,<6.0.0
|
flake8>=3.8.0,<6.0.0
|
||||||
hypothesis>=3.27.0,<7.0.0
|
hypothesis>=3.27.0,<7.0.0
|
||||||
mypy>=0.990,<1.1.0; platform_machine != "aarch64"
|
mypy>=1.5.0,<1.6.0; platform_machine != "aarch64" and python_version >= "3.8"
|
||||||
types-mock>=0.1.1
|
types-mock>=0.1.1
|
||||||
types-setuptools>=57.0.0
|
types-setuptools>=57.0.0
|
||||||
types-requests
|
types-requests
|
||||||
|
|
|
@ -139,6 +139,8 @@ class Span:
|
||||||
def lemma_(self) -> str: ...
|
def lemma_(self) -> str: ...
|
||||||
@property
|
@property
|
||||||
def label_(self) -> str: ...
|
def label_(self) -> str: ...
|
||||||
|
@label_.setter
|
||||||
|
def label_(self, label: str): ...
|
||||||
@property
|
@property
|
||||||
def kb_id_(self) -> str: ...
|
def kb_id_(self) -> str: ...
|
||||||
@property
|
@property
|
||||||
|
|
|
@ -9,6 +9,10 @@ def annotations_to_doc(
|
||||||
tok_annot: Dict[str, Any],
|
tok_annot: Dict[str, Any],
|
||||||
doc_annot: Dict[str, Any],
|
doc_annot: Dict[str, Any],
|
||||||
) -> Doc: ...
|
) -> Doc: ...
|
||||||
|
def validate_distillation_examples(
|
||||||
|
examples: Iterable[Example],
|
||||||
|
method: str,
|
||||||
|
) -> None: ...
|
||||||
def validate_examples(
|
def validate_examples(
|
||||||
examples: Iterable[Example],
|
examples: Iterable[Example],
|
||||||
method: str,
|
method: str,
|
||||||
|
|
|
@ -58,6 +58,12 @@ def validate_examples(examples, method):
|
||||||
|
|
||||||
|
|
||||||
def validate_distillation_examples(examples, method):
|
def validate_distillation_examples(examples, method):
|
||||||
|
"""Check that a batch of examples received during processing is valid
|
||||||
|
for distillation.
|
||||||
|
|
||||||
|
examples (Iterable[Examples]): A batch of examples.
|
||||||
|
method (str): The method name to show in error messages.
|
||||||
|
"""
|
||||||
validate_examples(examples, method)
|
validate_examples(examples, method)
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
if [token.text for token in eg.reference] != [token.text for token in eg.predicted]:
|
if [token.text for token in eg.reference] != [token.text for token in eg.predicted]:
|
||||||
|
|
|
@ -12,7 +12,9 @@ from typing import (
|
||||||
Iterable,
|
Iterable,
|
||||||
List,
|
List,
|
||||||
Optional,
|
Optional,
|
||||||
|
Sized,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
TypeVar,
|
||||||
Union,
|
Union,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -22,7 +24,6 @@ from wasabi import Printer
|
||||||
from .. import ty
|
from .. import ty
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
|
from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining
|
||||||
from ..tokens.doc import Doc
|
|
||||||
from ..util import (
|
from ..util import (
|
||||||
logger,
|
logger,
|
||||||
registry,
|
registry,
|
||||||
|
@ -282,7 +283,7 @@ def _distill_loop(
|
||||||
teacher: "Language",
|
teacher: "Language",
|
||||||
student: "Language",
|
student: "Language",
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
distill_data: Iterable[List[Example]],
|
distill_data: Iterable[Tuple[int, List[Example]]],
|
||||||
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
|
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
|
||||||
*,
|
*,
|
||||||
dropout: float,
|
dropout: float,
|
||||||
|
@ -401,7 +402,7 @@ def _distill_loop(
|
||||||
def train_while_improving(
|
def train_while_improving(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
optimizer: Optimizer,
|
optimizer: Optimizer,
|
||||||
train_data: Iterable[List[Example]],
|
train_data: Iterable[Tuple[int, List[Example]]],
|
||||||
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
|
evaluate: Callable[[], Tuple[float, Dict[str, float]]],
|
||||||
*,
|
*,
|
||||||
dropout: float,
|
dropout: float,
|
||||||
|
@ -520,15 +521,16 @@ def train_while_improving(
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|
||||||
|
ItemT = TypeVar("ItemT", bound=Sized)
|
||||||
|
|
||||||
|
|
||||||
def subdivide_batch(
|
def subdivide_batch(
|
||||||
batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient: int
|
batch: Iterable[ItemT], accumulate_gradient: int
|
||||||
):
|
) -> Iterable[List[ItemT]]:
|
||||||
batch = list(batch)
|
batch = list(batch)
|
||||||
if len(batch):
|
if len(batch):
|
||||||
if isinstance(batch[0], Example):
|
# Examples are sorted by their predicted length.
|
||||||
batch.sort(key=lambda eg: len(eg.predicted))
|
batch.sort(key=lambda item: len(item))
|
||||||
else:
|
|
||||||
batch.sort(key=lambda doc: len(doc))
|
|
||||||
sub_len = len(batch) // accumulate_gradient
|
sub_len = len(batch) // accumulate_gradient
|
||||||
start = 0
|
start = 0
|
||||||
for i in range(accumulate_gradient):
|
for i in range(accumulate_gradient):
|
||||||
|
@ -578,7 +580,7 @@ def create_distill_batches(
|
||||||
corpus: Callable[["Language"], Iterable[Example]],
|
corpus: Callable[["Language"], Iterable[Example]],
|
||||||
batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
|
batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
|
||||||
max_epochs: int,
|
max_epochs: int,
|
||||||
):
|
) -> Iterable[Tuple[int, List[Example]]]:
|
||||||
"""Create distillation batches. In contrast to training, the corpus
|
"""Create distillation batches. In contrast to training, the corpus
|
||||||
is normally too large to load into memory and shuffle."""
|
is normally too large to load into memory and shuffle."""
|
||||||
epoch = 0
|
epoch = 0
|
||||||
|
@ -592,9 +594,9 @@ def create_distill_batches(
|
||||||
def create_train_batches(
|
def create_train_batches(
|
||||||
nlp: "Language",
|
nlp: "Language",
|
||||||
corpus: Callable[["Language"], Iterable[Example]],
|
corpus: Callable[["Language"], Iterable[Example]],
|
||||||
batcher: Callable[[Iterable[Example]], Iterable[Example]],
|
batcher: Callable[[Iterable[Example]], Iterable[List[Example]]],
|
||||||
max_epochs: int,
|
max_epochs: int,
|
||||||
):
|
) -> Iterable[Tuple[int, List[Example]]]:
|
||||||
epoch = 0
|
epoch = 0
|
||||||
if max_epochs >= 0:
|
if max_epochs >= 0:
|
||||||
examples = list(corpus(nlp)) # type: Iterable[Example]
|
examples = list(corpus(nlp)) # type: Iterable[Example]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user