diff --git a/spacy/ml/models/__init__.py b/spacy/ml/models/__init__.py index 9ae5b5104..4368a556d 100644 --- a/spacy/ml/models/__init__.py +++ b/spacy/ml/models/__init__.py @@ -1,5 +1,3 @@ -from .coref import * #noqa -from .span_predictor import * #noqa from .entity_linker import * # noqa from .multi_task import * # noqa from .parser import * # noqa @@ -7,3 +5,10 @@ from .spancat import * # noqa from .tagger import * # noqa from .textcat import * # noqa from .tok2vec import * # noqa + +# some models require Torch +from thinc.util import has_torch +if has_torch: + from .coref import * #noqa + from .span_predictor import * #noqa + diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 299abdc6b..ca9011577 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -1,14 +1,12 @@ from typing import List, Tuple -import torch from thinc.api import Model, chain from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.types import Floats2d, Ints2d, Ints1d -from thinc.util import xp2torch, torch2xp +from thinc.util import torch, xp2torch, torch2xp from ...tokens import Doc from ...util import registry -from .coref_util import add_dummy @registry.architectures("spacy.Coref.v1") @@ -186,6 +184,23 @@ class CorefScorer(torch.nn.Module): return coref_scores, top_indices +EPSILON = 1e-7 +# Note this function is kept here to keep a torch dep out of coref_util. +def add_dummy(tensor: torch.Tensor, eps: bool = False): + """Prepends zeros (or a very small value if eps is True) + to the first (not zeroth) dimension of tensor. + """ + kwargs = dict(device=tensor.device, dtype=tensor.dtype) + shape: List[int] = list(tensor.shape) + shape[1] = 1 + if not eps: + dummy = torch.zeros(shape, **kwargs) # type: ignore + else: + dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore + output = torch.cat((dummy, tensor), dim=1) + return output + + class AnaphoricityScorer(torch.nn.Module): """Calculates anaphoricity scores by passing the inputs into a FFNN""" diff --git a/spacy/ml/models/coref_util.py b/spacy/ml/models/coref_util.py index e8de1e0ac..86dd0df4b 100644 --- a/spacy/ml/models/coref_util.py +++ b/spacy/ml/models/coref_util.py @@ -2,15 +2,12 @@ from thinc.types import Ints2d from spacy.tokens import Doc from typing import List, Tuple, Callable, Any, Set, Dict from ...util import registry -import torch # type alias to make writing this less tedious MentionClusters = List[List[Tuple[int, int]]] DEFAULT_CLUSTER_PREFIX = "coref_clusters" -EPSILON = 1e-7 - class GraphNode: def __init__(self, node_id: int): self.id = node_id @@ -25,20 +22,6 @@ class GraphNode: return str(self.id) -def add_dummy(tensor: torch.Tensor, eps: bool = False): - """ Prepends zeros (or a very small value if eps is True) - to the first (not zeroth) dimension of tensor. - """ - kwargs = dict(device=tensor.device, dtype=tensor.dtype) - shape: List[int] = list(tensor.shape) - shape[1] = 1 - if not eps: - dummy = torch.zeros(shape, **kwargs) # type: ignore - else: - dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore - output = torch.cat((dummy, tensor), dim=1) - return output - def get_sentence_ids(doc): out = [] sent_id = -1 diff --git a/spacy/ml/models/span_predictor.py b/spacy/ml/models/span_predictor.py index 7375c2153..1ded9c3c7 100644 --- a/spacy/ml/models/span_predictor.py +++ b/spacy/ml/models/span_predictor.py @@ -1,10 +1,9 @@ from typing import List, Tuple -import torch from thinc.api import Model, chain, tuplify from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.types import Floats2d, Ints1d -from thinc.util import xp2torch, torch2xp +from thinc.util import torch, xp2torch, torch2xp from ...tokens import Doc from ...util import registry