mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Preparing span predictor for predicting from gold (#10547)
Note this is squashed because rebasing had conflicts. * remove unnecessary .device * span predictor debug start * gearing up SpanPredictor for gold-heads * merge SpanPredictor attributes * remove useless extra prefix and device from spanpredictor * make sure predicted and reference keeps aligned * handle empty head_ids * handle empty clusters * addressing suggestions by @polm * nicer restore * fix score overwriting bug * prepare for aligned heads-spans training * span accuracy score * update with eg.predited as other components * add backprop callback to spanpredictor * report start- and end-accuracies separately * fixing scorer Co-authored-by: Kádár Ákos <akos@onyx.uvt.nl>
This commit is contained in:
parent
eec00ce60d
commit
b53113e3b8
|
@ -37,14 +37,11 @@ def build_wl_coref_model(
|
|||
except ValueError:
|
||||
# happens with transformer listener
|
||||
dim = 768
|
||||
|
||||
|
||||
with Model.define_operators({">>": chain}):
|
||||
# TODO chain tok2vec with these models
|
||||
# TODO fix device - should be automatic
|
||||
device = "cuda:0"
|
||||
coref_scorer = PyTorchWrapper(
|
||||
CorefScorer(
|
||||
device,
|
||||
dim,
|
||||
embedding_size,
|
||||
hidden_size,
|
||||
|
@ -56,8 +53,16 @@ def build_wl_coref_model(
|
|||
convert_inputs=convert_coref_scorer_inputs,
|
||||
convert_outputs=convert_coref_scorer_outputs
|
||||
)
|
||||
|
||||
coref_model = tok2vec >> coref_scorer
|
||||
# XXX just ignore this until the coref scorer is integrated
|
||||
# span_predictor = PyTorchWrapper(
|
||||
# SpanPredictor(
|
||||
# TODO this was hardcoded to 1024, check
|
||||
# hidden_size,
|
||||
# sp_embedding_size,
|
||||
# ),
|
||||
# convert_inputs=convert_span_predictor_inputs
|
||||
# )
|
||||
# TODO combine models so output is uniform (just one forward pass)
|
||||
# It may be reasonable to have an option to disable span prediction,
|
||||
# and just return words as spans.
|
||||
|
@ -77,14 +82,14 @@ def build_span_predictor(
|
|||
dim = 768
|
||||
|
||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||
# TODO fix device - should be automatic
|
||||
device = "cuda:0"
|
||||
span_predictor = PyTorchWrapper(
|
||||
SpanPredictor(dim, hidden_size, dist_emb_size, device),
|
||||
SpanPredictor(dim, hidden_size, dist_emb_size),
|
||||
convert_inputs=convert_span_predictor_inputs
|
||||
)
|
||||
# TODO use proper parameter for prefix
|
||||
head_info = build_get_head_metadata("coref_head_clusters")
|
||||
head_info = build_get_head_metadata(
|
||||
"coref_head_clusters"
|
||||
)
|
||||
model = (tok2vec & head_info) >> span_predictor
|
||||
|
||||
return model
|
||||
|
@ -99,13 +104,13 @@ def convert_coref_scorer_inputs(
|
|||
# just use the first
|
||||
# TODO real batching
|
||||
X = X[0]
|
||||
|
||||
|
||||
word_features = xp2torch(X, requires_grad=is_train)
|
||||
|
||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||
# convert to xp and wrap in list
|
||||
gradients = torch2xp(args.args[0])
|
||||
return [gradients]
|
||||
|
||||
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
|
||||
|
||||
|
||||
|
@ -128,6 +133,7 @@ def convert_coref_scorer_outputs(
|
|||
indices_xp = torch2xp(indices)
|
||||
return (scores_xp, indices_xp), convert_for_torch_backward
|
||||
|
||||
|
||||
def convert_span_predictor_inputs(
|
||||
model: Model,
|
||||
X: Tuple[Ints1d, Floats2d, Ints1d],
|
||||
|
@ -135,14 +141,23 @@ def convert_span_predictor_inputs(
|
|||
):
|
||||
tok2vec, (sent_ids, head_ids) = X
|
||||
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
||||
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
||||
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||
|
||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||
# convert to xp and wrap in list
|
||||
gradients = torch2xp(args.args[1])
|
||||
return [[gradients], None]
|
||||
|
||||
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
||||
if not head_ids[0].size:
|
||||
head_ids = torch.empty(size=(0,))
|
||||
else:
|
||||
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||
|
||||
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
||||
# TODO actually support backprop
|
||||
return argskwargs, lambda dX: []
|
||||
return argskwargs, backprop
|
||||
|
||||
|
||||
# TODO This probably belongs in the component, not the model.
|
||||
def predict_span_clusters(span_predictor: Model,
|
||||
|
@ -211,18 +226,21 @@ def _clusterize(
|
|||
clusters.append(sorted(cluster))
|
||||
return sorted(clusters)
|
||||
|
||||
|
||||
def build_get_head_metadata(prefix):
|
||||
# TODO this name is awful, fix it
|
||||
model = Model("HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward)
|
||||
model = Model("HeadDataProvider",
|
||||
attrs={'prefix': prefix},
|
||||
forward=head_data_forward)
|
||||
return model
|
||||
|
||||
|
||||
def head_data_forward(model, docs, is_train):
|
||||
"""A layer to generate the extra data needed for the span predictor.
|
||||
"""
|
||||
sent_ids = []
|
||||
head_ids = []
|
||||
prefix = model.attrs["prefix"]
|
||||
|
||||
for doc in docs:
|
||||
sids = model.ops.asarray2i(get_sentence_ids(doc))
|
||||
sent_ids.append(sids)
|
||||
|
@ -235,7 +253,6 @@ def head_data_forward(model, docs, is_train):
|
|||
heads.append(span[0].i)
|
||||
heads = model.ops.asarray2i(heads)
|
||||
head_ids.append(heads)
|
||||
|
||||
# each of these is a list with one entry per doc
|
||||
# backprop is just a placeholder
|
||||
# TODO it would probably be better to have a list of tuples than two lists of arrays
|
||||
|
@ -256,7 +273,6 @@ class CorefScorer(torch.nn.Module):
|
|||
"""
|
||||
def __init__(
|
||||
self,
|
||||
device: str,
|
||||
dim: int, # tok2vec size
|
||||
dist_emb_size: int,
|
||||
hidden_size: int,
|
||||
|
@ -273,8 +289,7 @@ class CorefScorer(torch.nn.Module):
|
|||
epochs_trained (int): the number of epochs finished
|
||||
(useful for warm start)
|
||||
"""
|
||||
# device, dist_emb_size, hidden_size, n_layers, dropout_rate
|
||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device)
|
||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
|
||||
#TODO clean this up
|
||||
bert_emb = dim
|
||||
pair_emb = bert_emb * 3 + self.pw.shape
|
||||
|
@ -283,7 +298,7 @@ class CorefScorer(torch.nn.Module):
|
|||
hidden_size,
|
||||
n_layers,
|
||||
dropout_rate
|
||||
).to(device)
|
||||
)
|
||||
self.lstm = torch.nn.LSTM(
|
||||
input_size=bert_emb,
|
||||
hidden_size=bert_emb,
|
||||
|
@ -294,7 +309,7 @@ class CorefScorer(torch.nn.Module):
|
|||
bert_emb,
|
||||
dropout_rate,
|
||||
roughk
|
||||
).to(device)
|
||||
)
|
||||
self.batch_size = batch_size
|
||||
|
||||
def forward(
|
||||
|
@ -443,7 +458,6 @@ class AnaphoricityScorer(torch.nn.Module):
|
|||
return out
|
||||
|
||||
|
||||
|
||||
class RoughScorer(torch.nn.Module):
|
||||
"""
|
||||
Is needed to give a roughly estimate of the anaphoricity of two candidates,
|
||||
|
@ -474,7 +488,6 @@ class RoughScorer(torch.nn.Module):
|
|||
pair_mask = torch.arange(mentions.shape[0])
|
||||
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
|
||||
pair_mask = torch.log((pair_mask > 0).to(torch.float))
|
||||
pair_mask = pair_mask.to(mentions.device)
|
||||
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
|
||||
rough_scores = pair_mask + bilinear_scores
|
||||
|
||||
|
@ -501,7 +514,7 @@ class RoughScorer(torch.nn.Module):
|
|||
|
||||
|
||||
class SpanPredictor(torch.nn.Module):
|
||||
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device):
|
||||
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int):
|
||||
super().__init__()
|
||||
# input size = single token size
|
||||
# 64 = probably distance emb size
|
||||
|
@ -517,7 +530,6 @@ class SpanPredictor(torch.nn.Module):
|
|||
# this use of dist_emb_size looks wrong but it was 64...?
|
||||
torch.nn.Linear(256, dist_emb_size),
|
||||
)
|
||||
self.device = device
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(64, 4, 3, 1, 1),
|
||||
torch.nn.Conv1d(4, 2, 3, 1, 1)
|
||||
|
@ -541,17 +553,18 @@ class SpanPredictor(torch.nn.Module):
|
|||
Returns:
|
||||
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
|
||||
"""
|
||||
# If we don't receive heads, return empty
|
||||
if heads_ids.nelement() == 0:
|
||||
return torch.empty(size=(0,))
|
||||
# Obtain distance embedding indices, [n_heads, n_words]
|
||||
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0))
|
||||
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0))
|
||||
# make all valid distances positive
|
||||
emb_ids = relative_positions + 63
|
||||
# "too_far"
|
||||
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
||||
# Obtain "same sentence" boolean mask, [n_heads, n_words]
|
||||
sent_id = torch.tensor(sent_id, device=words.device)
|
||||
heads_ids = heads_ids.long()
|
||||
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
|
||||
|
||||
# To save memory, only pass candidates from one sentence for each head
|
||||
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
|
||||
# for each candidate among the words in the same sentence as span_head
|
||||
|
@ -562,23 +575,20 @@ class SpanPredictor(torch.nn.Module):
|
|||
words[cols],
|
||||
self.emb(emb_ids[rows, cols]),
|
||||
), dim=1)
|
||||
|
||||
lengths = same_sent.sum(dim=1)
|
||||
padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0)
|
||||
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
|
||||
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
|
||||
|
||||
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
|
||||
# This is necessary to allow the convolution layer to look at several
|
||||
# word scores
|
||||
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device)
|
||||
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
|
||||
padded_pairs[padding_mask] = pair_matrix
|
||||
|
||||
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
|
||||
res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]
|
||||
|
||||
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device)
|
||||
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'))
|
||||
scores[rows, cols] = res[padding_mask]
|
||||
|
||||
# Make sure that start <= head <= end during inference
|
||||
if not self.training:
|
||||
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
|
||||
|
@ -586,6 +596,7 @@ class SpanPredictor(torch.nn.Module):
|
|||
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
|
||||
return scores + valid_positions
|
||||
return scores
|
||||
|
||||
class DistancePairwiseEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(self, embedding_size, dropout_rate):
|
||||
|
@ -595,17 +606,10 @@ class DistancePairwiseEncoder(torch.nn.Module):
|
|||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.shape = emb_size
|
||||
|
||||
@property
|
||||
def device(self) -> torch.device:
|
||||
""" A workaround to get current device (which is assumed to be the
|
||||
device of the first parameter of one of the submodules) """
|
||||
return next(self.distance_emb.parameters()).device
|
||||
|
||||
|
||||
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||
top_indices: torch.Tensor
|
||||
) -> torch.Tensor:
|
||||
word_ids = torch.arange(0, top_indices.size(0), device=self.device)
|
||||
word_ids = torch.arange(0, top_indices.size(0))
|
||||
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
|
||||
).clamp_min_(min=1)
|
||||
log_distance = distance.to(torch.float).log2().floor_()
|
||||
|
|
|
@ -3,7 +3,7 @@ import warnings
|
|||
|
||||
from thinc.types import Floats2d, Floats3d, Ints2d
|
||||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
||||
from thinc.api import set_dropout_rate
|
||||
from thinc.api import set_dropout_rate, to_categorical
|
||||
from itertools import islice
|
||||
from statistics import mean
|
||||
|
||||
|
@ -130,7 +130,6 @@ class CoreferenceResolver(TrainablePipe):
|
|||
|
||||
DOCS: https://spacy.io/api/coref#predict (TODO)
|
||||
"""
|
||||
#print("DOCS", docs)
|
||||
out = []
|
||||
for doc in docs:
|
||||
scores, idxs = self.model.predict([doc])
|
||||
|
@ -212,7 +211,6 @@ class CoreferenceResolver(TrainablePipe):
|
|||
# TODO check this causes no issues (in practice it runs)
|
||||
preds, backprop = self.model.begin_update([eg.predicted])
|
||||
score_matrix, mention_idx = preds
|
||||
|
||||
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
||||
total_loss += loss
|
||||
# TODO check shape here
|
||||
|
@ -366,9 +364,7 @@ class CoreferenceResolver(TrainablePipe):
|
|||
for ex in examples:
|
||||
p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix)
|
||||
g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix)
|
||||
|
||||
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
||||
|
||||
evaluator.update(cluster_info)
|
||||
|
||||
score = {
|
||||
|
@ -460,27 +456,29 @@ class SpanPredictor(TrainablePipe):
|
|||
out = []
|
||||
for doc in docs:
|
||||
# TODO check shape here
|
||||
span_scores = self.model.predict(doc)
|
||||
span_scores = span_scores[0]
|
||||
# the information about clustering has to come from the input docs
|
||||
# first let's convert the scores to a list of span idxs
|
||||
start_scores = span_scores[:, :, 0]
|
||||
end_scores = span_scores[:, :, 1]
|
||||
starts = start_scores.argmax(axis=1)
|
||||
ends = end_scores.argmax(axis=1)
|
||||
span_scores = self.model.predict([doc])
|
||||
if span_scores.size:
|
||||
# the information about clustering has to come from the input docs
|
||||
# first let's convert the scores to a list of span idxs
|
||||
start_scores = span_scores[:, :, 0]
|
||||
end_scores = span_scores[:, :, 1]
|
||||
starts = start_scores.argmax(axis=1)
|
||||
ends = end_scores.argmax(axis=1)
|
||||
|
||||
# TODO check start < end
|
||||
# TODO check start < end
|
||||
|
||||
# get the old clusters (shape will be preserved)
|
||||
clusters = doc2clusters(doc, self.input_prefix)
|
||||
cidx = 0
|
||||
out_clusters = []
|
||||
for cluster in clusters:
|
||||
ncluster = []
|
||||
for mention in cluster:
|
||||
ncluster.append( (starts[cidx], ends[cidx]) )
|
||||
cidx += 1
|
||||
out_clusters.append(ncluster)
|
||||
# get the old clusters (shape will be preserved)
|
||||
clusters = doc2clusters(doc, self.input_prefix)
|
||||
cidx = 0
|
||||
out_clusters = []
|
||||
for cluster in clusters:
|
||||
ncluster = []
|
||||
for mention in cluster:
|
||||
ncluster.append((starts[cidx], ends[cidx]))
|
||||
cidx += 1
|
||||
out_clusters.append(ncluster)
|
||||
else:
|
||||
out_clusters = []
|
||||
out.append(out_clusters)
|
||||
return out
|
||||
|
||||
|
@ -505,21 +503,21 @@ class SpanPredictor(TrainablePipe):
|
|||
losses = {}
|
||||
losses.setdefault(self.name, 0.0)
|
||||
validate_examples(examples, "SpanPredictor.update")
|
||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||
if not any(len(eg.reference) if eg.reference else 0 for eg in examples):
|
||||
# Handle cases where there are no tokens in any docs.
|
||||
return losses
|
||||
set_dropout_rate(self.model, drop)
|
||||
|
||||
total_loss = 0
|
||||
|
||||
for eg in examples:
|
||||
preds, backprop = self.model.begin_update([eg.predicted])
|
||||
score_matrix, mention_idx = preds
|
||||
|
||||
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
||||
total_loss += loss
|
||||
# TODO check shape here
|
||||
backprop((d_scores, mention_idx))
|
||||
span_scores, backprop = self.model.begin_update([eg.predicted])
|
||||
# FIXME, this only happens once in the first 1000 docs of OntoNotes
|
||||
# and I'm not sure yet why.
|
||||
if span_scores.size:
|
||||
loss, d_scores = self.get_loss([eg], span_scores)
|
||||
total_loss += loss
|
||||
# TODO check shape here
|
||||
backprop((d_scores))
|
||||
|
||||
if sgd is not None:
|
||||
self.finish_update(sgd)
|
||||
|
@ -562,19 +560,17 @@ class SpanPredictor(TrainablePipe):
|
|||
assert len(examples) == 1, "Only fake batching is supported."
|
||||
# starts and ends are gold starts and ends (Ints1d)
|
||||
# span_scores is a Floats3d. What are the axes? mention x token x start/end
|
||||
|
||||
for eg in examples:
|
||||
|
||||
# get gold data
|
||||
gold = doc2clusters(eg.reference, self.output_prefix)
|
||||
# flatten the gold data
|
||||
starts = []
|
||||
ends = []
|
||||
for cluster in gold:
|
||||
for mention in cluster:
|
||||
starts.append(mention[0])
|
||||
ends.append(mention[1])
|
||||
for key, sg in eg.reference.spans.items():
|
||||
if key.startswith(self.output_prefix):
|
||||
for mention in sg:
|
||||
starts.append(mention.start)
|
||||
ends.append(mention.end)
|
||||
|
||||
starts = self.model.ops.xp.asarray(starts)
|
||||
ends = self.model.ops.xp.asarray(ends)
|
||||
start_scores = span_scores[:, :, 0]
|
||||
end_scores = span_scores[:, :, 1]
|
||||
n_classes = start_scores.shape[1]
|
||||
|
@ -594,7 +590,7 @@ class SpanPredictor(TrainablePipe):
|
|||
*,
|
||||
nlp: Optional[Language] = None,
|
||||
) -> None:
|
||||
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
||||
validate_get_examples(get_examples, "SpanPredictor.initialize")
|
||||
|
||||
X = []
|
||||
Y = []
|
||||
|
@ -612,31 +608,33 @@ class SpanPredictor(TrainablePipe):
|
|||
self.model.initialize(X=X, Y=Y)
|
||||
|
||||
def score(self, examples, **kwargs):
|
||||
"""Score a batch of examples."""
|
||||
# TODO This is basically the same as the main coref component - factor out?
|
||||
|
||||
"""
|
||||
Evaluate on reconstructing the correct spans around
|
||||
gold heads.
|
||||
"""
|
||||
scores = []
|
||||
for metric in (b_cubed, muc, ceafe):
|
||||
evaluator = Evaluator(metric)
|
||||
xp = self.model.ops.xp
|
||||
for eg in examples:
|
||||
starts = []
|
||||
ends = []
|
||||
pred_starts = []
|
||||
pred_ends = []
|
||||
ref = eg.reference
|
||||
pred = eg.predicted
|
||||
for key, gold_sg in ref.spans.items():
|
||||
if key.startswith(self.output_prefix):
|
||||
pred_sg = pred.spans[key]
|
||||
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
|
||||
starts.append(gold_mention.start)
|
||||
ends.append(gold_mention.end)
|
||||
pred_starts.append(pred_mention.start)
|
||||
pred_ends.append(pred_mention.end)
|
||||
|
||||
for ex in examples:
|
||||
# XXX this is the only different part
|
||||
p_clusters = doc2clusters(ex.predicted, self.output_prefix)
|
||||
g_clusters = doc2clusters(ex.reference, self.output_prefix)
|
||||
|
||||
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
||||
|
||||
evaluator.update(cluster_info)
|
||||
|
||||
score = {
|
||||
"coref_f": evaluator.get_f1(),
|
||||
"coref_p": evaluator.get_precision(),
|
||||
"coref_r": evaluator.get_recall(),
|
||||
}
|
||||
scores.append(score)
|
||||
|
||||
out = {}
|
||||
for field in ("f", "p", "r"):
|
||||
fname = f"coref_{field}"
|
||||
out[fname] = mean([ss[fname] for ss in scores])
|
||||
return out
|
||||
starts = xp.asarray(starts)
|
||||
ends = xp.asarray(ends)
|
||||
pred_starts = xp.asarray(pred_starts)
|
||||
pred_ends = xp.asarray(pred_ends)
|
||||
correct = (starts == pred_starts) * (ends == pred_ends)
|
||||
accuracy = correct.mean()
|
||||
scores.append(float(accuracy))
|
||||
return {"span_accuracy": mean(scores)}
|
||||
|
|
Loading…
Reference in New Issue
Block a user