mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-28 17:03:04 +03:00
Fix vector loading
This commit is contained in:
parent
49a615e7d9
commit
93fb8b64e9
|
@ -4,16 +4,12 @@ from collections import OrderedDict
|
||||||
import msgpack
|
import msgpack
|
||||||
import msgpack_numpy
|
import msgpack_numpy
|
||||||
msgpack_numpy.patch()
|
msgpack_numpy.patch()
|
||||||
from cymem.cymem cimport Pool
|
|
||||||
cimport numpy as np
|
cimport numpy as np
|
||||||
from libcpp.vector cimport vector
|
|
||||||
|
|
||||||
from .typedefs cimport attr_t
|
from .typedefs cimport attr_t
|
||||||
from .strings cimport StringStore
|
from .strings cimport StringStore
|
||||||
from . import util
|
from . import util
|
||||||
from ._cfile cimport CFile
|
from .compat import basestring_
|
||||||
|
|
||||||
MAX_VEC_SIZE = 10000
|
|
||||||
|
|
||||||
|
|
||||||
cdef class Vectors:
|
cdef class Vectors:
|
||||||
|
@ -60,7 +56,21 @@ cdef class Vectors:
|
||||||
yield from self.data
|
yield from self.data
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
return len(self.strings)
|
# TODO: Fix the quadratic behaviour here!
|
||||||
|
return max(self.key2row.values())
|
||||||
|
|
||||||
|
def __contains__(self, key):
|
||||||
|
if isinstance(key, basestring_):
|
||||||
|
key = self.strings[key]
|
||||||
|
return key in self.key2row
|
||||||
|
|
||||||
|
def add_key(self, string, vector=None):
|
||||||
|
key = self.strings.add(string)
|
||||||
|
next_i = len(self) + 1
|
||||||
|
self.keys[next_i] = key
|
||||||
|
self.key2row[key] = next_i
|
||||||
|
if vector is not None:
|
||||||
|
self.data[next_i] = vector
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
for i, string in enumerate(self.strings):
|
for i, string in enumerate(self.strings):
|
||||||
|
@ -75,9 +85,9 @@ cdef class Vectors:
|
||||||
|
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
serializers = OrderedDict((
|
serializers = OrderedDict((
|
||||||
('vectors', lambda p: numpy.save(p.open('wb'), self.data)),
|
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
|
||||||
('strings.json', self.strings.to_disk),
|
('strings.json', self.strings.to_disk),
|
||||||
('keys', lambda p: numpy.save(p.open('wb'), self.keys)),
|
('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)),
|
||||||
))
|
))
|
||||||
return util.to_disk(path, serializers, exclude)
|
return util.to_disk(path, serializers, exclude)
|
||||||
|
|
||||||
|
|
|
@ -19,7 +19,7 @@ from .tokens.token cimport Token
|
||||||
from .attrs cimport PROB, LANG
|
from .attrs cimport PROB, LANG
|
||||||
from .structs cimport SerializedLexemeC
|
from .structs cimport SerializedLexemeC
|
||||||
|
|
||||||
from .compat import copy_reg, pickle
|
from .compat import copy_reg, pickle, basestring_
|
||||||
from .lemmatizer import Lemmatizer
|
from .lemmatizer import Lemmatizer
|
||||||
from .attrs import intify_attrs
|
from .attrs import intify_attrs
|
||||||
from .vectors import Vectors
|
from .vectors import Vectors
|
||||||
|
@ -244,7 +244,7 @@ cdef class Vocab:
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vectors_length(self):
|
def vectors_length(self):
|
||||||
raise NotImplementedError
|
return len(self.vectors)
|
||||||
|
|
||||||
def clear_vectors(self):
|
def clear_vectors(self):
|
||||||
"""Drop the current vector table. Because all vectors must be the same
|
"""Drop the current vector table. Because all vectors must be the same
|
||||||
|
@ -264,7 +264,9 @@ cdef class Vocab:
|
||||||
|
|
||||||
RAISES: If no vectors data is loaded, ValueError is raised.
|
RAISES: If no vectors data is loaded, ValueError is raised.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
if isinstance(orth, basestring_):
|
||||||
|
orth = self.strings.add(orth)
|
||||||
|
return self.vectors[orth]
|
||||||
|
|
||||||
def set_vector(self, orth, vector):
|
def set_vector(self, orth, vector):
|
||||||
"""Set a vector for a word in the vocabulary.
|
"""Set a vector for a word in the vocabulary.
|
||||||
|
@ -274,13 +276,17 @@ cdef class Vocab:
|
||||||
RETURNS:
|
RETURNS:
|
||||||
None
|
None
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
if not isinstance(orth, basestring_):
|
||||||
|
orth = self.strings[orth]
|
||||||
|
self.vectors.add_key(orth, vector=vector)
|
||||||
|
|
||||||
def has_vector(self, orth):
|
def has_vector(self, orth):
|
||||||
"""Check whether a word has a vector. Returns False if no
|
"""Check whether a word has a vector. Returns False if no
|
||||||
vectors have been loaded. Words can be looked up by string
|
vectors have been loaded. Words can be looked up by string
|
||||||
or int ID."""
|
or int ID."""
|
||||||
return False
|
if isinstance(orth, basestring_):
|
||||||
|
orth = self.strings.add(orth)
|
||||||
|
return orth in self.vectors
|
||||||
|
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
"""Save the current state to a directory.
|
"""Save the current state to a directory.
|
||||||
|
|
Loading…
Reference in New Issue
Block a user