mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-03 05:04:09 +03:00
TrainablePipe
: Add NVTX range decorator
This commit is contained in:
parent
126d1db123
commit
26536eb6b8
|
@ -1,7 +1,9 @@
|
|||
# cython: infer_types=True, profile=True
|
||||
import functools
|
||||
from typing import Iterable, Iterator, Optional, Dict, Tuple, Callable
|
||||
import srsly
|
||||
from thinc.api import set_dropout_rate, Model, Optimizer
|
||||
from thinc.util import use_nvtx_range
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
|
||||
|
@ -14,6 +16,19 @@ from ..language import Language
|
|||
from ..training import Example
|
||||
|
||||
|
||||
def trainable_pipe_nvtx_range(func):
|
||||
"""Decorator meant to used on `TrainablePipe` methods to annotate
|
||||
them in nVidia Nsight profile using the NVTX API.
|
||||
"""
|
||||
|
||||
@functools.wraps(func)
|
||||
def inner(self, *args, **kwargs):
|
||||
with use_nvtx_range(f"{self.name} {func.__name__}"):
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
cdef class TrainablePipe(Pipe):
|
||||
"""This class is a base class and not instantiated directly. Trainable
|
||||
pipeline components like the EntityRecognizer or TextCategorizer inherit
|
||||
|
@ -55,6 +70,7 @@ cdef class TrainablePipe(Pipe):
|
|||
except Exception as e:
|
||||
error_handler(self.name, self, [doc], e)
|
||||
|
||||
@trainable_pipe_nvtx_range
|
||||
def pipe(self, stream: Iterable[Doc], *, batch_size: int=128) -> Iterator[Doc]:
|
||||
"""Apply the pipe to a stream of documents. This usually happens under
|
||||
the hood when the nlp object is called on a text and all components are
|
||||
|
@ -99,6 +115,7 @@ cdef class TrainablePipe(Pipe):
|
|||
"""
|
||||
raise NotImplementedError(Errors.E931.format(parent="TrainablePipe", method="set_annotations", name=self.name))
|
||||
|
||||
@trainable_pipe_nvtx_range
|
||||
def update(self,
|
||||
examples: Iterable["Example"],
|
||||
*,
|
||||
|
@ -135,6 +152,7 @@ cdef class TrainablePipe(Pipe):
|
|||
losses[self.name] += loss
|
||||
return losses
|
||||
|
||||
@trainable_pipe_nvtx_range
|
||||
def rehearse(self,
|
||||
examples: Iterable[Example],
|
||||
*,
|
||||
|
@ -240,6 +258,7 @@ cdef class TrainablePipe(Pipe):
|
|||
with self.model.use_params(params):
|
||||
yield
|
||||
|
||||
@trainable_pipe_nvtx_range
|
||||
def finish_update(self, sgd: Optimizer) -> None:
|
||||
"""Update parameters using the current parameter gradients.
|
||||
The Optimizer instance contains the functionality to perform
|
||||
|
|
Loading…
Reference in New Issue
Block a user