Add hacky support for StringCFile, to make pickling easier.

This commit is contained in:
Matthew Honnibal 2017-03-07 20:24:37 +01:00
parent 3edb8ae207
commit 26614e028f
2 changed files with 56 additions and 0 deletions

View File

@ -4,6 +4,20 @@ from cymem.cymem cimport Pool
cdef class CFile: cdef class CFile:
cdef FILE* fp cdef FILE* fp
cdef bint is_open cdef bint is_open
cdef Pool mem
cdef int size # For compatibility with subclass
cdef int _capacity # For compatibility with subclass
cdef int read_into(self, void* dest, size_t number, size_t elem_size) except -1
cdef int write_from(self, void* src, size_t number, size_t elem_size) except -1
cdef void* alloc_read(self, Pool mem, size_t number, size_t elem_size) except *
cdef class StringCFile(CFile):
cdef unsigned char* data
cdef int read_into(self, void* dest, size_t number, size_t elem_size) except -1 cdef int read_into(self, void* dest, size_t number, size_t elem_size) except -1

View File

@ -1,4 +1,5 @@
from libc.stdio cimport fopen, fclose, fread, fwrite, FILE from libc.stdio cimport fopen, fclose, fread, fwrite, FILE
from libc.string cimport memcpy
cdef class CFile: cdef class CFile:
@ -9,6 +10,7 @@ cdef class CFile:
mode_str = mode mode_str = mode
if hasattr(loc, 'as_posix'): if hasattr(loc, 'as_posix'):
loc = loc.as_posix() loc = loc.as_posix()
self.mem = Pool()
cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc cdef bytes bytes_loc = loc.encode('utf8') if type(loc) == unicode else loc
self.fp = fopen(<char*>bytes_loc, mode_str) self.fp = fopen(<char*>bytes_loc, mode_str)
if self.fp == NULL: if self.fp == NULL:
@ -45,3 +47,43 @@ cdef class CFile:
cdef bytes py_bytes = value.encode('utf8') cdef bytes py_bytes = value.encode('utf8')
cdef char* chars = <char*>py_bytes cdef char* chars = <char*>py_bytes
self.write(sizeof(char), len(py_bytes), chars) self.write(sizeof(char), len(py_bytes), chars)
cdef class StringCFile:
def __init__(self, mode, bytes data=b'', on_open_error=None):
self.mem = Pool()
self.is_open = 'w' in mode
self._capacity = max(len(data), 8)
self.size = len(data)
self.data = <unsigned char*>self.mem.alloc(1, self._capacity)
for i in range(len(data)):
self.data[i] = data
def close(self):
self.is_open = False
def string_data(self):
return (self.data-self.size)[:self.size]
cdef int read_into(self, void* dest, size_t number, size_t elem_size) except -1:
memcpy(dest, self.data, elem_size * number)
self.data += elem_size * number
cdef int write_from(self, void* src, size_t number, size_t elem_size) except -1:
write_size = number * elem_size
if (self.size + write_size) >= self._capacity:
self._capacity = (self.size + write_size) * 2
self.data = <unsigned char*>self.mem.realloc(self.data, self._capacity)
memcpy(self.data, src, elem_size * number)
self.data += write_size
self.size += write_size
cdef void* alloc_read(self, Pool mem, size_t number, size_t elem_size) except *:
cdef void* dest = mem.alloc(number, elem_size)
self.read_into(dest, number, elem_size)
return dest
def write_unicode(self, unicode value):
cdef bytes py_bytes = value.encode('utf8')
cdef char* chars = <char*>py_bytes
self.write(sizeof(char), len(py_bytes), chars)