mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Add argument for CNN maxout pieces
This commit is contained in:
parent
78301b2d29
commit
f5144f04be
11
spacy/_ml.py
11
spacy/_ml.py
|
@ -226,9 +226,9 @@ def drop_layer(layer, factor=2.):
|
|||
return model
|
||||
|
||||
|
||||
def Tok2Vec(width, embed_size, pretrained_dims=0):
|
||||
if pretrained_dims is None:
|
||||
pretrained_dims = 0
|
||||
def Tok2Vec(width, embed_size, pretrained_dims=0, **kwargs):
|
||||
assert pretrained_dims is not None
|
||||
cnn_maxout_pieces = kwargs.get('cnn_maxout_pieces', 3)
|
||||
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE, ORTH]
|
||||
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}):
|
||||
norm = HashEmbed(width, embed_size, column=cols.index(NORM), name='embed_norm')
|
||||
|
@ -244,7 +244,10 @@ def Tok2Vec(width, embed_size, pretrained_dims=0):
|
|||
>> LN(Maxout(width, width*4, pieces=3)), column=5)
|
||||
)
|
||||
)
|
||||
convolution = Residual(ExtractWindow(nW=1) >> LN(Maxout(width, width*3, pieces=3)))
|
||||
convolution = Residual(
|
||||
ExtractWindow(nW=1)
|
||||
>> LN(Maxout(width, width*3, pieces=cnn_maxout_pieces))
|
||||
)
|
||||
|
||||
if pretrained_dims >= 1:
|
||||
embed = concatenate_lists(trained_vectors, SpacyVectors)
|
||||
|
|
Loading…
Reference in New Issue
Block a user