Preparing span predictor for predicting from gold (#10547)

Note this is squashed because rebasing had conflicts.

* remove unnecessary .device

* span predictor debug start

* gearing up SpanPredictor for gold-heads

* merge SpanPredictor attributes

* remove useless extra prefix and device from spanpredictor

* make sure predicted and reference keeps aligned

* handle empty head_ids

* handle empty clusters

* addressing suggestions by @polm

* nicer restore

* fix score overwriting bug

* prepare for aligned heads-spans training

* span accuracy score

* update with eg.predited as other components

* add backprop callback to spanpredictor

* report start- and end-accuracies separately

* fixing scorer

Co-authored-by: Kádár Ákos <akos@onyx.uvt.nl>
This commit is contained in:
kadarakos 2022-04-13 12:42:49 +02:00 committed by GitHub
parent eec00ce60d
commit b53113e3b8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 114 additions and 112 deletions

View File

@ -37,14 +37,11 @@ def build_wl_coref_model(
except ValueError: except ValueError:
# happens with transformer listener # happens with transformer listener
dim = 768 dim = 768
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
# TODO chain tok2vec with these models # TODO chain tok2vec with these models
# TODO fix device - should be automatic
device = "cuda:0"
coref_scorer = PyTorchWrapper( coref_scorer = PyTorchWrapper(
CorefScorer( CorefScorer(
device,
dim, dim,
embedding_size, embedding_size,
hidden_size, hidden_size,
@ -56,8 +53,16 @@ def build_wl_coref_model(
convert_inputs=convert_coref_scorer_inputs, convert_inputs=convert_coref_scorer_inputs,
convert_outputs=convert_coref_scorer_outputs convert_outputs=convert_coref_scorer_outputs
) )
coref_model = tok2vec >> coref_scorer coref_model = tok2vec >> coref_scorer
# XXX just ignore this until the coref scorer is integrated
# span_predictor = PyTorchWrapper(
# SpanPredictor(
# TODO this was hardcoded to 1024, check
# hidden_size,
# sp_embedding_size,
# ),
# convert_inputs=convert_span_predictor_inputs
# )
# TODO combine models so output is uniform (just one forward pass) # TODO combine models so output is uniform (just one forward pass)
# It may be reasonable to have an option to disable span prediction, # It may be reasonable to have an option to disable span prediction,
# and just return words as spans. # and just return words as spans.
@ -77,14 +82,14 @@ def build_span_predictor(
dim = 768 dim = 768
with Model.define_operators({">>": chain, "&": tuplify}): with Model.define_operators({">>": chain, "&": tuplify}):
# TODO fix device - should be automatic
device = "cuda:0"
span_predictor = PyTorchWrapper( span_predictor = PyTorchWrapper(
SpanPredictor(dim, hidden_size, dist_emb_size, device), SpanPredictor(dim, hidden_size, dist_emb_size),
convert_inputs=convert_span_predictor_inputs convert_inputs=convert_span_predictor_inputs
) )
# TODO use proper parameter for prefix # TODO use proper parameter for prefix
head_info = build_get_head_metadata("coref_head_clusters") head_info = build_get_head_metadata(
"coref_head_clusters"
)
model = (tok2vec & head_info) >> span_predictor model = (tok2vec & head_info) >> span_predictor
return model return model
@ -99,13 +104,13 @@ def convert_coref_scorer_inputs(
# just use the first # just use the first
# TODO real batching # TODO real batching
X = X[0] X = X[0]
word_features = xp2torch(X, requires_grad=is_train) word_features = xp2torch(X, requires_grad=is_train)
def backprop(args: ArgsKwargs) -> List[Floats2d]: def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list # convert to xp and wrap in list
gradients = torch2xp(args.args[0]) gradients = torch2xp(args.args[0])
return [gradients] return [gradients]
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
@ -128,6 +133,7 @@ def convert_coref_scorer_outputs(
indices_xp = torch2xp(indices) indices_xp = torch2xp(indices)
return (scores_xp, indices_xp), convert_for_torch_backward return (scores_xp, indices_xp), convert_for_torch_backward
def convert_span_predictor_inputs( def convert_span_predictor_inputs(
model: Model, model: Model,
X: Tuple[Ints1d, Floats2d, Ints1d], X: Tuple[Ints1d, Floats2d, Ints1d],
@ -135,14 +141,23 @@ def convert_span_predictor_inputs(
): ):
tok2vec, (sent_ids, head_ids) = X tok2vec, (sent_ids, head_ids) = X
# Normally we shoudl use the input is_train, but for these two it's not relevant # Normally we shoudl use the input is_train, but for these two it's not relevant
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
head_ids = xp2torch(head_ids[0], requires_grad=False) def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[1])
return [[gradients], None]
word_features = xp2torch(tok2vec[0], requires_grad=is_train) word_features = xp2torch(tok2vec[0], requires_grad=is_train)
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
if not head_ids[0].size:
head_ids = torch.empty(size=(0,))
else:
head_ids = xp2torch(head_ids[0], requires_grad=False)
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
# TODO actually support backprop # TODO actually support backprop
return argskwargs, lambda dX: [] return argskwargs, backprop
# TODO This probably belongs in the component, not the model. # TODO This probably belongs in the component, not the model.
def predict_span_clusters(span_predictor: Model, def predict_span_clusters(span_predictor: Model,
@ -211,18 +226,21 @@ def _clusterize(
clusters.append(sorted(cluster)) clusters.append(sorted(cluster))
return sorted(clusters) return sorted(clusters)
def build_get_head_metadata(prefix): def build_get_head_metadata(prefix):
# TODO this name is awful, fix it # TODO this name is awful, fix it
model = Model("HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward) model = Model("HeadDataProvider",
attrs={'prefix': prefix},
forward=head_data_forward)
return model return model
def head_data_forward(model, docs, is_train): def head_data_forward(model, docs, is_train):
"""A layer to generate the extra data needed for the span predictor. """A layer to generate the extra data needed for the span predictor.
""" """
sent_ids = [] sent_ids = []
head_ids = [] head_ids = []
prefix = model.attrs["prefix"] prefix = model.attrs["prefix"]
for doc in docs: for doc in docs:
sids = model.ops.asarray2i(get_sentence_ids(doc)) sids = model.ops.asarray2i(get_sentence_ids(doc))
sent_ids.append(sids) sent_ids.append(sids)
@ -235,7 +253,6 @@ def head_data_forward(model, docs, is_train):
heads.append(span[0].i) heads.append(span[0].i)
heads = model.ops.asarray2i(heads) heads = model.ops.asarray2i(heads)
head_ids.append(heads) head_ids.append(heads)
# each of these is a list with one entry per doc # each of these is a list with one entry per doc
# backprop is just a placeholder # backprop is just a placeholder
# TODO it would probably be better to have a list of tuples than two lists of arrays # TODO it would probably be better to have a list of tuples than two lists of arrays
@ -256,7 +273,6 @@ class CorefScorer(torch.nn.Module):
""" """
def __init__( def __init__(
self, self,
device: str,
dim: int, # tok2vec size dim: int, # tok2vec size
dist_emb_size: int, dist_emb_size: int,
hidden_size: int, hidden_size: int,
@ -273,8 +289,7 @@ class CorefScorer(torch.nn.Module):
epochs_trained (int): the number of epochs finished epochs_trained (int): the number of epochs finished
(useful for warm start) (useful for warm start)
""" """
# device, dist_emb_size, hidden_size, n_layers, dropout_rate self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device)
#TODO clean this up #TODO clean this up
bert_emb = dim bert_emb = dim
pair_emb = bert_emb * 3 + self.pw.shape pair_emb = bert_emb * 3 + self.pw.shape
@ -283,7 +298,7 @@ class CorefScorer(torch.nn.Module):
hidden_size, hidden_size,
n_layers, n_layers,
dropout_rate dropout_rate
).to(device) )
self.lstm = torch.nn.LSTM( self.lstm = torch.nn.LSTM(
input_size=bert_emb, input_size=bert_emb,
hidden_size=bert_emb, hidden_size=bert_emb,
@ -294,7 +309,7 @@ class CorefScorer(torch.nn.Module):
bert_emb, bert_emb,
dropout_rate, dropout_rate,
roughk roughk
).to(device) )
self.batch_size = batch_size self.batch_size = batch_size
def forward( def forward(
@ -443,7 +458,6 @@ class AnaphoricityScorer(torch.nn.Module):
return out return out
class RoughScorer(torch.nn.Module): class RoughScorer(torch.nn.Module):
""" """
Is needed to give a roughly estimate of the anaphoricity of two candidates, Is needed to give a roughly estimate of the anaphoricity of two candidates,
@ -474,7 +488,6 @@ class RoughScorer(torch.nn.Module):
pair_mask = torch.arange(mentions.shape[0]) pair_mask = torch.arange(mentions.shape[0])
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0) pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
pair_mask = torch.log((pair_mask > 0).to(torch.float)) pair_mask = torch.log((pair_mask > 0).to(torch.float))
pair_mask = pair_mask.to(mentions.device)
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T) bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
rough_scores = pair_mask + bilinear_scores rough_scores = pair_mask + bilinear_scores
@ -501,7 +514,7 @@ class RoughScorer(torch.nn.Module):
class SpanPredictor(torch.nn.Module): class SpanPredictor(torch.nn.Module):
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device): def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int):
super().__init__() super().__init__()
# input size = single token size # input size = single token size
# 64 = probably distance emb size # 64 = probably distance emb size
@ -517,7 +530,6 @@ class SpanPredictor(torch.nn.Module):
# this use of dist_emb_size looks wrong but it was 64...? # this use of dist_emb_size looks wrong but it was 64...?
torch.nn.Linear(256, dist_emb_size), torch.nn.Linear(256, dist_emb_size),
) )
self.device = device
self.conv = torch.nn.Sequential( self.conv = torch.nn.Sequential(
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(64, 4, 3, 1, 1),
torch.nn.Conv1d(4, 2, 3, 1, 1) torch.nn.Conv1d(4, 2, 3, 1, 1)
@ -541,17 +553,18 @@ class SpanPredictor(torch.nn.Module):
Returns: Returns:
torch.Tensor: span start/end scores, [n_heads, n_words, 2] torch.Tensor: span start/end scores, [n_heads, n_words, 2]
""" """
# If we don't receive heads, return empty
if heads_ids.nelement() == 0:
return torch.empty(size=(0,))
# Obtain distance embedding indices, [n_heads, n_words] # Obtain distance embedding indices, [n_heads, n_words]
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0)) relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0))
# make all valid distances positive # make all valid distances positive
emb_ids = relative_positions + 63 emb_ids = relative_positions + 63
# "too_far" # "too_far"
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
# Obtain "same sentence" boolean mask, [n_heads, n_words] # Obtain "same sentence" boolean mask, [n_heads, n_words]
sent_id = torch.tensor(sent_id, device=words.device)
heads_ids = heads_ids.long() heads_ids = heads_ids.long()
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)) same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
# To save memory, only pass candidates from one sentence for each head # To save memory, only pass candidates from one sentence for each head
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
# for each candidate among the words in the same sentence as span_head # for each candidate among the words in the same sentence as span_head
@ -562,23 +575,20 @@ class SpanPredictor(torch.nn.Module):
words[cols], words[cols],
self.emb(emb_ids[rows, cols]), self.emb(emb_ids[rows, cols]),
), dim=1) ), dim=1)
lengths = same_sent.sum(dim=1) lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0) padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len] padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size] # [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
# This is necessary to allow the convolution layer to look at several # This is necessary to allow the convolution layer to look at several
# word scores # word scores
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device) padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
padded_pairs[padding_mask] = pair_matrix padded_pairs[padding_mask] = pair_matrix
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output] res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2] res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device) scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'))
scores[rows, cols] = res[padding_mask] scores[rows, cols] = res[padding_mask]
# Make sure that start <= head <= end during inference # Make sure that start <= head <= end during inference
if not self.training: if not self.training:
valid_starts = torch.log((relative_positions >= 0).to(torch.float)) valid_starts = torch.log((relative_positions >= 0).to(torch.float))
@ -586,6 +596,7 @@ class SpanPredictor(torch.nn.Module):
valid_positions = torch.stack((valid_starts, valid_ends), dim=2) valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
return scores + valid_positions return scores + valid_positions
return scores return scores
class DistancePairwiseEncoder(torch.nn.Module): class DistancePairwiseEncoder(torch.nn.Module):
def __init__(self, embedding_size, dropout_rate): def __init__(self, embedding_size, dropout_rate):
@ -595,17 +606,10 @@ class DistancePairwiseEncoder(torch.nn.Module):
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout_rate)
self.shape = emb_size self.shape = emb_size
@property
def device(self) -> torch.device:
""" A workaround to get current device (which is assumed to be the
device of the first parameter of one of the submodules) """
return next(self.distance_emb.parameters()).device
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
top_indices: torch.Tensor top_indices: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
word_ids = torch.arange(0, top_indices.size(0), device=self.device) word_ids = torch.arange(0, top_indices.size(0))
distance = (word_ids.unsqueeze(1) - word_ids[top_indices] distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
).clamp_min_(min=1) ).clamp_min_(min=1)
log_distance = distance.to(torch.float).log2().floor_() log_distance = distance.to(torch.float).log2().floor_()

