from typing import Type, Callable, Dict, TYPE_CHECKING, List, Optional, Set import functools import inspect import types import warnings from thinc.layers import with_nvtx_range from thinc.model import Model, wrap_model_recursive from thinc.util import use_nvtx_range from ..errors import Warnings from ..util import registry if TYPE_CHECKING: # This lets us add type hints for mypy etc. without causing circular imports from ..language import Language # noqa: F401 DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS = [ "pipe", "predict", "set_annotations", "update", "rehearse", "get_loss", "get_teacher_student_loss", "initialize", "begin_update", "finish_update", "update", ] def models_with_nvtx_range(nlp, forward_color: int, backprop_color: int): pipes = [ pipe for _, pipe in nlp.components if hasattr(pipe, "is_trainable") and pipe.is_trainable ] seen_models: Set[int] = set() for pipe in pipes: for node in pipe.model.walk(): if id(node) in seen_models: continue seen_models.add(id(node)) with_nvtx_range( node, forward_color=forward_color, backprop_color=backprop_color ) return nlp @registry.callbacks("spacy.models_with_nvtx_range.v1") def create_models_with_nvtx_range( forward_color: int = -1, backprop_color: int = -1 ) -> Callable[["Language"], "Language"]: return functools.partial( models_with_nvtx_range, forward_color=forward_color, backprop_color=backprop_color, ) def nvtx_range_wrapper_for_pipe_method(self, func, *args, **kwargs): if isinstance(func, functools.partial): return func(*args, **kwargs) else: with use_nvtx_range(f"{self.name} {func.__name__}"): return func(*args, **kwargs) def pipes_with_nvtx_range( nlp, additional_pipe_functions: Optional[Dict[str, List[str]]] ): for _, pipe in nlp.components: if additional_pipe_functions: extra_funcs = additional_pipe_functions.get(pipe.name, []) else: extra_funcs = [] for name in DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS + extra_funcs: func = getattr(pipe, name, None) if func is None: if name in extra_funcs: warnings.warn(Warnings.W121.format(method=name, pipe=pipe.name)) continue wrapped_func = functools.partial( types.MethodType(nvtx_range_wrapper_for_pipe_method, pipe), func ) # We need to preserve the original function signature so that # the original parameters are passed to pydantic for validation downstream. try: wrapped_func.__signature__ = inspect.signature(func) # type: ignore except: # Can fail for Cython methods that do not have bindings. warnings.warn(Warnings.W122.format(method=name, pipe=pipe.name)) continue try: setattr( pipe, name, wrapped_func, ) except AttributeError: warnings.warn(Warnings.W122.format(method=name, pipe=pipe.name)) return nlp @registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1") def create_models_and_pipes_with_nvtx_range( forward_color: int = -1, backprop_color: int = -1, additional_pipe_functions: Optional[Dict[str, List[str]]] = None, ) -> Callable[["Language"], "Language"]: def inner(nlp): nlp = models_with_nvtx_range(nlp, forward_color, backprop_color) nlp = pipes_with_nvtx_range(nlp, additional_pipe_functions) return nlp return inner