mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 20:22:25 +03:00
merge
This commit is contained in:
commit
a872c69ffb
|
@ -511,22 +511,31 @@ class RoughScorer(torch.nn.Module):
|
||||||
|
|
||||||
|
|
||||||
class SpanPredictor(torch.nn.Module):
|
class SpanPredictor(torch.nn.Module):
|
||||||
|
<<<<<<< HEAD
|
||||||
def __init__(self, input_size: int, distance_emb_size: int):
|
def __init__(self, input_size: int, distance_emb_size: int):
|
||||||
|
=======
|
||||||
|
def __init__(self, input_size: int, hidden_size: int, dist_emb_size: int, device):
|
||||||
|
>>>>>>> eec00ce60d83f500e18f2da7d9feafa7143440f2
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
# input size = single token size
|
||||||
|
# 64 = probably distance emb size
|
||||||
|
# TODO check that dist_emb_size use is correct
|
||||||
self.ffnn = torch.nn.Sequential(
|
self.ffnn = torch.nn.Sequential(
|
||||||
torch.nn.Linear(input_size * 2 + 64, input_size),
|
torch.nn.Linear(input_size * 2 + dist_emb_size, hidden_size),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Dropout(0.3),
|
torch.nn.Dropout(0.3),
|
||||||
torch.nn.Linear(input_size, 256),
|
#TODO seems weird the 256 isn't a parameter???
|
||||||
|
torch.nn.Linear(hidden_size, 256),
|
||||||
torch.nn.ReLU(),
|
torch.nn.ReLU(),
|
||||||
torch.nn.Dropout(0.3),
|
torch.nn.Dropout(0.3),
|
||||||
torch.nn.Linear(256, 64),
|
# this use of dist_emb_size looks wrong but it was 64...?
|
||||||
|
torch.nn.Linear(256, dist_emb_size),
|
||||||
)
|
)
|
||||||
self.conv = torch.nn.Sequential(
|
self.conv = torch.nn.Sequential(
|
||||||
torch.nn.Conv1d(64, 4, 3, 1, 1),
|
torch.nn.Conv1d(64, 4, 3, 1, 1),
|
||||||
torch.nn.Conv1d(4, 2, 3, 1, 1)
|
torch.nn.Conv1d(4, 2, 3, 1, 1)
|
||||||
)
|
)
|
||||||
self.emb = torch.nn.Embedding(128, distance_emb_size) # [-63, 63] + too_far
|
self.emb = torch.nn.Embedding(128, dist_emb_size) # [-63, 63] + too_far
|
||||||
|
|
||||||
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
def forward(self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
|
||||||
sent_id,
|
sent_id,
|
||||||
|
|
Loading…
Reference in New Issue
Block a user