mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 17:24:41 +03:00
Hackish pickle support for Vocab.
This commit is contained in:
parent
26614e028f
commit
d814892805
116
spacy/vocab.pyx
116
spacy/vocab.pyx
|
@ -9,11 +9,16 @@ import bz2
|
|||
import ujson as json
|
||||
import re
|
||||
|
||||
try:
|
||||
import cPickle as pickle
|
||||
except ImportError:
|
||||
import pickle
|
||||
|
||||
from .lexeme cimport EMPTY_LEXEME
|
||||
from .lexeme cimport Lexeme
|
||||
from .strings cimport hash_string
|
||||
from .typedefs cimport attr_t
|
||||
from .cfile cimport CFile
|
||||
from .cfile cimport CFile, StringCFile
|
||||
from .lemmatizer import Lemmatizer
|
||||
from .attrs import intify_attrs
|
||||
from .tokens.token cimport Token
|
||||
|
@ -346,17 +351,18 @@ cdef class Vocab:
|
|||
Token.set_struct_attr(token, attr_id, value)
|
||||
return tokens
|
||||
|
||||
def dump(self, loc):
|
||||
"""Save the lexemes binary data to the given location.
|
||||
def dump(self, loc=None):
|
||||
"""Save the lexemes binary data to the given location, or
|
||||
return a byte-string with the data if loc is None.
|
||||
|
||||
Arguments:
|
||||
loc (Path): The path to save to.
|
||||
loc (Path or None): The path to save to, or None.
|
||||
"""
|
||||
if hasattr(loc, 'as_posix'):
|
||||
loc = loc.as_posix()
|
||||
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
|
||||
|
||||
cdef CFile fp = CFile(bytes_loc, 'wb')
|
||||
cdef CFile fp
|
||||
if loc is None:
|
||||
fp = StringCFile('wb')
|
||||
else:
|
||||
fp = CFile(loc, 'wb')
|
||||
cdef size_t st
|
||||
cdef size_t addr
|
||||
cdef hash_t key
|
||||
|
@ -378,6 +384,8 @@ cdef class Vocab:
|
|||
fp.write_from(&lexeme.l2_norm, sizeof(lexeme.l2_norm), 1)
|
||||
fp.write_from(&lexeme.lang, sizeof(lexeme.lang), 1)
|
||||
fp.close()
|
||||
if loc is None:
|
||||
return fp.string_data()
|
||||
|
||||
def load_lexemes(self, loc):
|
||||
'''Load the binary vocabulary data from the given location.
|
||||
|
@ -427,6 +435,60 @@ cdef class Vocab:
|
|||
i += 1
|
||||
fp.close()
|
||||
|
||||
def _deserialize_lexemes(self, CFile fp):
|
||||
'''Load the binary vocabulary data from the given CFile.
|
||||
'''
|
||||
cdef LexemeC* lexeme
|
||||
cdef hash_t key
|
||||
cdef unicode py_str
|
||||
cdef attr_t orth
|
||||
assert sizeof(orth) == sizeof(lexeme.orth)
|
||||
i = 0
|
||||
cdef int todo = fp.size
|
||||
cdef int lex_size = sizeof(lexeme.flags)
|
||||
lex_size += sizeof(lexeme.id)
|
||||
lex_size += sizeof(lexeme.length)
|
||||
lex_size += sizeof(lexeme.orth)
|
||||
lex_size += sizeof(lexeme.lower)
|
||||
lex_size += sizeof(lexeme.norm)
|
||||
lex_size += sizeof(lexeme.shape)
|
||||
lex_size += sizeof(lexeme.prefix)
|
||||
lex_size += sizeof(lexeme.suffix)
|
||||
lex_size += sizeof(lexeme.cluster)
|
||||
lex_size += sizeof(lexeme.prob)
|
||||
lex_size += sizeof(lexeme.sentiment)
|
||||
lex_size += sizeof(lexeme.l2_norm)
|
||||
lex_size += sizeof(lexeme.lang)
|
||||
while True:
|
||||
if todo < lex_size:
|
||||
break
|
||||
todo -= lex_size
|
||||
lexeme = <LexemeC*>self.mem.alloc(sizeof(LexemeC), 1)
|
||||
# Copy data from the file into the lexeme
|
||||
fp.read_into(&lexeme.flags, 1, sizeof(lexeme.flags))
|
||||
fp.read_into(&lexeme.id, 1, sizeof(lexeme.id))
|
||||
fp.read_into(&lexeme.length, 1, sizeof(lexeme.length))
|
||||
fp.read_into(&lexeme.orth, 1, sizeof(lexeme.orth))
|
||||
fp.read_into(&lexeme.lower, 1, sizeof(lexeme.lower))
|
||||
fp.read_into(&lexeme.norm, 1, sizeof(lexeme.norm))
|
||||
fp.read_into(&lexeme.shape, 1, sizeof(lexeme.shape))
|
||||
fp.read_into(&lexeme.prefix, 1, sizeof(lexeme.prefix))
|
||||
fp.read_into(&lexeme.suffix, 1, sizeof(lexeme.suffix))
|
||||
fp.read_into(&lexeme.cluster, 1, sizeof(lexeme.cluster))
|
||||
fp.read_into(&lexeme.prob, 1, sizeof(lexeme.prob))
|
||||
fp.read_into(&lexeme.sentiment, 1, sizeof(lexeme.sentiment))
|
||||
fp.read_into(&lexeme.l2_norm, 1, sizeof(lexeme.l2_norm))
|
||||
fp.read_into(&lexeme.lang, 1, sizeof(lexeme.lang))
|
||||
|
||||
lexeme.vector = EMPTY_VEC
|
||||
py_str = self.strings[lexeme.orth]
|
||||
key = hash_string(py_str)
|
||||
self._by_hash.set(key, lexeme)
|
||||
self._by_orth.set(lexeme.orth, lexeme)
|
||||
self.length += 1
|
||||
i += 1
|
||||
fp.close()
|
||||
|
||||
def dump_vectors(self, out_loc):
|
||||
'''Save the word vectors to a binary file.
|
||||
|
||||
|
@ -553,6 +615,42 @@ cdef class Vocab:
|
|||
return vec_len
|
||||
|
||||
|
||||
def pickle_vocab(vocab):
|
||||
sstore = vocab.strings
|
||||
morph = vocab.morphology
|
||||
length = vocab.length
|
||||
serializer = vocab._serializer
|
||||
data_dir = vocab.data_dir
|
||||
lex_attr_getters = vocab.lex_attr_getters
|
||||
|
||||
lexemes_data = vocab.dump()
|
||||
vectors_length = vocab.vectors_length
|
||||
|
||||
return (unpickle_vocab,
|
||||
(sstore, morph, serializer, data_dir, lex_attr_getters,
|
||||
lexemes_data, length, vectors_length))
|
||||
|
||||
|
||||
def unpickle_vocab(sstore, morphology, serializer, data_dir,
|
||||
lex_attr_getters, bytes lexemes_data, int length, int vectors_length):
|
||||
cdef Vocab vocab = Vocab()
|
||||
vocab.length = length
|
||||
vocab.vectors_length = vectors_length
|
||||
vocab.strings = sstore
|
||||
cdef CFile fp = StringCFile('r', data=lexemes_data)
|
||||
vocab.morphology = morphology
|
||||
vocab._serializer = serializer
|
||||
vocab.data_dir = data_dir
|
||||
vocab.lex_attr_getters = lex_attr_getters
|
||||
vocab._deserialize_lexemes(fp)
|
||||
vocab.length = length
|
||||
vocab.vectors_length = vectors_length
|
||||
return vocab
|
||||
|
||||
|
||||
copy_reg.pickle(Vocab, pickle_vocab, unpickle_vocab)
|
||||
|
||||
|
||||
def write_binary_vectors(in_loc, out_loc):
|
||||
cdef CFile out_file = CFile(out_loc, 'wb')
|
||||
cdef Address mem
|
||||
|
|
Loading…
Reference in New Issue
Block a user