Fix textcat model for GPU

This commit is contained in:
Matthew Honnibal 2019-03-09 17:50:08 +00:00
parent 16fa4d6b90
commit 28c26e212d

View File

@ -84,16 +84,52 @@ def _zero_init(model):
@layerize
def _preprocess_doc(docs, drop=0.0):
keys = [doc.to_array(LOWER) for doc in docs]
ops = Model.ops
# The dtype here matches what thinc is expecting -- which differs per
# platform (by int definition). This should be fixed once the problem
# is fixed on Thinc's side.
lengths = ops.asarray([arr.shape[0] for arr in keys], dtype=numpy.int_)
keys = ops.xp.concatenate(keys)
vals = ops.allocate(keys.shape) + 1.0
lengths = numpy.array([arr.shape[0] for arr in keys], dtype=numpy.int_)
keys = numpy.concatenate(keys)
vals = numpy.zeros(keys.shape, dtype='f')
return (keys, vals, lengths), None
def with_cpu(ops, model):
model.to_cpu()
def with_cpu_forward(inputs, drop=0.):
cpu_outputs, backprop = model.begin_update(_to_cpu(inputs), drop=drop)
gpu_outputs = _to_device(ops, cpu_outputs)
def with_cpu_backprop(d_outputs, sgd=None):
cpu_d_outputs = _to_cpu(d_outputs)
return backprop(cpu_d_outputs, sgd=sgd)
return gpu_outputs, with_cpu_backprop
return wrap(with_cpu_forward, model)
def _to_cpu(X):
if isinstance(X, numpy.ndarray):
return X
elif isinstance(X, tuple):
return tuple([_to_cpu(x) for x in X])
elif isinstance(X, list):
return [_to_cpu(x) for x in X]
elif hasattr(X, 'get'):
return X.get()
else:
return X
def _to_device(ops, X):
if isinstance(X, tuple):
return tuple([_to_device(ops, x) for x in X])
elif isinstance(X, list):
return [_to_device(ops, x) for x in X]
else:
return ops.asarray(X)
@layerize
def _preprocess_doc_bigrams(docs, drop=0.0):
unigrams = [doc.to_array(LOWER) for doc in docs]
@ -563,7 +599,10 @@ def build_text_classifier(nr_class, width=64, **cfg):
>> zero_init(Affine(nr_class, width, drop_factor=0.0))
)
linear_model = _preprocess_doc >> LinearModel(nr_class)
linear_model = (
_preprocess_doc
>> with_cpu(Model.ops, LinearModel(nr_class))
)
if cfg.get('exclusive_classes'):
output_layer = Softmax(nr_class, nr_class * 2)
else: