Load vectors in vocab

This commit is contained in:
Matthew Honnibal 2017-08-18 20:46:56 +02:00
parent de7f3509d2
commit 2993b54fff

View File

@ -280,7 +280,7 @@ cdef class Vocab:
or int ID.""" or int ID."""
return False return False
def to_disk(self, path): def to_disk(self, path, **exclude):
"""Save the current state to a directory. """Save the current state to a directory.
path (unicode or Path): A path to a directory, which will be created if path (unicode or Path): A path to a directory, which will be created if
@ -292,8 +292,10 @@ cdef class Vocab:
self.strings.to_disk(path / 'strings.json') self.strings.to_disk(path / 'strings.json')
with (path / 'lexemes.bin').open('wb') as file_: with (path / 'lexemes.bin').open('wb') as file_:
file_.write(self.lexemes_to_bytes()) file_.write(self.lexemes_to_bytes())
if self.vectors is not None:
self.vectors.to_disk(path, exclude='strings.json')
def from_disk(self, path): def from_disk(self, path, **exclude):
"""Loads state from a directory. Modifies the object in place and """Loads state from a directory. Modifies the object in place and
returns it. returns it.
@ -305,6 +307,8 @@ cdef class Vocab:
self.strings.from_disk(path / 'strings.json') self.strings.from_disk(path / 'strings.json')
with (path / 'lexemes.bin').open('rb') as file_: with (path / 'lexemes.bin').open('rb') as file_:
self.lexemes_from_bytes(file_.read()) self.lexemes_from_bytes(file_.read())
if self.vectors is not None:
self.vectors.from_disk(path, exclude='string.json')
return self return self
def to_bytes(self, **exclude): def to_bytes(self, **exclude):
@ -313,9 +317,16 @@ cdef class Vocab:
**exclude: Named attributes to prevent from being serialized. **exclude: Named attributes to prevent from being serialized.
RETURNS (bytes): The serialized form of the `Vocab` object. RETURNS (bytes): The serialized form of the `Vocab` object.
""" """
def deserialize_vectors():
if self.vectors is None:
return None
else:
return self.vectors.to_bytes(exclude='strings')
getters = OrderedDict(( getters = OrderedDict((
('strings', lambda: self.strings.to_bytes()), ('strings', lambda: self.strings.to_bytes()),
('lexemes', lambda: self.lexemes_to_bytes()), ('lexemes', lambda: self.lexemes_to_bytes()),
('vectors', deserialize_vectors)
)) ))
return util.to_bytes(getters, exclude) return util.to_bytes(getters, exclude)
@ -326,9 +337,15 @@ cdef class Vocab:
**exclude: Named attributes to prevent from being loaded. **exclude: Named attributes to prevent from being loaded.
RETURNS (Vocab): The `Vocab` object. RETURNS (Vocab): The `Vocab` object.
""" """
def serialize_vectors(b):
if self.vectors is None:
return None
else:
return self.vectors.from_bytes(b, exclude='strings')
setters = OrderedDict(( setters = OrderedDict((
('strings', lambda b: self.strings.from_bytes(b)), ('strings', lambda b: self.strings.from_bytes(b)),
('lexemes', lambda b: self.lexemes_from_bytes(b)), ('lexemes', lambda b: self.lexemes_from_bytes(b)),
('vectors', lambda b: serialize_vectors(b))
)) ))
util.from_bytes(bytes_data, setters, exclude) util.from_bytes(bytes_data, setters, exclude)
return self return self