This commit is contained in:
Kádár Ákos 2022-03-24 16:10:04 +01:00
commit a872c69ffb

View File

@ -511,22 +511,31 @@ class RoughScorer(torch.nn.Module):
class SpanPredictor(torch.nn.Module): class SpanPredictor(torch.nn.Module):
<<<<<<< HEAD
def __init__(self, input_size: int, distance_emb_size: int): def __init__(self, input_size: int, distance_emb_size: int):
=======
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device):
>>>>>>> eec00ce60d83f500e18f2da7d9feafa7143440f2
super().__init__() super().__init__()
# input size = single token size
# 64 = probably distance emb size
# TODO check that dist_emb_size use is correct
self.ffnn = torch.nn.Sequential( self.ffnn = torch.nn.Sequential(
torch.nn.Linear(input_size * 2 + 64, input_size), torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
torch.nn.ReLU(), torch.nn.ReLU(),
torch.nn.Dropout(0.3), torch.nn.Dropout(0.3),
torch.nn.Linear(input_size, 256), #TODO seems weird the 256 isn't a parameter???
torch.nn.Linear(hidden_size, 256),
torch.nn.ReLU(), torch.nn.ReLU(),
torch.nn.Dropout(0.3), torch.nn.Dropout(0.3),
torch.nn.Linear(256, 64), # this use of dist_emb_size looks wrong but it was 64...?
torch.nn.Linear(256, dist_emb_size),
) )
self.conv = torch.nn.Sequential( self.conv = torch.nn.Sequential(
torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(64, 4, 3, 1, 1),
torch.nn.Conv1d(4, 2, 3, 1, 1) torch.nn.Conv1d(4, 2, 3, 1, 1)
) )
self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
sent_id, sent_id,