mirror of
https://github.com/explosion/spaCy.git
synced 2025-10-03 02:17:00 +03:00
384 lines
14 KiB
Python
384 lines
14 KiB
Python
from typing import List, Tuple, Callable, cast
|
|
|
|
from thinc.api import Model, chain, get_width
|
|
from thinc.api import PyTorchWrapper, ArgsKwargs
|
|
from thinc.types import Floats2d, Ints2d
|
|
from thinc.util import torch, xp2torch, torch2xp
|
|
|
|
from ...tokens import Doc
|
|
from ...util import registry
|
|
|
|
|
|
EPSILON = 1e-7
|
|
|
|
|
|
@registry.architectures("spacy.Coref.v1")
|
|
def build_wl_coref_model(
|
|
tok2vec: Model[List[Doc], List[Floats2d]],
|
|
distance_embedding_size: int = 20,
|
|
hidden_size: int = 1024,
|
|
depth: int = 1,
|
|
dropout: float = 0.3,
|
|
# pairs to keep per mention after rough scoring
|
|
antecedent_limit: int = 50,
|
|
antecedent_batch_size: int = 512,
|
|
nI=None,
|
|
) -> Model[List[Doc], Tuple[Floats2d, Ints2d]]:
|
|
|
|
with Model.define_operators({">>": chain}):
|
|
coref_clusterer: Model[List[Floats2d], Tuple[Floats2d, Ints2d]] = Model(
|
|
"coref_clusterer",
|
|
forward=coref_forward,
|
|
init=coref_init,
|
|
dims={"nI": nI},
|
|
attrs={
|
|
"distance_embedding_size": distance_embedding_size,
|
|
"hidden_size": hidden_size,
|
|
"depth": depth,
|
|
"dropout": dropout,
|
|
"antecedent_limit": antecedent_limit,
|
|
"antecedent_batch_size": antecedent_batch_size,
|
|
},
|
|
)
|
|
|
|
model = tok2vec >> coref_clusterer
|
|
model.set_ref("coref_clusterer", coref_clusterer)
|
|
return model
|
|
|
|
|
|
def coref_init(model: Model, X=None, Y=None):
|
|
if model.layers:
|
|
return
|
|
|
|
if X is not None and model.has_dim("nI") is None:
|
|
model.set_dim("nI", get_width(X))
|
|
|
|
hidden_size = model.attrs["hidden_size"]
|
|
depth = model.attrs["depth"]
|
|
dropout = model.attrs["dropout"]
|
|
antecedent_limit = model.attrs["antecedent_limit"]
|
|
antecedent_batch_size = model.attrs["antecedent_batch_size"]
|
|
distance_embedding_size = model.attrs["distance_embedding_size"]
|
|
|
|
model._layers = [
|
|
PyTorchWrapper(
|
|
CorefClusterer(
|
|
model.get_dim("nI"),
|
|
distance_embedding_size,
|
|
hidden_size,
|
|
depth,
|
|
dropout,
|
|
antecedent_limit,
|
|
antecedent_batch_size,
|
|
),
|
|
convert_inputs=convert_coref_clusterer_inputs,
|
|
convert_outputs=convert_coref_clusterer_outputs,
|
|
)
|
|
# TODO maybe we need mixed precision and grad scaling?
|
|
]
|
|
|
|
|
|
def coref_forward(model: Model, X, is_train: bool):
|
|
return model.layers[0](X, is_train)
|
|
|
|
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
|
|
# The input here is List[Floats2d], one for each doc
|
|
# just use the first
|
|
# TODO real batching
|
|
X = X[0]
|
|
word_features = xp2torch(X, requires_grad=is_train)
|
|
|
|
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
|
# convert to xp and wrap in list
|
|
gradients = cast(Floats2d, torch2xp(args.args[0]))
|
|
return [gradients]
|
|
|
|
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
|
|
|
|
|
def convert_coref_clusterer_outputs(
|
|
model: Model, inputs_outputs, is_train: bool
|
|
) -> Tuple[Tuple[Floats2d, Ints2d], Callable]:
|
|
_, outputs = inputs_outputs
|
|
scores, indices = outputs
|
|
|
|
def convert_for_torch_backward(dY: Floats2d) -> ArgsKwargs:
|
|
dY_t = xp2torch(dY[0])
|
|
return ArgsKwargs(
|
|
args=([scores],),
|
|
kwargs={"grad_tensors": [dY_t]},
|
|
)
|
|
|
|
scores_xp = cast(Floats2d, torch2xp(scores))
|
|
indices_xp = cast(Ints2d, torch2xp(indices))
|
|
return (scores_xp, indices_xp), convert_for_torch_backward
|
|
|
|
|
|
class CorefClusterer(torch.nn.Module):
|
|
"""
|
|
Combines all coref modules together to find coreferent token pairs.
|
|
Submodules (in the order of their usage in the pipeline):
|
|
- rough_scorer (RoughScorer) that prunes candidate pairs
|
|
- pw (DistancePairwiseEncoder) that computes pairwise features
|
|
- a_scorer (AnaphoricityScorer) produces the final scores
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
dim: int,
|
|
dist_emb_size: int,
|
|
hidden_size: int,
|
|
n_layers: int,
|
|
dropout: float,
|
|
roughk: int,
|
|
batch_size: int,
|
|
):
|
|
super().__init__()
|
|
"""
|
|
dim: Size of the input features.
|
|
dist_emb_size: Size of the distance embeddings.
|
|
hidden_size: Size of the coreference candidate embeddings.
|
|
n_layers: Numbers of layers in the AnaphoricityScorer.
|
|
dropout: Dropout probability to apply across all modules.
|
|
roughk: Number of candidates the RoughScorer returns.
|
|
batch_size: Internal batch-size for the more expensive scorer.
|
|
"""
|
|
self.dropout = torch.nn.Dropout(dropout)
|
|
self.batch_size = batch_size
|
|
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
|
|
|
pair_emb = dim * 3 + self.pw.shape
|
|
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
|
|
self.lstm = torch.nn.LSTM(
|
|
input_size=dim,
|
|
hidden_size=dim,
|
|
batch_first=True,
|
|
)
|
|
|
|
self.rough_scorer = RoughScorer(dim, dropout, roughk)
|
|
|
|
def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
1. LSTM encodes the incoming word_features.
|
|
2. The RoughScorer scores and prunes the candidates.
|
|
3. The DistancePairwiseEncoder embeds the distances between pairs.
|
|
4. The AnaphoricityScorer scores all pairs in mini-batches.
|
|
|
|
word_features: torch.Tensor containing word encodings
|
|
|
|
returns:
|
|
coref_scores: n_words x roughk floats.
|
|
top_indices: n_words x roughk integers.
|
|
"""
|
|
self.lstm.flatten_parameters() # XXX without this there's a warning
|
|
word_features = torch.unsqueeze(word_features, dim=0)
|
|
words, _ = self.lstm(word_features)
|
|
words = words.squeeze()
|
|
# words: n_words x dim
|
|
words = self.dropout(words)
|
|
# Obtain bilinear scores and leave only top-k antecedents for each word
|
|
# top_rough_scores: (n_words x roughk)
|
|
# top_indices: (n_words x roughk)
|
|
top_rough_scores, top_indices = self.rough_scorer(words)
|
|
# Get pairwise features
|
|
# (n_words x roughk x n_pw_features)
|
|
pw = self.pw(top_indices)
|
|
batch_size = self.batch_size
|
|
a_scores_lst: List[torch.Tensor] = []
|
|
|
|
for i in range(0, len(words), batch_size):
|
|
pw_batch = pw[i : i + batch_size]
|
|
words_batch = words[i : i + batch_size]
|
|
top_indices_batch = top_indices[i : i + batch_size]
|
|
top_rough_scores_batch = top_rough_scores[i : i + batch_size]
|
|
|
|
# a_scores_batch [batch_size, n_ants]
|
|
a_scores_batch = self.a_scorer(
|
|
all_mentions=words,
|
|
mentions_batch=words_batch,
|
|
pw_batch=pw_batch,
|
|
top_indices_batch=top_indices_batch,
|
|
top_rough_scores_batch=top_rough_scores_batch,
|
|
)
|
|
a_scores_lst.append(a_scores_batch)
|
|
|
|
coref_scores = torch.cat(a_scores_lst, dim=0)
|
|
return coref_scores, top_indices
|
|
|
|
|
|
# Note this function is kept here to keep a torch dep out of coref_util.
|
|
def add_dummy(tensor: torch.Tensor, eps: bool = False):
|
|
"""Prepends zeros (or a very small value if eps is True)
|
|
to the first (not zeroth) dimension of tensor.
|
|
"""
|
|
kwargs = dict(device=tensor.device, dtype=tensor.dtype)
|
|
shape: List[int] = list(tensor.shape)
|
|
shape[1] = 1
|
|
if not eps:
|
|
dummy = torch.zeros(shape, **kwargs) # type: ignore
|
|
else:
|
|
dummy = torch.full(shape, EPSILON, **kwargs) # type: ignore
|
|
output = torch.cat((dummy, tensor), dim=1)
|
|
return output
|
|
|
|
|
|
class AnaphoricityScorer(torch.nn.Module):
|
|
"""Calculates anaphoricity scores by passing the inputs into a FFNN"""
|
|
|
|
def __init__(self, in_features: int, hidden_size, depth, dropout):
|
|
super().__init__()
|
|
hidden_size = hidden_size
|
|
if not depth:
|
|
hidden_size = in_features
|
|
layers = []
|
|
for i in range(depth):
|
|
layers.extend(
|
|
[
|
|
torch.nn.Linear(hidden_size if i else in_features, hidden_size),
|
|
torch.nn.LeakyReLU(),
|
|
torch.nn.Dropout(dropout),
|
|
]
|
|
)
|
|
self.hidden = torch.nn.Sequential(*layers)
|
|
self.out = torch.nn.Linear(hidden_size, out_features=1)
|
|
|
|
def forward(
|
|
self,
|
|
*, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
|
all_mentions: torch.Tensor,
|
|
mentions_batch: torch.Tensor,
|
|
pw_batch: torch.Tensor,
|
|
top_indices_batch: torch.Tensor,
|
|
top_rough_scores_batch: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""Builds a pairwise matrix, scores the pairs and returns the scores.
|
|
|
|
Args:
|
|
all_mentions (torch.Tensor): [n_mentions, mention_emb]
|
|
mentions_batch (torch.Tensor): [batch_size, mention_emb]
|
|
pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb]
|
|
top_indices_batch (torch.Tensor): [batch_size, n_ants]
|
|
top_rough_scores_batch (torch.Tensor): [batch_size, n_ants]
|
|
|
|
Returns:
|
|
torch.Tensor [batch_size, n_ants + 1]
|
|
anaphoricity scores for the pairs + a dummy column
|
|
"""
|
|
# [batch_size, n_ants, pair_emb]
|
|
pair_matrix = self._get_pair_matrix(
|
|
all_mentions, mentions_batch, pw_batch, top_indices_batch
|
|
)
|
|
|
|
# [batch_size, n_ants]
|
|
scores = top_rough_scores_batch + self._ffnn(pair_matrix)
|
|
scores = add_dummy(scores, eps=True)
|
|
|
|
return scores
|
|
|
|
def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
|
|
"""
|
|
x: tensor of shape (batch_size x roughk x n_features
|
|
returns: tensor of shape (batch_size x antecedent_limit)
|
|
"""
|
|
x = self.out(self.hidden(x))
|
|
return x.squeeze(2)
|
|
|
|
@staticmethod
|
|
def _get_pair_matrix(
|
|
all_mentions: torch.Tensor,
|
|
mentions_batch: torch.Tensor,
|
|
pw_batch: torch.Tensor,
|
|
top_indices_batch: torch.Tensor,
|
|
) -> torch.Tensor:
|
|
"""
|
|
Builds the matrix used as input for AnaphoricityScorer.
|
|
|
|
all_mentions: (n_mentions x mention_emb),
|
|
all the valid mentions of the document,
|
|
can be on a different device
|
|
mentions_batch: (batch_size x mention_emb),
|
|
the mentions of the current batch.
|
|
pw_batch: (batch_size x roughk x pw_emb),
|
|
pairwise distance features of the current batch.
|
|
top_indices_batch: (batch_size x n_ants),
|
|
indices of antecedents of each mention
|
|
|
|
Returns:
|
|
out: pairwise features (batch_size x n_ants x pair_emb)
|
|
"""
|
|
emb_size = mentions_batch.shape[1]
|
|
n_ants = pw_batch.shape[1]
|
|
|
|
a_mentions = mentions_batch.unsqueeze(1).expand(-1, n_ants, emb_size)
|
|
b_mentions = all_mentions[top_indices_batch]
|
|
similarity = a_mentions * b_mentions
|
|
|
|
out = torch.cat((a_mentions, b_mentions, similarity, pw_batch), dim=2)
|
|
return out
|
|
|
|
|
|
class RoughScorer(torch.nn.Module):
|
|
"""
|
|
Cheaper module that gives a rough estimate of the anaphoricity of two
|
|
candidates, only top scoring candidates are considered on later
|
|
steps to reduce computational cost.
|
|
"""
|
|
|
|
def __init__(self, features: int, dropout: float, antecedent_limit: int):
|
|
super().__init__()
|
|
self.dropout = torch.nn.Dropout(dropout)
|
|
self.bilinear = torch.nn.Linear(features, features)
|
|
self.k = antecedent_limit
|
|
|
|
def forward(
|
|
self, # type: ignore
|
|
mentions: torch.Tensor,
|
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
"""
|
|
Returns rough anaphoricity scores for candidates, which consist of
|
|
the bilinear output of the current model summed with mention scores.
|
|
"""
|
|
# [n_mentions, n_mentions]
|
|
pair_mask = torch.arange(mentions.shape[0])
|
|
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
|
|
pair_mask = torch.log((pair_mask > 0).to(torch.float))
|
|
pair_mask = pair_mask.to(mentions.device)
|
|
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
|
|
rough_scores = pair_mask + bilinear_scores
|
|
top_scores, indices = torch.topk(
|
|
rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False
|
|
)
|
|
|
|
return top_scores, indices
|
|
|
|
|
|
class DistancePairwiseEncoder(torch.nn.Module):
|
|
def __init__(self, distance_embedding_size, dropout):
|
|
"""
|
|
Takes the top_indices indicating, which is a ranked
|
|
list for each word and its most likely corresponding
|
|
anaphora candidates. For each of these pairs it looks
|
|
up a distance embedding from a table, where the distance
|
|
corresponds to the log-distance.
|
|
|
|
distance_embedding_size: int,
|
|
Dimensionality of the distance-embeddings table.
|
|
dropout: float,
|
|
Dropout probability.
|
|
"""
|
|
super().__init__()
|
|
emb_size = distance_embedding_size
|
|
self.distance_emb = torch.nn.Embedding(9, emb_size)
|
|
self.dropout = torch.nn.Dropout(dropout)
|
|
self.shape = emb_size
|
|
|
|
def forward(self, top_indices: torch.Tensor) -> torch.Tensor:
|
|
word_ids = torch.arange(0, top_indices.size(0))
|
|
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1)
|
|
log_distance = distance.to(torch.float).log2().floor_()
|
|
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
|
|
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
|
|
distance = distance.to(top_indices.device)
|
|
distance = self.distance_emb(distance)
|
|
return self.dropout(distance)
|