mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Fix various sizes in SpanPredictor FFNN
This commit is contained in:
parent
2190cbc0e6
commit
eec00ce60d
|
@ -58,17 +58,6 @@ def build_wl_coref_model(
|
||||||
)
|
)
|
||||||
|
|
||||||
coref_model = tok2vec >> coref_scorer
|
coref_model = tok2vec >> coref_scorer
|
||||||
# XXX just ignore this until the coref scorer is integrated
|
|
||||||
span_predictor = PyTorchWrapper(
|
|
||||||
SpanPredictor(
|
|
||||||
# TODO this was hardcoded to 1024, check
|
|
||||||
hidden_size,
|
|
||||||
sp_embedding_size,
|
|
||||||
device
|
|
||||||
),
|
|
||||||
|
|
||||||
convert_inputs=convert_span_predictor_inputs
|
|
||||||
)
|
|
||||||
# TODO combine models so output is uniform (just one forward pass)
|
# TODO combine models so output is uniform (just one forward pass)
|
||||||
# It may be reasonable to have an option to disable span prediction,
|
# It may be reasonable to have an option to disable span prediction,
|
||||||
# and just return words as spans.
|
# and just return words as spans.
|
||||||
|
@ -91,7 +80,7 @@ def build_span_predictor(
|
||||||
# TODO fix device - should be automatic
|
# TODO fix device - should be automatic
|
||||||
device = "cuda:0"
|
device = "cuda:0"
|
||||||
span_predictor = PyTorchWrapper(
|
span_predictor = PyTorchWrapper(
|
||||||
SpanPredictor(hidden_size, dist_emb_size, device),
|
SpanPredictor(dim, hidden_size, dist_emb_size, device),
|
||||||
convert_inputs=convert_span_predictor_inputs
|
convert_inputs=convert_span_predictor_inputs
|
||||||
)
|
)
|
||||||
# TODO use proper parameter for prefix
|
# TODO use proper parameter for prefix
|
||||||
|
@ -512,23 +501,28 @@ class RoughScorer(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class SpanPredictor(torch.nn.Module):
|
class SpanPredictor(torch.nn.Module):
|
||||||
def __init__(self, input_size: int, distance_emb_size: int, device):
|
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# input size = single token size
|
||||||
|
# 64 = probably distance emb size
|
||||||
|
# TODO check that dist_emb_size use is correct
|
||||||
self.ffnn = torch.nn.Sequential(
|
self.ffnn = torch.nn.Sequential(
|
||||||
torch.nn.Linear(input_size * 2 + 64, input_size),
|
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Dropout(0.3),
|
torch.nn.Dropout(0.3),
|
||||||
torch.nn.Linear(input_size, 256),
|
#TODO seems weird the 256 isn't a parameter???
|
||||||
|
torch.nn.Linear(hidden_size, 256),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Dropout(0.3),
|
torch.nn.Dropout(0.3),
|
||||||
torch.nn.Linear(256, 64),
|
# this use of dist_emb_size looks wrong but it was 64...?
|
||||||
|
torch.nn.Linear(256, dist_emb_size),
|
||||||
)
|
)
|
||||||
self.device = device
|
self.device = device
|
||||||
self.conv = torch.nn.Sequential(
|
self.conv = torch.nn.Sequential(
|
||||||
torch.nn.Conv1d(64, 4, 3, 1, 1),
|
torch.nn.Conv1d(64, 4, 3, 1, 1),
|
||||||
torch.nn.Conv1d(4, 2, 3, 1, 1)
|
torch.nn.Conv1d(4, 2, 3, 1, 1)
|
||||||
)
|
)
|
||||||
self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far
|
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
|
||||||
|
|
||||||
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||||
sent_id,
|
sent_id,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user