* Improve serialization speed

This commit is contained in:
Matthew Honnibal 2015-07-20 03:27:59 +02:00
parent f13d5dae91
commit bb0ba1f0cd
2 changed files with 29 additions and 22 deletions

View File

@ -1,3 +1,4 @@
# cython: profile=True
cimport cython
from libcpp.queue cimport priority_queue
from libcpp.pair cimport pair
@ -74,6 +75,8 @@ cdef class HuffmanCodec:
node = self.root
cdef int i = 0
cdef int n = len(msg)
cdef int branch
cdef bint bit
for bit in bits:
branch = node.right if bit else node.left
if branch >= 0:

View File

@ -106,17 +106,21 @@ cdef class Packer:
return cls(vocab, util.read_encoding_freqs(data_dir))
def pack(self, Doc doc):
orths = [t.orth for t in doc]
chars = doc.string.encode('utf8')
orths = doc.to_array([ORTH])
orths = orths[:, 0]
cdef bytes chars = doc.string.encode('utf8')
# n_bits returns nan for oov words, i.e. can't encode message.
# So, it's important to write the conditional like this.
if self.orth_codec.n_bits(orths) < self.char_codec.n_bits(chars, overhead=1):
bits = self._orth_encode(doc)
bits = self._orth_encode(doc, orths)
else:
bits = self._char_encode(doc)
array = doc.to_array(self.attrs)
for i, codec in enumerate(self._codecs):
codec.encode(array[:, i], bits)
bits = self._char_encode(doc, chars)
cdef int i
if self.attrs:
array = doc.to_array(self.attrs)
for i, codec in enumerate(self._codecs):
codec.encode(array[:, i], bits)
return bits
def unpack(self, BitArray bits):
@ -134,9 +138,8 @@ cdef class Packer:
doc.from_array(self.attrs, array)
return doc
def _orth_encode(self, Doc doc):
def _orth_encode(self, Doc doc, attr_t[:] orths):
cdef BitArray bits = BitArray()
orths = [w.orth for w in doc]
cdef int32_t length = len(doc)
bits.extend(length, 32)
self.orth_codec.encode(orths, bits)
@ -145,46 +148,47 @@ cdef class Packer:
return bits
def _orth_decode(self, BitArray bits, n):
orths = [0] * n
orths = numpy.ndarray(shape=(n,), dtype=numpy.int32)
self.orth_codec.decode(bits, orths)
orths_and_spaces = zip(orths, bits)
cdef Doc doc = Doc(self.vocab, orths_and_spaces)
return doc
def _char_encode(self, Doc doc):
def _char_encode(self, Doc doc, bytes utf8_str):
cdef BitArray bits = BitArray()
cdef bytes utf8_str = doc.string.encode('utf8')
cdef int32_t length = len(utf8_str)
# Signal chars with negative length
bits.extend(-length, 32)
self.char_codec.encode(utf8_str, bits)
for token in doc:
for i in range(len(token)-1):
cdef int i, j
for i in range(doc.length):
for j in range(doc.data[i].lex.length-1):
bits.append(False)
bits.append(True)
if token.whitespace_:
if doc.data[i].spacy:
bits.append(False)
return bits
def _char_decode(self, BitArray bits, n):
chars = [b''] * n
self.char_codec.decode(bits, chars)
cdef bytes utf8_str = b''.join(chars)
cdef bytearray utf8_str = bytearray(n)
self.char_codec.decode(bits, utf8_str)
cdef unicode string = utf8_str.decode('utf8')
cdef Doc tokens = Doc(self.vocab)
cdef int i
cdef int start = 0
cdef bint is_spacy
cdef UniStr span
cdef int length = len(string)
iter_bits = iter(bits)
for i in range(length):
is_end_token = iter_bits.next()
cdef int i = 0
cdef bint is_end_token
for is_end_token in bits:
if is_end_token:
slice_unicode(&span, string, start, i+1)
lex = self.vocab.get(tokens.mem, &span)
is_spacy = (i+1) < length and string[i+1] == u' '
tokens.push_back(lex, is_spacy)
start = i + 1 + is_spacy
i += 1
if i >= n:
break
return tokens