mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-25 13:11:03 +03:00 
			
		
		
		
	* Add `TrainablePipe.{distill,get_teacher_student_loss}`
This change adds two methods:
- `TrainablePipe::distill` which performs a training step of a
   student pipe on a teacher pipe, giving a batch of `Doc`s.
- `TrainablePipe::get_teacher_student_loss` computes the loss
  of a student relative to the teacher.
The `distill` or `get_teacher_student_loss` methods are also implemented
in the tagger, edit tree lemmatizer, and parser pipes, to enable
distillation in those pipes and as an example for other pipes.
* Fix stray `Beam` import
* Fix incorrect import
* Apply suggestions from code review
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* Apply suggestions from code review
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* TrainablePipe.distill: use `Iterable[Example]`
* Add Pipe.is_distillable method
* Add `validate_distillation_examples`
This first calls `validate_examples` and then checks that the
student/teacher tokens are the same.
* Update distill documentation
* Add distill documentation for all pipes that support distillation
* Fix incorrect identifier
* Apply suggestions from code review
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* Add comment to explain `is_distillable`
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
		
	
			
		
			
				
	
	
		
			147 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
			
		
		
	
	
			147 lines
		
	
	
		
			5.3 KiB
		
	
	
	
		
			Cython
		
	
	
	
	
	
| # cython: infer_types=True, profile=True, binding=True
 | |
| from typing import Optional, Tuple, Iterable, Iterator, Callable, Union, Dict
 | |
| import srsly
 | |
| import warnings
 | |
| 
 | |
| from ..tokens.doc cimport Doc
 | |
| 
 | |
| from ..training import Example
 | |
| from ..errors import Errors, Warnings
 | |
| from ..language import Language
 | |
| from ..util import raise_error
 | |
| 
 | |
| cdef class Pipe:
 | |
|     """This class is a base class and not instantiated directly. It provides
 | |
|     an interface for pipeline components to implement.
 | |
|     Trainable pipeline components like the EntityRecognizer or TextCategorizer
 | |
|     should inherit from the subclass 'TrainablePipe'.
 | |
| 
 | |
|     DOCS: https://spacy.io/api/pipe
 | |
|     """
 | |
| 
 | |
|     def __call__(self, Doc doc) -> Doc:
 | |
|         """Apply the pipe to one document. The document is modified in place,
 | |
|         and returned. This usually happens under the hood when the nlp object
 | |
|         is called on a text and all components are applied to the Doc.
 | |
| 
 | |
|         doc (Doc): The Doc to process.
 | |
|         RETURNS (Doc): The processed Doc.
 | |
| 
 | |
|         DOCS: https://spacy.io/api/pipe#call
 | |
|         """
 | |
|         raise NotImplementedError(Errors.E931.format(parent="Pipe", method="__call__", name=self.name))
 | |
| 
 | |
|     def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
 | |
|         """Apply the pipe to a stream of documents. This usually happens under
 | |
|         the hood when the nlp object is called on a text and all components are
 | |
|         applied to the Doc.
 | |
| 
 | |
|         stream (Iterable[Doc]): A stream of documents.
 | |
|         batch_size (int): The number of documents to buffer.
 | |
|         YIELDS (Doc): Processed documents in order.
 | |
| 
 | |
|         DOCS: https://spacy.io/api/pipe#pipe
 | |
|         """
 | |
|         error_handler = self.get_error_handler()
 | |
|         for doc in stream:
 | |
|             try:
 | |
|                 doc = self(doc)
 | |
|                 yield doc
 | |
|             except Exception as e:
 | |
|                 error_handler(self.name, self, [doc], e)
 | |
| 
 | |
|     def initialize(self, get_examples: Callable[[], Iterable[Example]], *, nlp: Language=None):
 | |
