Fix dtype

This commit is contained in:
Matthew Honnibal 2017-11-08 12:18:32 +01:00
parent fa7fdd0d9b
commit 1d5599cd28

View File

@ -94,7 +94,7 @@ def _zero_init(model):
def _preprocess_doc(docs, drop=0.): def _preprocess_doc(docs, drop=0.):
keys = [doc.to_array([LOWER]) for doc in docs] keys = [doc.to_array([LOWER]) for doc in docs]
ops = Model.ops ops = Model.ops
lengths = ops.asarray([arr.shape[0] for arr in keys], dtype='int32') lengths = ops.asarray([arr.shape[0] for arr in keys], dtype='int64')
keys = ops.xp.concatenate(keys) keys = ops.xp.concatenate(keys)
vals = ops.allocate(keys.shape[0]) + 1 vals = ops.allocate(keys.shape[0]) + 1
return (keys, vals, lengths), None return (keys, vals, lengths), None