From eec00ce60d83f500e18f2da7d9feafa7143440f2 Mon Sep 17 00:00:00 2001 From: Paul O'Leary McCann Date: Wed, 23 Mar 2022 16:20:31 +0900 Subject: [PATCH] Fix various sizes in SpanPredictor FFNN --- spacy/ml/models/coref.py | 28 +++++++++++----------------- 1 file changed, 11 insertions(+), 17 deletions(-) diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 382d7a98b..0f1614ef5 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -58,17 +58,6 @@ def build_wl_coref_model( ) coref_model = tok2vec >> coref_scorer - # XXX just ignore this until the coref scorer is integrated - span_predictor = PyTorchWrapper( - SpanPredictor( - # TODO this was hardcoded to 1024, check - hidden_size, - sp_embedding_size, - device - ), - - convert_inputs=convert_span_predictor_inputs - ) # TODO combine models so output is uniform (just one forward pass) # It may be reasonable to have an option to disable span prediction, # and just return words as spans. @@ -91,7 +80,7 @@ def build_span_predictor( # TODO fix device - should be automatic device = "cuda:0" span_predictor = PyTorchWrapper( - SpanPredictor(hidden_size, dist_emb_size, device), + SpanPredictor(dim, hidden_size, dist_emb_size, device), convert_inputs=convert_span_predictor_inputs ) # TODO use proper parameter for prefix @@ -512,23 +501,28 @@ class RoughScorer(torch.nn.Module): class SpanPredictor(torch.nn.Module): - def __init__(self, input_size: int, distance_emb_size: int, device): + def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device): super().__init__() + # input size = single token size + # 64 = probably distance emb size + # TODO check that dist_emb_size use is correct self.ffnn = torch.nn.Sequential( - torch.nn.Linear(input_size * 2 + 64, input_size), + torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size), torch.nn.ReLU(), torch.nn.Dropout(0.3), - torch.nn.Linear(input_size, 256), + #TODO seems weird the 256 isn't a parameter??? + torch.nn.Linear(hidden_size, 256), torch.nn.ReLU(), torch.nn.Dropout(0.3), - torch.nn.Linear(256, 64), + # this use of dist_emb_size looks wrong but it was 64...? + torch.nn.Linear(256, dist_emb_size), ) self.device = device self.conv = torch.nn.Sequential( torch.nn.Conv1d(64, 4, 3, 1, 1), torch.nn.Conv1d(4, 2, 3, 1, 1) ) - self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far + self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch sent_id,