mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 09:56:28 +03:00
0ec9a696e6
* Enable Cython<->Python bindings for `Pipe` and `TrainablePipe` methods * `pipes_with_nvtx_range`: Skip hooking methods whose signature cannot be ascertained When loading pipelines from a config file, the arguments passed to individual pipeline components is validated by `pydantic` during init. For this, the validation model attempts to parse the function signature of the component's c'tor/entry point so that it can check if all mandatory parameters are present in the config file. When using the `models_and_pipes_with_nvtx_range` as a `after_pipeline_creation` callback, the methods of all pipeline components get replaced by a NVTX range wrapper **before** the above-mentioned validation takes place. This can be problematic for components that are implemented as Cython extension types - if the extension type is not compiled with Python bindings for its methods, they will have no signatures at runtime. This resulted in `pydantic` matching the *wrapper's* parameters with the those in the config and raising errors. To avoid this, we now skip applying the wrapper to any (Cython) methods that do not have signatures.
125 lines
3.7 KiB
Python
125 lines
3.7 KiB
Python
from typing import Type, Callable, Dict, TYPE_CHECKING, List, Optional, Set
|
|
import functools
|
|
import inspect
|
|
import types
|
|
import warnings
|
|
|
|
from thinc.layers import with_nvtx_range
|
|
from thinc.model import Model, wrap_model_recursive
|
|
from thinc.util import use_nvtx_range
|
|
|
|
from ..errors import Warnings
|
|
from ..util import registry
|
|
|
|
if TYPE_CHECKING:
|
|
# This lets us add type hints for mypy etc. without causing circular imports
|
|
from ..language import Language # noqa: F401
|
|
|
|
|
|
DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS = [
|
|
"pipe",
|
|
"predict",
|
|
"set_annotations",
|
|
"update",
|
|
"rehearse",
|
|
"get_loss",
|
|
"initialize",
|
|
"begin_update",
|
|
"finish_update",
|
|
"update",
|
|
]
|
|
|
|
|
|
def models_with_nvtx_range(nlp, forward_color: int, backprop_color: int):
|
|
pipes = [
|
|
pipe
|
|
for _, pipe in nlp.components
|
|
if hasattr(pipe, "is_trainable") and pipe.is_trainable
|
|
]
|
|
|
|
seen_models: Set[int] = set()
|
|
for pipe in pipes:
|
|
for node in pipe.model.walk():
|
|
if id(node) in seen_models:
|
|
continue
|
|
seen_models.add(id(node))
|
|
with_nvtx_range(
|
|
node, forward_color=forward_color, backprop_color=backprop_color
|
|
)
|
|
|
|
return nlp
|
|
|
|
|
|
@registry.callbacks("spacy.models_with_nvtx_range.v1")
|
|
def create_models_with_nvtx_range(
|
|
forward_color: int = -1, backprop_color: int = -1
|
|
) -> Callable[["Language"], "Language"]:
|
|
return functools.partial(
|
|
models_with_nvtx_range,
|
|
forward_color=forward_color,
|
|
backprop_color=backprop_color,
|
|
)
|
|
|
|
|
|
def nvtx_range_wrapper_for_pipe_method(self, func, *args, **kwargs):
|
|
if isinstance(func, functools.partial):
|
|
return func(*args, **kwargs)
|
|
else:
|
|
with use_nvtx_range(f"{self.name} {func.__name__}"):
|
|
return func(*args, **kwargs)
|
|
|
|
|
|
def pipes_with_nvtx_range(
|
|
nlp, additional_pipe_functions: Optional[Dict[str, List[str]]]
|
|
):
|
|
for _, pipe in nlp.components:
|
|
if additional_pipe_functions:
|
|
extra_funcs = additional_pipe_functions.get(pipe.name, [])
|
|
else:
|
|
extra_funcs = []
|
|
|
|
for name in DEFAULT_NVTX_ANNOTATABLE_PIPE_METHODS + extra_funcs:
|
|
func = getattr(pipe, name, None)
|
|
if func is None:
|
|
if name in extra_funcs:
|
|
warnings.warn(Warnings.W121.format(method=name, pipe=pipe.name))
|
|
continue
|
|
|
|
wrapped_func = functools.partial(
|
|
types.MethodType(nvtx_range_wrapper_for_pipe_method, pipe), func
|
|
)
|
|
|
|
# We need to preserve the original function signature so that
|
|
# the original parameters are passed to pydantic for validation downstream.
|
|
try:
|
|
wrapped_func.__signature__ = inspect.signature(func) # type: ignore
|
|
except:
|
|
# Can fail for Cython methods that do not have bindings.
|
|
warnings.warn(Warnings.W122.format(method=name, pipe=pipe.name))
|
|
continue
|
|
|
|
try:
|
|
setattr(
|
|
pipe,
|
|
name,
|
|
wrapped_func,
|
|
)
|
|
except AttributeError:
|
|
warnings.warn(Warnings.W122.format(method=name, pipe=pipe.name))
|
|
|
|
return nlp
|
|
|
|
|
|
@registry.callbacks("spacy.models_and_pipes_with_nvtx_range.v1")
|
|
def create_models_and_pipes_with_nvtx_range(
|
|
forward_color: int = -1,
|
|
backprop_color: int = -1,
|
|
additional_pipe_functions: Optional[Dict[str, List[str]]] = None,
|
|
) -> Callable[["Language"], "Language"]:
|
|
def inner(nlp):
|
|
nlp = models_with_nvtx_range(nlp, forward_color, backprop_color)
|
|
nlp = pipes_with_nvtx_range(nlp, additional_pipe_functions)
|
|
return nlp
|
|
|
|
return inner
|