Merge pull request #10844 from polm/feature/coref-torch-guard

Add guards around torch import for coref
This commit is contained in:
Paul O'Leary McCann 2022-05-25 13:50:46 +09:00 committed by GitHub
commit 3807a1ba74
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 26 additions and 24 deletions

View File

@ -1,5 +1,3 @@
from .coref import * #noqa
from .span_predictor import * #noqa
from .entity_linker import * # noqa from .entity_linker import * # noqa
from .multi_task import * # noqa from .multi_task import * # noqa
from .parser import * # noqa from .parser import * # noqa
@ -7,3 +5,10 @@ from .spancat import * # noqa
from .tagger import * # noqa from .tagger import * # noqa
from .textcat import * # noqa from .textcat import * # noqa
from .tok2vec import * # noqa from .tok2vec import * # noqa
# some models require Torch
from thinc.util import has_torch
if has_torch:
from .coref import * #noqa
from .span_predictor import * #noqa

View File

@ -1,14 +1,12 @@
from typing import List, Tuple from typing import List, Tuple
import torch
from thinc.api import Model, chain from thinc.api import Model, chain
from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints2d, Ints1d from thinc.types import Floats2d, Ints2d, Ints1d
from thinc.util import xp2torch, torch2xp from thinc.util import torch, xp2torch, torch2xp
from ...tokens import Doc from ...tokens import Doc
from ...util import registry from ...util import registry
from .coref_util import add_dummy
@registry.architectures("spacy.Coref.v1") @registry.architectures("spacy.Coref.v1")
@ -186,6 +184,23 @@ class CorefScorer(torch.nn.Module):
return coref_scores, top_indices return coref_scores, top_indices
EPSILON = 1e-7
# Note this function is kept here to keep a torch dep out of coref_util.
def add_dummy(tensor: torch.Tensor, eps: bool = False):
"""Prepends zeros (or a very small value if eps is True)
to the first (not zeroth) dimension of tensor.
"""
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
shape: List[int] = list(tensor.shape)
shape[1] = 1
if not eps:
dummy = torch.zeros(shape, **kwargs) # type: ignore
else:
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
output = torch.cat((dummy, tensor), dim=1)
return output
class AnaphoricityScorer(torch.nn.Module): class AnaphoricityScorer(torch.nn.Module):
"""Calculates anaphoricity scores by passing the inputs into a FFNN""" """Calculates anaphoricity scores by passing the inputs into a FFNN"""

View File

@ -2,15 +2,12 @@ from thinc.types import Ints2d
from spacy.tokens import Doc from spacy.tokens import Doc
from typing import List, Tuple, Callable, Any, Set, Dict from typing import List, Tuple, Callable, Any, Set, Dict
from ...util import registry from ...util import registry
import torch
# type alias to make writing this less tedious # type alias to make writing this less tedious
MentionClusters = List[List[Tuple[int, int]]] MentionClusters = List[List[Tuple[int, int]]]
DEFAULT_CLUSTER_PREFIX = "coref_clusters" DEFAULT_CLUSTER_PREFIX = "coref_clusters"
EPSILON = 1e-7
class GraphNode: class GraphNode:
def __init__(self, node_id: int): def __init__(self, node_id: int):
self.id = node_id self.id = node_id
@ -25,20 +22,6 @@ class GraphNode:
return str(self.id) return str(self.id)
def add_dummy(tensor: torch.Tensor, eps: bool = False):
""" Prepends zeros (or a very small value if eps is True)
to the first (not zeroth) dimension of tensor.
"""
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
shape: List[int] = list(tensor.shape)
shape[1] = 1
if not eps:
dummy = torch.zeros(shape, **kwargs) # type: ignore
else:
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
output = torch.cat((dummy, tensor), dim=1)
return output
def get_sentence_ids(doc): def get_sentence_ids(doc):
out = [] out = []
sent_id = -1 sent_id = -1

View File

@ -1,10 +1,9 @@
from typing import List, Tuple from typing import List, Tuple
import torch
from thinc.api import Model, chain, tuplify from thinc.api import Model, chain, tuplify
from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints1d from thinc.types import Floats2d, Ints1d
from thinc.util import xp2torch, torch2xp from thinc.util import torch, xp2torch, torch2xp
from ...tokens import Doc from ...tokens import Doc
from ...util import registry from ...util import registry