Support option of BiLSTM in Tok2Vec (requires pytorch)

This commit is contained in:
Matthew Honnibal 2018-09-13 19:28:35 +02:00
parent 3eb9f3e2b8
commit 45032fe9e1

View File

@ -11,6 +11,7 @@ from thinc.misc import LayerNorm as LN
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
from thinc.api import FeatureExtracter, with_getitem, flatten_add_lengths from thinc.api import FeatureExtracter, with_getitem, flatten_add_lengths
from thinc.api import uniqued, wrap, noop from thinc.api import uniqued, wrap, noop
from thinc.api import with_square_sequences
from thinc.linear.linear import LinearModel from thinc.linear.linear import LinearModel
from thinc.neural.ops import NumpyOps, CupyOps from thinc.neural.ops import NumpyOps, CupyOps
from thinc.neural.util import get_array_module, copy_array from thinc.neural.util import get_array_module, copy_array
@ -26,6 +27,10 @@ from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE
from .errors import Errors from .errors import Errors
from . import util from . import util
try:
import torch.nn
except:
torch = None
VECTORS_KEY = 'spacy_pretrained_vectors' VECTORS_KEY = 'spacy_pretrained_vectors'
@ -245,11 +250,19 @@ def link_vectors_to_models(vocab):
thinc.extra.load_nlp.VECTORS[(ops.device, vectors.name)] = data thinc.extra.load_nlp.VECTORS[(ops.device, vectors.name)] = data
def PyTorchBiLSTM(nO, nI, depth, dropout=0.2):
if depth == 0:
return noop()
model = torch.nn.LSTM(nI, nO//2, depth, bidirectional=True, dropout=dropout)
return with_square_sequences(PyTorchWrapperRNN(model))
def Tok2Vec(width, embed_size, **kwargs): def Tok2Vec(width, embed_size, **kwargs):
pretrained_vectors = kwargs.get('pretrained_vectors', None) pretrained_vectors = kwargs.get('pretrained_vectors', None)
cnn_maxout_pieces = kwargs.get('cnn_maxout_pieces', 2) cnn_maxout_pieces = kwargs.get('cnn_maxout_pieces', 2)
subword_features = kwargs.get('subword_features', True) subword_features = kwargs.get('subword_features', True)
conv_depth = kwargs.get('conv_depth', 4) conv_depth = kwargs.get('conv_depth', 4)
bilstm_depth = kwargs.get('bilstm_depth', 0)
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH] cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, with Model.define_operators({'>>': chain, '|': concatenate, '**': clone,
'+': add, '*': reapply}): '+': add, '*': reapply}):
@ -293,6 +306,7 @@ def Tok2Vec(width, embed_size, **kwargs):
embed embed
>> convolution ** conv_depth, pad=conv_depth >> convolution ** conv_depth, pad=conv_depth
) )
>> PyTorchBiLSTM(width, width, bilstm_depth)
) )
# Work around thinc API limitations :(. TODO: Revise in Thinc 7 # Work around thinc API limitations :(. TODO: Revise in Thinc 7
tok2vec.nO = width tok2vec.nO = width