mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
Restore tok2vec function
This commit is contained in:
parent
efe9630e1c
commit
fa7c1990b6
36
spacy/_ml.py
36
spacy/_ml.py
|
@ -1,7 +1,11 @@
|
|||
from thinc.api import layerize, chain, clone
|
||||
from thinc.api import layerize, chain, clone, concatenate
|
||||
from thinc.neural import Model, Maxout, Softmax
|
||||
from thinc.neural._classes.hash_embed import HashEmbed
|
||||
from .attrs import TAG, DEP
|
||||
|
||||
from thinc.neural._classes.convolution import ExtractWindow
|
||||
from thinc.neural._classes.static_vectors import StaticVectors
|
||||
|
||||
from .attrs import ID, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||
|
||||
|
||||
def get_col(idx):
|
||||
|
@ -79,19 +83,15 @@ def _reshape(layer):
|
|||
model._layers.append(layer)
|
||||
return model
|
||||
|
||||
#from thinc.api import layerize, chain, clone, concatenate, add
|
||||
# from thinc.neural._classes.convolution import ExtractWindow
|
||||
# from thinc.neural._classes.static_vectors import StaticVectors
|
||||
|
||||
#def build_tok2vec(lang, width, depth, embed_size, cols):
|
||||
# with Model.define_operators({'>>': chain, '|': concatenate, '**': clone}):
|
||||
# static = get_col(cols.index(ID)) >> StaticVectors(lang, width)
|
||||
# prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size)
|
||||
# suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size)
|
||||
# shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size)
|
||||
# tok2vec = (
|
||||
# (static | prefix | suffix | shape)
|
||||
# >> Maxout(width, width*4)
|
||||
# >> (ExtractWindow(nW=1) >> Maxout(width, width*3)) ** depth
|
||||
# )
|
||||
# return tok2vec
|
||||
def build_tok2vec(lang, width, depth, embed_size, cols):
|
||||
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone}):
|
||||
static = get_col(cols.index(ID)) >> StaticVectors(lang, width)
|
||||
prefix = get_col(cols.index(PREFIX)) >> HashEmbed(width, embed_size)
|
||||
suffix = get_col(cols.index(SUFFIX)) >> HashEmbed(width, embed_size)
|
||||
shape = get_col(cols.index(SHAPE)) >> HashEmbed(width, embed_size)
|
||||
tok2vec = (
|
||||
(static | prefix | suffix | shape)
|
||||
>> Maxout(width, width*4)
|
||||
>> (ExtractWindow(nW=1) >> Maxout(width, width*3)) ** depth
|
||||
)
|
||||
return tok2vec
|
||||
|
|
Loading…
Reference in New Issue
Block a user