mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 10:46:29 +03:00
Fixes to hacky vocab pickling
This commit is contained in:
parent
d814892805
commit
a89c3500f6
|
@ -57,7 +57,7 @@ cdef class StringCFile:
|
||||||
self.size = len(data)
|
self.size = len(data)
|
||||||
self.data = <unsigned char*>self.mem.alloc(1, self._capacity)
|
self.data = <unsigned char*>self.mem.alloc(1, self._capacity)
|
||||||
for i in range(len(data)):
|
for i in range(len(data)):
|
||||||
self.data[i] = data
|
self.data[i] = data[i]
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
self.is_open = False
|
self.is_open = False
|
||||||
|
@ -69,13 +69,12 @@ cdef class StringCFile:
|
||||||
memcpy(dest, self.data, elem_size * number)
|
memcpy(dest, self.data, elem_size * number)
|
||||||
self.data += elem_size * number
|
self.data += elem_size * number
|
||||||
|
|
||||||
cdef int write_from(self, void* src, size_t number, size_t elem_size) except -1:
|
cdef int write_from(self, void* src, size_t elem_size, size_t number) except -1:
|
||||||
write_size = number * elem_size
|
write_size = number * elem_size
|
||||||
if (self.size + write_size) >= self._capacity:
|
if (self.size + write_size) >= self._capacity:
|
||||||
self._capacity = (self.size + write_size) * 2
|
self._capacity = (self.size + write_size) * 2
|
||||||
self.data = <unsigned char*>self.mem.realloc(self.data, self._capacity)
|
self.data = <unsigned char*>self.mem.realloc(self.data, self._capacity)
|
||||||
memcpy(self.data, src, elem_size * number)
|
memcpy(&self.data[self.size], src, elem_size * number)
|
||||||
self.data += write_size
|
|
||||||
self.size += write_size
|
self.size += write_size
|
||||||
|
|
||||||
cdef void* alloc_read(self, Pool mem, size_t number, size_t elem_size) except *:
|
cdef void* alloc_read(self, Pool mem, size_t number, size_t elem_size) except *:
|
||||||
|
|
|
@ -1,9 +1,12 @@
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import pickle
|
import pytest
|
||||||
|
import dill as pickle
|
||||||
|
|
||||||
from ..strings import StringStore
|
from ..strings import StringStore
|
||||||
|
from ..vocab import Vocab
|
||||||
|
from ..attrs import NORM
|
||||||
|
|
||||||
|
|
||||||
def test_pickle_string_store():
|
def test_pickle_string_store():
|
||||||
|
@ -14,4 +17,23 @@ def test_pickle_string_store():
|
||||||
unpickled = pickle.loads(bdata)
|
unpickled = pickle.loads(bdata)
|
||||||
assert unpickled['hello'] == hello
|
assert unpickled['hello'] == hello
|
||||||
assert unpickled['bye'] == bye
|
assert unpickled['bye'] == bye
|
||||||
|
assert len(sstore) == len(unpickled)
|
||||||
|
|
||||||
|
|
||||||
|
def test_pickle_vocab():
|
||||||
|
vocab = Vocab(lex_attr_getters={int(NORM): lambda string: string[:-1]})
|
||||||
|
dog = vocab[u'dog']
|
||||||
|
cat = vocab[u'cat']
|
||||||
|
assert dog.norm_ == 'do'
|
||||||
|
assert cat.norm_ == 'ca'
|
||||||
|
|
||||||
|
bdata = pickle.dumps(vocab)
|
||||||
|
unpickled = pickle.loads(bdata)
|
||||||
|
|
||||||
|
assert unpickled[u'dog'].orth == dog.orth
|
||||||
|
assert unpickled[u'cat'].orth == cat.orth
|
||||||
|
assert unpickled[u'dog'].norm == dog.norm
|
||||||
|
assert unpickled[u'cat'].norm == cat.norm
|
||||||
|
dog_ = unpickled[u'dog']
|
||||||
|
cat_ = unpickled[u'cat']
|
||||||
|
assert dog_.norm != cat_.norm
|
||||||
|
|
Loading…
Reference in New Issue
Block a user