mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-23 15:54:13 +03:00
Support having word vectors data on GPU
This commit is contained in:
parent
95bca20c17
commit
e0a2aa9289
|
@ -6,6 +6,8 @@ import msgpack
|
|||
import msgpack_numpy
|
||||
msgpack_numpy.patch()
|
||||
cimport numpy as np
|
||||
from thinc.neural.util import get_array_module
|
||||
from thinc.neural._classes.model import Model
|
||||
|
||||
from .typedefs cimport attr_t
|
||||
from .strings cimport StringStore
|
||||
|
@ -118,9 +120,14 @@ cdef class Vectors:
|
|||
self.data
|
||||
|
||||
def to_disk(self, path, **exclude):
|
||||
xp = get_array_module(self.data)
|
||||
if xp is numpy:
|
||||
save_array = lambda arr, file_: xp.save(file_, arr, allow_pickle=False)
|
||||
else:
|
||||
save_array = lambda arr, file_: xp.save(file_, arr)
|
||||
serializers = OrderedDict((
|
||||
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
|
||||
('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)),
|
||||
('vectors', lambda p: save_array(self.data, p.open('wb'))),
|
||||
('keys', lambda p: xp.save(p.open('wb'), self.keys))
|
||||
))
|
||||
return util.to_disk(path, serializers, exclude)
|
||||
|
||||
|
@ -133,8 +140,9 @@ cdef class Vectors:
|
|||
self.key2row[key] = i
|
||||
|
||||
def load_vectors(path):
|
||||
xp = Model.ops.xp
|
||||
if path.exists():
|
||||
self.data = numpy.load(path)
|
||||
self.data = xp.load(path)
|
||||
|
||||
serializers = OrderedDict((
|
||||
('keys', load_keys),
|
||||
|
|
Loading…
Reference in New Issue
Block a user