mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-03 21:24:11 +03:00
Fix PyTorch BiLSTM
This commit is contained in:
parent
a26fe8e7bb
commit
afeddfff26
|
@ -29,6 +29,7 @@ from . import util
|
|||
|
||||
try:
|
||||
import torch.nn
|
||||
from thinc.extra.wrappers import PyTorchWrapperRNN
|
||||
except:
|
||||
torch = None
|
||||
|
||||
|
@ -252,7 +253,7 @@ def link_vectors_to_models(vocab):
|
|||
|
||||
def PyTorchBiLSTM(nO, nI, depth, dropout=0.2):
|
||||
if depth == 0:
|
||||
return noop()
|
||||
return layerize(noop())
|
||||
model = torch.nn.LSTM(nI, nO//2, depth, bidirectional=True, dropout=dropout)
|
||||
return with_square_sequences(PyTorchWrapperRNN(model))
|
||||
|
||||
|
@ -299,7 +300,6 @@ def Tok2Vec(width, embed_size, **kwargs):
|
|||
ExtractWindow(nW=1)
|
||||
>> LN(Maxout(width, width*3, pieces=cnn_maxout_pieces))
|
||||
)
|
||||
|
||||
tok2vec = (
|
||||
FeatureExtracter(cols)
|
||||
>> with_flatten(
|
||||
|
|
Loading…
Reference in New Issue
Block a user