mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	* Add `TrainablePipe.{distill,get_teacher_student_loss}`
This change adds two methods:
- `TrainablePipe::distill` which performs a training step of a
   student pipe on a teacher pipe, giving a batch of `Doc`s.
- `TrainablePipe::get_teacher_student_loss` computes the loss
  of a student relative to the teacher.
The `distill` or `get_teacher_student_loss` methods are also implemented
in the tagger, edit tree lemmatizer, and parser pipes, to enable
distillation in those pipes and as an example for other pipes.
* Fix stray `Beam` import
* Fix incorrect import
* Apply suggestions from code review
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* Apply suggestions from code review
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* TrainablePipe.distill: use `Iterable[Example]`
* Add Pipe.is_distillable method
* Add `validate_distillation_examples`
This first calls `validate_examples` and then checks that the
student/teacher tokens are the same.
* Update distill documentation
* Add distill documentation for all pipes that support distillation
* Fix incorrect identifier
* Apply suggestions from code review
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
* Add comment to explain `is_distillable`
Co-authored-by: Sofie Van Landeghem <svlandeg@users.noreply.github.com>
		
	
			
		
			
				
	
	
		
			126 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			126 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from typing import Type, Callable, Dict, TYPE_CHECKING, List, Optional, Set
 | |
| import functools
 | |
| import inspect
 | |
| import types
 | |
| import warnings
 | |
| 
 | |
| from thinc.layers import with_nvtx_range
 | |
| from thinc.model import Model, wrap_model_recursive
 | |
| from thinc.util import use_nvtx_range
 | |
| 
 | |
| from ..errors import Warnings
 | |
| from ..util import registry
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     # This lets us add type hints for mypy etc. without causing circular imports
 | |
|     from ..language import Language  # noqa: F401
 | |
| 
 | |
| 
 | |
| DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS = [
 | |
|     "pipe",
 | |
|     "predict",
 | |
|     "set_annotations",
 | |
|     "update",
 | |
|     "rehearse",
 | |
|     "get_loss",
 | |
|     "get_teacher_student_loss",
 | |
|     "initialize",
 | |
|     "begin_update",
 | |
|     "finish_update",
 | |
|     "update",
 | |
| ]
 | |
| 
 | |
| 
 | |
| def models_with_nvtx_range(nlp, forward_color: int, backprop_color: int):
 | |
|     pipes = [
 | |
|         pipe
 | |
|         for _, pipe in nlp.components
 | |
|         if hasattr(pipe, "is_trainable") and pipe.is_trainable
 | |
|     ]
 | |
| 
 | |
|     seen_models: Set[int] = set()
 | |
|     for pipe in pipes:
 | |
|         for node in pipe.model.walk():
 | |
|             if id(node) in seen_models:
 | |
|                 continue
 | |
|             seen_models.add(id(node))
 | |
|             with_nvtx_range(
 | |
|                 node, forward_color=forward_color, backprop_color=backprop_color
 | |
|             )
 | |
| 
 | |
|     return nlp
 | |
| 
 | |
| 
 | |
| @registry.callbacks("spacy.models_with_nvtx_range.v1")
 | |
| def create_models_with_nvtx_range(
 | |
|     forward_color: int = -1, backprop_color: int = -1
 | |
| ) -> Callable[["Language"], "Language"]:
 | |
|     return functools.partial(
 | |
|         models_with_nvtx_range,
 | |
|         forward_color=forward_color,
 | |
|         backprop_color=backprop_color,
 | |
|     )
 | |
| 
 | |
| 
 | |
| def nvtx_range_wrapper_for_pipe_method(self, func, *args, **kwargs):
 | |
|     if isinstance(func, functools.partial):
 | |
|         return func(*args, **kwargs)
 | |
|     else:
 | |
|         with use_nvtx_range(f"{self.name} {func.__name__}"):
 | |
|             return func(*args, **kwargs)
 | |
| 
 | |
| 
 | |
| def pipes_with_nvtx_range(
 | |
|     nlp, additional_pipe_functions: Optional[Dict[str, List[str]]]
 | |
| ):
 | |
|     for _, pipe in nlp.components:
 | |
|         if additional_pipe_functions:
 | |
|             extra_funcs = additional_pipe_functions.get(pipe.name, [])
 | |
|         else:
 | |
|             extra_funcs = []
 | |
| 
 | |
|         for name in DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS + extra_funcs:
 | |
|             func = getattr(pipe, name, None)
 | |
|             if func is None:
 | |
|                 if name in extra_funcs:
 | |
|                     warnings.warn(Warnings.W121.format(method=name, pipe=pipe.name))
 | |
|                 continue
 | |
| 
 | |
|             wrapped_func = functools.partial(
 | |
|                 types.MethodType(nvtx_range_wrapper_for_pipe_method, pipe), func
 | |
|             )
 | |
| 
 | |
|             # We need to preserve the original function signature so that
 | |
|             # the original parameters are passed to pydantic for validation downstream.
 | |
|             try:
 | |
|                 wrapped_func.__signature__ = inspect.signature(func)  # type: ignore
 | |
|             except:
 | |
|                 # Can fail for Cython methods that do not have bindings.
 | |
|                 warnings.warn(Warnings.W122.format(method=name, pipe=pipe.name))
 | |
|                 continue
 | |
| 
 | |
|             try:
 | |
|                 setattr(
 | |
|                     pipe,
 | |
|                     name,
 | |
|                     wrapped_func,
 | |
|                 )
 | |
|             except AttributeError:
 | |
|                 warnings.warn(Warnings.W122.format(method=name, pipe=pipe.name))
 | |
| 
 | |
|     return nlp
 | |
| 
 | |
| 
 | |
| @registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")
 | |
| def create_models_and_pipes_with_nvtx_range(
 | |
|     forward_color: int = -1,
 | |
|     backprop_color: int = -1,
 | |
|     additional_pipe_functions: Optional[Dict[str, List[str]]] = None,
 | |
| ) -> Callable[["Language"], "Language"]:
 | |
|     def inner(nlp):
 | |
|         nlp = models_with_nvtx_range(nlp, forward_color, backprop_color)
 | |
|         nlp = pipes_with_nvtx_range(nlp, additional_pipe_functions)
 | |
|         return nlp
 | |
| 
 | |
|     return inner
 |