Fix various sizes in SpanPredictor FFNN

This commit is contained in:
Paul O'Leary McCann 2022-03-23 16:20:31 +09:00
parent 2190cbc0e6
commit eec00ce60d

View File

@ -58,17 +58,6 @@ def build_wl_coref_model(
)
coref_model = tok2vec >> coref_scorer
# XXX just ignore this until the coref scorer is integrated
span_predictor = PyTorchWrapper(
SpanPredictor(
# TODO this was hardcoded to 1024, check
hidden_size,
sp_embedding_size,
device
),
convert_inputs=convert_span_predictor_inputs
)
# TODO combine models so output is uniform (just one forward pass)
# It may be reasonable to have an option to disable span prediction,
# and just return words as spans.
@ -91,7 +80,7 @@ def build_span_predictor(
# TODO fix device - should be automatic
device = "cuda:0"
span_predictor = PyTorchWrapper(
SpanPredictor(hidden_size, dist_emb_size, device),
SpanPredictor(dim, hidden_size, dist_emb_size, device),
convert_inputs=convert_span_predictor_inputs
)
# TODO use proper parameter for prefix
@ -512,23 +501,28 @@ class RoughScorer(torch.nn.Module):
class SpanPredictor(torch.nn.Module):
def __init__(self, input_size: int, distance_emb_size: int, device):
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device):
super().__init__()
# input size = single token size
# 64 = probably distance emb size
# TODO check that dist_emb_size use is correct
self.ffnn = torch.nn.Sequential(
torch.nn.Linear(input_size * 2 + 64, input_size),
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(input_size, 256),
#TODO seems weird the 256 isn't a parameter???
torch.nn.Linear(hidden_size, 256),
torch.nn.ReLU(),
torch.nn.Dropout(0.3),
torch.nn.Linear(256, 64),
# this use of dist_emb_size looks wrong but it was 64...?
torch.nn.Linear(256, dist_emb_size),
)
self.device = device
self.conv = torch.nn.Sequential(
torch.nn.Conv1d(64, 4, 3, 1, 1),
torch.nn.Conv1d(4, 2, 3, 1, 1)
)
self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
sent_id,