Add guards around torch import

Torch is required for the coref/spanpred models but shouldn't be
required for spaCy in general.

The one tricky part of this is that one function in coref_util relied on
torch, but that file was imported in several places. Since the function
was only used in one place I moved it there.
This commit is contained in:
Paul O'Leary McCann 2022-05-24 15:16:25 +09:00
parent e38e84a677
commit 9da16df96e
3 changed files with 25 additions and 18 deletions

View File

@ -1,5 +1,3 @@
from .coref import * #noqa
from .span_predictor import * #noqa
from .entity_linker import * # noqa from .entity_linker import * # noqa
from .multi_task import * # noqa from .multi_task import * # noqa
from .parser import * # noqa from .parser import * # noqa
@ -7,3 +5,12 @@ from .spancat import * # noqa
from .tagger import * # noqa from .tagger import * # noqa
from .textcat import * # noqa from .textcat import * # noqa
from .tok2vec import * # noqa from .tok2vec import * # noqa
# some models require Torch
try:
import torch
from .coref import * #noqa
from .span_predictor import * #noqa
except ImportError:
pass

View File

@ -8,7 +8,6 @@ from thinc.util import xp2torch, torch2xp
from ...tokens import Doc from ...tokens import Doc
from ...util import registry from ...util import registry
from .coref_util import add_dummy
@registry.architectures("spacy.Coref.v1") @registry.architectures("spacy.Coref.v1")
@ -186,6 +185,22 @@ class CorefScorer(torch.nn.Module):
return coref_scores, top_indices return coref_scores, top_indices
# Note this function is kept here to keep a torch dep out of coref_util.
def add_dummy(tensor: torch.Tensor, eps: bool = False):
"""Prepends zeros (or a very small value if eps is True)
to the first (not zeroth) dimension of tensor.
"""
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
shape: List[int] = list(tensor.shape)
shape[1] = 1
if not eps:
dummy = torch.zeros(shape, **kwargs) # type: ignore
else:
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
output = torch.cat((dummy, tensor), dim=1)
return output
class AnaphoricityScorer(torch.nn.Module): class AnaphoricityScorer(torch.nn.Module):
"""Calculates anaphoricity scores by passing the inputs into a FFNN""" """Calculates anaphoricity scores by passing the inputs into a FFNN"""

View File

@ -2,7 +2,6 @@ from thinc.types import Ints2d
from spacy.tokens import Doc from spacy.tokens import Doc
from typing import List, Tuple, Callable, Any, Set, Dict from typing import List, Tuple, Callable, Any, Set, Dict
from ...util import registry from ...util import registry
import torch
# type alias to make writing this less tedious # type alias to make writing this less tedious
MentionClusters = List[List[Tuple[int, int]]] MentionClusters = List[List[Tuple[int, int]]]
@ -25,20 +24,6 @@ class GraphNode:
return str(self.id) return str(self.id)
def add_dummy(tensor: torch.Tensor, eps: bool = False):
""" Prepends zeros (or a very small value if eps is True)
to the first (not zeroth) dimension of tensor.
"""
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
shape: List[int] = list(tensor.shape)
shape[1] = 1
if not eps:
dummy = torch.zeros(shape, **kwargs) # type: ignore
else:
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
output = torch.cat((dummy, tensor), dim=1)
return output
def get_sentence_ids(doc): def get_sentence_ids(doc):
out = [] out = []
sent_id = -1 sent_id = -1