mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-26 16:42:55 +03:00
Add text-classifer thinc models
This commit is contained in:
parent
f014138c11
commit
727481377e
60
spacy/_ml.py
60
spacy/_ml.py
|
@ -130,6 +130,7 @@ class PrecomputableMaxouts(Model):
|
|||
return dXf
|
||||
return Yfp, backward
|
||||
|
||||
|
||||
def Tok2Vec(width, embed_size, preprocess=None):
|
||||
cols = [ID, NORM, PREFIX, SUFFIX, SHAPE]
|
||||
with Model.define_operators({'>>': chain, '|': concatenate, '**': clone, '+': add}):
|
||||
|
@ -282,3 +283,62 @@ def flatten(seqs, drop=0.):
|
|||
return ops.unflatten(d_X, lengths)
|
||||
X = ops.xp.vstack(seqs)
|
||||
return X, finish_update
|
||||
|
||||
|
||||
@layerize
|
||||
def logistic(X, drop=0.):
|
||||
xp = get_array_module(X)
|
||||
if not isinstance(X, xp.ndarray):
|
||||
X = xp.asarray(X)
|
||||
# Clip to range (-10, 10)
|
||||
X = xp.minimum(X, 10., X)
|
||||
X = xp.maximum(X, -10., X)
|
||||
Y = 1. / (1. + xp.exp(-X))
|
||||
def logistic_bwd(dY, sgd=None):
|
||||
dX = dY * (Y * (1-Y))
|
||||
return dX
|
||||
return Y, logistic_bwd
|
||||
|
||||
|
||||
def zero_init(model):
|
||||
def _zero_init_impl(self, X, y):
|
||||
self.W.fill(0)
|
||||
model.on_data_hooks.append(_zero_init_impl)
|
||||
return model
|
||||
|
||||
@layerize
|
||||
def preprocess_doc(docs, drop=0.):
|
||||
keys = [doc.to_array([LOWER]) for doc in docs]
|
||||
keys = [a[:, 0] for a in keys]
|
||||
ops = Model.ops
|
||||
lengths = ops.asarray([arr.shape[0] for arr in keys])
|
||||
keys = ops.xp.concatenate(keys)
|
||||
vals = ops.allocate(keys.shape[0]) + 1
|
||||
return (keys, vals, lengths), None
|
||||
|
||||
|
||||
def build_text_classifier(nr_class, width=64, **cfg):
|
||||
with Model.define_operators({'>>': chain, '+': add, '|': concatenate, '**': clone}):
|
||||
embed_lower = HashEmbed(width, 300, column=1)
|
||||
embed_prefix = HashEmbed(width//2, 300, column=2)
|
||||
embed_suffix = HashEmbed(width//2, 300, column=3)
|
||||
embed_shape = HashEmbed(width//2, 300, column=4)
|
||||
|
||||
model = (
|
||||
FeatureExtracter([ORTH, LOWER, PREFIX, SUFFIX, SHAPE])
|
||||
>> _flatten_add_lengths
|
||||
>> with_getitem(0,
|
||||
(embed_lower | embed_prefix | embed_suffix | embed_shape)
|
||||
>> Maxout(width, width+(width//2)*3)
|
||||
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
||||
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
||||
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
||||
)
|
||||
>> Pooling(mean_pool, max_pool)
|
||||
>> Residual(ReLu(width*2, width*2))
|
||||
>> zero_init(Affine(nr_class, width*2, drop_factor=0.0))
|
||||
>> logistic
|
||||
)
|
||||
model.lsuv = False
|
||||
return model
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user