Fix PyTorch BiLSTM

This commit is contained in:
Matthew Honnibal 2018-09-13 22:54:34 +00:00
parent a26fe8e7bb
commit afeddfff26

View File

@ -29,6 +29,7 @@ from . import util
try:
import torch.nn
from thinc.extra.wrappers import PyTorchWrapperRNN
except:
torch = None
@ -252,7 +253,7 @@ def link_vectors_to_models(vocab):
def PyTorchBiLSTM(nO, nI, depth, dropout=0.2):
if depth == 0:
return noop()
return layerize(noop())
model = torch.nn.LSTM(nI, nO//2, depth, bidirectional=True, dropout=dropout)
return with_square_sequences(PyTorchWrapperRNN(model))
@ -299,7 +300,6 @@ def Tok2Vec(width, embed_size, **kwargs):
ExtractWindow(nW=1)
>> LN(Maxout(width, width*3, pieces=cnn_maxout_pieces))
)
tok2vec = (
FeatureExtracter(cols)
>> with_flatten(