mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	
		
			
				
	
	
		
			40 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			40 lines
		
	
	
		
			1.2 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from functools import partial
 | 
						|
from typing import Type, Callable, TYPE_CHECKING
 | 
						|
 | 
						|
from thinc.layers import with_nvtx_range
 | 
						|
from thinc.model import Model, wrap_model_recursive
 | 
						|
 | 
						|
from ..util import registry
 | 
						|
 | 
						|
if TYPE_CHECKING:
 | 
						|
    # This lets us add type hints for mypy etc. without causing circular imports
 | 
						|
    from ..language import Language  # noqa: F401
 | 
						|
 | 
						|
 | 
						|
@registry.callbacks("spacy.models_with_nvtx_range.v1")
 | 
						|
def create_models_with_nvtx_range(
 | 
						|
    forward_color: int = -1, backprop_color: int = -1
 | 
						|
) -> Callable[["Language"], "Language"]:
 | 
						|
    def models_with_nvtx_range(nlp):
 | 
						|
        pipes = [
 | 
						|
            pipe
 | 
						|
            for _, pipe in nlp.components
 | 
						|
            if hasattr(pipe, "is_trainable") and pipe.is_trainable
 | 
						|
        ]
 | 
						|
 | 
						|
        # We need process all models jointly to avoid wrapping callbacks twice.
 | 
						|
        models = Model(
 | 
						|
            "wrap_with_nvtx_range",
 | 
						|
            forward=lambda model, X, is_train: ...,
 | 
						|
            layers=[pipe.model for pipe in pipes],
 | 
						|
        )
 | 
						|
 | 
						|
        for node in models.walk():
 | 
						|
            with_nvtx_range(
 | 
						|
                node, forward_color=forward_color, backprop_color=backprop_color
 | 
						|
            )
 | 
						|
 | 
						|
        return nlp
 | 
						|
 | 
						|
    return models_with_nvtx_range
 |