mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	Typing fixes
This commit is contained in:
		
							parent
							
								
									82ef6783a8
								
							
						
					
					
						commit
						c621e251b8
					
				|  | @ -31,7 +31,7 @@ pytest-timeout>=1.3.0,<2.0.0 | |||
| mock>=2.0.0,<3.0.0 | ||||
| flake8>=3.8.0,<6.0.0 | ||||
| hypothesis>=3.27.0,<7.0.0 | ||||
| mypy>=0.990,<1.1.0; platform_machine != "aarch64" | ||||
| mypy>=1.5.0,<1.6.0; platform_machine != "aarch64" and python_version >= "3.8" | ||||
| types-mock>=0.1.1 | ||||
| types-setuptools>=57.0.0 | ||||
| types-requests | ||||
|  |  | |||
|  | @ -139,6 +139,8 @@ class Span: | |||
|     def lemma_(self) -> str: ... | ||||
|     @property | ||||
|     def label_(self) -> str: ... | ||||
|     @label_.setter | ||||
|     def label_(self, label: str): ... | ||||
|     @property | ||||
|     def kb_id_(self) -> str: ... | ||||
|     @property | ||||
|  |  | |||
|  | @ -9,6 +9,10 @@ def annotations_to_doc( | |||
|     tok_annot: Dict[str, Any], | ||||
|     doc_annot: Dict[str, Any], | ||||
| ) -> Doc: ... | ||||
| def validate_distillation_examples( | ||||
|     examples: Iterable[Example], | ||||
|     method: str, | ||||
| ) -> None: ... | ||||
| def validate_examples( | ||||
|     examples: Iterable[Example], | ||||
|     method: str, | ||||
|  |  | |||
|  | @ -58,6 +58,12 @@ def validate_examples(examples, method): | |||
| 
 | ||||
| 
 | ||||
| def validate_distillation_examples(examples, method): | ||||
|     """Check that a batch of examples received during processing is valid | ||||
|     for distillation. | ||||
| 
 | ||||
|     examples (Iterable[Examples]): A batch of examples. | ||||
|     method (str): The method name to show in error messages. | ||||
|     """ | ||||
|     validate_examples(examples, method) | ||||
|     for eg in examples: | ||||
|         if [token.text for token in eg.reference] != [token.text for token in eg.predicted]: | ||||
|  |  | |||
|  | @ -12,7 +12,9 @@ from typing import ( | |||
|     Iterable, | ||||
|     List, | ||||
|     Optional, | ||||
|     Sized, | ||||
|     Tuple, | ||||
|     TypeVar, | ||||
|     Union, | ||||
| ) | ||||
| 
 | ||||
|  | @ -22,7 +24,6 @@ from wasabi import Printer | |||
| from .. import ty | ||||
| from ..errors import Errors | ||||
| from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining | ||||
| from ..tokens.doc import Doc | ||||
| from ..util import ( | ||||
|     logger, | ||||
|     registry, | ||||
|  | @ -282,7 +283,7 @@ def _distill_loop( | |||
|     teacher: "Language", | ||||
|     student: "Language", | ||||
|     optimizer: Optimizer, | ||||
|     distill_data: Iterable[List[Example]], | ||||
|     distill_data: Iterable[Tuple[int, List[Example]]], | ||||
|     evaluate: Callable[[], Tuple[float, Dict[str, float]]], | ||||
|     *, | ||||
|     dropout: float, | ||||
|  | @ -401,7 +402,7 @@ def _distill_loop( | |||
| def train_while_improving( | ||||
|     nlp: "Language", | ||||
|     optimizer: Optimizer, | ||||
|     train_data: Iterable[List[Example]], | ||||
|     train_data: Iterable[Tuple[int, List[Example]]], | ||||
|     evaluate: Callable[[], Tuple[float, Dict[str, float]]], | ||||
|     *, | ||||
|     dropout: float, | ||||
|  | @ -520,15 +521,16 @@ def train_while_improving( | |||
|             break | ||||
| 
 | ||||
| 
 | ||||
| ItemT = TypeVar("ItemT", bound=Sized) | ||||
| 
 | ||||
| 
 | ||||
| def subdivide_batch( | ||||
|     batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient: int | ||||
| ): | ||||
|     batch: Iterable[ItemT], accumulate_gradient: int | ||||
| ) -> Iterable[List[ItemT]]: | ||||
|     batch = list(batch) | ||||
|     if len(batch): | ||||
|         if isinstance(batch[0], Example): | ||||
|             batch.sort(key=lambda eg: len(eg.predicted)) | ||||
|         else: | ||||
|             batch.sort(key=lambda doc: len(doc)) | ||||
|         # Examples are sorted by their predicted length. | ||||
|         batch.sort(key=lambda item: len(item)) | ||||
|     sub_len = len(batch) // accumulate_gradient | ||||
|     start = 0 | ||||
|     for i in range(accumulate_gradient): | ||||
|  | @ -578,7 +580,7 @@ def create_distill_batches( | |||
|     corpus: Callable[["Language"], Iterable[Example]], | ||||
|     batcher: Callable[[Iterable[Example]], Iterable[List[Example]]], | ||||
|     max_epochs: int, | ||||
| ): | ||||
| ) -> Iterable[Tuple[int, List[Example]]]: | ||||
|     """Create distillation batches. In contrast to training, the corpus | ||||
|     is normally too large to load into memory and shuffle.""" | ||||
|     epoch = 0 | ||||
|  | @ -592,9 +594,9 @@ def create_distill_batches( | |||
| def create_train_batches( | ||||
|     nlp: "Language", | ||||
|     corpus: Callable[["Language"], Iterable[Example]], | ||||
|     batcher: Callable[[Iterable[Example]], Iterable[Example]], | ||||
|     batcher: Callable[[Iterable[Example]], Iterable[List[Example]]], | ||||
|     max_epochs: int, | ||||
| ): | ||||
| ) -> Iterable[Tuple[int, List[Example]]]: | ||||
|     epoch = 0 | ||||
|     if max_epochs >= 0: | ||||
|         examples = list(corpus(nlp))  # type: Iterable[Example] | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user