Preparing span predictor for predicting from gold (#10547)

Note this is squashed because rebasing had conflicts.

* remove unnecessary .device

* span predictor debug start

* gearing up SpanPredictor for gold-heads

* merge SpanPredictor attributes

* remove useless extra prefix and device from spanpredictor

* make sure predicted and reference keeps aligned

* handle empty head_ids

* handle empty clusters

* addressing suggestions by @polm

* nicer restore

* fix score overwriting bug

* prepare for aligned heads-spans training

* span accuracy score

* update with eg.predited as other components

* add backprop callback to spanpredictor

* report start- and end-accuracies separately

* fixing scorer

Co-authored-by: Kádár Ákos <akos@onyx.uvt.nl>
This commit is contained in:
kadarakos 2022-04-13 12:42:49 +02:00 committed by GitHub
parent eec00ce60d
commit b53113e3b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 114 additions and 112 deletions

View File

@ -40,11 +40,8 @@ def build_wl_coref_model(
with Model.define_operators({">>": chain}):
# TODO chain tok2vec with these models
# TODO fix device - should be automatic
device = "cuda:0"
coref_scorer = PyTorchWrapper(
CorefScorer(
device,
dim,
embedding_size,
hidden_size,
@ -56,8 +53,16 @@ def build_wl_coref_model(
convert_inputs=convert_coref_scorer_inputs,
convert_outputs=convert_coref_scorer_outputs
)
coref_model = tok2vec >> coref_scorer
# XXX just ignore this until the coref scorer is integrated
# span_predictor = PyTorchWrapper(
# SpanPredictor(
# TODO this was hardcoded to 1024, check
# hidden_size,
# sp_embedding_size,
# ),
# convert_inputs=convert_span_predictor_inputs
# )
# TODO combine models so output is uniform (just one forward pass)
# It may be reasonable to have an option to disable span prediction,
# and just return words as spans.
@ -77,14 +82,14 @@ def build_span_predictor(
dim = 768
with Model.define_operators({">>": chain, "&": tuplify}):
# TODO fix device - should be automatic
device = "cuda:0"
span_predictor = PyTorchWrapper(
SpanPredictor(dim, hidden_size, dist_emb_size, device),
SpanPredictor(dim, hidden_size, dist_emb_size),
convert_inputs=convert_span_predictor_inputs
)
# TODO use proper parameter for prefix
head_info = build_get_head_metadata("coref_head_clusters")
head_info = build_get_head_metadata(
"coref_head_clusters"
)
model = (tok2vec & head_info) >> span_predictor
return model
@ -99,13 +104,13 @@ def convert_coref_scorer_inputs(
# just use the first
# TODO real batching
X = X[0]
word_features = xp2torch(X, requires_grad=is_train)
def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[0])
return [gradients]
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
@ -128,6 +133,7 @@ def convert_coref_scorer_outputs(
indices_xp = torch2xp(indices)
return (scores_xp, indices_xp), convert_for_torch_backward
def convert_span_predictor_inputs(
model: Model,
X: Tuple[Ints1d, Floats2d, Ints1d],
@ -135,14 +141,23 @@ def convert_span_predictor_inputs(
):
tok2vec, (sent_ids, head_ids) = X
# Normally we shoudl use the input is_train, but for these two it's not relevant
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
head_ids = xp2torch(head_ids[0], requires_grad=False)
def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[1])
return [[gradients], None]
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
if not head_ids[0].size:
head_ids = torch.empty(size=(0,))
else:
head_ids = xp2torch(head_ids[0], requires_grad=False)
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
# TODO actually support backprop
return argskwargs, lambda dX: []
return argskwargs, backprop
# TODO This probably belongs in the component, not the model.
def predict_span_clusters(span_predictor: Model,
@ -211,18 +226,21 @@ def _clusterize(
clusters.append(sorted(cluster))
return sorted(clusters)
def build_get_head_metadata(prefix):
# TODO this name is awful, fix it
model = Model("HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward)
model = Model("HeadDataProvider",
attrs={'prefix': prefix},
forward=head_data_forward)
return model
def head_data_forward(model, docs, is_train):
"""A layer to generate the extra data needed for the span predictor.
"""
sent_ids = []
head_ids = []
prefix = model.attrs["prefix"]
for doc in docs:
sids = model.ops.asarray2i(get_sentence_ids(doc))
sent_ids.append(sids)
@ -235,7 +253,6 @@ def head_data_forward(model, docs, is_train):
heads.append(span[0].i)
heads = model.ops.asarray2i(heads)
head_ids.append(heads)
# each of these is a list with one entry per doc
# backprop is just a placeholder
# TODO it would probably be better to have a list of tuples than two lists of arrays
@ -256,7 +273,6 @@ class CorefScorer(torch.nn.Module):
"""
def __init__(
self,
device: str,
dim: int, # tok2vec size
dist_emb_size: int,
hidden_size: int,
@ -273,8 +289,7 @@ class CorefScorer(torch.nn.Module):
epochs_trained (int): the number of epochs finished
(useful for warm start)
"""
# device, dist_emb_size, hidden_size, n_layers, dropout_rate
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device)
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
#TODO clean this up
bert_emb = dim
pair_emb = bert_emb * 3 + self.pw.shape
@ -283,7 +298,7 @@ class CorefScorer(torch.nn.Module):
hidden_size,
n_layers,
dropout_rate
).to(device)
)
self.lstm = torch.nn.LSTM(
input_size=bert_emb,
hidden_size=bert_emb,
@ -294,7 +309,7 @@ class CorefScorer(torch.nn.Module):
bert_emb,
dropout_rate,
roughk
).to(device)
)
self.batch_size = batch_size
def forward(
@ -443,7 +458,6 @@ class AnaphoricityScorer(torch.nn.Module):
return out
class RoughScorer(torch.nn.Module):
"""
Is needed to give a roughly estimate of the anaphoricity of two candidates,
@ -474,7 +488,6 @@ class RoughScorer(torch.nn.Module):
pair_mask = torch.arange(mentions.shape[0])
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
pair_mask = torch.log((pair_mask > 0).to(torch.float))
pair_mask = pair_mask.to(mentions.device)
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
rough_scores = pair_mask + bilinear_scores
@ -501,7 +514,7 @@ class RoughScorer(torch.nn.Module):
class SpanPredictor(torch.nn.Module):
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device):
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int):
super().__init__()
# input size = single token size
# 64 = probably distance emb size
@ -517,7 +530,6 @@ class SpanPredictor(torch.nn.Module):
# this use of dist_emb_size looks wrong but it was 64...?
torch.nn.Linear(256, dist_emb_size),
)
self.device = device
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(64, 4, 3, 1, 1),
torch.nn.Conv1d(4, 2, 3, 1, 1)
@ -541,17 +553,18 @@ class SpanPredictor(torch.nn.Module):
Returns:
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
"""
# If we don't receive heads, return empty
if heads_ids.nelement() == 0:
return torch.empty(size=(0,))
# Obtain distance embedding indices, [n_heads, n_words]
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0))
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0))
# make all valid distances positive
emb_ids = relative_positions + 63
# "too_far"
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
# Obtain "same sentence" boolean mask, [n_heads, n_words]
sent_id = torch.tensor(sent_id, device=words.device)
heads_ids = heads_ids.long()
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
# To save memory, only pass candidates from one sentence for each head
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
# for each candidate among the words in the same sentence as span_head
@ -562,23 +575,20 @@ class SpanPredictor(torch.nn.Module):
words[cols],
self.emb(emb_ids[rows, cols]),
), dim=1)
lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0)
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
# This is necessary to allow the convolution layer to look at several
# word scores
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device)
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
padded_pairs[padding_mask] = pair_matrix
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device)
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'))
scores[rows, cols] = res[padding_mask]
# Make sure that start <= head <= end during inference
if not self.training:
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
@ -586,6 +596,7 @@ class SpanPredictor(torch.nn.Module):
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
return scores + valid_positions
return scores
class DistancePairwiseEncoder(torch.nn.Module):
def __init__(self, embedding_size, dropout_rate):
@ -595,17 +606,10 @@ class DistancePairwiseEncoder(torch.nn.Module):
self.dropout = torch.nn.Dropout(dropout_rate)
self.shape = emb_size
@property
def device(self) -> torch.device:
""" A workaround to get current device (which is assumed to be the
device of the first parameter of one of the submodules) """
return next(self.distance_emb.parameters()).device
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
top_indices: torch.Tensor
) -> torch.Tensor:
word_ids = torch.arange(0, top_indices.size(0), device=self.device)
word_ids = torch.arange(0, top_indices.size(0))
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
).clamp_min_(min=1)
log_distance = distance.to(torch.float).log2().floor_()