View File

@ -3,7 +3,7 @@ import warnings
from thinc.types import Floats2d, Floats3d, Ints2d from thinc.types import Floats2d, Floats3d, Ints2d
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
from thinc.api import set_dropout_rate from thinc.api import set_dropout_rate, to_categorical
from itertools import islice from itertools import islice
from statistics import mean from statistics import mean
@ -130,7 +130,6 @@ class CoreferenceResolver(TrainablePipe):
DOCS: https://spacy.io/api/coref#predict (TODO) DOCS: https://spacy.io/api/coref#predict (TODO)
""" """
#print("DOCS", docs)
out = [] out = []
for doc in docs: for doc in docs:
scores, idxs = self.model.predict([doc]) scores, idxs = self.model.predict([doc])
@ -212,7 +211,6 @@ class CoreferenceResolver(TrainablePipe):
# TODO check this causes no issues (in practice it runs) # TODO check this causes no issues (in practice it runs)
preds, backprop = self.model.begin_update([eg.predicted]) preds, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds score_matrix, mention_idx = preds
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx) loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
total_loss += loss total_loss += loss
# TODO check shape here # TODO check shape here
@ -366,9 +364,7 @@ class CoreferenceResolver(TrainablePipe):
for ex in examples: for ex in examples:
p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix) p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix)
g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix) g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix)
cluster_info = get_cluster_info(p_clusters, g_clusters) cluster_info = get_cluster_info(p_clusters, g_clusters)
evaluator.update(cluster_info) evaluator.update(cluster_info)
score = { score = {
@ -460,27 +456,29 @@ class SpanPredictor(TrainablePipe):
out = [] out = []
for doc in docs: for doc in docs:
# TODO check shape here # TODO check shape here
span_scores = self.model.predict(doc) span_scores = self.model.predict([doc])
span_scores = span_scores[0] if span_scores.size:
# the information about clustering has to come from the input docs # the information about clustering has to come from the input docs
# first let's convert the scores to a list of span idxs # first let's convert the scores to a list of span idxs
start_scores = span_scores[:, :, 0] start_scores = span_scores[:, :, 0]
end_scores = span_scores[:, :, 1] end_scores = span_scores[:, :, 1]
starts = start_scores.argmax(axis=1) starts = start_scores.argmax(axis=1)
ends = end_scores.argmax(axis=1) ends = end_scores.argmax(axis=1)
# TODO check start < end # TODO check start < end
# get the old clusters (shape will be preserved) # get the old clusters (shape will be preserved)
clusters = doc2clusters(doc, self.input_prefix) clusters = doc2clusters(doc, self.input_prefix)
cidx = 0 cidx = 0
out_clusters = [] out_clusters = []
for cluster in clusters: for cluster in clusters:
ncluster = [] ncluster = []
for mention in cluster: for mention in cluster:
ncluster.append( (starts[cidx], ends[cidx]) ) ncluster.append((starts[cidx], ends[cidx]))
cidx += 1 cidx += 1
out_clusters.append(ncluster) out_clusters.append(ncluster)
else:
out_clusters = []
out.append(out_clusters) out.append(out_clusters)
return out return out
@ -505,21 +503,21 @@ class SpanPredictor(TrainablePipe):
losses = {} losses = {}
losses.setdefault(self.name, 0.0) losses.setdefault(self.name, 0.0)
validate_examples(examples, "SpanPredictor.update") validate_examples(examples, "SpanPredictor.update")
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples): if not any(len(eg.reference) if eg.reference else 0 for eg in examples):
# Handle cases where there are no tokens in any docs. # Handle cases where there are no tokens in any docs.
return losses return losses
set_dropout_rate(self.model, drop) set_dropout_rate(self.model, drop)
total_loss = 0 total_loss = 0
for eg in examples: for eg in examples:
preds, backprop = self.model.begin_update([eg.predicted]) span_scores, backprop = self.model.begin_update([eg.predicted])
score_matrix, mention_idx = preds # FIXME, this only happens once in the first 1000 docs of OntoNotes
# and I'm not sure yet why.
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx) if span_scores.size:
total_loss += loss loss, d_scores = self.get_loss([eg], span_scores)
# TODO check shape here total_loss += loss
backprop((d_scores, mention_idx)) # TODO check shape here
backprop((d_scores))
if sgd is not None: if sgd is not None:
self.finish_update(sgd) self.finish_update(sgd)
@ -562,19 +560,17 @@ class SpanPredictor(TrainablePipe):
assert len(examples) == 1, "Only fake batching is supported." assert len(examples) == 1, "Only fake batching is supported."
# starts and ends are gold starts and ends (Ints1d) # starts and ends are gold starts and ends (Ints1d)
# span_scores is a Floats3d. What are the axes? mention x token x start/end # span_scores is a Floats3d. What are the axes? mention x token x start/end
for eg in examples: for eg in examples:
# get gold data
gold = doc2clusters(eg.reference, self.output_prefix)
# flatten the gold data
starts = [] starts = []
ends = [] ends = []
for cluster in gold: for key, sg in eg.reference.spans.items():
for mention in cluster: if key.startswith(self.output_prefix):
starts.append(mention[0]) for mention in sg:
ends.append(mention[1]) starts.append(mention.start)
ends.append(mention.end)
starts = self.model.ops.xp.asarray(starts)
ends = self.model.ops.xp.asarray(ends)
start_scores = span_scores[:, :, 0] start_scores = span_scores[:, :, 0]
end_scores = span_scores[:, :, 1] end_scores = span_scores[:, :, 1]
n_classes = start_scores.shape[1] n_classes = start_scores.shape[1]
@ -594,7 +590,7 @@ class SpanPredictor(TrainablePipe):
*, *,
nlp: Optional[Language] = None, nlp: Optional[Language] = None,
) -> None: ) -> None:
validate_get_examples(get_examples, "CoreferenceResolver.initialize") validate_get_examples(get_examples, "SpanPredictor.initialize")
X = [] X = []
Y = [] Y = []
@ -612,31 +608,33 @@ class SpanPredictor(TrainablePipe):
self.model.initialize(X=X, Y=Y) self.model.initialize(X=X, Y=Y)
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
"""Score a batch of examples.""" """
# TODO This is basically the same as the main coref component - factor out? Evaluate on reconstructing the correct spans around
gold heads.
"""
scores = [] scores = []
for metric in (b_cubed, muc, ceafe): xp = self.model.ops.xp
evaluator = Evaluator(metric) for eg in examples:
starts = []
ends = []
pred_starts = []
pred_ends = []
ref = eg.reference
pred = eg.predicted
for key, gold_sg in ref.spans.items():
if key.startswith(self.output_prefix):
pred_sg = pred.spans[key]
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
starts.append(gold_mention.start)
ends.append(gold_mention.end)
pred_starts.append(pred_mention.start)
pred_ends.append(pred_mention.end)
for ex in examples: starts = xp.asarray(starts)
# XXX this is the only different part ends = xp.asarray(ends)
p_clusters = doc2clusters(ex.predicted, self.output_prefix) pred_starts = xp.asarray(pred_starts)
g_clusters = doc2clusters(ex.reference, self.output_prefix) pred_ends = xp.asarray(pred_ends)
correct = (starts == pred_starts) * (ends == pred_ends)
cluster_info = get_cluster_info(p_clusters, g_clusters) accuracy = correct.mean()
scores.append(float(accuracy))
evaluator.update(cluster_info) return {"span_accuracy": mean(scores)}
score = {
"coref_f": evaluator.get_f1(),
"coref_p": evaluator.get_precision(),
"coref_r": evaluator.get_recall(),
}
scores.append(score)
out = {}
for field in ("f", "p", "r"):
fname = f"coref_{field}"
out[fname] = mean([ss[fname] for ss in scores])
return out