mirror of
https://github.com/explosion/spaCy.git
synced 2025-09-18 01:52:37 +03:00
Clean up util code
Moved everything into coref_util.py, deleted wl-specific file.
This commit is contained in:
parent
55039a66ad
commit
abdc7d87af
|
@ -458,7 +458,7 @@ import torch
|
||||||
from thinc.util import xp2torch, torch2xp
|
from thinc.util import xp2torch, torch2xp
|
||||||
|
|
||||||
# TODO rename this to coref_util
|
# TODO rename this to coref_util
|
||||||
from .coref_util_wl import add_dummy
|
from .coref_util import add_dummy
|
||||||
|
|
||||||
# TODO rename to plain coref
|
# TODO rename to plain coref
|
||||||
@registry.architectures("spacy.WLCoref.v1")
|
@registry.architectures("spacy.WLCoref.v1")
|
||||||
|
|
|
@ -1,13 +1,43 @@
|
||||||
from thinc.types import Ints2d
|
from thinc.types import Ints2d
|
||||||
from spacy.tokens import Doc
|
from spacy.tokens import Doc
|
||||||
from typing import List, Tuple, Callable, Any
|
from typing import List, Tuple, Callable, Any, Set, Dict
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
|
import torch
|
||||||
|
|
||||||
# type alias to make writing this less tedious
|
# type alias to make writing this less tedious
|
||||||
MentionClusters = List[List[Tuple[int, int]]]
|
MentionClusters = List[List[Tuple[int, int]]]
|
||||||
|
|
||||||
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
|
DEFAULT_CLUSTER_PREFIX = "coref_clusters"
|
||||||
|
|
||||||
|
EPSILON = 1e-7
|
||||||
|
|
||||||
|
class GraphNode:
|
||||||
|
def __init__(self, node_id: int):
|
||||||
|
self.id = node_id
|
||||||
|
self.links: Set[GraphNode] = set()
|
||||||
|
self.visited = False
|
||||||
|
|
||||||
|
def link(self, another: "GraphNode"):
|
||||||
|
self.links.add(another)
|
||||||
|
another.links.add(self)
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return str(self.id)
|
||||||
|
|
||||||
|
|
||||||
|
def add_dummy(tensor: torch.Tensor, eps: bool = False):
|
||||||
|
""" Prepends zeros (or a very small value if eps is True)
|
||||||
|
to the first (not zeroth) dimension of tensor.
|
||||||
|
"""
|
||||||
|
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
|
||||||
|
shape: List[int] = list(tensor.shape)
|
||||||
|
shape[1] = 1
|
||||||
|
if not eps:
|
||||||
|
dummy = torch.zeros(shape, **kwargs) # type: ignore
|
||||||
|
else:
|
||||||
|
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
|
||||||
|
output = torch.cat((dummy, tensor), dim=1)
|
||||||
|
return output
|
||||||
|
|
||||||
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
|
def doc2clusters(doc: Doc, prefix=DEFAULT_CLUSTER_PREFIX) -> MentionClusters:
|
||||||
"""Given a doc, give the mention clusters.
|
"""Given a doc, give the mention clusters.
|
||||||
|
|
|
@ -1,134 +0,0 @@
|
||||||
""" Contains functions not directly linked to coreference resolution """
|
|
||||||
|
|
||||||
from typing import List, Set, Dict, Tuple
|
|
||||||
from thinc.types import Ints1d
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from ...tokens import Doc
|
|
||||||
from ...language import Language
|
|
||||||
|
|
||||||
import torch
|
|
||||||
|
|
||||||
EPSILON = 1e-7
|
|
||||||
|
|
||||||
class GraphNode:
|
|
||||||
def __init__(self, node_id: int):
|
|
||||||
self.id = node_id
|
|
||||||
self.links: Set[GraphNode] = set()
|
|
||||||
self.visited = False
|
|
||||||
|
|
||||||
def link(self, another: "GraphNode"):
|
|
||||||
self.links.add(another)
|
|
||||||
another.links.add(self)
|
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
|
||||||
return str(self.id)
|
|
||||||
|
|
||||||
|
|
||||||
def add_dummy(tensor: torch.Tensor, eps: bool = False):
|
|
||||||
""" Prepends zeros (or a very small value if eps is True)
|
|
||||||
to the first (not zeroth) dimension of tensor.
|
|
||||||
"""
|
|
||||||
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
|
|
||||||
shape: List[int] = list(tensor.shape)
|
|
||||||
shape[1] = 1
|
|
||||||
if not eps:
|
|
||||||
dummy = torch.zeros(shape, **kwargs) # type: ignore
|
|
||||||
else:
|
|
||||||
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
|
|
||||||
output = torch.cat((dummy, tensor), dim=1)
|
|
||||||
return output
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
# TODO replace with spaCy config
|
|
||||||
@dataclass
|
|
||||||
class CorefConfig: # pylint: disable=too-many-instance-attributes, too-few-public-methods
|
|
||||||
""" Contains values needed to set up the coreference model. """
|
|
||||||
section: str
|
|
||||||
|
|
||||||
data_dir: str
|
|
||||||
|
|
||||||
train_data: str
|
|
||||||
dev_data: str
|
|
||||||
test_data: str
|
|
||||||
|
|
||||||
device: str
|
|
||||||
|
|
||||||
bert_model: str
|
|
||||||
bert_window_size: int
|
|
||||||
|
|
||||||
embedding_size: int
|
|
||||||
sp_embedding_size: int
|
|
||||||
a_scoring_batch_size: int
|
|
||||||
hidden_size: int
|
|
||||||
n_hidden_layers: int
|
|
||||||
|
|
||||||
max_span_len: int
|
|
||||||
|
|
||||||
rough_k: int
|
|
||||||
|
|
||||||
bert_finetune: bool
|
|
||||||
bert_mini_finetune: bool
|
|
||||||
dropout_rate: float
|
|
||||||
learning_rate: float
|
|
||||||
bert_learning_rate: float
|
|
||||||
train_epochs: int
|
|
||||||
bce_loss_weight: float
|
|
||||||
|
|
||||||
tokenizer_kwargs: Dict[str, dict]
|
|
||||||
conll_log_dir: str
|
|
||||||
|
|
||||||
|
|
||||||
def get_sent_ids(doc):
|
|
||||||
sid = 0
|
|
||||||
sids = []
|
|
||||||
for sent in doc.sents:
|
|
||||||
for tok in sent:
|
|
||||||
sids.append(sid)
|
|
||||||
sid += 1
|
|
||||||
return sids
|
|
||||||
|
|
||||||
def get_cluster_ids(doc):
|
|
||||||
"""Get the cluster ids of head tokens."""
|
|
||||||
|
|
||||||
out = [0] * len(doc)
|
|
||||||
head_spangroups = [doc.spans[sk] for sk in doc.spans if sk.startswith("coref_word_clusters")]
|
|
||||||
for ii, group in enumerate(head_spangroups, start=1):
|
|
||||||
for span in group:
|
|
||||||
out[span[0].i] = ii
|
|
||||||
|
|
||||||
return out
|
|
||||||
|
|
||||||
def get_head2span(doc):
|
|
||||||
out = []
|
|
||||||
for sk in doc.spans:
|
|
||||||
if not sk.startswith("coref_clusters"):
|
|
||||||
continue
|
|
||||||
|
|
||||||
if len(doc.spans[sk]) == 1:
|
|
||||||
print("===== UNARY MENTION ====")
|
|
||||||
|
|
||||||
for span in doc.spans[sk]:
|
|
||||||
out.append( (span.root.i, span.start, span.end) )
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def doc2tensors(
|
|
||||||
xp,
|
|
||||||
doc: Doc
|
|
||||||
) -> Tuple[Ints1d, Ints1d, Ints1d, Ints1d, Ints1d]:
|
|
||||||
sent_ids = get_sent_ids(doc)
|
|
||||||
cluster_ids = get_cluster_ids(doc)
|
|
||||||
head2span = get_head2span(doc)
|
|
||||||
|
|
||||||
|
|
||||||
if not head2span:
|
|
||||||
heads, starts, ends = [], [], []
|
|
||||||
else:
|
|
||||||
heads, starts, ends = zip(*head2span)
|
|
||||||
sent_ids = xp.asarray(sent_ids)
|
|
||||||
cluster_ids = xp.asarray(cluster_ids)
|
|
||||||
heads = xp.asarray(heads)
|
|
||||||
starts = xp.asarray(starts)
|
|
||||||
ends = xp.asarray(ends) - 1
|
|
||||||
return sent_ids, cluster_ids, heads, starts, ends
|
|
Loading…
Reference in New Issue
Block a user