mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Fix vector loading
This commit is contained in:
parent
49a615e7d9
commit
93fb8b64e9
|
@ -4,16 +4,12 @@ from collections import OrderedDict
|
|||
import msgpack
|
||||
import msgpack_numpy
|
||||
msgpack_numpy.patch()
|
||||
from cymem.cymem cimport Pool
|
||||
cimport numpy as np
|
||||
from libcpp.vector cimport vector
|
||||
|
||||
from .typedefs cimport attr_t
|
||||
from .strings cimport StringStore
|
||||
from . import util
|
||||
from ._cfile cimport CFile
|
||||
|
||||
MAX_VEC_SIZE = 10000
|
||||
from .compat import basestring_
|
||||
|
||||
|
||||
cdef class Vectors:
|
||||
|
@ -60,7 +56,21 @@ cdef class Vectors:
|
|||
yield from self.data
|
||||
|
||||
def __len__(self):
|
||||
return len(self.strings)
|
||||
# TODO: Fix the quadratic behaviour here!
|
||||
return max(self.key2row.values())
|
||||
|
||||
def __contains__(self, key):
|
||||
if isinstance(key, basestring_):
|
||||
key = self.strings[key]
|
||||
return key in self.key2row
|
||||
|
||||
def add_key(self, string, vector=None):
|
||||
key = self.strings.add(string)
|
||||
next_i = len(self) + 1
|
||||
self.keys[next_i] = key
|
||||
self.key2row[key] = next_i
|
||||
if vector is not None:
|
||||
self.data[next_i] = vector
|
||||
|
||||
def items(self):
|
||||
for i, string in enumerate(self.strings):
|
||||
|
@ -75,9 +85,9 @@ cdef class Vectors:
|
|||
|
||||
def to_disk(self, path, **exclude):
|
||||
serializers = OrderedDict((
|
||||
('vectors', lambda p: numpy.save(p.open('wb'), self.data)),
|
||||
('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
|
||||
('strings.json', self.strings.to_disk),
|
||||
('keys', lambda p: numpy.save(p.open('wb'), self.keys)),
|
||||
('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)),
|
||||
))
|
||||
return util.to_disk(path, serializers, exclude)
|
||||
|
||||
|
|
|
@ -19,7 +19,7 @@ from .tokens.token cimport Token
|
|||
from .attrs cimport PROB, LANG
|
||||
from .structs cimport SerializedLexemeC
|
||||
|
||||
from .compat import copy_reg, pickle
|
||||
from .compat import copy_reg, pickle, basestring_
|
||||
from .lemmatizer import Lemmatizer
|
||||
from .attrs import intify_attrs
|
||||
from .vectors import Vectors
|
||||
|
@ -244,7 +244,7 @@ cdef class Vocab:
|
|||
|
||||
@property
|
||||
def vectors_length(self):
|
||||
raise NotImplementedError
|
||||
return len(self.vectors)
|
||||
|
||||
def clear_vectors(self):
|
||||
"""Drop the current vector table. Because all vectors must be the same
|
||||
|
@ -264,7 +264,9 @@ cdef class Vocab:
|
|||
|
||||
RAISES: If no vectors data is loaded, ValueError is raised.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if isinstance(orth, basestring_):
|
||||
orth = self.strings.add(orth)
|
||||
return self.vectors[orth]
|
||||
|
||||
def set_vector(self, orth, vector):
|
||||
"""Set a vector for a word in the vocabulary.
|
||||
|
@ -274,13 +276,17 @@ cdef class Vocab:
|
|||
RETURNS:
|
||||
None
|
||||
"""
|
||||
raise NotImplementedError
|
||||
if not isinstance(orth, basestring_):
|
||||
orth = self.strings[orth]
|
||||
self.vectors.add_key(orth, vector=vector)
|
||||
|
||||
def has_vector(self, orth):
|
||||
"""Check whether a word has a vector. Returns False if no
|
||||
vectors have been loaded. Words can be looked up by string
|
||||
or int ID."""
|
||||
return False
|
||||
if isinstance(orth, basestring_):
|
||||
orth = self.strings.add(orth)
|
||||
return orth in self.vectors
|
||||
|
||||
def to_disk(self, path, **exclude):
|
||||
"""Save the current state to a directory.
|
||||
|
|
Loading…
Reference in New Issue
Block a user