merge misery

This commit is contained in:
kadarakos 2022-05-10 17:19:16 +00:00
commit 7cf6bcca0e
3 changed files with 47 additions and 37 deletions

View File

@ -3,7 +3,7 @@ import torch
from thinc.api import Model, chain from thinc.api import Model, chain
from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints2d from thinc.types import Floats2d, Ints2d, Ints1d
from thinc.util import xp2torch, torch2xp from thinc.util import xp2torch, torch2xp
from ...tokens import Doc from ...tokens import Doc
@ -50,7 +50,11 @@ def build_wl_coref_model(
return coref_model return coref_model
def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool): def convert_coref_scorer_inputs(
model: Model,
X: List[Floats2d],
is_train: bool
):
# The input here is List[Floats2d], one for each doc # The input here is List[Floats2d], one for each doc
# just use the first # just use the first
# TODO real batching # TODO real batching
@ -62,7 +66,7 @@ def convert_coref_scorer_inputs(model: Model, X: List[Floats2d], is_train: bool)
gradients = torch2xp(args.args[0]) gradients = torch2xp(args.args[0])
return [gradients] return [gradients]
return ArgsKwargs(args=(word_features,), kwargs={}), backprop return ArgsKwargs(args=(word_features, ), kwargs={}), backprop
def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool): def convert_coref_scorer_outputs(model: Model, inputs_outputs, is_train: bool):
@ -108,11 +112,19 @@ class CorefScorer(torch.nn.Module):
n_layers: Numbers of layers in the AnaphoricityScorer. n_layers: Numbers of layers in the AnaphoricityScorer.
dropout_rate: Dropout probability to apply across all modules. dropout_rate: Dropout probability to apply across all modules.
roughk: Number of candidates the RoughScorer returns. roughk: Number of candidates the RoughScorer returns.
batch_size: Internal batch-size for the more expensive AnaphoricityScorer. batch_size: Internal batch-size for the more expensive scorer.
""" """
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout_rate)
self.batch_size = batch_size self.batch_size = batch_size
# Modules # Modules
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer(
pair_emb,
hidden_size,
n_layers,
dropout_rate
)
self.lstm = torch.nn.LSTM( self.lstm = torch.nn.LSTM(
input_size=dim, input_size=dim,
hidden_size=dim, hidden_size=dim,
@ -125,11 +137,13 @@ class CorefScorer(torch.nn.Module):
pair_emb, hidden_size, n_layers, dropout_rate pair_emb, hidden_size, n_layers, dropout_rate
) )
def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(
self, word_features: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
1. LSTM encodes the incoming word_features. 1. LSTM encodes the incoming word_features.
2. The RoughScorer scores and prunes the candidates. 2. The RoughScorer scores and prunes the candidates.
3. The DistancePairwiseEncoder embeds the distance between remaning pairs. 3. The DistancePairwiseEncoder embeds the distances between pairs.
4. The AnaphoricityScorer scores all pairs in mini-batches. 4. The AnaphoricityScorer scores all pairs in mini-batches.
word_features: torch.Tensor containing word encodings word_features: torch.Tensor containing word encodings
@ -299,6 +313,7 @@ class RoughScorer(torch.nn.Module):
top_scores, indices = torch.topk( top_scores, indices = torch.topk(
rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False rough_scores, k=min(self.k, len(rough_scores)), dim=1, sorted=False
) )
return top_scores, indices return top_scores, indices
@ -324,10 +339,11 @@ class DistancePairwiseEncoder(torch.nn.Module):
def forward( def forward(
self, self,
top_indices: torch.Tensor, top_indices: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
word_ids = torch.arange(0, top_indices.size(0)) word_ids = torch.arange(0, top_indices.size(0))
distance = (word_ids.unsqueeze(1) - word_ids[top_indices]).clamp_min_(min=1) distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
).clamp_min_(min=1)
log_distance = distance.to(torch.float).log2().floor_() log_distance = distance.to(torch.float).log2().floor_()
log_distance = log_distance.clamp_max_(max=6).to(torch.long) log_distance = log_distance.clamp_max_(max=6).to(torch.long)
distance = torch.where(distance < 5, distance - 1, log_distance + 2) distance = torch.where(distance < 5, distance - 1, log_distance + 2)

View File

@ -3,7 +3,7 @@ import torch
from thinc.api import Model, chain, tuplify from thinc.api import Model, chain, tuplify
from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d, Ints1d, Ints2d from thinc.types import Floats2d, Ints1d
from thinc.util import xp2torch, torch2xp from thinc.util import xp2torch, torch2xp
from ...tokens import Doc from ...tokens import Doc
@ -40,10 +40,9 @@ def convert_span_predictor_inputs(
model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool model: Model, X: Tuple[Ints1d, Floats2d, Ints1d], is_train: bool
): ):
tok2vec, (sent_ids, head_ids) = X tok2vec, (sent_ids, head_ids) = X
# Normally we shoudl use the input is_train, but for these two it's not relevant # Normally we should use the input is_train, but for these two it's not relevant
def backprop(args: ArgsKwargs) -> List[Floats2d]: def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list
gradients = torch2xp(args.args[1]) gradients = torch2xp(args.args[1])
return [[gradients], None] return [[gradients], None]
@ -55,7 +54,6 @@ def convert_span_predictor_inputs(
head_ids = xp2torch(head_ids[0], requires_grad=False) head_ids = xp2torch(head_ids[0], requires_grad=False)
argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={}) argskwargs = ArgsKwargs(args=(sent_ids, word_features, head_ids), kwargs={})
# TODO actually support backprop
return argskwargs, backprop return argskwargs, backprop
@ -66,15 +64,13 @@ def predict_span_clusters(
""" """
Predicts span clusters based on the word clusters. Predicts span clusters based on the word clusters.
Args: span_predictor: a SpanPredictor instance
doc (Doc): the document data sent_ids: For each word indicates, which sentence it appears in.
words (torch.Tensor): [n_words, emb_size] matrix containing words: Features for words.
embeddings for each of the words in the text clusters: Clusters inferred by the CorefScorer.
clusters (List[List[int]]): a list of clusters where each cluster
is a list of word indices
Returns: Returns:
List[List[Span]]: span clusters List[List[Tuple[int, int]]: span clusters
""" """
if not clusters: if not clusters:
return [] return []
@ -141,29 +137,29 @@ class SpanPredictor(torch.nn.Module):
# this use of dist_emb_size looks wrong but it was 64...? # this use of dist_emb_size looks wrong but it was 64...?
torch.nn.Linear(256, dist_emb_size), torch.nn.Linear(256, dist_emb_size),
) )
# TODO make the Convs also parametrizeable
self.conv = torch.nn.Sequential( self.conv = torch.nn.Sequential(
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1) torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1)
) )
# TODO make embeddings size a parameter
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
def forward( def forward(
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch self,
sent_id, sent_id,
words: torch.Tensor, words: torch.Tensor,
heads_ids: torch.Tensor, heads_ids: torch.Tensor,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
Calculates span start/end scores of words for each span head in Calculates span start/end scores of words for each span
heads_ids for each head.
Args: sent_id: Sentence id of each word.
doc (Doc): the document data words: features for each word in the document.
words (torch.Tensor): contextual embeddings for each word in the heads_ids: word indices of span heads
document, [n_words, emb_size]
heads_ids (torch.Tensor): word indices of span heads
Returns: Returns:
torch.Tensor: span start/end scores, [n_heads, n_words, 2] torch.Tensor: span start/end scores, (n_heads x n_words x 2)
""" """
# If we don't receive heads, return empty # If we don't receive heads, return empty
if heads_ids.nelement() == 0: if heads_ids.nelement() == 0:
@ -176,13 +172,13 @@ class SpanPredictor(torch.nn.Module):
emb_ids = relative_positions + 63 emb_ids = relative_positions + 63
# "too_far" # "too_far"
emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127 emb_ids[(emb_ids < 0) + (emb_ids > 126)] = 127
# Obtain "same sentence" boolean mask, [n_heads, n_words] # Obtain "same sentence" boolean mask: (n_heads x n_words)
heads_ids = heads_ids.long() heads_ids = heads_ids.long()
same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0) same_sent = sent_id[heads_ids].unsqueeze(1) == sent_id.unsqueeze(0)
# To save memory, only pass candidates from one sentence for each head # To save memory, only pass candidates from one sentence for each head
# pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb # pair_matrix contains concatenated span_head_emb + candidate_emb + distance_emb
# for each candidate among the words in the same sentence as span_head # for each candidate among the words in the same sentence as span_head
# [n_heads, input_size * 2 + distance_emb_size] # (n_heads x input_size * 2 x distance_emb_size)
rows, cols = same_sent.nonzero(as_tuple=True) rows, cols = same_sent.nonzero(as_tuple=True)
pair_matrix = torch.cat( pair_matrix = torch.cat(
( (
@ -194,17 +190,17 @@ class SpanPredictor(torch.nn.Module):
) )
lengths = same_sent.sum(dim=1) lengths = same_sent.sum(dim=1)
padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0) padding_mask = torch.arange(0, lengths.max().item()).unsqueeze(0)
padding_mask = padding_mask < lengths.unsqueeze(1) # [n_heads, max_sent_len] # (n_heads x max_sent_len)
# [n_heads, max_sent_len, input_size * 2 + distance_emb_size] padding_mask = padding_mask < lengths.unsqueeze(1)
# (n_heads x max_sent_len x input_size * 2 + distance_emb_size)
# This is necessary to allow the convolution layer to look at several # This is necessary to allow the convolution layer to look at several
# word scores # word scores
padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1]) padded_pairs = torch.zeros(*padding_mask.shape, pair_matrix.shape[-1])
padded_pairs[padding_mask] = pair_matrix padded_pairs[padding_mask] = pair_matrix
res = self.ffnn(padded_pairs) # (n_heads x n_candidates x last_layer_output)
res = self.ffnn(padded_pairs) # [n_heads, n_candidates, last_layer_output]
res = self.conv(res.permute(0, 2, 1)).permute( res = self.conv(res.permute(0, 2, 1)).permute(
0, 2, 1 0, 2, 1
) # [n_heads, n_candidates, 2] ) # (n_heads x n_candidates, 2)
scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf")) scores = torch.full((heads_ids.shape[0], words.shape[0], 2), float("-inf"))
scores[rows, cols] = res[padding_mask] scores[rows, cols] = res[padding_mask]

View File

@ -350,9 +350,7 @@ class CoreferenceResolver(TrainablePipe):
def score(self, examples, **kwargs): def score(self, examples, **kwargs):
"""Score a batch of examples using LEA. """Score a batch of examples using LEA.
For details on how LEA works and why to use it see the paper: For details on how LEA works and why to use it see the paper:
Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric Which Coreference Evaluation Metric Do You Trust? A Proposal for a Link-based Entity Aware Metric
Moosavi and Strube, 2016 Moosavi and Strube, 2016
https://api.semanticscholar.org/CorpusID:17606580 https://api.semanticscholar.org/CorpusID:17606580