From 9da16df96edb5ecb2dbe261b002a750d02efadf4 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Tue, 24 May 2022 15:16:25 +0900 Subject: [PATCH 1/4] Add guards around torch import Torch is required for the coref/spanpred models but shouldn't be required for spaCy in general. The one tricky part of this is that one function in coref_util relied on torch, but that file was imported in several places. Since the function was only used in one place I moved it there. --- spacy/ml/models/__init__.py | 11 +++++++++-- spacy/ml/models/coref.py | 17 ++++++++++++++++- spacy/ml/models/coref_util.py | 15 --------------- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/spacy/ml/models/__init__.py b/spacy/ml/models/__init__.py index 9ae5b5104..b01721964 100644 --- a/spacy/ml/models/__init__.py +++ b/spacy/ml/models/__init__.py @@ -1,5 +1,3 @@ -from .coref import * #noqa -from .span_predictor import * #noqa from .entity_linker import * # noqa from .multi_task import * # noqa from .parser import * # noqa @@ -7,3 +5,12 @@ from .spancat import * # noqa from .tagger import * # noqa from .textcat import * # noqa from .tok2vec import * # noqa + +# some models require Torch +try: + import torch + from .coref import * #noqa + from .span_predictor import * #noqa +except ImportError: + pass + diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 299abdc6b..0667053c6 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -8,7 +8,6 @@ from thinc.util import xp2torch, torch2xp from ...tokens import Doc from ...util import registry -from .coref_util import add_dummy @registry.architectures("spacy.Coref.v1") @@ -186,6 +185,22 @@ class CorefScorer(torch.nn.Module): return coref_scores, top_indices +# Note this function is kept here to keep a torch dep out of coref_util. +def add_dummy(tensor: torch.Tensor, eps: bool = False): + """Prepends zeros (or a very small value if eps is True) + to the first (not zeroth) dimension of tensor. + """ + kwargs = dict(device=tensor.device, dtype=tensor.dtype) + shape: List[int] = list(tensor.shape) + shape[1] = 1 + if not eps: + dummy = torch.zeros(shape, **kwargs) # type: ignore + else: + dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore + output = torch.cat((dummy, tensor), dim=1) + return output + + class AnaphoricityScorer(torch.nn.Module): """Calculates anaphoricity scores by passing the inputs into a FFNN""" diff --git a/spacy/ml/models/coref_util.py b/spacy/ml/models/coref_util.py index e8de1e0ac..05f83189a 100644 --- a/spacy/ml/models/coref_util.py +++ b/spacy/ml/models/coref_util.py @@ -2,7 +2,6 @@ from thinc.types import Ints2d from spacy.tokens import Doc from typing import List, Tuple, Callable, Any, Set, Dict from ...util import registry -import torch # type alias to make writing this less tedious MentionClusters = List[List[Tuple[int, int]]] @@ -25,20 +24,6 @@ class GraphNode: return str(self.id) -def add_dummy(tensor: torch.Tensor, eps: bool = False): - """ Prepends zeros (or a very small value if eps is True) - to the first (not zeroth) dimension of tensor. - """ - kwargs = dict(device=tensor.device, dtype=tensor.dtype) - shape: List[int] = list(tensor.shape) - shape[1] = 1 - if not eps: - dummy = torch.zeros(shape, **kwargs) # type: ignore - else: - dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore - output = torch.cat((dummy, tensor), dim=1) - return output - def get_sentence_ids(doc): out = [] sent_id = -1 From b1118cee584d8a1a4eb40dfd6d9660388807cf12 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Tue, 24 May 2022 15:59:08 +0900 Subject: [PATCH 2/4] Move epsilon --- spacy/ml/models/coref.py | 1 + spacy/ml/models/coref_util.py | 2 -- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 0667053c6..b4d8030e8 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -185,6 +185,7 @@ class CorefScorer(torch.nn.Module): return coref_scores, top_indices +EPSILON = 1e-7 # Note this function is kept here to keep a torch dep out of coref_util. def add_dummy(tensor: torch.Tensor, eps: bool = False): """Prepends zeros (or a very small value if eps is True) diff --git a/spacy/ml/models/coref_util.py b/spacy/ml/models/coref_util.py index 05f83189a..86dd0df4b 100644 --- a/spacy/ml/models/coref_util.py +++ b/spacy/ml/models/coref_util.py @@ -8,8 +8,6 @@ MentionClusters = List[List[Tuple[int, int]]] DEFAULT_CLUSTER_PREFIX = "coref_clusters" -EPSILON = 1e-7 - class GraphNode: def __init__(self, node_id: int): self.id = node_id From 5cbc9f4573686857cd5b2cfb8cc38fb51ff2a5a2 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Tue, 24 May 2022 16:02:39 +0900 Subject: [PATCH 3/4] Use thinc.util.has_torch --- spacy/ml/models/__init__.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/spacy/ml/models/__init__.py b/spacy/ml/models/__init__.py index b01721964..4368a556d 100644 --- a/spacy/ml/models/__init__.py +++ b/spacy/ml/models/__init__.py @@ -7,10 +7,8 @@ from .textcat import * # noqa from .tok2vec import * # noqa # some models require Torch -try: - import torch +from thinc.util import has_torch +if has_torch: from .coref import * #noqa from .span_predictor import * #noqa -except ImportError: - pass From c9233a5a1f34b2e8397d40abfde283412bc79b6e Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Tue, 24 May 2022 17:28:27 +0900 Subject: [PATCH 4/4] Import torch from thinc --- spacy/ml/models/coref.py | 3 +-- spacy/ml/models/span_predictor.py | 3 +-- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index b4d8030e8..ca9011577 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -1,10 +1,9 @@ from typing import List, Tuple -import torch from thinc.api import Model, chain from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.types import Floats2d, Ints2d, Ints1d -from thinc.util import xp2torch, torch2xp +from thinc.util import torch, xp2torch, torch2xp from ...tokens import Doc from ...util import registry diff --git a/spacy/ml/models/span_predictor.py b/spacy/ml/models/span_predictor.py index 7375c2153..1ded9c3c7 100644 --- a/spacy/ml/models/span_predictor.py +++ b/spacy/ml/models/span_predictor.py @@ -1,10 +1,9 @@ from typing import List, Tuple -import torch from thinc.api import Model, chain, tuplify from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.types import Floats2d, Ints1d -from thinc.util import xp2torch, torch2xp +from thinc.util import torch, xp2torch, torch2xp from ...tokens import Doc from ...util import registry