spaCy/spacy/ml/callbacks.py
github-actions[bot] b0b115ff39
Auto-format code with black (#9530)
Co-authored-by: explosion-bot <explosion-bot@users.noreply.github.com>
2021-10-22 13:03:10 +02:00

40 lines
1.2 KiB
Python

from functools import partial
from typing import Type, Callable, TYPE_CHECKING
from thinc.layers import with_nvtx_range
from thinc.model import Model, wrap_model_recursive
from ..util import registry
if TYPE_CHECKING:
# This lets us add type hints for mypy etc. without causing circular imports
from ..language import Language # noqa: F401
@registry.callbacks("spacy.models_with_nvtx_range.v1")
def create_models_with_nvtx_range(
forward_color: int = -1, backprop_color: int = -1
) -> Callable[["Language"], "Language"]:
def models_with_nvtx_range(nlp):
pipes = [
pipe
for _, pipe in nlp.components
if hasattr(pipe, "is_trainable") and pipe.is_trainable
]
# We need process all models jointly to avoid wrapping callbacks twice.
models = Model(
"wrap_with_nvtx_range",
forward=lambda model, X, is_train: ...,
layers=[pipe.model for pipe in pipes],
)
for node in models.walk():
with_nvtx_range(
node, forward_color=forward_color, backprop_color=backprop_color
)
return nlp
return models_with_nvtx_range