mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Update _ml, for textcat model
This commit is contained in:
parent
d6a5c2c85a
commit
6ffec9dfea
109
spacy/_ml.py
109
spacy/_ml.py
|
@ -3,6 +3,7 @@ from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
|||
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||
from thinc.neural._classes.hash_embed import HashEmbed
|
||||
from thinc.neural.ops import NumpyOps, CupyOps
|
||||
from thinc.neural.util import get_array_module
|
||||
|
||||
from thinc.neural._classes.convolution import ExtractWindow
|
||||
from thinc.neural._classes.static_vectors import StaticVectors
|
||||
|
@ -12,14 +13,61 @@ from thinc.neural import ReLu
|
|||
from thinc import describe
|
||||
from thinc.describe import Dimension, Synapses, Biases, Gradient
|
||||
from thinc.neural._classes.affine import _set_dimensions_if_needed
|
||||
from thinc.api import FeatureExtracter, with_getitem
|
||||
from thinc.neural.pooling import Pooling, max_pool, mean_pool
|
||||
|
||||
from .attrs import ID, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||
from .attrs import ID, ORTH, LOWER, NORM, PREFIX, SUFFIX, SHAPE, TAG, DEP
|
||||
from .tokens.doc import Doc
|
||||
|
||||
import numpy
|
||||
import io
|
||||
|
||||
|
||||
@layerize
|
||||
def _flatten_add_lengths(seqs, pad=0, drop=0.):
|
||||
ops = Model.ops
|
||||
lengths = ops.asarray([len(seq) for seq in seqs], dtype='i')
|
||||
def finish_update(d_X, sgd=None):
|
||||
return ops.unflatten(d_X, lengths, pad=pad)
|
||||
X = ops.flatten(seqs, pad=pad)
|
||||
return (X, lengths), finish_update
|
||||
|
||||
|
||||
@layerize
|
||||
def _logistic(X, drop=0.):
|
||||
xp = get_array_module(X)
|
||||
if not isinstance(X, xp.ndarray):
|
||||
X = xp.asarray(X)
|
||||
# Clip to range (-10, 10)
|
||||
X = xp.minimum(X, 10., X)
|
||||
X = xp.maximum(X, -10., X)
|
||||
Y = 1. / (1. + xp.exp(-X))
|
||||
def logistic_bwd(dY, sgd=None):
|
||||
dX = dY * (Y * (1-Y))
|
||||
return dX
|
||||
return Y, logistic_bwd
|
||||
|
||||
|
||||
def _zero_init(model):
|
||||
def _zero_init_impl(self, X, y):
|
||||
self.W.fill(0)
|
||||
model.on_data_hooks.append(_zero_init_impl)
|
||||
if model.W is not None:
|
||||
model.W.fill(0.)
|
||||
return model
|
||||
|
||||
@layerize
|
||||
def _preprocess_doc(docs, drop=0.):
|
||||
keys = [doc.to_array([LOWER]) for doc in docs]
|
||||
keys = [a[:, 0] for a in keys]
|
||||
ops = Model.ops
|
||||
lengths = ops.asarray([arr.shape[0] for arr in keys])
|
||||
keys = ops.xp.concatenate(keys)
|
||||
vals = ops.allocate(keys.shape[0]) + 1
|
||||
return (keys, vals, lengths), None
|
||||
|
||||
|
||||
|
||||
def _init_for_precomputed(W, ops):
|
||||
if (W**2).sum() != 0.:
|
||||
return
|
||||
|
@ -317,19 +365,66 @@ def preprocess_doc(docs, drop=0.):
|
|||
return (keys, vals, lengths), None
|
||||
|
||||
|
||||
# This belongs in thinc
|
||||
def wrap(func, *child_layers):
|
||||
model = layerize(func)
|
||||
model._layers.extend(child_layers)
|
||||
def on_data(self, X, y):
|
||||
for child in self._layers:
|
||||
for hook in child.on_data_hooks:
|
||||
hook(child, X, y)
|
||||
model.on_data_hooks.append(on_data)
|
||||
return model
|
||||
|
||||
# This belongs in thinc
|
||||
def uniqued(layer, column=0):
|
||||
'''Group inputs to a layer, so that the layer only has to compute
|
||||
for the unique values. The data is transformed back before output, and the same
|
||||
transformation is applied for the gradient. Effectively, this is a cache
|
||||
local to each minibatch.
|
||||
|
||||
The uniqued wrapper is useful for word inputs, because common words are
|
||||
seen often, but we may want to compute complicated features for the words,
|
||||
using e.g. character LSTM.
|
||||
'''
|
||||
def uniqued_fwd(X, drop=0.):
|
||||
keys = X[:, column]
|
||||
if not isinstance(keys, numpy.ndarray):
|
||||
keys = keys.get()
|
||||
uniq_keys, ind, inv, counts = numpy.unique(keys, return_index=True,
|
||||
return_inverse=True,
|
||||
return_counts=True)
|
||||
Y_uniq, bp_Y_uniq = layer.begin_update(X[ind], drop=drop)
|
||||
Y = Y_uniq[inv].reshape((X.shape[0],) + Y_uniq.shape[1:])
|
||||
def uniqued_bwd(dY, sgd=None):
|
||||
dY_uniq = layer.ops.allocate(Y_uniq.shape, dtype='f')
|
||||
layer.ops.scatter_add(dY_uniq, inv, dY)
|
||||
d_uniques = bp_Y_uniq(dY_uniq, sgd=sgd)
|
||||
if d_uniques is not None:
|
||||
dX = (d_uniques / counts)[inv]
|
||||
return dX
|
||||
else:
|
||||
return None
|
||||
return Y, uniqued_bwd
|
||||
model = wrap(uniqued_fwd, layer)
|
||||
return model
|
||||
|
||||
|
||||
def build_text_classifier(nr_class, width=64, **cfg):
|
||||
with Model.define_operators({'>>': chain, '+': add, '|': concatenate, '**': clone}):
|
||||
embed_lower = HashEmbed(width, 300, column=1)
|
||||
embed_prefix = HashEmbed(width//2, 300, column=2)
|
||||
embed_suffix = HashEmbed(width//2, 300, column=3)
|
||||
embed_shape = HashEmbed(width//2, 300, column=4)
|
||||
embed_lower = HashEmbed(width, 1000, column=1)
|
||||
embed_prefix = HashEmbed(width//2, 1000, column=2)
|
||||
embed_suffix = HashEmbed(width//2, 1000, column=3)
|
||||
embed_shape = HashEmbed(width//2, 1000, column=4)
|
||||
|
||||
model = (
|
||||
FeatureExtracter([ORTH, LOWER, PREFIX, SUFFIX, SHAPE])
|
||||
>> _flatten_add_lengths
|
||||
>> with_getitem(0,
|
||||
(embed_lower | embed_prefix | embed_suffix | embed_shape)
|
||||
>> Maxout(width, width+(width//2)*3)
|
||||
uniqued(
|
||||
(embed_lower | embed_prefix | embed_suffix | embed_shape)
|
||||
>> Maxout(width, width+(width//2)*3)
|
||||
)
|
||||
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
||||
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
||||
>> Residual(ExtractWindow(nW=1) >> ReLu(width, width*3))
|
||||
|
|
Loading…
Reference in New Issue
Block a user