remove useless extra prefix and device from spanpredictor

This commit is contained in:
Kádár Ákos 2022-03-24 16:44:50 +01:00
parent 1c5dabcb47
commit 83ac0477c8
2 changed files with 20 additions and 31 deletions

View File

@ -55,14 +55,14 @@ def build_wl_coref_model(
)
coref_model = tok2vec >> coref_scorer
# XXX just ignore this until the coref scorer is integrated
span_predictor = PyTorchWrapper(
SpanPredictor(
# span_predictor = PyTorchWrapper(
# SpanPredictor(
# TODO this was hardcoded to 1024, check
hidden_size,
sp_embedding_size,
),
convert_inputs=convert_span_predictor_inputs
)
# hidden_size,
# sp_embedding_size,
# ),
# convert_inputs=convert_span_predictor_inputs
# )
# TODO combine models so output is uniform (just one forward pass)
# It may be reasonable to have an option to disable span prediction,
# and just return words as spans.
@ -88,7 +88,6 @@ def build_span_predictor(
)
# TODO use proper parameter for prefix
head_info = build_get_head_metadata(
"span_coref_head_clusters",
"coref_head_clusters"
)
model = (tok2vec & head_info) >> span_predictor
@ -217,13 +216,10 @@ def _clusterize(
return sorted(clusters)
def build_get_head_metadata(update_prefix, predict_prefix):
def build_get_head_metadata(prefix):
# TODO this name is awful, fix it
model = Model("HeadDataProvider",
attrs={
"update_prefix": update_prefix,
"predict_prefix": predict_prefix
},
attrs={'prefix': prefix},
forward=head_data_forward)
return model
@ -233,10 +229,7 @@ def head_data_forward(model, docs, is_train):
"""
sent_ids = []
head_ids = []
if is_train:
prefix = model.attrs["update_prefix"]
else:
prefix = model.attrs["predict_prefix"]
prefix = model.attrs["prefix"]
for doc in docs:
sids = model.ops.asarray2i(get_sentence_ids(doc))
sent_ids.append(sids)
@ -511,7 +504,7 @@ class RoughScorer(torch.nn.Module):
class SpanPredictor(torch.nn.Module):
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device):
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int):
super().__init__()
# input size = single token size
# 64 = probably distance emb size
@ -551,13 +544,13 @@ class SpanPredictor(torch.nn.Module):
torch.Tensor: span start/end scores, [n_heads, n_words, 2]
"""
# Obtain distance embedding indices, [n_heads, n_words]
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0], device=words.device).unsqueeze(0))
relative_positions = (heads_ids.unsqueeze(1) - torch.arange(words.shape[0]).unsqueeze(0))
# make all valid distances positive
emb_ids = relative_positions + 63
# "too_far"
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
# Obtain "same sentence" boolean mask, [n_heads, n_words]
sent_id = torch.tensor(sent_id, device=words.device)
sent_id = torch.tensor(sent_id)
heads_ids = heads_ids.long()
same_sent = (sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0))
# To save memory, only pass candidates from one sentence for each head
@ -571,18 +564,18 @@ class SpanPredictor(torch.nn.Module):
self.emb(emb_ids[rows, cols]),
), dim=1)
lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max().item(), device=words.device).unsqueeze(0)
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
padding_mask = (padding_mask < lengths.unsqueeze(1)) # [n_heads, max_sent_len]
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size]
# This is necessary to allow the convolution layer to look at several
# word scores
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1], device=words.device)
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
padded_pairs[padding_mask] = pair_matrix
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
res = self.conv(res.permute(0, 2, 1)).permute(0, 2, 1) # [n_heads, n_candidates, 2]
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'), device=words.device)
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float('-inf'))
scores[rows, cols] = res[padding_mask]
# Make sure that start <= head <= end during inference

View File

@ -417,7 +417,6 @@ DEFAULT_SPAN_PREDICTOR_MODEL = Config().from_str(default_span_predictor_config)[
default_config={
"model": DEFAULT_SPAN_PREDICTOR_MODEL,
"input_prefix": "coref_head_clusters",
"target_prefix": "span_head_target_clusters",
"output_prefix": "coref_clusters",
},
default_score_weights={"span_predictor_f": 1.0, "span_predictor_p": None, "span_predictor_r": None},
@ -427,7 +426,6 @@ def make_span_predictor(
name: str,
model,
input_prefix: str = "coref_head_clusters",
target_prefix: str = "span_head_target_clusters",
output_prefix: str = "coref_clusters",
) -> "SpanPredictor":
"""Create a SpanPredictor component."""
@ -446,14 +444,12 @@ class SpanPredictor(TrainablePipe):
name: str = "span_predictor",
*,
input_prefix: str = "coref_head_clusters",
target_prefix: str = "span_coref_head_clusters",
output_prefix: str = "coref_clusters",
) -> None:
self.vocab = vocab
self.model = model
self.name = name
self.input_prefix = input_prefix
self.target_prefix = target_prefix
self.output_prefix = output_prefix
self.cfg = {}
@ -519,8 +515,9 @@ class SpanPredictor(TrainablePipe):
for doc, eg in zip(docs, examples):
# replicates the EntityLinker's behaviour and
# copies annotations over https://bit.ly/3iweDcW
# takes 'coref_head_clusters' from the reference.
for key, sg in eg.reference.spans.items():
if key.startswith(self.target_prefix):
if key.startswith(self.input_prefix):
doc.spans[key] = [doc[span.start:span.end] for span in sg]
span_scores, backprop = self.model.begin_update([doc])
loss, d_scores = self.get_loss([eg], span_scores)
@ -573,7 +570,7 @@ class SpanPredictor(TrainablePipe):
for eg in examples:
# get gold data
gold = doc2clusters(eg.predicted, self.target_prefix)
gold = doc2clusters(eg.predicted, self.input_prefix)
# flatten the gold data
starts = []
ends = []
@ -614,7 +611,6 @@ class SpanPredictor(TrainablePipe):
doc = ex.predicted
assert len(doc) > 2, "Coreference requires at least two tokens"
doc.spans[f"{self.input_prefix}_0"] = [doc[0:1], doc[1:2]]
doc.spans[f"{self.target_prefix}_0"] = [doc[0:1], doc[1:2]]
X.append(ex.predicted)
Y.append(ex.reference)