spaCy/spacy/ml/models/coref.py
Paul O'Leary McCann 8eadf3781b Training runs now
Evaluation needs fixing, and code still needs cleanup.
2022-03-14 19:02:17 +09:00

998 lines
34 KiB
Python

from dataclasses import dataclass
import warnings
from thinc.api import Model, Linear, Relu, Dropout
from thinc.api import chain, noop, Embed, add, tuplify, concatenate
from thinc.api import reduce_first, reduce_last, reduce_mean
from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Floats1d, Ints1d, Ints2d, Ragged
from typing import List, Callable, Tuple, Any
from ...tokens import Doc
from ...util import registry
from ..extract_spans import extract_spans
from .coref_util import get_candidate_mentions, select_non_crossing_spans, topk
@registry.architectures("spacy.Coref.v1")
def build_coref(
tok2vec: Model[List[Doc], List[Floats2d]],
get_mentions: Any = get_candidate_mentions,
hidden: int = 1000,
dropout: float = 0.3,
mention_limit: int = 3900,
# TODO this needs a better name. It limits the max mentions as a ratio of
# the token count.
mention_limit_ratio: float = 0.4,
max_span_width: int = 20,
antecedent_limit: int = 50,
):
dim = tok2vec.get_dim("nO") * 3
span_embedder = build_span_embedder(get_mentions, max_span_width)
with Model.define_operators({">>": chain, "&": tuplify, "+": add}):
mention_scorer = (
Linear(nI=dim, nO=hidden)
>> Relu(nI=hidden, nO=hidden)
>> Dropout(dropout)
>> Linear(nI=hidden, nO=hidden)
>> Relu(nI=hidden, nO=hidden)
>> Dropout(dropout)
>> Linear(nI=hidden, nO=1)
)
mention_scorer.initialize()
# TODO make feature_embed_size a param
feature_embed_size = 20
width_scorer = build_width_scorer(max_span_width, hidden, feature_embed_size)
bilinear = Linear(nI=dim, nO=dim) >> Dropout(dropout)
bilinear.initialize()
ms = (build_take_vecs() >> mention_scorer) + width_scorer
model = (
(tok2vec & noop())
>> span_embedder
>> (ms & noop())
>> build_coarse_pruner(mention_limit, mention_limit_ratio)
>> build_ant_scorer(bilinear, Dropout(dropout), antecedent_limit)
)
return model
@dataclass
class SpanEmbeddings:
indices: Ints2d # Array with 2 columns (for start and end index)
vectors: Ragged # Ragged[Floats2d] # One vector per span
# NB: We assume that the indices refer to a concatenated Floats2d that
# has one row per token in the *batch* of documents. This makes it unambiguous
# which row is in which document, because if the lengths are e.g. [10, 5],
# a span starting at 11 must be starting at token 2 of doc 1. A bug could
# potentially cause you to have a span which crosses a doc boundary though,
# which would be bad.
# The lengths in the Ragged are not the tokens per doc, but the number of
# mentions per doc.
def __add__(self, right):
out = self.vectors.data + right.vectors.data
return SpanEmbeddings(self.indices, Ragged(out, self.vectors.lengths))
def __iadd__(self, right):
self.vectors.data += right.vectors.data
return self
def build_width_scorer(max_span_width, hidden_size, feature_embed_size=20):
span_width_prior = (
Embed(nV=max_span_width, nO=feature_embed_size)
>> Linear(nI=feature_embed_size, nO=hidden_size)
>> Relu(nI=hidden_size, nO=hidden_size)
>> Dropout()
>> Linear(nI=hidden_size, nO=1)
)
span_width_prior.initialize()
model = Model("WidthScorer", forward=width_score_forward, layers=[span_width_prior])
model.set_ref("width_prior", span_width_prior)
return model
def width_score_forward(
model, embeds: SpanEmbeddings, is_train
) -> Tuple[Floats1d, Callable]:
# calculate widths, subtracting 1 so it's 0-index
w_ffnn = model.get_ref("width_prior")
idxs = embeds.indices
widths = idxs[:, 1] - idxs[:, 0] - 1
wscores, width_b = w_ffnn(widths, is_train)
lens = embeds.vectors.lengths
def width_score_backward(d_score: Floats1d) -> SpanEmbeddings:
dX = width_b(d_score)
vecs = Ragged(dX, lens)
return SpanEmbeddings(idxs, vecs)
return wscores, width_score_backward
# model converting a Doc/Mention to span embeddings
# get_mentions: Callable[Doc, Pairs[int]]
def build_span_embedder(
get_mentions: Callable,
max_span_width: int = 20,
) -> Model[Tuple[List[Floats2d], List[Doc]], SpanEmbeddings]:
with Model.define_operators({">>": chain, "|": concatenate}):
span_reduce = extract_spans() >> (
reduce_first() | reduce_last() | reduce_mean()
)
model = Model(
"SpanEmbedding",
forward=span_embeddings_forward,
attrs={
"get_mentions": get_mentions,
# XXX might be better to make this an implicit parameter in the
# mention generator
"max_span_width": max_span_width,
},
layers=[span_reduce],
)
model.set_ref("span_reducer", span_reduce)
return model
def span_embeddings_forward(
model, inputs: Tuple[List[Floats2d], List[Doc]], is_train
) -> Tuple[SpanEmbeddings, Callable]:
ops = model.ops
xp = ops.xp
tokvecs, docs = inputs
# TODO fix this
dim = tokvecs[0].shape[1]
get_mentions = model.attrs["get_mentions"]
max_span_width = model.attrs["max_span_width"]
mentions = ops.alloc2i(0, 2)
docmenlens = [] # number of mentions per doc
for doc in docs:
starts, ends = get_mentions(doc, max_span_width)
docmenlens.append(len(starts))
cments = ops.asarray2i([starts, ends]).transpose()
mentions = xp.concatenate((mentions, cments))
# TODO support attention here
tokvecs = xp.concatenate(tokvecs)
doclens = [len(doc) for doc in docs]
tokvecs_r = Ragged(tokvecs, doclens)
mentions_r = Ragged(mentions, docmenlens)
span_reduce = model.get_ref("span_reducer")
spanvecs, span_reduce_back = span_reduce((tokvecs_r, mentions_r), is_train)
embeds = Ragged(spanvecs, docmenlens)
def backprop_span_embed(dY: SpanEmbeddings) -> Tuple[List[Floats2d], List[Doc]]:
grad, idxes = span_reduce_back(dY.vectors.data)
oweights = []
offset = 0
for doclen in doclens:
hi = offset + doclen
oweights.append(grad.data[offset:hi])
offset = hi
return oweights, docs
return SpanEmbeddings(mentions, embeds), backprop_span_embed
def build_coarse_pruner(
mention_limit: int,
mention_limit_ratio: float,
) -> Model[SpanEmbeddings, SpanEmbeddings]:
model = Model(
"CoarsePruner",
forward=coarse_prune,
attrs={
"mention_limit": mention_limit,
"mention_limit_ratio": mention_limit_ratio,
},
)
return model
def coarse_prune(
model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
) -> Tuple[Tuple[Floats1d, SpanEmbeddings], Callable]:
"""Given scores for mention, output the top non-crossing mentions.
Mentions can contain other mentions, but candidate mentions cannot cross each other.
"""
rawscores, spanembeds = inputs
scores = rawscores.flatten()
mention_limit = model.attrs["mention_limit"]
mention_limit_ratio = model.attrs["mention_limit_ratio"]
# XXX: Issue here. Don't need docs to find crossing spans, but might for the limits.
# In old code the limit can be:
# - hard number per doc
# - ratio of tokens in the doc
offset = 0
selected = []
sellens = []
for menlen in spanembeds.vectors.lengths:
hi = offset + menlen
cscores = scores[offset:hi]
# negate it so highest numbers come first
# This is relatively slow but can't be skipped.
tops = (model.ops.xp.argsort(-1 * cscores)).tolist()
starts = spanembeds.indices[offset:hi, 0].tolist()
ends = spanembeds.indices[offset:hi:, 1].tolist()
# calculate the doc length
doclen = ends[-1] - starts[0]
# XXX seems to make more sense to use menlen than doclen here?
# coref-hoi uses doclen (number of words).
mlimit = min(mention_limit, int(mention_limit_ratio * doclen))
# csel is a 1d integer list
csel = select_non_crossing_spans(tops, starts, ends, mlimit)
# add the offset so these indices are absolute
csel = [ii + offset for ii in csel]
# this should be constant because short choices are padded
sellens.append(len(csel))
selected += csel
offset += menlen
selected = model.ops.asarray1i(selected)
top_spans = spanembeds.indices[selected]
top_vecs = spanembeds.vectors.data[selected]
out = SpanEmbeddings(top_spans, Ragged(top_vecs, sellens))
# save some variables so the embeds can be garbage collected
idxlen = spanembeds.indices.shape[0]
vecshape = spanembeds.vectors.data.shape
indices = spanembeds.indices
veclens = out.vectors.lengths
def coarse_prune_backprop(
dY: Tuple[Floats1d, SpanEmbeddings]
) -> Tuple[Floats1d, SpanEmbeddings]:
dYscores, dYembeds = dY
dXscores = model.ops.alloc1f(idxlen)
dXscores[selected] = dYscores.flatten()
dXvecs = model.ops.alloc2f(*vecshape)
dXvecs[selected] = dYembeds.vectors.data
rout = Ragged(dXvecs, veclens)
dXembeds = SpanEmbeddings(indices, rout)
# inflate for mention scorer
dXscores = model.ops.xp.expand_dims(dXscores, 1)
return (dXscores, dXembeds)
return (scores[selected], out), coarse_prune_backprop
def build_take_vecs() -> Model[SpanEmbeddings, Floats2d]:
# this just gets vectors out of spanembeddings
# XXX Might be better to convert SpanEmbeddings to a tuple and use with_getitem
return Model("TakeVecs", forward=take_vecs_forward)
def take_vecs_forward(model, inputs: SpanEmbeddings, is_train) -> Floats2d:
idxs = inputs.indices
lens = inputs.vectors.lengths
def backprop(dY: Floats2d) -> SpanEmbeddings:
vecs = Ragged(dY, lens)
return SpanEmbeddings(idxs, vecs)
return inputs.vectors.data, backprop
def build_ant_scorer(
bilinear, dropout, ant_limit=50
) -> Model[Tuple[Floats1d, SpanEmbeddings], List[Floats2d]]:
model = Model(
"AntScorer",
forward=ant_scorer_forward,
layers=[bilinear, dropout],
attrs={
"ant_limit": ant_limit,
},
)
model.set_ref("bilinear", bilinear)
model.set_ref("dropout", dropout)
return model
def ant_scorer_forward(
model, inputs: Tuple[Floats1d, SpanEmbeddings], is_train
) -> Tuple[Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d], Callable]:
ops = model.ops
xp = ops.xp
ant_limit = model.attrs["ant_limit"]
# this contains the coarse bilinear in coref-hoi
# coarse bilinear is a single layer linear network
# TODO make these proper refs
bilinear = model.get_ref("bilinear")
dropout = model.get_ref("dropout")
mscores, sembeds = inputs
vecs = sembeds.vectors # ragged
offset = 0
backprops = []
out = []
for ll in vecs.lengths:
hi = offset + ll
# each iteration is one doc
# first calculate the pairwise product scores
cvecs = vecs.data[offset:hi]
pw_prod, prod_back = pairwise_product(bilinear, dropout, cvecs, is_train)
# now calculate the pairwise mention scores
ms = mscores[offset:hi].flatten()
pw_sum, pw_sum_back = pairwise_sum(ops, ms)
# make a mask so antecedents precede referrents
ant_range = xp.arange(0, cvecs.shape[0])
# This will take the log of 0, which causes a warning, but we're doing
# it on purpose so we can just ignore the warning.
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=RuntimeWarning)
mask = xp.log(
(xp.expand_dims(ant_range, 1) - xp.expand_dims(ant_range, 0)) >= 1
).astype("f")
scores = pw_prod + pw_sum + mask
top_limit = min(ant_limit, len(scores))
top_scores, top_scores_idx = topk(xp, scores, top_limit)
# now add the placeholder
placeholder = ops.alloc2f(scores.shape[0], 1)
top_scores = xp.concatenate((placeholder, top_scores), 1)
out.append((top_scores, top_scores_idx))
# In the full model these scores can be further refined. In the current
# state of this model we're done here, so this pruning is less important,
# but it's still helpful for reducing memory usage (since scores can be
# garbage collected when the loop exits).
offset += ll
backprops.append((prod_back, pw_sum_back))
# save vars for gc
vecshape = vecs.data.shape
veclens = vecs.lengths
scoreshape = mscores.shape
idxes = sembeds.indices
def backprop(
dYs: Tuple[List[Tuple[Floats2d, Ints2d]], Ints2d]
) -> Tuple[Floats2d, SpanEmbeddings]:
dYscores, dYembeds = dYs
dXembeds = Ragged(ops.alloc2f(*vecshape), veclens)
dXscores = ops.alloc1f(*scoreshape)
offset = 0
for dy, (prod_back, pw_sum_back), ll in zip(dYscores, backprops, veclens):
hi = offset + ll
dyscore, dyidx = dy
# remove the placeholder
dyscore = dyscore[:, 1:]
# the full score grid is square
fullscore = ops.alloc2f(ll, ll)
for ii, (ridx, rscores) in enumerate(zip(dyidx, dyscore)):
fullscore[ii][ridx] = rscores
dXembeds.data[offset:hi] = prod_back(fullscore)
dXscores[offset:hi] = pw_sum_back(fullscore)
offset = hi
# make it fit back into the linear
dXscores = xp.expand_dims(dXscores, 1)
return (dXscores, SpanEmbeddings(idxes, dXembeds))
return (out, sembeds.indices), backprop
def pairwise_sum(ops, mention_scores: Floats1d) -> Tuple[Floats2d, Callable]:
"""Find the most likely mention-antecedent pairs."""
# This doesn't use multiplication because two items with low mention scores
# don't make a good candidate pair.
pw_sum = ops.xp.expand_dims(mention_scores, 1) + ops.xp.expand_dims(
mention_scores, 0
)
def backward(d_pwsum: Floats2d) -> Floats1d:
# For the backward pass, the gradient is distributed over the whole row and
# column, so pull it all in.
out = d_pwsum.sum(axis=0) + d_pwsum.sum(axis=1)
return out
return pw_sum, backward
def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train):
# A neat side effect of this is that we don't have to pass the backprops
# around separately because the closure handles them.
source, source_b = bilinear(vecs, is_train)
target, target_b = dropout(vecs.T, is_train)
pw_prod = source @ target
def backward(d_prod: Floats2d) -> Floats2d:
dS = source_b(d_prod @ target.T)
dT = target_b(source.T @ d_prod)
dX = dS + dT.T
return dX
return pw_prod, backward
# XXX here down is wl-coref
from typing import List, Tuple
import torch
from thinc.util import xp2torch, torch2xp
# TODO rename this to coref_util
from .coref_util_wl import add_dummy
# TODO rename to plain coref
@registry.architectures("spacy.WLCoref.v1")
def build_wl_coref_model(
tok2vec: Model[List[Doc], List[Floats2d]],
embedding_size: int = 20,
hidden_size: int = 1024,
n_hidden_layers: int = 1, # TODO rename to "depth"?
dropout: float = 0.3,
# pairs to keep per mention after rough scoring
# TODO change to meaningful name
rough_k: int = 50,
# TODO is this not a training loop setting?
a_scoring_batch_size: int = 512,
# span predictor embeddings
sp_embedding_size: int = 64,
):
dim = tok2vec.get_dim("nO")
with Model.define_operators({">>": chain}):
# TODO chain tok2vec with these models
# TODO fix device - should be automatic
device = "cuda:0"
coref_scorer = PyTorchWrapper(
CorefScorer(
device,
dim,
embedding_size,
hidden_size,
n_hidden_layers,
dropout,
rough_k,
a_scoring_batch_size
),
convert_inputs=convert_coref_scorer_inputs,
convert_outputs=convert_coref_scorer_outputs
)
coref_model = tok2vec >> coref_scorer
# XXX just ignore this until the coref scorer is integrated
span_predictor = PyTorchWrapper(
SpanPredictor(
# TODO this was hardcoded to 1024, check
hidden_size,
sp_embedding_size,
device
),
convert_inputs=convert_span_predictor_inputs
)
# TODO combine models so output is uniform (just one forward pass)
# It may be reasonable to have an option to disable span prediction,
# and just return words as spans.
return coref_model
def convert_coref_scorer_inputs(
model: Model,
X: List[Floats2d],
is_train: bool
):
# The input here is List[Floats2d], one for each doc
# just use the first
# TODO real batching
X = X[0]
word_features = xp2torch(X, requires_grad=is_train)
def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[0])
return [gradients]
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
def convert_coref_scorer_outputs(
model: Model,
inputs_outputs,
is_train: bool
):
_, outputs = inputs_outputs
scores, indices = outputs
def convert_for_torch_backward(dY: Floats2d) -> ArgsKwargs:
dY_t = xp2torch(dY[0])
return ArgsKwargs(
args=([scores],),
kwargs={"grad_tensors": [dY_t]},
)
scores_xp = torch2xp(scores)
indices_xp = torch2xp(indices)
return (scores_xp, indices_xp), convert_for_torch_backward
def convert_span_predictor_inputs(
model: Model,
X: Tuple[Ints1d, Floats2d, Ints1d],
is_train: bool
):
sent_id = xp2torch(X[0], requires_grad=False)
word_features = xp2torch(X[1], requires_grad=False)
head_ids = xp2torch(X[2], requires_grad=False)
argskwargs = ArgsKwargs(args=(sent_id, word_features, head_ids), kwargs={})
return argskwargs, lambda dX: []
# TODO This probably belongs in the component, not the model.
def predict_span_clusters(span_predictor: Model,
sent_ids: Ints1d,
words: Floats2d,
clusters: List[Ints1d]):
"""
Predicts span clusters based on the word clusters.
Args:
doc (Doc): the document data
words (torch.Tensor): [n_words, emb_size] matrix containing
embeddings for each of the words in the text
clusters (List[List[int]]): a list of clusters where each cluster
is a list of word indices
Returns:
List[List[Span]]: span clusters
"""
if not clusters:
return []
xp = span_predictor.ops.xp
heads_ids = xp.asarray(sorted(i for cluster in clusters for i in cluster))
scores = span_predictor.predict((sent_ids, words, heads_ids))
starts = scores[:, :, 0].argmax(axis=1).tolist()
ends = (scores[:, :, 1].argmax(axis=1) + 1).tolist()
head2span = {
head: (start, end)
for head, start, end in zip(heads_ids.tolist(), starts, ends)
}
return [[head2span[head] for head in cluster]
for cluster in clusters]
# TODO add docstring for this, maybe move to utils.
# This might belong in the component.
def _clusterize(
model,
scores: Floats2d,
top_indices: Ints2d
):
xp = model.ops.xp
antecedents = scores.argmax(axis=1) - 1
not_dummy = antecedents >= 0
coref_span_heads = xp.arange(0, len(scores))[not_dummy]
antecedents = top_indices[coref_span_heads, antecedents[not_dummy]]
n_words = scores.shape[0]
nodes = [GraphNode(i) for i in range(n_words)]
for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()):
nodes[i].link(nodes[j])
assert nodes[i] is not nodes[j]
clusters = []
for node in nodes:
if len(node.links) > 0 and not node.visited:
cluster = []
stack = [node]
while stack:
current_node = stack.pop()
current_node.visited = True
cluster.append(current_node.id)
stack.extend(link for link in current_node.links if not link.visited)
assert len(cluster) > 1
clusters.append(sorted(cluster))
return sorted(clusters)
class CorefScorer(torch.nn.Module):
"""Combines all coref modules together to find coreferent spans.
Attributes:
epochs_trained (int): number of epochs the model has been trained for
Submodules (in the order of their usage in the pipeline):
rough_scorer (RoughScorer)
pw (PairwiseEncoder)
a_scorer (AnaphoricityScorer)
sp (SpanPredictor)
"""
def __init__(
self,
device: str,
dim: int, # tok2vec size
dist_emb_size: int,
hidden_size: int,
n_layers: int,
dropout_rate: float,
roughk: int,
batch_size: int
):
super().__init__()
"""
A newly created model is set to evaluation mode.
Args:
epochs_trained (int): the number of epochs finished
(useful for warm start)
"""
# device, dist_emb_size, hidden_size, n_layers, dropout_rate
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device)
#TODO clean this up
bert_emb = dim
pair_emb = bert_emb * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer(
pair_emb,
hidden_size,
n_layers,
dropout_rate
).to(device)
self.lstm = torch.nn.LSTM(
input_size=bert_emb,
hidden_size=bert_emb,
batch_first=True,
)
self.dropout = torch.nn.Dropout(dropout_rate)
self.rough_scorer = RoughScorer(
bert_emb,
dropout_rate,
roughk
).to(device)
self.batch_size = batch_size
def forward(
self,
word_features: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This is a massive method, but it made sense to me to not split it into
several ones to let one see the data flow.
Args:
word_features: torch.Tensor containing word encodings
Returns:
coreference scores and top indices
"""
# words [n_words, span_emb]
# cluster_ids [n_words]
self.lstm.flatten_parameters() # XXX without this there's a warning
word_features = torch.unsqueeze(word_features, dim=0)
words, _ = self.lstm(word_features)
words = words.squeeze()
words = self.dropout(words)
# Obtain bilinear scores and leave only top-k antecedents for each word
# top_rough_scores [n_words, n_ants]
# top_indices [n_words, n_ants]
top_rough_scores, top_indices = self.rough_scorer(words)
# Get pairwise features [n_words, n_ants, n_pw_features]
pw = self.pw(top_indices)
batch_size = self.batch_size
a_scores_lst: List[torch.Tensor] = []
for i in range(0, len(words), batch_size):
pw_batch = pw[i:i + batch_size]
words_batch = words[i:i + batch_size]
top_indices_batch = top_indices[i:i + batch_size]
top_rough_scores_batch = top_rough_scores[i:i + batch_size]
# a_scores_batch [batch_size, n_ants]
a_scores_batch = self.a_scorer(
all_mentions=words, mentions_batch=words_batch,
pw_batch=pw_batch, top_indices_batch=top_indices_batch,
top_rough_scores_batch=top_rough_scores_batch
)
a_scores_lst.append(a_scores_batch)
coref_scores = torch.cat(a_scores_lst, dim=0)
return coref_scores, top_indices
class AnaphoricityScorer(torch.nn.Module):
""" Calculates anaphoricity scores by passing the inputs into a FFNN """
def __init__(self,
in_features: int,
hidden_size,
n_hidden_layers,
dropout_rate):
super().__init__()
hidden_size = hidden_size
if not n_hidden_layers:
hidden_size = in_features
layers = []
for i in range(n_hidden_layers):
layers.extend([torch.nn.Linear(hidden_size if i else in_features,
hidden_size),
torch.nn.LeakyReLU(),
torch.nn.Dropout(dropout_rate)])
self.hidden = torch.nn.Sequential(*layers)
self.out = torch.nn.Linear(hidden_size, out_features=1)
def forward(self, *, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
all_mentions: torch.Tensor,
mentions_batch: torch.Tensor,
pw_batch: torch.Tensor,
top_indices_batch: torch.Tensor,
top_rough_scores_batch: torch.Tensor,
) -> torch.Tensor:
""" Builds a pairwise matrix, scores the pairs and returns the scores.
Args:
all_mentions (torch.Tensor): [n_mentions, mention_emb]
mentions_batch (torch.Tensor): [batch_size, mention_emb]
pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb]
top_indices_batch (torch.Tensor): [batch_size, n_ants]
top_rough_scores_batch (torch.Tensor): [batch_size, n_ants]
Returns:
torch.Tensor [batch_size, n_ants + 1]
anaphoricity scores for the pairs + a dummy column
"""
# [batch_size, n_ants, pair_emb]
pair_matrix = self._get_pair_matrix(
all_mentions, mentions_batch, pw_batch, top_indices_batch)
# [batch_size, n_ants]
scores = top_rough_scores_batch + self._ffnn(pair_matrix)
scores = add_dummy(scores, eps=True)
return scores
def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculates anaphoricity scores.
Args:
x: tensor of shape [batch_size, n_ants, n_features]
Returns:
tensor of shape [batch_size, n_ants]
"""
x = self.out(self.hidden(x))
return x.squeeze(2)
@staticmethod
def _get_pair_matrix(all_mentions: torch.Tensor,
mentions_batch: torch.Tensor,
pw_batch: torch.Tensor,
top_indices_batch: torch.Tensor,
) -> torch.Tensor:
"""
Builds the matrix used as input for AnaphoricityScorer.
Args:
all_mentions (torch.Tensor): [n_mentions, mention_emb],
all the valid mentions of the document,
can be on a different device
mentions_batch (torch.Tensor): [batch_size, mention_emb],
the mentions of the current batch,
is expected to be on the current device
pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb],
pairwise features of the current batch,
is expected to be on the current device
top_indices_batch (torch.Tensor): [batch_size, n_ants],
indices of antecedents of each mention
Returns:
torch.Tensor: [batch_size, n_ants, pair_emb]
"""
emb_size = mentions_batch.shape[1]
n_ants = pw_batch.shape[1]
a_mentions = mentions_batch.unsqueeze(1).expand(-1, n_ants, emb_size)
b_mentions = all_mentions[top_indices_batch]
similarity = a_mentions * b_mentions
out = torch.cat((a_mentions, b_mentions, similarity, pw_batch), dim=2)
return out
class RoughScorer(torch.nn.Module):
"""
Is needed to give a roughly estimate of the anaphoricity of two candidates,
only top scoring candidates are considered on later steps to reduce
computational complexity.
"""
def __init__(
self,
features: int,
dropout_rate: float,
rough_k: float
):
super().__init__()
self.dropout = torch.nn.Dropout(dropout_rate)
self.bilinear = torch.nn.Linear(features, features)
self.k = rough_k
def forward(
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
mentions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns rough anaphoricity scores for candidates, which consist of
the bilinear output of the current model summed with mention scores.
"""
# [n_mentions, n_mentions]
pair_mask = torch.arange(mentions.shape[0])
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
pair_mask = torch.log((pair_mask > 0).to(torch.float))
pair_mask = pair_mask.to(mentions.device)
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
rough_scores = pair_mask + bilinear_scores
return self._prune(rough_scores)
def _prune(self,
rough_scores: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Selects top-k rough antecedent scores for each mention.
Args:
rough_scores: tensor of shape [n_mentions, n_mentions], containing
rough antecedent scores of each mention-antecedent pair.
Returns:
FloatTensor of shape [n_mentions, k], top rough scores
LongTensor of shape [n_mentions, k], top indices
"""
top_scores, indices = torch.topk(rough_scores,
k=min(self.k, len(rough_scores)),
dim=1, sorted=False)
return top_scores, indices
class SpanPredictor(torch.nn.Module):
def __init__(self, input_size: int, distance_emb_size: int, device):
super().__init__()
self.ffnn = torch.nn.Sequential(
torch.nn.Linear(input_size * 2 + 64, input_size),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(input_size, 256),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(256, 64),
)
self.device = device
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(64, 4, 3, 1, 1),
torch.nn.Conv1d(4, 2, 3, 1, 1)
)
self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
sent_id,
words: torch.Tensor,
heads_ids: torch.Tensor) -> torch.Tensor:
"""
Calculates span start/end scores of words for each span head in
heads_ids
Args:
doc (Doc): the document data
words (torch.Tensor): contextual embeddings for each word in the
document, [n_words, emb_size]
heads_ids (torch.Tensor): word indices of span heads
Returns:
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
"""
# Obtain distance embedding indices, [n_heads, n_words]
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0))
# make all valid distances positive
emb_ids = relative_positions + 63
# "too_far"
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
# Obtain "same sentence" boolean mask, [n_heads, n_words]
sent_id = torch.tensor(sent_id, device=words.device)
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
# To save memory, only pass candidates from one sentence for each head
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
# for each candidate among the words in the same sentence as span_head
# [n_heads, input_size * 2 + distance_emb_size]
rows, cols = same_sent.nonzero(as_tuple=True)
pair_matrix = torch.cat((
words[heads_ids[rows]],
words[cols],
self.emb(emb_ids[rows, cols]),
), dim=1)
lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max(), device=words.device).unsqueeze(0)
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
# This is necessary to allow the convolution layer to look at several
# word scores
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device)
padded_pairs[padding_mask] = pair_matrix
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device)
scores[rows, cols] = res[padding_mask]
# Make sure that start <= head <= end during inference
if not self.training:
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
valid_ends = torch.log((relative_positions <= 0).to(torch.float))
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
return scores + valid_positions
return scores
class DistancePairwiseEncoder(torch.nn.Module):
def __init__(self, embedding_size, dropout_rate):
super().__init__()
emb_size = embedding_size
self.distance_emb = torch.nn.Embedding(9, emb_size)
self.dropout = torch.nn.Dropout(dropout_rate)
self.shape = emb_size
@property
def device(self) -> torch.device:
""" A workaround to get current device (which is assumed to be the
device of the first parameter of one of the submodules) """
return next(self.distance_emb.parameters()).device
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
top_indices: torch.Tensor
) -> torch.Tensor:
word_ids = torch.arange(0, top_indices.size(0), device=self.device)
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
).clamp_min_(min=1)
log_distance = distance.to(torch.float).log2().floor_()
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
distance = self.distance_emb(distance)
return self.dropout(distance)