merge SpanPredictor attributes

This commit is contained in:
Kádár Ákos 2022-03-24 16:23:12 +01:00
parent a872c69ffb
commit 1c5dabcb47

View File

@ -83,7 +83,7 @@ def build_span_predictor(
with Model.define_operators({">>": chain, "&": tuplify}): with Model.define_operators({">>": chain, "&": tuplify}):
span_predictor = PyTorchWrapper( span_predictor = PyTorchWrapper(
SpanPredictor(dim, dist_emb_size), SpanPredictor(dim, hidden_size, dist_emb_size),
convert_inputs=convert_span_predictor_inputs convert_inputs=convert_span_predictor_inputs
) )
# TODO use proper parameter for prefix # TODO use proper parameter for prefix
@ -511,11 +511,7 @@ class RoughScorer(torch.nn.Module):
class SpanPredictor(torch.nn.Module): class SpanPredictor(torch.nn.Module):
<<<<<<< HEAD
def __init__(self, input_size: int, distance_emb_size: int):
=======
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device): def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device):
>>>>>>> eec00ce60d83f500e18f2da7d9feafa7143440f2
super().__init__() super().__init__()
# input size = single token size # input size = single token size
# 64 = probably distance emb size # 64 = probably distance emb size