mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-24 00:04:15 +03:00
Add NVTX ranges to TrainablePipe
components (#10965)
* `TrainablePipe`: Add NVTX range decorator * Annotate `TrainablePipe` subclasses with NVTX ranges * Export function signature to allow introspection of args in tests * Revert "Annotate `TrainablePipe` subclasses with NVTX ranges" This reverts commitd8684f7372
. * Revert "Export function signature to allow introspection of args in tests" This reverts commitf4405ca3ad
. * Revert "`TrainablePipe`: Add NVTX range decorator" This reverts commit26536eb6b8
. * Add `spacy.pipes_with_nvtx_range` pipeline callback * Show warnings for all missing user-defined pipe functions that need to be annotated Fix imports, typos * Rename `DEFAULT_ANNOTATABLE_PIPE_METHODS` to `DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS` Reorder import * Walk model nodes directly whilst applying NVTX ranges Ignore pipe method wrapper when applying range
This commit is contained in:
parent
3fe9f47de4
commit
eaf66e7431
|
@ -209,6 +209,9 @@ class Warnings(metaclass=ErrorsWithCodes):
|
||||||
"Only the last span group will be loaded under "
|
"Only the last span group will be loaded under "
|
||||||
"Doc.spans['{group_name}']. Skipping span group with values: "
|
"Doc.spans['{group_name}']. Skipping span group with values: "
|
||||||
"{group_values}")
|
"{group_values}")
|
||||||
|
W121 = ("Attempting to trace non-existent method '{method}' in pipe '{pipe}'")
|
||||||
|
W122 = ("Couldn't trace method '{method}' in pipe '{pipe}'. This can happen if the pipe class "
|
||||||
|
"is a Cython extension type.")
|
||||||
|
|
||||||
|
|
||||||
class Errors(metaclass=ErrorsWithCodes):
|
class Errors(metaclass=ErrorsWithCodes):
|
||||||
|
|
|
@ -1,9 +1,14 @@
|
||||||
from functools import partial
|
from typing import Type, Callable, Dict, TYPE_CHECKING, List, Optional, Set
|
||||||
from typing import Type, Callable, TYPE_CHECKING
|
import functools
|
||||||
|
import inspect
|
||||||
|
import types
|
||||||
|
import warnings
|
||||||
|
|
||||||
from thinc.layers import with_nvtx_range
|
from thinc.layers import with_nvtx_range
|
||||||
from thinc.model import Model, wrap_model_recursive
|
from thinc.model import Model, wrap_model_recursive
|
||||||
|
from thinc.util import use_nvtx_range
|
||||||
|
|
||||||
|
from ..errors import Warnings
|
||||||
from ..util import registry
|
from ..util import registry
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
@ -11,29 +16,106 @@ if TYPE_CHECKING:
|
||||||
from ..language import Language # noqa: F401
|
from ..language import Language # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
@registry.callbacks("spacy.models_with_nvtx_range.v1")
|
DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS = [
|
||||||
def create_models_with_nvtx_range(
|
"pipe",
|
||||||
forward_color: int = -1, backprop_color: int = -1
|
"predict",
|
||||||
) -> Callable[["Language"], "Language"]:
|
"set_annotations",
|
||||||
def models_with_nvtx_range(nlp):
|
"update",
|
||||||
|
"rehearse",
|
||||||
|
"get_loss",
|
||||||
|
"initialize",
|
||||||
|
"begin_update",
|
||||||
|
"finish_update",
|
||||||
|
"update",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def models_with_nvtx_range(nlp, forward_color: int, backprop_color: int):
|
||||||
pipes = [
|
pipes = [
|
||||||
pipe
|
pipe
|
||||||
for _, pipe in nlp.components
|
for _, pipe in nlp.components
|
||||||
if hasattr(pipe, "is_trainable") and pipe.is_trainable
|
if hasattr(pipe, "is_trainable") and pipe.is_trainable
|
||||||
]
|
]
|
||||||
|
|
||||||
# We need process all models jointly to avoid wrapping callbacks twice.
|
seen_models: Set[int] = set()
|
||||||
models = Model(
|
for pipe in pipes:
|
||||||
"wrap_with_nvtx_range",
|
for node in pipe.model.walk():
|
||||||
forward=lambda model, X, is_train: ...,
|
if id(node) in seen_models:
|
||||||
layers=[pipe.model for pipe in pipes],
|
continue
|
||||||
)
|
seen_models.add(id(node))
|
||||||
|
|
||||||
for node in models.walk():
|
|
||||||
with_nvtx_range(
|
with_nvtx_range(
|
||||||
node, forward_color=forward_color, backprop_color=backprop_color
|
node, forward_color=forward_color, backprop_color=backprop_color
|
||||||
)
|
)
|
||||||
|
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
return models_with_nvtx_range
|
|
||||||
|
@registry.callbacks("spacy.models_with_nvtx_range.v1")
|
||||||
|
def create_models_with_nvtx_range(
|
||||||
|
forward_color: int = -1, backprop_color: int = -1
|
||||||
|
) -> Callable[["Language"], "Language"]:
|
||||||
|
return functools.partial(
|
||||||
|
models_with_nvtx_range,
|
||||||
|
forward_color=forward_color,
|
||||||
|
backprop_color=backprop_color,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def nvtx_range_wrapper_for_pipe_method(self, func, *args, **kwargs):
|
||||||
|
if isinstance(func, functools.partial):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
else:
|
||||||
|
with use_nvtx_range(f"{self.name} {func.__name__}"):
|
||||||
|
return func(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def pipes_with_nvtx_range(
|
||||||
|
nlp, additional_pipe_functions: Optional[Dict[str, List[str]]]
|
||||||
|
):
|
||||||
|
for _, pipe in nlp.components:
|
||||||
|
if additional_pipe_functions:
|
||||||
|
extra_funcs = additional_pipe_functions.get(pipe.name, [])
|
||||||
|
else:
|
||||||
|
extra_funcs = []
|
||||||
|
|
||||||
|
for name in DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS + extra_funcs:
|
||||||
|
func = getattr(pipe, name, None)
|
||||||
|
if func is None:
|
||||||
|
if name in extra_funcs:
|
||||||
|
warnings.warn(Warnings.W121.format(method=name, pipe=pipe.name))
|
||||||
|
continue
|
||||||
|
|
||||||
|
wrapped_func = functools.partial(
|
||||||
|
types.MethodType(nvtx_range_wrapper_for_pipe_method, pipe), func
|
||||||
|
)
|
||||||
|
|
||||||
|
# Try to preserve the original function signature.
|
||||||
|
try:
|
||||||
|
wrapped_func.__signature__ = inspect.signature(func) # type: ignore
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
setattr(
|
||||||
|
pipe,
|
||||||
|
name,
|
||||||
|
wrapped_func,
|
||||||
|
)
|
||||||
|
except AttributeError:
|
||||||
|
warnings.warn(Warnings.W122.format(method=name, pipe=pipe.name))
|
||||||
|
|
||||||
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
|
@registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")
|
||||||
|
def create_models_and_pipes_with_nvtx_range(
|
||||||
|
forward_color: int = -1,
|
||||||
|
backprop_color: int = -1,
|
||||||
|
additional_pipe_functions: Optional[Dict[str, List[str]]] = None,
|
||||||
|
) -> Callable[["Language"], "Language"]:
|
||||||
|
def inner(nlp):
|
||||||
|
nlp = models_with_nvtx_range(nlp, forward_color, backprop_color)
|
||||||
|
nlp = pipes_with_nvtx_range(nlp, additional_pipe_functions)
|
||||||
|
return nlp
|
||||||
|
|
||||||
|
return inner
|
||||||
|
|
Loading…
Reference in New Issue
Block a user