|         """Initialize the pipe. For non-trainable components, this method
 | |
|         is optional. For trainable components, which should inherit
 | |
|         from the subclass TrainablePipe, the provided data examples
 | |
|         should be used to ensure that the internal model is initialized
 | |
|         properly and all input/output dimensions throughout the network are
 | |
|         inferred.
 | |
| 
 | |
|         get_examples (Callable[[], Iterable[Example]]): Function that
 | |
|             returns a representative sample of gold-standard Example objects.
 | |
|         nlp (Language): The current nlp object the component is part of.
 | |
| 
 | |
|         DOCS: https://spacy.io/api/pipe#initialize
 | |
|         """
 | |
|         pass
 | |
| 
 | |
|     def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Union[float, Dict[str, float]]]:
 | |
|         """Score a batch of examples.
 | |
| 
 | |
|         examples (Iterable[Example]): The examples to score.
 | |
|         RETURNS (Dict[str, Any]): The scores.
 | |
| 
 | |
|         DOCS: https://spacy.io/api/pipe#score
 | |
|         """
 | |
|         if hasattr(self, "scorer") and self.scorer is not None:
 | |
|             scorer_kwargs = {}
 | |
|             # use default settings from cfg (e.g., threshold)
 | |
|             if hasattr(self, "cfg") and isinstance(self.cfg, dict):
 | |
|                 scorer_kwargs.update(self.cfg)
 | |
|             # override self.cfg["labels"] with self.labels
 | |
|             if hasattr(self, "labels"):
 | |
|                 scorer_kwargs["labels"] = self.labels
 | |
|             # override with kwargs settings
 | |
|             scorer_kwargs.update(kwargs)
 | |
|             return self.scorer(examples, **scorer_kwargs)
 | |
|         return {}
 | |
| 
 | |
|     @property
 | |
|     def is_distillable(self) -> bool:
 | |
|         return False
 | |
| 
 | |
|     @property
 | |
|     def is_trainable(self) -> bool:
 | |
|         return False
 | |
| 
 | |
|     @property
 | |
|     def labels(self) -> Tuple[str, ...]:
 | |
|         return tuple()
 | |
| 
 | |
|     @property
 | |
|     def hide_labels(self) -> bool:
 | |
|         return False
 | |
| 
 | |
|     @property
 | |
|     def label_data(self):
 | |
|         """Optional JSON-serializable data that would be sufficient to recreate
 | |
|         the label set if provided to the `pipe.initialize()` method.
 | |
|         """
 | |
|         return None
 | |
| 
 | |
|     def _require_labels(self) -> None:
 | |
|         """Raise an error if this component has no labels defined."""
 | |
|         if not self.labels or list(self.labels) == [""]:
 | |
|             raise ValueError(Errors.E143.format(name=self.name))
 | |
| 
 | |
|     def set_error_handler(self, error_handler: Callable) -> None:
 | |
|         """Set an error handler function.
 | |
| 
 | |
|         error_handler (Callable[[str, Callable[[Doc], Doc], List[Doc], Exception], None]):
 | |
|             Function that deals with a failing batch of documents. This callable function should take in
 | |
|             the component's name, the component itself, the offending batch of documents, and the exception
 | |
|             that was thrown.
 | |
| 
 | |
|         DOCS: https://spacy.io/api/pipe#set_error_handler
 | |
|         """
 | |
|         self.error_handler = error_handler
 | |
| 
 | |
|     def get_error_handler(self) -> Callable:
 | |
|         """Retrieve the error handler function.
 | |
| 
 | |
|         RETURNS (Callable): The error handler, or if it's not set a default function that just reraises.
 | |
| 
 | |
|         DOCS: https://spacy.io/api/pipe#get_error_handler
 | |
|         """
 | |
|         if hasattr(self, "error_handler"):
 | |
|             return self.error_handler
 | |
|         return raise_error
 | |
| 
 | |
| 
 | |
| def deserialize_config(path):
 | |
|     if path.exists():
 | |
|         return srsly.read_json(path)
 | |
|     else:
 | |
|         return {}
 |