mirror of
https://github.com/explosion/spaCy.git
synced 2025-05-30 02:33:07 +03:00
Fix PyTorch BiLSTM
This commit is contained in:
parent
a26fe8e7bb
commit
afeddfff26
|
@ -29,6 +29,7 @@ from . import util
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import torch.nn
|
import torch.nn
|
||||||
|
from thinc.extra.wrappers import PyTorchWrapperRNN
|
||||||
except:
|
except:
|
||||||
torch = None
|
torch = None
|
||||||
|
|
||||||
|
@ -252,7 +253,7 @@ def link_vectors_to_models(vocab):
|
||||||
|
|
||||||
def PyTorchBiLSTM(nO, nI, depth, dropout=0.2):
|
def PyTorchBiLSTM(nO, nI, depth, dropout=0.2):
|
||||||
if depth == 0:
|
if depth == 0:
|
||||||
return noop()
|
return layerize(noop())
|
||||||
model = torch.nn.LSTM(nI, nO//2, depth, bidirectional=True, dropout=dropout)
|
model = torch.nn.LSTM(nI, nO//2, depth, bidirectional=True, dropout=dropout)
|
||||||
return with_square_sequences(PyTorchWrapperRNN(model))
|
return with_square_sequences(PyTorchWrapperRNN(model))
|
||||||
|
|
||||||
|
@ -299,7 +300,6 @@ def Tok2Vec(width, embed_size, **kwargs):
|
||||||
ExtractWindow(nW=1)
|
ExtractWindow(nW=1)
|
||||||
>> LN(Maxout(width, width*3, pieces=cnn_maxout_pieces))
|
>> LN(Maxout(width, width*3, pieces=cnn_maxout_pieces))
|
||||||
)
|
)
|
||||||
|
|
||||||
tok2vec = (
|
tok2vec = (
|
||||||
FeatureExtracter(cols)
|
FeatureExtracter(cols)
|
||||||
>> with_flatten(
|
>> with_flatten(
|
||||||
|
|
Loading…
Reference in New Issue
Block a user