mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Support option of BiLSTM in Tok2Vec (requires pytorch)
This commit is contained in:
parent
3eb9f3e2b8
commit
45032fe9e1
14
spacy/_ml.py
14
spacy/_ml.py
|
@ -11,6 +11,7 @@ from thinc.misc import LayerNorm as LN
|
|||
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
||||
from thinc.api import FeatureExtracter, with_getitem, flatten_add_lengths
|
||||
from thinc.api import uniqued, wrap, noop
|
||||
from thinc.api import with_square_sequences
|
||||
from thinc.linear.linear import LinearModel
|
||||
from thinc.neural.ops import NumpyOps, CupyOps
|
||||
from thinc.neural.util import get_array_module, copy_array
|
||||
|
@ -26,6 +27,10 @@ from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE
|
|||
from .errors import Errors
|
||||
from . import util
|
||||
|
||||
try:
|
||||
import torch.nn
|
||||
except:
|
||||
torch = None
|
||||
|
||||
VECTORS_KEY = 'spacy_pretrained_vectors'
|
||||
|
||||
|
@ -245,11 +250,19 @@ def link_vectors_to_models(vocab):
|
|||
thinc.extra.load_nlp.VECTORS[(ops.device, vectors.name)] = data
|
||||
|
||||
|
||||
def PyTorchBiLSTM(nO, nI, depth, dropout=0.2):
|
||||
if depth == 0:
|
||||
return noop()
|
||||
model = torch.nn.LSTM(nI, nO//2, depth, bidirectional=True, dropout=dropout)
|
||||
return with_square_sequences(PyTorchWrapperRNN(model))
|
||||
|
||||
|
||||
def Tok2Vec(width, embed_size, **kwargs):
|
||||
pretrained_vectors = kwargs.get('pretrained_vectors', None)
|
||||
cnn_maxout_pieces = kwargs.get('cnn_maxout_pieces', 2)
|
||||
subword_features = kwargs.get('subword_features', True)
|
||||
conv_depth = kwargs.get('conv_depth', 4)
|
||||
bilstm_depth = kwargs.get('bilstm_depth', 0)
|
||||
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone,
|
||||
'+': add, '*': reapply}):
|
||||
|
@ -293,6 +306,7 @@ def Tok2Vec(width, embed_size, **kwargs):
|
|||
embed
|
||||
>> convolution ** conv_depth, pad=conv_depth
|
||||
)
|
||||
>> PyTorchBiLSTM(width, width, bilstm_depth)
|
||||
)
|
||||
# Work around thinc API limitations :(. TODO: Revise in Thinc 7
|
||||
tok2vec.nO = width
|
||||
|
|
Loading…
Reference in New Issue
Block a user