diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 3350a8dd9..5fbb64a29 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -511,22 +511,31 @@ class RoughScorer(torch.nn.Module): class SpanPredictor(torch.nn.Module): +<<<<<<< HEAD def __init__(self, input_size: int, distance_emb_size: int): +======= + def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device): +>>>>>>> eec00ce60d83f500e18f2da7d9feafa7143440f2 super().__init__() + # input size = single token size + # 64 = probably distance emb size + # TODO check that dist_emb_size use is correct self.ffnn = torch.nn.Sequential( - torch.nn.Linear(input_size * 2 + 64, input_size), + torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size), torch.nn.ReLU(), torch.nn.Dropout(0.3), - torch.nn.Linear(input_size, 256), + #TODO seems weird the 256 isn't a parameter??? + torch.nn.Linear(hidden_size, 256), torch.nn.ReLU(), torch.nn.Dropout(0.3), - torch.nn.Linear(256, 64), + # this use of dist_emb_size looks wrong but it was 64...? + torch.nn.Linear(256, dist_emb_size), ) self.conv = torch.nn.Sequential( torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1) ) - self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far + self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch sent_id,