View File

@ -3,7 +3,7 @@ import warnings
from thinc.types import Floats2d, Floats3d, Ints2d
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
from thinc.api import set_dropout_rate
from thinc.api import set_dropout_rate, to_categorical
from itertools import islice
from statistics import mean
@ -130,7 +130,6 @@ class CoreferenceResolver(TrainablePipe):
DOCS: https://spacy.io/api/coref#predict (TODO)
"""
#print("DOCS", docs)
out = []
for doc in docs:
scores, idxs = self.model.predict([doc])
@ -212,7 +211,6 @@ class CoreferenceResolver(TrainablePipe):
# TODO check this causes no issues (in practice it runs)
preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
total_loss += loss
# TODO check shape here
@ -366,9 +364,7 @@ class CoreferenceResolver(TrainablePipe):
for ex in examples:
p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix)
g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix)
cluster_info = get_cluster_info(p_clusters, g_clusters)
evaluator.update(cluster_info)
score = {
@ -460,8 +456,8 @@ class SpanPredictor(TrainablePipe):
out = []
for doc in docs:
# TODO check shape here
span_scores = self.model.predict(doc)
span_scores = span_scores[0]
span_scores = self.model.predict([doc])
if span_scores.size:
# the information about clustering has to come from the input docs
# first let's convert the scores to a list of span idxs
start_scores = span_scores[:, :, 0]
@ -478,9 +474,11 @@ class SpanPredictor(TrainablePipe):
for cluster in clusters:
ncluster = []
for mention in cluster:
ncluster.append( (starts[cidx], ends[cidx]) )
ncluster.append((starts[cidx], ends[cidx]))
cidx += 1
out_clusters.append(ncluster)
else:
out_clusters = []
out.append(out_clusters)
return out
@ -505,21 +503,21 @@ class SpanPredictor(TrainablePipe):
losses = {}
losses.setdefault(self.name, 0.0)
validate_examples(examples, "SpanPredictor.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
if not any(len(eg.reference) if eg.reference else 0 for eg in examples):
# Handle cases where there are no tokens in any docs.
return losses
set_dropout_rate(self.model, drop)
total_loss = 0
for eg in examples:
preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
span_scores, backprop = self.model.begin_update([eg.predicted])
# FIXME, this only happens once in the first 1000 docs of OntoNotes
# and I'm not sure yet why.
if span_scores.size:
loss, d_scores = self.get_loss([eg], span_scores)
total_loss += loss
# TODO check shape here
backprop((d_scores, mention_idx))
backprop((d_scores))
if sgd is not None:
self.finish_update(sgd)
@ -562,19 +560,17 @@ class SpanPredictor(TrainablePipe):
assert len(examples) == 1, "Only fake batching is supported."
# starts and ends are gold starts and ends (Ints1d)
# span_scores is a Floats3d. What are the axes? mention x token x start/end
for eg in examples:
# get gold data
gold = doc2clusters(eg.reference, self.output_prefix)
# flatten the gold data
starts = []
ends = []
for cluster in gold:
for mention in cluster:
starts.append(mention[0])
ends.append(mention[1])
for key, sg in eg.reference.spans.items():
if key.startswith(self.output_prefix):
for mention in sg:
starts.append(mention.start)
ends.append(mention.end)
starts = self.model.ops.xp.asarray(starts)
ends = self.model.ops.xp.asarray(ends)
start_scores = span_scores[:, :, 0]
end_scores = span_scores[:, :, 1]
n_classes = start_scores.shape[1]
@ -594,7 +590,7 @@ class SpanPredictor(TrainablePipe):
*,
nlp: Optional[Language] = None,
) -> None:
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
validate_get_examples(get_examples, "SpanPredictor.initialize")
X = []
Y = []
@ -612,31 +608,33 @@ class SpanPredictor(TrainablePipe):
self.model.initialize(X=X, Y=Y)
def score(self, examples, **kwargs):
"""Score a batch of examples."""
# TODO This is basically the same as the main coref component - factor out?
"""
Evaluate on reconstructing the correct spans around
gold heads.
"""
scores = []
for metric in (b_cubed, muc, ceafe):
evaluator = Evaluator(metric)
xp = self.model.ops.xp
for eg in examples:
starts = []
ends = []
pred_starts = []
pred_ends = []
ref = eg.reference
pred = eg.predicted
for key, gold_sg in ref.spans.items():
if key.startswith(self.output_prefix):
pred_sg = pred.spans[key]
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
starts.append(gold_mention.start)
ends.append(gold_mention.end)
pred_starts.append(pred_mention.start)
pred_ends.append(pred_mention.end)
for ex in examples:
# XXX this is the only different part
p_clusters = doc2clusters(ex.predicted, self.output_prefix)
g_clusters = doc2clusters(ex.reference, self.output_prefix)
cluster_info = get_cluster_info(p_clusters, g_clusters)
evaluator.update(cluster_info)
score = {
"coref_f": evaluator.get_f1(),
"coref_p": evaluator.get_precision(),
"coref_r": evaluator.get_recall(),
}
scores.append(score)
out = {}
for field in ("f", "p", "r"):
fname = f"coref_{field}"
out[fname] = mean([ss[fname] for ss in scores])
return out
starts = xp.asarray(starts)
ends = xp.asarray(ends)
pred_starts = xp.asarray(pred_starts)
pred_ends = xp.asarray(pred_ends)
correct = (starts == pred_starts) * (ends == pred_ends)
accuracy = correct.mean()
scores.append(float(accuracy))
return {"span_accuracy": mean(scores)}