Add guards around torch import

Torch is required for the coref/spanpred models but shouldn't be
required for spaCy in general.

The one tricky part of this is that one function in coref_util relied on
torch, but that file was imported in several places. Since the function
was only used in one place I moved it there.
This commit is contained in:
Paul O'Leary McCann 2022-05-24 15:16:25 +09:00
parent e38e84a677
commit 9da16df96e
3 changed files with 25 additions and 18 deletions

View File

@ -1,5 +1,3 @@
from .coref import * #noqa
from .span_predictor import * #noqa
from .entity_linker import * # noqa
from .multi_task import * # noqa
from .parser import * # noqa
@ -7,3 +5,12 @@ from .spancat import * # noqa
from .tagger import * # noqa
from .textcat import * # noqa
from .tok2vec import * # noqa
# some models require Torch
try:
import torch
from .coref import * #noqa
from .span_predictor import * #noqa
except ImportError:
pass

View File

@ -8,7 +8,6 @@ from thinc.util import xp2torch, torch2xp
from ...tokens import Doc
from ...util import registry
from .coref_util import add_dummy
@registry.architectures("spacy.Coref.v1")
@ -186,6 +185,22 @@ class CorefScorer(torch.nn.Module):
return coref_scores, top_indices
# Note this function is kept here to keep a torch dep out of coref_util.
def add_dummy(tensor: torch.Tensor, eps: bool = False):
"""Prepends zeros (or a very small value if eps is True)
to the first (not zeroth) dimension of tensor.
"""
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
shape: List[int] = list(tensor.shape)
shape[1] = 1
if not eps:
dummy = torch.zeros(shape, **kwargs) # type: ignore
else:
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
output = torch.cat((dummy, tensor), dim=1)
return output
class AnaphoricityScorer(torch.nn.Module):
"""Calculates anaphoricity scores by passing the inputs into a FFNN"""

View File

@ -2,7 +2,6 @@ from thinc.types import Ints2d
from spacy.tokens import Doc
from typing import List, Tuple, Callable, Any, Set, Dict
from ...util import registry
import torch
# type alias to make writing this less tedious
MentionClusters = List[List[Tuple[int, int]]]
@ -25,20 +24,6 @@ class GraphNode:
return str(self.id)
def add_dummy(tensor: torch.Tensor, eps: bool = False):
""" Prepends zeros (or a very small value if eps is True)
to the first (not zeroth) dimension of tensor.
"""
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
shape: List[int] = list(tensor.shape)
shape[1] = 1
if not eps:
dummy = torch.zeros(shape, **kwargs) # type: ignore
else:
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
output = torch.cat((dummy, tensor), dim=1)
return output
def get_sentence_ids(doc):
out = []
sent_id = -1