This commit is contained in:
Kádár Ákos 2022-03-23 11:27:02 +01:00
commit 150e7c46d7

View File

@ -40,11 +40,8 @@ def build_wl_coref_model(
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
# TODO chain tok2vec with these models # TODO chain tok2vec with these models
# TODO fix device - should be automatic
device = "cuda:0"
coref_scorer = PyTorchWrapper( coref_scorer = PyTorchWrapper(
CorefScorer( CorefScorer(
device,
dim, dim,
embedding_size, embedding_size,
hidden_size, hidden_size,
@ -64,7 +61,6 @@ def build_wl_coref_model(
# TODO this was hardcoded to 1024, check # TODO this was hardcoded to 1024, check
hidden_size, hidden_size,
sp_embedding_size, sp_embedding_size,
device
), ),
convert_inputs=convert_span_predictor_inputs convert_inputs=convert_span_predictor_inputs
@ -266,7 +262,6 @@ class CorefScorer(torch.nn.Module):
""" """
def __init__( def __init__(
self, self,
device: str,
dim: int, # tok2vec size dim: int, # tok2vec size
dist_emb_size: int, dist_emb_size: int,
hidden_size: int, hidden_size: int,
@ -283,8 +278,7 @@ class CorefScorer(torch.nn.Module):
epochs_trained (int): the number of epochs finished epochs_trained (int): the number of epochs finished
(useful for warm start) (useful for warm start)
""" """
# device, dist_emb_size, hidden_size, n_layers, dropout_rate self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate).to(device)
#TODO clean this up #TODO clean this up
bert_emb = dim bert_emb = dim
pair_emb = bert_emb * 3 + self.pw.shape pair_emb = bert_emb * 3 + self.pw.shape
@ -293,7 +287,7 @@ class CorefScorer(torch.nn.Module):
hidden_size, hidden_size,
n_layers, n_layers,
dropout_rate dropout_rate
).to(device) )
self.lstm = torch.nn.LSTM( self.lstm = torch.nn.LSTM(
input_size=bert_emb, input_size=bert_emb,
hidden_size=bert_emb, hidden_size=bert_emb,
@ -304,7 +298,7 @@ class CorefScorer(torch.nn.Module):
bert_emb, bert_emb,
dropout_rate, dropout_rate,
roughk roughk
).to(device) )
self.batch_size = batch_size self.batch_size = batch_size
def forward( def forward(
@ -453,7 +447,6 @@ class AnaphoricityScorer(torch.nn.Module):
return out return out
class RoughScorer(torch.nn.Module): class RoughScorer(torch.nn.Module):
""" """
Is needed to give a roughly estimate of the anaphoricity of two candidates, Is needed to give a roughly estimate of the anaphoricity of two candidates,
@ -484,7 +477,6 @@ class RoughScorer(torch.nn.Module):
pair_mask = torch.arange(mentions.shape[0]) pair_mask = torch.arange(mentions.shape[0])
pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0) pair_mask = pair_mask.unsqueeze(1) - pair_mask.unsqueeze(0)
pair_mask = torch.log((pair_mask > 0).to(torch.float)) pair_mask = torch.log((pair_mask > 0).to(torch.float))
pair_mask = pair_mask.to(mentions.device)
bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T) bilinear_scores = self.dropout(self.bilinear(mentions)).mm(mentions.T)
rough_scores = pair_mask + bilinear_scores rough_scores = pair_mask + bilinear_scores
@ -511,7 +503,7 @@ class RoughScorer(torch.nn.Module):
class SpanPredictor(torch.nn.Module): class SpanPredictor(torch.nn.Module):
def __init__(self, input_size: int, distance_emb_size: int, device): def __init__(self, input_size: int, distance_emb_size: int):
super().__init__() super().__init__()
self.ffnn = torch.nn.Sequential( self.ffnn = torch.nn.Sequential(
torch.nn.Linear(input_size * 2 + 64, input_size), torch.nn.Linear(input_size * 2 + 64, input_size),
@ -522,7 +514,6 @@ class SpanPredictor(torch.nn.Module):
torch.nn.Dropout(0.3), torch.nn.Dropout(0.3),
torch.nn.Linear(256, 64), torch.nn.Linear(256, 64),
) )
self.device = device
self.conv = torch.nn.Sequential( self.conv = torch.nn.Sequential(
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(64, 4, 3, 1, 1),
torch.nn.Conv1d(4, 2, 3, 1, 1) torch.nn.Conv1d(4, 2, 3, 1, 1)
@ -600,17 +591,10 @@ class DistancePairwiseEncoder(torch.nn.Module):
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout_rate)
self.shape = emb_size self.shape = emb_size
@property
def device(self) -> torch.device:
""" A workaround to get current device (which is assumed to be the
device of the first parameter of one of the submodules) """
return next(self.distance_emb.parameters()).device
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
top_indices: torch.Tensor top_indices: torch.Tensor
) -> torch.Tensor: ) -> torch.Tensor:
word_ids = torch.arange(0, top_indices.size(0), device=self.device) word_ids = torch.arange(0, top_indices.size(0))
distance = (word_ids.unsqueeze(1) - word_ids[top_indices] distance = (word_ids.unsqueeze(1) - word_ids[top_indices]
).clamp_min_(min=1) ).clamp_min_(min=1)
log_distance = distance.to(torch.float).log2().floor_() log_distance = distance.to(torch.float).log2().floor_()