mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Preparing span predictor for predicting from gold (#10547)
Note this is squashed because rebasing had conflicts. * remove unnecessary .device * span predictor debug start * gearing up SpanPredictor for gold-heads * merge SpanPredictor attributes * remove useless extra prefix and device from spanpredictor * make sure predicted and reference keeps aligned * handle empty head_ids * handle empty clusters * addressing suggestions by @polm * nicer restore * fix score overwriting bug * prepare for aligned heads-spans training * span accuracy score * update with eg.predited as other components * add backprop callback to spanpredictor * report start- and end-accuracies separately * fixing scorer Co-authored-by: Kádár Ákos <akos@onyx.uvt.nl>
This commit is contained in:
parent
eec00ce60d
commit
b53113e3b8
|
@ -40,11 +40,8 @@ def build_wl_coref_model(
|
||||||
|
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
# TODO chain tok2vec with these models
|
# TODO chain tok2vec with these models
|
||||||
# TODO fix device - should be automatic
|
|
||||||
device = "cuda:0"
|
|
||||||
coref_scorer = PyTorchWrapper(
|
coref_scorer = PyTorchWrapper(
|
||||||
CorefScorer(
|
CorefScorer(
|
||||||
device,
|
|
||||||
dim,
|
dim,
|
||||||
embedding_size,
|
embedding_size,
|
||||||
hidden_size,
|
hidden_size,
|
||||||
|
@ -56,8 +53,16 @@ def build_wl_coref_model(
|
||||||
convert_inputs=convert_coref_scorer_inputs,
|
convert_inputs=convert_coref_scorer_inputs,
|
||||||
convert_outputs=convert_coref_scorer_outputs
|
convert_outputs=convert_coref_scorer_outputs
|
||||||
)
|
)
|
||||||
|
|
||||||
coref_model = tok2vec >> coref_scorer
|
coref_model = tok2vec >> coref_scorer
|
||||||
|
# XXX just ignore this until the coref scorer is integrated
|
||||||
|
# span_predictor = PyTorchWrapper(
|
||||||
|
# SpanPredictor(
|
||||||
|
# TODO this was hardcoded to 1024, check
|
||||||
|
# hidden_size,
|
||||||
|
# sp_embedding_size,
|
||||||
|
# ),
|
||||||
|
# convert_inputs=convert_span_predictor_inputs
|
||||||
|
# )
|
||||||
# TODO combine models so output is uniform (just one forward pass)
|
# TODO combine models so output is uniform (just one forward pass)
|
||||||
# It may be reasonable to have an option to disable span prediction,
|
# It may be reasonable to have an option to disable span prediction,
|
||||||
# and just return words as spans.
|
# and just return words as spans.
|
||||||
|
@ -77,14 +82,14 @@ def build_span_predictor(
|
||||||
dim = 768
|
dim = 768
|
||||||
|
|
||||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||||
# TODO fix device - should be automatic
|
|
||||||
device = "cuda:0"
|
|
||||||
span_predictor = PyTorchWrapper(
|
span_predictor = PyTorchWrapper(
|
||||||
SpanPredictor(dim, hidden_size, dist_emb_size, device),
|
SpanPredictor(dim, hidden_size, dist_emb_size),
|
||||||
convert_inputs=convert_span_predictor_inputs
|
convert_inputs=convert_span_predictor_inputs
|
||||||
)
|
)
|
||||||
# TODO use proper parameter for prefix
|
# TODO use proper parameter for prefix
|
||||||
head_info = build_get_head_metadata("coref_head_clusters")
|
head_info = build_get_head_metadata(
|
||||||
|
"coref_head_clusters"
|
||||||
|
)
|
||||||
model = (tok2vec & head_info) >> span_predictor
|
model = (tok2vec & head_info) >> span_predictor
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -99,13 +104,13 @@ def convert_coref_scorer_inputs(
|
||||||
# just use the first
|
# just use the first
|
||||||
# TODO real batching
|
# TODO real batching
|
||||||
X = X[0]
|
X = X[0]
|
||||||
|
|
||||||
|
|
||||||
word_features = xp2torch(X, requires_grad=is_train)
|
word_features = xp2torch(X, requires_grad=is_train)
|
||||||
|
|
||||||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||||
# convert to xp and wrap in list
|
# convert to xp and wrap in list
|
||||||
gradients = torch2xp(args.args[0])
|
gradients = torch2xp(args.args[0])
|
||||||
return [gradients]
|
return [gradients]
|
||||||
|
|
||||||
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
|
return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
|
||||||
|
|
||||||
|
|
||||||
|
@ -128,6 +133,7 @@ def convert_coref_scorer_outputs(
|
||||||
indices_xp = torch2xp(indices)
|
indices_xp = torch2xp(indices)
|
||||||
return (scores_xp, indices_xp), convert_for_torch_backward
|
return (scores_xp, indices_xp), convert_for_torch_backward
|
||||||
|
|
||||||
|
|
||||||
def convert_span_predictor_inputs(
|
def convert_span_predictor_inputs(
|
||||||
model: Model,
|
model: Model,
|
||||||
X: Tuple[Ints1d, Floats2d, Ints1d],
|
X: Tuple[Ints1d, Floats2d, Ints1d],
|
||||||
|
@ -135,14 +141,23 @@ def convert_span_predictor_inputs(
|
||||||
):
|
):
|
||||||
tok2vec, (sent_ids, head_ids) = X
|
tok2vec, (sent_ids, head_ids) = X
|
||||||
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
# Normally we shoudl use the input is_train, but for these two it's not relevant
|
||||||
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
|
||||||
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||||
|
# convert to xp and wrap in list
|
||||||
|
gradients = torch2xp(args.args[1])
|
||||||
|
return [[gradients], None]
|
||||||
|
|
||||||
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
word_features = xp2torch(tok2vec[0], requires_grad=is_train)
|
||||||
|
sent_ids = xp2torch(sent_ids[0], requires_grad=False)
|
||||||
|
if not head_ids[0].size:
|
||||||
|
head_ids = torch.empty(size=(0,))
|
||||||
|
else:
|
||||||
|
head_ids = xp2torch(head_ids[0], requires_grad=False)
|
||||||
|
|
||||||
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
|
||||||
# TODO actually support backprop
|
# TODO actually support backprop
|
||||||
return argskwargs, lambda dX: []
|
return argskwargs, backprop
|
||||||
|
|
||||||
|
|
||||||
# TODO This probably belongs in the component, not the model.
|
# TODO This probably belongs in the component, not the model.
|
||||||
def predict_span_clusters(span_predictor: Model,
|
def predict_span_clusters(span_predictor: Model,
|
||||||
|
@ -211,18 +226,21 @@ def _clusterize(
|
||||||
clusters.append(sorted(cluster))
|
clusters.append(sorted(cluster))
|
||||||
return sorted(clusters)
|
return sorted(clusters)
|
||||||
|
|
||||||
|
|
||||||
def build_get_head_metadata(prefix):
|
def build_get_head_metadata(prefix):
|
||||||
# TODO this name is awful, fix it
|
# TODO this name is awful, fix it
|
||||||
model = Model("HeadDataProvider", attrs={"prefix": prefix}, forward=head_data_forward)
|
model = Model("HeadDataProvider",
|
||||||
|
attrs={'prefix': prefix},
|
||||||
|
forward=head_data_forward)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
def head_data_forward(model, docs, is_train):
|
def head_data_forward(model, docs, is_train):
|
||||||
"""A layer to generate the extra data needed for the span predictor.
|
"""A layer to generate the extra data needed for the span predictor.
|
||||||
"""
|
"""
|
||||||
sent_ids = []
|
sent_ids = []
|
||||||
head_ids = []
|
head_ids = []
|
||||||
prefix = model.attrs["prefix"]
|
prefix = model.attrs["prefix"]
|
||||||
|
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
sids = model.ops.asarray2i(get_sentence_ids(doc))
|
sids = model.ops.asarray2i(get_sentence_ids(doc))
|
||||||
sent_ids.append(sids)
|
sent_ids.append(sids)
|
||||||
|
@ -235,7 +253,6 @@ def head_data_forward(model, docs, is_train):
|
||||||
heads.append(span[0].i)
|
heads.append(span[0].i)
|
||||||
heads = model.ops.asarray2i(heads)
|
heads = model.ops.asarray2i(heads)
|
||||||
head_ids.append(heads)
|
head_ids.append(heads)
|
||||||
|
|
||||||
# each of these is a list with one entry per doc
|
# each of these is a list with one entry per doc
|
||||||
# backprop is just a placeholder
|
# backprop is just a placeholder
|
||||||
# TODO it would probably be better to have a list of tuples than two lists of arrays
|
# TODO it would probably be better to have a list of tuples than two lists of arrays
|
||||||
|
@ -256,7 +273,6 @@ class CorefScorer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
device: str,
|
|
||||||
dim: int, # tok2vec size
|
dim: int, # tok2vec size
|
||||||
dist_emb_size: int,
|
dist_emb_size: int,
|
||||||
hidden_size: int,
|
hidden_size: int,
|
||||||
|
@ -273,8 +289,7 @@ class CorefScorer(torch.nn.Module):
|
||||||
epochs_trained (int): the number of epochs finished
|
epochs_trained (int): the number of epochs finished
|
||||||
(useful for warm start)
|
(useful for warm start)
|
||||||
"""
|
"""
|
||||||
# device, dist_emb_size, hidden_size, n_layers, dropout_rate
|
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
|
||||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device)
|
|
||||||
#TODO clean this up
|
#TODO clean this up
|
||||||
bert_emb = dim
|
bert_emb = dim
|
||||||
pair_emb = bert_emb * 3 + self.pw.shape
|
pair_emb = bert_emb * 3 + self.pw.shape
|
||||||
|
@ -283,7 +298,7 @@ class CorefScorer(torch.nn.Module):
|
||||||
hidden_size,
|
hidden_size,
|
||||||
n_layers,
|
n_layers,
|
||||||
dropout_rate
|
dropout_rate
|
||||||
).to(device)
|
)
|
||||||
self.lstm = torch.nn.LSTM(
|
self.lstm = torch.nn.LSTM(
|
||||||
input_size=bert_emb,
|
input_size=bert_emb,
|
||||||
hidden_size=bert_emb,
|
hidden_size=bert_emb,
|
||||||
|
@ -294,7 +309,7 @@ class CorefScorer(torch.nn.Module):
|
||||||
bert_emb,
|
bert_emb,
|
||||||
dropout_rate,
|
dropout_rate,
|
||||||
roughk
|
roughk
|
||||||
).to(device)
|
)
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
|
|
||||||
def forward(
|
def forward(
|
||||||
|
@ -443,7 +458,6 @@ class AnaphoricityScorer(torch.nn.Module):
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
class RoughScorer(torch.nn.Module):
|
class RoughScorer(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
Is needed to give a roughly estimate of the anaphoricity of two candidates,
|
Is needed to give a roughly estimate of the anaphoricity of two candidates,
|
||||||
|
@ -474,7 +488,6 @@ class RoughScorer(torch.nn.Module):
|
||||||
pair_mask = torch.arange(mentions.shape[0])
|
pair_mask = torch.arange(mentions.shape[0])
|
||||||
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
|
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
|
||||||
pair_mask = torch.log((pair_mask > 0).to(torch.float))
|
pair_mask = torch.log((pair_mask > 0).to(torch.float))
|
||||||
pair_mask = pair_mask.to(mentions.device)
|
|
||||||
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
|
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
|
||||||
rough_scores = pair_mask + bilinear_scores
|
rough_scores = pair_mask + bilinear_scores
|
||||||
|
|
||||||
|
@ -501,7 +514,7 @@ class RoughScorer(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class SpanPredictor(torch.nn.Module):
|
class SpanPredictor(torch.nn.Module):
|
||||||
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device):
|
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
# input size = single token size
|
# input size = single token size
|
||||||
# 64 = probably distance emb size
|
# 64 = probably distance emb size
|
||||||
|
@ -517,7 +530,6 @@ class SpanPredictor(torch.nn.Module):
|
||||||
# this use of dist_emb_size looks wrong but it was 64...?
|
# this use of dist_emb_size looks wrong but it was 64...?
|
||||||
torch.nn.Linear(256, dist_emb_size),
|
torch.nn.Linear(256, dist_emb_size),
|
||||||
)
|
)
|
||||||
self.device = device
|
|
||||||
self.conv = torch.nn.Sequential(
|
self.conv = torch.nn.Sequential(
|
||||||
torch.nn.Conv1d(64, 4, 3, 1, 1),
|
torch.nn.Conv1d(64, 4, 3, 1, 1),
|
||||||
torch.nn.Conv1d(4, 2, 3, 1, 1)
|
torch.nn.Conv1d(4, 2, 3, 1, 1)
|
||||||
|
@ -541,17 +553,18 @@ class SpanPredictor(torch.nn.Module):
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
|
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
|
||||||
"""
|
"""
|
||||||
|
# If we don't receive heads, return empty
|
||||||
|
if heads_ids.nelement() == 0:
|
||||||
|
return torch.empty(size=(0,))
|
||||||
# Obtain distance embedding indices, [n_heads, n_words]
|
# Obtain distance embedding indices, [n_heads, n_words]
|
||||||
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0))
|
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0))
|
||||||
# make all valid distances positive
|
# make all valid distances positive
|
||||||
emb_ids = relative_positions + 63
|
emb_ids = relative_positions + 63
|
||||||
# "too_far"
|
# "too_far"
|
||||||
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
|
||||||
# Obtain "same sentence" boolean mask, [n_heads, n_words]
|
# Obtain "same sentence" boolean mask, [n_heads, n_words]
|
||||||
sent_id = torch.tensor(sent_id, device=words.device)
|
|
||||||
heads_ids = heads_ids.long()
|
heads_ids = heads_ids.long()
|
||||||
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
|
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
|
||||||
|
|
||||||
# To save memory, only pass candidates from one sentence for each head
|
# To save memory, only pass candidates from one sentence for each head
|
||||||
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
|
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
|
||||||
# for each candidate among the words in the same sentence as span_head
|
# for each candidate among the words in the same sentence as span_head
|
||||||
|
@ -562,23 +575,20 @@ class SpanPredictor(torch.nn.Module):
|
||||||
words[cols],
|
words[cols],
|
||||||
self.emb(emb_ids[rows, cols]),
|
self.emb(emb_ids[rows, cols]),
|
||||||
), dim=1)
|
), dim=1)
|
||||||
|
|
||||||
lengths = same_sent.sum(dim=1)
|
lengths = same_sent.sum(dim=1)
|
||||||
padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0)
|
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
|
||||||
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
|
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
|
||||||
|
|
||||||
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
|
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
|
||||||
# This is necessary to allow the convolution layer to look at several
|
# This is necessary to allow the convolution layer to look at several
|
||||||
# word scores
|
# word scores
|
||||||
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device)
|
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
|
||||||
padded_pairs[padding_mask] = pair_matrix
|
padded_pairs[padding_mask] = pair_matrix
|
||||||
|
|
||||||
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
|
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
|
||||||
res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]
|
res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]
|
||||||
|
|
||||||
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device)
|
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'))
|
||||||
scores[rows, cols] = res[padding_mask]
|
scores[rows, cols] = res[padding_mask]
|
||||||
|
|
||||||
# Make sure that start <= head <= end during inference
|
# Make sure that start <= head <= end during inference
|
||||||
if not self.training:
|
if not self.training:
|
||||||
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
|
valid_starts = torch.log((relative_positions >= 0).to(torch.float))
|
||||||
|
@ -586,6 +596,7 @@ class SpanPredictor(torch.nn.Module):
|
||||||
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
|
valid_positions = torch.stack((valid_starts, valid_ends), dim=2)
|
||||||
return scores + valid_positions
|
return scores + valid_positions
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
class DistancePairwiseEncoder(torch.nn.Module):
|
class DistancePairwiseEncoder(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, embedding_size, dropout_rate):
|
def __init__(self, embedding_size, dropout_rate):
|
||||||
|
@ -595,17 +606,10 @@ class DistancePairwiseEncoder(torch.nn.Module):
|
||||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||||
self.shape = emb_size
|
self.shape = emb_size
|
||||||
|
|
||||||
@property
|
|
||||||
def device(self) -> torch.device:
|
|
||||||
""" A workaround to get current device (which is assumed to be the
|
|
||||||
device of the first parameter of one of the submodules) """
|
|
||||||
return next(self.distance_emb.parameters()).device
|
|
||||||
|
|
||||||
|
|
||||||
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||||
top_indices: torch.Tensor
|
top_indices: torch.Tensor
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
word_ids = torch.arange(0, top_indices.size(0), device=self.device)
|
word_ids = torch.arange(0, top_indices.size(0))
|
||||||
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
|
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
|
||||||
).clamp_min_(min=1)
|
).clamp_min_(min=1)
|
||||||
log_distance = distance.to(torch.float).log2().floor_()
|
log_distance = distance.to(torch.float).log2().floor_()
|
||||||
|
|
|
@ -3,7 +3,7 @@ import warnings
|
||||||
|
|
||||||
from thinc.types import Floats2d, Floats3d, Ints2d
|
from thinc.types import Floats2d, Floats3d, Ints2d
|
||||||
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
|
||||||
from thinc.api import set_dropout_rate
|
from thinc.api import set_dropout_rate, to_categorical
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
|
|
||||||
|
@ -130,7 +130,6 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/coref#predict (TODO)
|
DOCS: https://spacy.io/api/coref#predict (TODO)
|
||||||
"""
|
"""
|
||||||
#print("DOCS", docs)
|
|
||||||
out = []
|
out = []
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
scores, idxs = self.model.predict([doc])
|
scores, idxs = self.model.predict([doc])
|
||||||
|
@ -212,7 +211,6 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
# TODO check this causes no issues (in practice it runs)
|
# TODO check this causes no issues (in practice it runs)
|
||||||
preds, backprop = self.model.begin_update([eg.predicted])
|
preds, backprop = self.model.begin_update([eg.predicted])
|
||||||
score_matrix, mention_idx = preds
|
score_matrix, mention_idx = preds
|
||||||
|
|
||||||
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
# TODO check shape here
|
# TODO check shape here
|
||||||
|
@ -366,9 +364,7 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
for ex in examples:
|
for ex in examples:
|
||||||
p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix)
|
p_clusters = doc2clusters(ex.predicted, self.span_cluster_prefix)
|
||||||
g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix)
|
g_clusters = doc2clusters(ex.reference, self.span_cluster_prefix)
|
||||||
|
|
||||||
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
||||||
|
|
||||||
evaluator.update(cluster_info)
|
evaluator.update(cluster_info)
|
||||||
|
|
||||||
score = {
|
score = {
|
||||||
|
@ -460,8 +456,8 @@ class SpanPredictor(TrainablePipe):
|
||||||
out = []
|
out = []
|
||||||
for doc in docs:
|
for doc in docs:
|
||||||
# TODO check shape here
|
# TODO check shape here
|
||||||
span_scores = self.model.predict(doc)
|
span_scores = self.model.predict([doc])
|
||||||
span_scores = span_scores[0]
|
if span_scores.size:
|
||||||
# the information about clustering has to come from the input docs
|
# the information about clustering has to come from the input docs
|
||||||
# first let's convert the scores to a list of span idxs
|
# first let's convert the scores to a list of span idxs
|
||||||
start_scores = span_scores[:, :, 0]
|
start_scores = span_scores[:, :, 0]
|
||||||
|
@ -478,9 +474,11 @@ class SpanPredictor(TrainablePipe):
|
||||||
for cluster in clusters:
|
for cluster in clusters:
|
||||||
ncluster = []
|
ncluster = []
|
||||||
for mention in cluster:
|
for mention in cluster:
|
||||||
ncluster.append( (starts[cidx], ends[cidx]) )
|
ncluster.append((starts[cidx], ends[cidx]))
|
||||||
cidx += 1
|
cidx += 1
|
||||||
out_clusters.append(ncluster)
|
out_clusters.append(ncluster)
|
||||||
|
else:
|
||||||
|
out_clusters = []
|
||||||
out.append(out_clusters)
|
out.append(out_clusters)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
@ -505,21 +503,21 @@ class SpanPredictor(TrainablePipe):
|
||||||
losses = {}
|
losses = {}
|
||||||
losses.setdefault(self.name, 0.0)
|
losses.setdefault(self.name, 0.0)
|
||||||
validate_examples(examples, "SpanPredictor.update")
|
validate_examples(examples, "SpanPredictor.update")
|
||||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
if not any(len(eg.reference) if eg.reference else 0 for eg in examples):
|
||||||
# Handle cases where there are no tokens in any docs.
|
# Handle cases where there are no tokens in any docs.
|
||||||
return losses
|
return losses
|
||||||
set_dropout_rate(self.model, drop)
|
set_dropout_rate(self.model, drop)
|
||||||
|
|
||||||
total_loss = 0
|
total_loss = 0
|
||||||
|
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
preds, backprop = self.model.begin_update([eg.predicted])
|
span_scores, backprop = self.model.begin_update([eg.predicted])
|
||||||
score_matrix, mention_idx = preds
|
# FIXME, this only happens once in the first 1000 docs of OntoNotes
|
||||||
|
# and I'm not sure yet why.
|
||||||
loss, d_scores = self.get_loss([eg], score_matrix, mention_idx)
|
if span_scores.size:
|
||||||
|
loss, d_scores = self.get_loss([eg], span_scores)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
# TODO check shape here
|
# TODO check shape here
|
||||||
backprop((d_scores, mention_idx))
|
backprop((d_scores))
|
||||||
|
|
||||||
if sgd is not None:
|
if sgd is not None:
|
||||||
self.finish_update(sgd)
|
self.finish_update(sgd)
|
||||||
|
@ -562,19 +560,17 @@ class SpanPredictor(TrainablePipe):
|
||||||
assert len(examples) == 1, "Only fake batching is supported."
|
assert len(examples) == 1, "Only fake batching is supported."
|
||||||
# starts and ends are gold starts and ends (Ints1d)
|
# starts and ends are gold starts and ends (Ints1d)
|
||||||
# span_scores is a Floats3d. What are the axes? mention x token x start/end
|
# span_scores is a Floats3d. What are the axes? mention x token x start/end
|
||||||
|
|
||||||
for eg in examples:
|
for eg in examples:
|
||||||
|
|
||||||
# get gold data
|
|
||||||
gold = doc2clusters(eg.reference, self.output_prefix)
|
|
||||||
# flatten the gold data
|
|
||||||
starts = []
|
starts = []
|
||||||
ends = []
|
ends = []
|
||||||
for cluster in gold:
|
for key, sg in eg.reference.spans.items():
|
||||||
for mention in cluster:
|
if key.startswith(self.output_prefix):
|
||||||
starts.append(mention[0])
|
for mention in sg:
|
||||||
ends.append(mention[1])
|
starts.append(mention.start)
|
||||||
|
ends.append(mention.end)
|
||||||
|
|
||||||
|
starts = self.model.ops.xp.asarray(starts)
|
||||||
|
ends = self.model.ops.xp.asarray(ends)
|
||||||
start_scores = span_scores[:, :, 0]
|
start_scores = span_scores[:, :, 0]
|
||||||
end_scores = span_scores[:, :, 1]
|
end_scores = span_scores[:, :, 1]
|
||||||
n_classes = start_scores.shape[1]
|
n_classes = start_scores.shape[1]
|
||||||
|
@ -594,7 +590,7 @@ class SpanPredictor(TrainablePipe):
|
||||||
*,
|
*,
|
||||||
nlp: Optional[Language] = None,
|
nlp: Optional[Language] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
validate_get_examples(get_examples, "CoreferenceResolver.initialize")
|
validate_get_examples(get_examples, "SpanPredictor.initialize")
|
||||||
|
|
||||||
X = []
|
X = []
|
||||||
Y = []
|
Y = []
|
||||||
|
@ -612,31 +608,33 @@ class SpanPredictor(TrainablePipe):
|
||||||
self.model.initialize(X=X, Y=Y)
|
self.model.initialize(X=X, Y=Y)
|
||||||
|
|
||||||
def score(self, examples, **kwargs):
|
def score(self, examples, **kwargs):
|
||||||
"""Score a batch of examples."""
|
"""
|
||||||
# TODO This is basically the same as the main coref component - factor out?
|
Evaluate on reconstructing the correct spans around
|
||||||
|
gold heads.
|
||||||
|
"""
|
||||||
scores = []
|
scores = []
|
||||||
for metric in (b_cubed, muc, ceafe):
|
xp = self.model.ops.xp
|
||||||
evaluator = Evaluator(metric)
|
for eg in examples:
|
||||||
|
starts = []
|
||||||
|
ends = []
|
||||||
|
pred_starts = []
|
||||||
|
pred_ends = []
|
||||||
|
ref = eg.reference
|
||||||
|
pred = eg.predicted
|
||||||
|
for key, gold_sg in ref.spans.items():
|
||||||
|
if key.startswith(self.output_prefix):
|
||||||
|
pred_sg = pred.spans[key]
|
||||||
|
for gold_mention, pred_mention in zip(gold_sg, pred_sg):
|
||||||
|
starts.append(gold_mention.start)
|
||||||
|
ends.append(gold_mention.end)
|
||||||
|
pred_starts.append(pred_mention.start)
|
||||||
|
pred_ends.append(pred_mention.end)
|
||||||
|
|
||||||
for ex in examples:
|
starts = xp.asarray(starts)
|
||||||
# XXX this is the only different part
|
ends = xp.asarray(ends)
|
||||||
p_clusters = doc2clusters(ex.predicted, self.output_prefix)
|
pred_starts = xp.asarray(pred_starts)
|
||||||
g_clusters = doc2clusters(ex.reference, self.output_prefix)
|
pred_ends = xp.asarray(pred_ends)
|
||||||
|
correct = (starts == pred_starts) * (ends == pred_ends)
|
||||||
cluster_info = get_cluster_info(p_clusters, g_clusters)
|
accuracy = correct.mean()
|
||||||
|
scores.append(float(accuracy))
|
||||||
evaluator.update(cluster_info)
|
return {"span_accuracy": mean(scores)}
|
||||||
|
|
||||||
score = {
|
|
||||||
"coref_f": evaluator.get_f1(),
|
|
||||||
"coref_p": evaluator.get_precision(),
|
|
||||||
"coref_r": evaluator.get_recall(),
|
|
||||||
}
|
|
||||||
scores.append(score)
|
|
||||||
|
|
||||||
out = {}
|
|
||||||
for field in ("f", "p", "r"):
|
|
||||||
fname = f"coref_{field}"
|
|
||||||
out[fname] = mean([ss[fname] for ss in scores])
|
|
||||||
return out
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user