mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	* Use isort with Black profile * isort all the things * Fix import cycles as a result of import sorting * Add DOCBIN_ALL_ATTRS type definition * Add isort to requirements * Remove isort from build dependencies check * Typo
		
			
				
	
	
		
			125 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			125 lines
		
	
	
		
			3.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
import functools
 | 
						|
import inspect
 | 
						|
import types
 | 
						|
import warnings
 | 
						|
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Type
 | 
						|
 | 
						|
from thinc.layers import with_nvtx_range
 | 
						|
from thinc.model import Model, wrap_model_recursive
 | 
						|
from thinc.util import use_nvtx_range
 | 
						|
 | 
						|
from ..errors import Warnings
 | 
						|
from ..util import registry
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    # This lets us add type hints for mypy etc. without causing circular imports
 | 
						|
    from ..language import Language  # noqa: F401
 | 
						|
 | 
						|
 | 
						|
DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS = [
 | 
						|
    "pipe",
 | 
						|
    "predict",
 | 
						|
    "set_annotations",
 | 
						|
    "update",
 | 
						|
    "rehearse",
 | 
						|
    "get_loss",
 | 
						|
    "initialize",
 | 
						|
    "begin_update",
 | 
						|
    "finish_update",
 | 
						|
    "update",
 | 
						|
]
 | 
						|
 | 
						|
 | 
						|
def models_with_nvtx_range(nlp, forward_color: int, backprop_color: int):
 | 
						|
    pipes = [
 | 
						|
        pipe
 | 
						|
        for _, pipe in nlp.components
 | 
						|
        if hasattr(pipe, "is_trainable") and pipe.is_trainable
 | 
						|
    ]
 | 
						|
 | 
						|
    seen_models: Set[int] = set()
 | 
						|
    for pipe in pipes:
 | 
						|
        for node in pipe.model.walk():
 | 
						|
            if id(node) in seen_models:
 | 
						|
                continue
 | 
						|
            seen_models.add(id(node))
 | 
						|
            with_nvtx_range(
 | 
						|
                node, forward_color=forward_color, backprop_color=backprop_color
 | 
						|
            )
 | 
						|
 | 
						|
    return nlp
 | 
						|
 | 
						|
 | 
						|
@registry.callbacks("spacy.models_with_nvtx_range.v1")
 | 
						|
def create_models_with_nvtx_range(
 | 
						|
    forward_color: int = -1, backprop_color: int = -1
 | 
						|
) -> Callable[["Language"], "Language"]:
 | 
						|
    return functools.partial(
 | 
						|
        models_with_nvtx_range,
 | 
						|
        forward_color=forward_color,
 | 
						|
        backprop_color=backprop_color,
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def nvtx_range_wrapper_for_pipe_method(self, func, *args, **kwargs):
 | 
						|
    if isinstance(func, functools.partial):
 | 
						|
        return func(*args, **kwargs)
 | 
						|
    else:
 | 
						|
        with use_nvtx_range(f"{self.name} {func.__name__}"):
 | 
						|
            return func(*args, **kwargs)
 | 
						|
 | 
						|
 | 
						|
def pipes_with_nvtx_range(
 | 
						|
    nlp, additional_pipe_functions: Optional[Dict[str, List[str]]]
 | 
						|
):
 | 
						|
    for _, pipe in nlp.components:
 | 
						|
        if additional_pipe_functions:
 | 
						|
            extra_funcs = additional_pipe_functions.get(pipe.name, [])
 | 
						|
        else:
 | 
						|
            extra_funcs = []
 | 
						|
 | 
						|
        for name in DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS + extra_funcs:
 | 
						|
            func = getattr(pipe, name, None)
 | 
						|
            if func is None:
 | 
						|
                if name in extra_funcs:
 | 
						|
                    warnings.warn(Warnings.W121.format(method=name, pipe=pipe.name))
 | 
						|
                continue
 | 
						|
 | 
						|
            wrapped_func = functools.partial(
 | 
						|
                types.MethodType(nvtx_range_wrapper_for_pipe_method, pipe), func
 | 
						|
            )
 | 
						|
 | 
						|
            # We need to preserve the original function signature so that
 | 
						|
            # the original parameters are passed to pydantic for validation downstream.
 | 
						|
            try:
 | 
						|
                wrapped_func.__signature__ = inspect.signature(func)  # type: ignore
 | 
						|
            except:
 | 
						|
                # Can fail for Cython methods that do not have bindings.
 | 
						|
                warnings.warn(Warnings.W122.format(method=name, pipe=pipe.name))
 | 
						|
                continue
 | 
						|
 | 
						|
            try:
 | 
						|
                setattr(
 | 
						|
                    pipe,
 | 
						|
                    name,
 | 
						|
                    wrapped_func,
 | 
						|
                )
 | 
						|
            except AttributeError:
 | 
						|
                warnings.warn(Warnings.W122.format(method=name, pipe=pipe.name))
 | 
						|
 | 
						|
    return nlp
 | 
						|
 | 
						|
 | 
						|
@registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")
 | 
						|
def create_models_and_pipes_with_nvtx_range(
 | 
						|
    forward_color: int = -1,
 | 
						|
    backprop_color: int = -1,
 | 
						|
    additional_pipe_functions: Optional[Dict[str, List[str]]] = None,
 | 
						|
) -> Callable[["Language"], "Language"]:
 | 
						|
    def inner(nlp):
 | 
						|
        nlp = models_with_nvtx_range(nlp, forward_color, backprop_color)
 | 
						|
        nlp = pipes_with_nvtx_range(nlp, additional_pipe_functions)
 | 
						|
        return nlp
 | 
						|
 | 
						|
    return inner
 |