mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Fix vector loading
This commit is contained in:
		
							parent
							
								
									49a615e7d9
								
							
						
					
					
						commit
						93fb8b64e9
					
				| 
						 | 
					@ -4,16 +4,12 @@ from collections import OrderedDict
 | 
				
			||||||
import msgpack
 | 
					import msgpack
 | 
				
			||||||
import msgpack_numpy
 | 
					import msgpack_numpy
 | 
				
			||||||
msgpack_numpy.patch()
 | 
					msgpack_numpy.patch()
 | 
				
			||||||
from cymem.cymem cimport Pool
 | 
					 | 
				
			||||||
cimport numpy as np
 | 
					cimport numpy as np
 | 
				
			||||||
from libcpp.vector cimport vector
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .typedefs cimport attr_t
 | 
					from .typedefs cimport attr_t
 | 
				
			||||||
from .strings cimport StringStore
 | 
					from .strings cimport StringStore
 | 
				
			||||||
from . import util
 | 
					from . import util
 | 
				
			||||||
from ._cfile cimport CFile
 | 
					from .compat import basestring_
 | 
				
			||||||
 | 
					 | 
				
			||||||
MAX_VEC_SIZE = 10000
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class Vectors:
 | 
					cdef class Vectors:
 | 
				
			||||||
| 
						 | 
					@ -60,7 +56,21 @@ cdef class Vectors:
 | 
				
			||||||
        yield from self.data
 | 
					        yield from self.data
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def __len__(self):
 | 
					    def __len__(self):
 | 
				
			||||||
        return len(self.strings)
 | 
					        # TODO: Fix the quadratic behaviour here!
 | 
				
			||||||
 | 
					        return max(self.key2row.values())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __contains__(self, key):
 | 
				
			||||||
 | 
					        if isinstance(key, basestring_):
 | 
				
			||||||
 | 
					            key = self.strings[key]
 | 
				
			||||||
 | 
					        return key in self.key2row
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def add_key(self, string, vector=None):
 | 
				
			||||||
 | 
					        key = self.strings.add(string)
 | 
				
			||||||
 | 
					        next_i = len(self) + 1
 | 
				
			||||||
 | 
					        self.keys[next_i] = key
 | 
				
			||||||
 | 
					        self.key2row[key] = next_i
 | 
				
			||||||
 | 
					        if vector is not None:
 | 
				
			||||||
 | 
					            self.data[next_i] = vector
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def items(self):
 | 
					    def items(self):
 | 
				
			||||||
        for i, string in enumerate(self.strings):
 | 
					        for i, string in enumerate(self.strings):
 | 
				
			||||||
| 
						 | 
					@ -75,9 +85,9 @@ cdef class Vectors:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def to_disk(self, path, **exclude):
 | 
					    def to_disk(self, path, **exclude):
 | 
				
			||||||
        serializers = OrderedDict((
 | 
					        serializers = OrderedDict((
 | 
				
			||||||
            ('vectors', lambda p: numpy.save(p.open('wb'), self.data)),
 | 
					            ('vectors', lambda p: numpy.save(p.open('wb'), self.data, allow_pickle=False)),
 | 
				
			||||||
            ('strings.json', self.strings.to_disk),
 | 
					            ('strings.json', self.strings.to_disk),
 | 
				
			||||||
            ('keys', lambda p: numpy.save(p.open('wb'), self.keys)),
 | 
					            ('keys', lambda p: numpy.save(p.open('wb'), self.keys, allow_pickle=False)),
 | 
				
			||||||
        ))
 | 
					        ))
 | 
				
			||||||
        return util.to_disk(path, serializers, exclude)
 | 
					        return util.to_disk(path, serializers, exclude)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -19,7 +19,7 @@ from .tokens.token cimport Token
 | 
				
			||||||
from .attrs cimport PROB, LANG
 | 
					from .attrs cimport PROB, LANG
 | 
				
			||||||
from .structs cimport SerializedLexemeC
 | 
					from .structs cimport SerializedLexemeC
 | 
				
			||||||
 | 
					
 | 
				
			||||||
from .compat import copy_reg, pickle
 | 
					from .compat import copy_reg, pickle, basestring_
 | 
				
			||||||
from .lemmatizer import Lemmatizer
 | 
					from .lemmatizer import Lemmatizer
 | 
				
			||||||
from .attrs import intify_attrs
 | 
					from .attrs import intify_attrs
 | 
				
			||||||
from .vectors import Vectors
 | 
					from .vectors import Vectors
 | 
				
			||||||
| 
						 | 
					@ -244,7 +244,7 @@ cdef class Vocab:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    @property
 | 
					    @property
 | 
				
			||||||
    def vectors_length(self):
 | 
					    def vectors_length(self):
 | 
				
			||||||
        raise NotImplementedError
 | 
					        return len(self.vectors)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def clear_vectors(self):
 | 
					    def clear_vectors(self):
 | 
				
			||||||
        """Drop the current vector table. Because all vectors must be the same
 | 
					        """Drop the current vector table. Because all vectors must be the same
 | 
				
			||||||
| 
						 | 
					@ -264,7 +264,9 @@ cdef class Vocab:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        RAISES: If no vectors data is loaded, ValueError is raised.
 | 
					        RAISES: If no vectors data is loaded, ValueError is raised.
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        raise NotImplementedError
 | 
					        if isinstance(orth, basestring_):
 | 
				
			||||||
 | 
					            orth = self.strings.add(orth)
 | 
				
			||||||
 | 
					        return self.vectors[orth]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def set_vector(self, orth, vector):
 | 
					    def set_vector(self, orth, vector):
 | 
				
			||||||
        """Set a vector for a word in the vocabulary.
 | 
					        """Set a vector for a word in the vocabulary.
 | 
				
			||||||
| 
						 | 
					@ -274,13 +276,17 @@ cdef class Vocab:
 | 
				
			||||||
        RETURNS:
 | 
					        RETURNS:
 | 
				
			||||||
            None
 | 
					            None
 | 
				
			||||||
        """
 | 
					        """
 | 
				
			||||||
        raise NotImplementedError
 | 
					        if not isinstance(orth, basestring_):
 | 
				
			||||||
 | 
					            orth = self.strings[orth]
 | 
				
			||||||
 | 
					        self.vectors.add_key(orth, vector=vector)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def has_vector(self, orth):
 | 
					    def has_vector(self, orth):
 | 
				
			||||||
        """Check whether a word has a vector. Returns False if no
 | 
					        """Check whether a word has a vector. Returns False if no
 | 
				
			||||||
        vectors have been loaded. Words can be looked up by string
 | 
					        vectors have been loaded. Words can be looked up by string
 | 
				
			||||||
        or int ID."""
 | 
					        or int ID."""
 | 
				
			||||||
        return False
 | 
					        if isinstance(orth, basestring_):
 | 
				
			||||||
 | 
					            orth = self.strings.add(orth)
 | 
				
			||||||
 | 
					        return orth in self.vectors
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def to_disk(self, path, **exclude):
 | 
					    def to_disk(self, path, **exclude):
 | 
				
			||||||
        """Save the current state to a directory.
 | 
					        """Save the current state to a directory.
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user