mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
Merge pull request #10844 from polm/feature/coref-torch-guard
Add guards around torch import for coref
This commit is contained in:
commit
3807a1ba74
|
@ -1,5 +1,3 @@
|
||||||
from .coref import * #noqa
|
|
||||||
from .span_predictor import * #noqa
|
|
||||||
from .entity_linker import * # noqa
|
from .entity_linker import * # noqa
|
||||||
from .multi_task import * # noqa
|
from .multi_task import * # noqa
|
||||||
from .parser import * # noqa
|
from .parser import * # noqa
|
||||||
|
@ -7,3 +5,10 @@ from .spancat import * # noqa
|
||||||
from .tagger import * # noqa
|
from .tagger import * # noqa
|
||||||
from .textcat import * # noqa
|
from .textcat import * # noqa
|
||||||
from .tok2vec import * # noqa
|
from .tok2vec import * # noqa
|
||||||
|
|
||||||
|
# some models require Torch
|
||||||
|
from thinc.util import has_torch
|
||||||
|
if has_torch:
|
||||||
|
from .coref import * #noqa
|
||||||
|
from .span_predictor import * #noqa
|
||||||
|
|
||||||
|
|
|
@ -1,14 +1,12 @@
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
import torch
|
|
||||||
|
|
||||||
from thinc.api import Model, chain
|
from thinc.api import Model, chain
|
||||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||||
from thinc.types import Floats2d, Ints2d, Ints1d
|
from thinc.types import Floats2d, Ints2d, Ints1d
|
||||||
from thinc.util import xp2torch, torch2xp
|
from thinc.util import torch, xp2torch, torch2xp
|
||||||
|
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
from .coref_util import add_dummy
|
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.Coref.v1")
|
@registry.architectures("spacy.Coref.v1")
|
||||||
|
@ -186,6 +184,23 @@ class CorefScorer(torch.nn.Module):
|
||||||
return coref_scores, top_indices
|
return coref_scores, top_indices
|
||||||
|
|
||||||
|
|
||||||
|
EPSILON = 1e-7
|
||||||
|
# Note this function is kept here to keep a torch dep out of coref_util.
|
||||||
|
def add_dummy(tensor: torch.Tensor, eps: bool = False):
|
||||||
|
"""Prepends zeros (or a very small value if eps is True)
|
||||||
|
to the first (not zeroth) dimension of tensor.
|
||||||
|
"""
|
||||||
|
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
|
||||||
|
shape: List[int] = list(tensor.shape)
|
||||||
|
shape[1] = 1
|
||||||
|
if not eps:
|
||||||
|
dummy = torch.zeros(shape, **kwargs) # type: ignore
|
||||||
|
else:
|
||||||
|
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
|
||||||
|
output = torch.cat((dummy, tensor), dim=1)
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
class AnaphoricityScorer(torch.nn.Module):
|
class AnaphoricityScorer(torch.nn.Module):
|
||||||
"""Calculates anaphoricity scores by passing the inputs into a FFNN"""
|
"""Calculates anaphoricity scores by passing the inputs into a FFNN"""
|
||||||
|
|
||||||
|
|
|
@ -2,15 +2,12 @@ from thinc.types import Ints2d
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from typing import List, Tuple, Callable, Any, Set, Dict
|
from typing import List, Tuple, Callable, Any, Set, Dict
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
import torch
|
|
||||||
|
|
||||||
# type alias to make writing this less tedious
|
# type alias to make writing this less tedious
|
||||||
MentionClusters = List[List[Tuple[int, int]]]
|
MentionClusters = List[List[Tuple[int, int]]]
|
||||||
|
|
||||||
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
|
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
|
||||||
|
|
||||||
EPSILON = 1e-7
|
|
||||||
|
|
||||||
class GraphNode:
|
class GraphNode:
|
||||||
def __init__(self, node_id: int):
|
def __init__(self, node_id: int):
|
||||||
self.id = node_id
|
self.id = node_id
|
||||||
|
@ -25,20 +22,6 @@ class GraphNode:
|
||||||
return str(self.id)
|
return str(self.id)
|
||||||
|
|
||||||
|
|
||||||
def add_dummy(tensor: torch.Tensor, eps: bool = False):
|
|
||||||
""" Prepends zeros (or a very small value if eps is True)
|
|
||||||
to the first (not zeroth) dimension of tensor.
|
|
||||||
"""
|
|
||||||
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
|
|
||||||
shape: List[int] = list(tensor.shape)
|
|
||||||
shape[1] = 1
|
|
||||||
if not eps:
|
|
||||||
dummy = torch.zeros(shape, **kwargs) # type: ignore
|
|
||||||
else:
|
|
||||||
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
|
|
||||||
output = torch.cat((dummy, tensor), dim=1)
|
|
||||||
return output
|
|
||||||
|
|
||||||
def get_sentence_ids(doc):
|
def get_sentence_ids(doc):
|
||||||
out = []
|
out = []
|
||||||
sent_id = -1
|
sent_id = -1
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
import torch
|
|
||||||
|
|
||||||
from thinc.api import Model, chain, tuplify
|
from thinc.api import Model, chain, tuplify
|
||||||
from thinc.api import PyTorchWrapper, ArgsKwargs
|
from thinc.api import PyTorchWrapper, ArgsKwargs
|
||||||
from thinc.types import Floats2d, Ints1d
|
from thinc.types import Floats2d, Ints1d
|
||||||
from thinc.util import xp2torch, torch2xp
|
from thinc.util import torch, xp2torch, torch2xp
|
||||||
|
|
||||||
from ...tokens import Doc
|
from ...tokens import Doc
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
|
|
Loading…
Reference in New Issue
Block a user