Start bringin in wl-coref

This absolutely does not work. First step here is getting over most of
the code in roughly the files we want it in. After the code has been
pulled over it can be restructured to match spaCy and cleaned up.
This commit is contained in:
Paul O'Leary McCann 2022-03-06 20:00:15 +09:00
parent 0c15ab7ca1
commit c0cd5025e3
3 changed files with 438 additions and 2 deletions

View File

@ -1,4 +1,4 @@
from .coref import *
from .coref import * #noqa
from .entity_linker import * # noqa
from .multi_task import * # noqa
from .parser import * # noqa

View File

@ -448,3 +448,434 @@ def pairwise_product(bilinear, dropout, vecs: Floats2d, is_train):
return dX
return pw_prod, backward
# XXX here down is wl-coref
from typing import List, Tuple
import torch
# TODO rename this to coref_util
import .coref_util_wl as utils
# TODO rename to plain coref
@registry.architectures("spacy.WLCoref.v1")
def build_wl_coref_model(
#TODO add other hyperparams
tok2vec: Model[List[Doc], List[Floats2d]],
):
# TODO change to use passed in values for config
config = utils._load_config("/dev/null")
with Model.define_operators({">>": chain}):
coref_scorer, span_predictor = configure_pytorch_modules(config)
# TODO chain tok2vec with these models
coref_scorer = PyTorchWrapper(
CorefScorer(
config.device,
config.embedding_size,
config.hidden_size,
config.n_hidden_layers,
config.dropout_rate,
config.rough_k,
config.a_scoring_batch_size
),
convert_inputs=convert_coref_scorer_inputs,
convert_outputs=convert_coref_scorer_outputs
)
span_predictor = PyTorchWrapper(
SpanPredictor(
1024,
config.sp_embedding_size,
config.device
),
convert_inputs=convert_span_predictor_inputs
)
# TODO combine models so output is uniform (just one forward pass)
# It may be reasonable to have an option to disable span prediction,
# and just return words as spans.
return coref_scorer
def convert_coref_scorer_inputs(
model: Model,
X: Floats2d,
is_train: bool
):
word_features = xp2torch(X, requires_grad=False)
return ArgsKwargs(args=(word_features, ), kwargs={}), lambda dX: []
def convert_coref_scorer_outputs(
model: Model,
inputs_outputs,
is_train: bool
):
_, outputs = inputs_outputs
scores, indices = outputs
def convert_for_torch_backward(dY: Floats2d) -> ArgsKwargs:
dY_t = xp2torch(dY)
return ArgsKwargs(
args=([scores],),
kwargs={"grad_tensors": [dY_t]},
)
scores_xp = torch2xp(scores)
indices_xp = torch2xp(indices)
return (scores_xp, indices_xp), convert_for_torch_backward
# TODO This probably belongs in the component, not the model.
def predict_span_clusters(span_predictor: Model,
sent_ids: Ints1d,
words: Floats2d,
clusters: List[Ints1d]):
"""
Predicts span clusters based on the word clusters.
Args:
doc (Doc): the document data
words (torch.Tensor): [n_words, emb_size] matrix containing
embeddings for each of the words in the text
clusters (List[List[int]]): a list of clusters where each cluster
is a list of word indices
Returns:
List[List[Span]]: span clusters
"""
if not clusters:
return []
xp = span_predictor.ops.xp
heads_ids = xp.asarray(sorted(i for cluster in clusters for i in cluster))
scores = span_predictor.predict((sent_ids, words, heads_ids))
starts = scores[:, :, 0].argmax(axis=1).tolist()
ends = (scores[:, :, 1].argmax(axis=1) + 1).tolist()
head2span = {
head: (start, end)
for head, start, end in zip(heads_ids.tolist(), starts, ends)
}
return [[head2span[head] for head in cluster]
for cluster in clusters]
# TODO add docstring for this, maybe move to utils.
# This might belong in the component.
def _clusterize(
model,
scores: Floats2d,
top_indices: Ints2d
):
xp = model.ops.xp
antecedents = scores.argmax(axis=1) - 1
not_dummy = antecedents >= 0
coref_span_heads = xp.arange(0, len(scores))[not_dummy]
antecedents = top_indices[coref_span_heads, antecedents[not_dummy]]
n_words = scores.shape[0]
nodes = [GraphNode(i) for i in range(n_words)]
for i, j in zip(coref_span_heads.tolist(), antecedents.tolist()):
nodes[i].link(nodes[j])
assert nodes[i] is not nodes[j]
clusters = []
for node in nodes:
if len(node.links) > 0 and not node.visited:
cluster = []
stack = [node]
while stack:
current_node = stack.pop()
current_node.visited = True
cluster.append(current_node.id)
stack.extend(link for link in current_node.links if not link.visited)
assert len(cluster) > 1
clusters.append(sorted(cluster))
return sorted(clusters)
class CorefScorer(torch.nn.Module):
"""Combines all coref modules together to find coreferent spans.
Attributes:
config (coref.config.Config): the model's configuration,
see config.toml for the details
epochs_trained (int): number of epochs the model has been trained for
Submodules (in the order of their usage in the pipeline):
rough_scorer (RoughScorer)
pw (PairwiseEncoder)
a_scorer (AnaphoricityScorer)
sp (SpanPredictor)
"""
def __init__(
self,
device: str,
dist_emb_size: int,
hidden_size: int,
n_layers: int,
dropout_rate: float,
roughk: int,
batch_size: int
):
super().__init__()
"""
A newly created model is set to evaluation mode.
Args:
config_path (str): the path to the toml file with the configuration
section (str): the selected section of the config file
epochs_trained (int): the number of epochs finished
(useful for warm start)
"""
# device, dist_emb_size, hidden_size, n_layers, dropout_rate
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device)
bert_emb = 1024
pair_emb = bert_emb * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer(
pair_emb,
hidden_size,
n_layers,
dropout_rate
).to(device)
self.lstm = torch.nn.LSTM(
input_size=bert_emb,
hidden_size=bert_emb,
batch_first=True,
)
self.dropout = torch.nn.Dropout(dropout_rate)
self.rough_scorer = RoughScorer(
bert_emb,
dropout_rate,
roughk
).to(device)
self.batch_size = batch_size
def forward(
self,
word_features: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
This is a massive method, but it made sense to me to not split it into
several ones to let one see the data flow.
Args:
word_features: torch.Tensor containing word encodings
Returns:
coreference scores and top indices
"""
# words [n_words, span_emb]
# cluster_ids [n_words]
word_features = torch.unsqueeze(word_features, dim=0)
words, _ = self.lstm(word_features)
words = words.squeeze()
words = self.dropout(words)
# Obtain bilinear scores and leave only top-k antecedents for each word
# top_rough_scores [n_words, n_ants]
# top_indices [n_words, n_ants]
top_rough_scores, top_indices = self.rough_scorer(words)
# Get pairwise features [n_words, n_ants, n_pw_features]
pw = self.pw(top_indices)
batch_size = self.batch_size
a_scores_lst: List[torch.Tensor] = []
for i in range(0, len(words), batch_size):
pw_batch = pw[i:i + batch_size]
words_batch = words[i:i + batch_size]
top_indices_batch = top_indices[i:i + batch_size]
top_rough_scores_batch = top_rough_scores[i:i + batch_size]
# a_scores_batch [batch_size, n_ants]
a_scores_batch = self.a_scorer(
all_mentions=words, mentions_batch=words_batch,
pw_batch=pw_batch, top_indices_batch=top_indices_batch,
top_rough_scores_batch=top_rough_scores_batch
)
a_scores_lst.append(a_scores_batch)
coref_scores = torch.cat(a_scores_lst, dim=0)
return coref_scores, top_indices
class AnaphoricityScorer(torch.nn.Module):
""" Calculates anaphoricity scores by passing the inputs into a FFNN """
def __init__(self,
in_features: int,
hidden_size,
n_hidden_layers,
dropout_rate):
super().__init__()
hidden_size = hidden_size
if not n_hidden_layers:
hidden_size = in_features
layers = []
for i in range(n_hidden_layers):
layers.extend([torch.nn.Linear(hidden_size if i else in_features,
hidden_size),
torch.nn.LeakyReLU(),
torch.nn.Dropout(dropout_rate)])
self.hidden = torch.nn.Sequential(*layers)
self.out = torch.nn.Linear(hidden_size, out_features=1)
def forward(self, *, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
all_mentions: torch.Tensor,
mentions_batch: torch.Tensor,
pw_batch: torch.Tensor,
top_indices_batch: torch.Tensor,
top_rough_scores_batch: torch.Tensor,
) -> torch.Tensor:
""" Builds a pairwise matrix, scores the pairs and returns the scores.
Args:
all_mentions (torch.Tensor): [n_mentions, mention_emb]
mentions_batch (torch.Tensor): [batch_size, mention_emb]
pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb]
top_indices_batch (torch.Tensor): [batch_size, n_ants]
top_rough_scores_batch (torch.Tensor): [batch_size, n_ants]
Returns:
torch.Tensor [batch_size, n_ants + 1]
anaphoricity scores for the pairs + a dummy column
"""
# [batch_size, n_ants, pair_emb]
pair_matrix = self._get_pair_matrix(
all_mentions, mentions_batch, pw_batch, top_indices_batch)
# [batch_size, n_ants]
scores = top_rough_scores_batch + self._ffnn(pair_matrix)
scores = utils.add_dummy(scores, eps=True)
return scores
def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
"""
Calculates anaphoricity scores.
Args:
x: tensor of shape [batch_size, n_ants, n_features]
Returns:
tensor of shape [batch_size, n_ants]
"""
x = self.out(self.hidden(x))
return x.squeeze(2)
@staticmethod
def _get_pair_matrix(all_mentions: torch.Tensor,
mentions_batch: torch.Tensor,
pw_batch: torch.Tensor,
top_indices_batch: torch.Tensor,
) -> torch.Tensor:
"""
Builds the matrix used as input for AnaphoricityScorer.
Args:
all_mentions (torch.Tensor): [n_mentions, mention_emb],
all the valid mentions of the document,
can be on a different device
mentions_batch (torch.Tensor): [batch_size, mention_emb],
the mentions of the current batch,
is expected to be on the current device
pw_batch (torch.Tensor): [batch_size, n_ants, pw_emb],
pairwise features of the current batch,
is expected to be on the current device
top_indices_batch (torch.Tensor): [batch_size, n_ants],
indices of antecedents of each mention
Returns:
torch.Tensor: [batch_size, n_ants, pair_emb]
"""
emb_size = mentions_batch.shape[1]
n_ants = pw_batch.shape[1]
a_mentions = mentions_batch.unsqueeze(1).expand(-1, n_ants, emb_size)
b_mentions = all_mentions[top_indices_batch]
similarity = a_mentions * b_mentions
out = torch.cat((a_mentions, b_mentions, similarity, pw_batch), dim=2)
return out
class RoughScorer(torch.nn.Module):
"""
Is needed to give a roughly estimate of the anaphoricity of two candidates,
only top scoring candidates are considered on later steps to reduce
computational complexity.
"""
def __init__(
self,
features: int,
dropout_rate: float,
rough_k: float
):
super().__init__()
self.dropout = torch.nn.Dropout(dropout_rate)
self.bilinear = torch.nn.Linear(features, features)
self.k = rough_k
def forward(
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
mentions: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Returns rough anaphoricity scores for candidates, which consist of
the bilinear output of the current model summed with mention scores.
"""
# [n_mentions, n_mentions]
pair_mask = torch.arange(mentions.shape[0])
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
pair_mask = torch.log((pair_mask > 0).to(torch.float))
pair_mask = pair_mask.to(mentions.device)
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
rough_scores = pair_mask + bilinear_scores
return self._prune(rough_scores)
def _prune(self,
rough_scores: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Selects top-k rough antecedent scores for each mention.
Args:
rough_scores: tensor of shape [n_mentions, n_mentions], containing
rough antecedent scores of each mention-antecedent pair.
Returns:
FloatTensor of shape [n_mentions, k], top rough scores
LongTensor of shape [n_mentions, k], top indices
"""
top_scores, indices = torch.topk(rough_scores,
k=min(self.k, len(rough_scores)),
dim=1, sorted=False)
return top_scores, indices
class DistancePairwiseEncoder(torch.nn.Module):
def __init__(self, embedding_size, dropout_rate):
super().__init__()
emb_size = embedding_size
self.distance_emb = torch.nn.Embedding(9, emb_size)
self.dropout = torch.nn.Dropout(dropout_rate)
self.shape = emb_size
@property
def device(self) -> torch.device:
""" A workaround to get current device (which is assumed to be the
device of the first parameter of one of the submodules) """
return next(self.distance_emb.parameters()).device
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
top_indices: torch.Tensor
) -> torch.Tensor:
word_ids = torch.arange(0, top_indices.size(0), device=self.device)
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
).clamp_min_(min=1)
log_distance = distance.to(torch.float).log2().floor_()
log_distance = log_distance.clamp_max_(max=6).to(torch.long)
distance = torch.where(distance < 5, distance - 1, log_distance + 2)
distance = self.distance_emb(distance)
return self.dropout(distance)

View File

@ -6,7 +6,7 @@ from numpy.testing import assert_array_equal, assert_array_almost_equal
import numpy
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier
from spacy.ml.models import build_spancat_model
from spacy.ml.models import build_spancat_model, build_wl_coref_model
from spacy.ml.staticvectors import StaticVectors
from spacy.ml.extract_spans import extract_spans, _get_span_indices
from spacy.lang.en import English
@ -269,3 +269,8 @@ def test_spancat_model_forward_backward(nO=5):
Y, backprop = model((docs, spans), is_train=True)
assert Y.shape == (spans.dataXd.shape[0], nO)
backprop(Y)
#TODO expand this
def test_coref_model_init():
tok2vec = build_Tok2Vec_model(**get_tok2vec_kwargs())
model = build_wl_coref_model(tok2vec)