mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Fix various sizes in SpanPredictor FFNN
This commit is contained in:
parent
2190cbc0e6
commit
eec00ce60d
|
@ -58,17 +58,6 @@ def build_wl_coref_model(
|
|||
)
|
||||
|
||||
coref_model = tok2vec >> coref_scorer
|
||||
# XXX just ignore this until the coref scorer is integrated
|
||||
span_predictor = PyTorchWrapper(
|
||||
SpanPredictor(
|
||||
# TODO this was hardcoded to 1024, check
|
||||
hidden_size,
|
||||
sp_embedding_size,
|
||||
device
|
||||
),
|
||||
|
||||
convert_inputs=convert_span_predictor_inputs
|
||||
)
|
||||
# TODO combine models so output is uniform (just one forward pass)
|
||||
# It may be reasonable to have an option to disable span prediction,
|
||||
# and just return words as spans.
|
||||
|
@ -91,7 +80,7 @@ def build_span_predictor(
|
|||
# TODO fix device - should be automatic
|
||||
device = "cuda:0"
|
||||
span_predictor = PyTorchWrapper(
|
||||
SpanPredictor(hidden_size, dist_emb_size, device),
|
||||
SpanPredictor(dim, hidden_size, dist_emb_size, device),
|
||||
convert_inputs=convert_span_predictor_inputs
|
||||
)
|
||||
# TODO use proper parameter for prefix
|
||||
|
@ -512,23 +501,28 @@ class RoughScorer(torch.nn.Module):
|
|||
|
||||
|
||||
class SpanPredictor(torch.nn.Module):
|
||||
def __init__(self, input_size: int, distance_emb_size: int, device):
|
||||
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device):
|
||||
super().__init__()
|
||||
# input size = single token size
|
||||
# 64 = probably distance emb size
|
||||
# TODO check that dist_emb_size use is correct
|
||||
self.ffnn = torch.nn.Sequential(
|
||||
torch.nn.Linear(input_size * 2 + 64, input_size),
|
||||
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.3),
|
||||
torch.nn.Linear(input_size, 256),
|
||||
#TODO seems weird the 256 isn't a parameter???
|
||||
torch.nn.Linear(hidden_size, 256),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Dropout(0.3),
|
||||
torch.nn.Linear(256, 64),
|
||||
# this use of dist_emb_size looks wrong but it was 64...?
|
||||
torch.nn.Linear(256, dist_emb_size),
|
||||
)
|
||||
self.device = device
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(64, 4, 3, 1, 1),
|
||||
torch.nn.Conv1d(4, 2, 3, 1, 1)
|
||||
)
|
||||
self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far
|
||||
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
|
||||
|
||||
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||
sent_id,
|
||||
|
|
Loading…
Reference in New Issue
Block a user