mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	Typing fixes
This commit is contained in:
		
							parent
							
								
									82ef6783a8
								
							
						
					
					
						commit
						c621e251b8
					
				|  | @ -31,7 +31,7 @@ pytest-timeout>=1.3.0,<2.0.0 | ||||||
| mock>=2.0.0,<3.0.0 | mock>=2.0.0,<3.0.0 | ||||||
| flake8>=3.8.0,<6.0.0 | flake8>=3.8.0,<6.0.0 | ||||||
| hypothesis>=3.27.0,<7.0.0 | hypothesis>=3.27.0,<7.0.0 | ||||||
| mypy>=0.990,<1.1.0; platform_machine != "aarch64" | mypy>=1.5.0,<1.6.0; platform_machine != "aarch64" and python_version >= "3.8" | ||||||
| types-mock>=0.1.1 | types-mock>=0.1.1 | ||||||
| types-setuptools>=57.0.0 | types-setuptools>=57.0.0 | ||||||
| types-requests | types-requests | ||||||
|  |  | ||||||
|  | @ -139,6 +139,8 @@ class Span: | ||||||
|     def lemma_(self) -> str: ... |     def lemma_(self) -> str: ... | ||||||
|     @property |     @property | ||||||
|     def label_(self) -> str: ... |     def label_(self) -> str: ... | ||||||
|  |     @label_.setter | ||||||
|  |     def label_(self, label: str): ... | ||||||
|     @property |     @property | ||||||
|     def kb_id_(self) -> str: ... |     def kb_id_(self) -> str: ... | ||||||
|     @property |     @property | ||||||
|  |  | ||||||
|  | @ -9,6 +9,10 @@ def annotations_to_doc( | ||||||
|     tok_annot: Dict[str, Any], |     tok_annot: Dict[str, Any], | ||||||
|     doc_annot: Dict[str, Any], |     doc_annot: Dict[str, Any], | ||||||
| ) -> Doc: ... | ) -> Doc: ... | ||||||
|  | def validate_distillation_examples( | ||||||
|  |     examples: Iterable[Example], | ||||||
|  |     method: str, | ||||||
|  | ) -> None: ... | ||||||
| def validate_examples( | def validate_examples( | ||||||
|     examples: Iterable[Example], |     examples: Iterable[Example], | ||||||
|     method: str, |     method: str, | ||||||
|  |  | ||||||
|  | @ -58,6 +58,12 @@ def validate_examples(examples, method): | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| def validate_distillation_examples(examples, method): | def validate_distillation_examples(examples, method): | ||||||
|  |     """Check that a batch of examples received during processing is valid | ||||||
|  |     for distillation. | ||||||
|  | 
 | ||||||
|  |     examples (Iterable[Examples]): A batch of examples. | ||||||
|  |     method (str): The method name to show in error messages. | ||||||
|  |     """ | ||||||
|     validate_examples(examples, method) |     validate_examples(examples, method) | ||||||
|     for eg in examples: |     for eg in examples: | ||||||
|         if [token.text for token in eg.reference] != [token.text for token in eg.predicted]: |         if [token.text for token in eg.reference] != [token.text for token in eg.predicted]: | ||||||
|  |  | ||||||
|  | @ -12,7 +12,9 @@ from typing import ( | ||||||
|     Iterable, |     Iterable, | ||||||
|     List, |     List, | ||||||
|     Optional, |     Optional, | ||||||
|  |     Sized, | ||||||
|     Tuple, |     Tuple, | ||||||
|  |     TypeVar, | ||||||
|     Union, |     Union, | ||||||
| ) | ) | ||||||
| 
 | 
 | ||||||
|  | @ -22,7 +24,6 @@ from wasabi import Printer | ||||||
| from .. import ty | from .. import ty | ||||||
| from ..errors import Errors | from ..errors import Errors | ||||||
| from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining | from ..schemas import ConfigSchemaDistill, ConfigSchemaTraining | ||||||
| from ..tokens.doc import Doc |  | ||||||
| from ..util import ( | from ..util import ( | ||||||
|     logger, |     logger, | ||||||
|     registry, |     registry, | ||||||
|  | @ -282,7 +283,7 @@ def _distill_loop( | ||||||
|     teacher: "Language", |     teacher: "Language", | ||||||
|     student: "Language", |     student: "Language", | ||||||
|     optimizer: Optimizer, |     optimizer: Optimizer, | ||||||
|     distill_data: Iterable[List[Example]], |     distill_data: Iterable[Tuple[int, List[Example]]], | ||||||
|     evaluate: Callable[[], Tuple[float, Dict[str, float]]], |     evaluate: Callable[[], Tuple[float, Dict[str, float]]], | ||||||
|     *, |     *, | ||||||
|     dropout: float, |     dropout: float, | ||||||
|  | @ -401,7 +402,7 @@ def _distill_loop( | ||||||
| def train_while_improving( | def train_while_improving( | ||||||
|     nlp: "Language", |     nlp: "Language", | ||||||
|     optimizer: Optimizer, |     optimizer: Optimizer, | ||||||
|     train_data: Iterable[List[Example]], |     train_data: Iterable[Tuple[int, List[Example]]], | ||||||
|     evaluate: Callable[[], Tuple[float, Dict[str, float]]], |     evaluate: Callable[[], Tuple[float, Dict[str, float]]], | ||||||
|     *, |     *, | ||||||
|     dropout: float, |     dropout: float, | ||||||
|  | @ -520,15 +521,16 @@ def train_while_improving( | ||||||
|             break |             break | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | ItemT = TypeVar("ItemT", bound=Sized) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
| def subdivide_batch( | def subdivide_batch( | ||||||
|     batch: Union[Iterable[Doc], Iterable[Example]], accumulate_gradient: int |     batch: Iterable[ItemT], accumulate_gradient: int | ||||||
| ): | ) -> Iterable[List[ItemT]]: | ||||||
|     batch = list(batch) |     batch = list(batch) | ||||||
|     if len(batch): |     if len(batch): | ||||||
|         if isinstance(batch[0], Example): |         # Examples are sorted by their predicted length. | ||||||
|             batch.sort(key=lambda eg: len(eg.predicted)) |         batch.sort(key=lambda item: len(item)) | ||||||
|         else: |  | ||||||
|             batch.sort(key=lambda doc: len(doc)) |  | ||||||
|     sub_len = len(batch) // accumulate_gradient |     sub_len = len(batch) // accumulate_gradient | ||||||
|     start = 0 |     start = 0 | ||||||
|     for i in range(accumulate_gradient): |     for i in range(accumulate_gradient): | ||||||
|  | @ -578,7 +580,7 @@ def create_distill_batches( | ||||||
|     corpus: Callable[["Language"], Iterable[Example]], |     corpus: Callable[["Language"], Iterable[Example]], | ||||||
|     batcher: Callable[[Iterable[Example]], Iterable[List[Example]]], |     batcher: Callable[[Iterable[Example]], Iterable[List[Example]]], | ||||||
|     max_epochs: int, |     max_epochs: int, | ||||||
| ): | ) -> Iterable[Tuple[int, List[Example]]]: | ||||||
|     """Create distillation batches. In contrast to training, the corpus |     """Create distillation batches. In contrast to training, the corpus | ||||||
|     is normally too large to load into memory and shuffle.""" |     is normally too large to load into memory and shuffle.""" | ||||||
|     epoch = 0 |     epoch = 0 | ||||||
|  | @ -592,9 +594,9 @@ def create_distill_batches( | ||||||
| def create_train_batches( | def create_train_batches( | ||||||
|     nlp: "Language", |     nlp: "Language", | ||||||
|     corpus: Callable[["Language"], Iterable[Example]], |     corpus: Callable[["Language"], Iterable[Example]], | ||||||
|     batcher: Callable[[Iterable[Example]], Iterable[Example]], |     batcher: Callable[[Iterable[Example]], Iterable[List[Example]]], | ||||||
|     max_epochs: int, |     max_epochs: int, | ||||||
| ): | ) -> Iterable[Tuple[int, List[Example]]]: | ||||||
|     epoch = 0 |     epoch = 0 | ||||||
|     if max_epochs >= 0: |     if max_epochs >= 0: | ||||||
|         examples = list(corpus(nlp))  # type: Iterable[Example] |         examples = list(corpus(nlp))  # type: Iterable[Example] | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user