* Improve serialization speed

This commit is contained in:
Matthew Honnibal 2015-07-20 03:27:59 +02:00
parent f13d5dae91
commit bb0ba1f0cd
2 changed files with 29 additions and 22 deletions

View File

@ -1,3 +1,4 @@
# cython: profile=True
cimport cython cimport cython
from libcpp.queue cimport priority_queue from libcpp.queue cimport priority_queue
from libcpp.pair cimport pair from libcpp.pair cimport pair
@ -74,6 +75,8 @@ cdef class HuffmanCodec:
node = self.root node = self.root
cdef int i = 0 cdef int i = 0
cdef int n = len(msg) cdef int n = len(msg)
cdef int branch
cdef bint bit
for bit in bits: for bit in bits:
branch = node.right if bit else node.left branch = node.right if bit else node.left
if branch >= 0: if branch >= 0:

View File

@ -106,14 +106,18 @@ cdef class Packer:
return cls(vocab, util.read_encoding_freqs(data_dir)) return cls(vocab, util.read_encoding_freqs(data_dir))
def pack(self, Doc doc): def pack(self, Doc doc):
orths = [t.orth for t in doc] orths = doc.to_array([ORTH])
chars = doc.string.encode('utf8') orths = orths[:, 0]
cdef bytes chars = doc.string.encode('utf8')
# n_bits returns nan for oov words, i.e. can't encode message. # n_bits returns nan for oov words, i.e. can't encode message.
# So, it's important to write the conditional like this. # So, it's important to write the conditional like this.
if self.orth_codec.n_bits(orths) < self.char_codec.n_bits(chars, overhead=1): if self.orth_codec.n_bits(orths) < self.char_codec.n_bits(chars, overhead=1):
bits = self._orth_encode(doc) bits = self._orth_encode(doc, orths)
else: else:
bits = self._char_encode(doc) bits = self._char_encode(doc, chars)
cdef int i
if self.attrs:
array = doc.to_array(self.attrs) array = doc.to_array(self.attrs)
for i, codec in enumerate(self._codecs): for i, codec in enumerate(self._codecs):
codec.encode(array[:, i], bits) codec.encode(array[:, i], bits)
@ -134,9 +138,8 @@ cdef class Packer:
doc.from_array(self.attrs, array) doc.from_array(self.attrs, array)
return doc return doc
def _orth_encode(self, Doc doc): def _orth_encode(self, Doc doc, attr_t[:] orths):
cdef BitArray bits = BitArray() cdef BitArray bits = BitArray()
orths = [w.orth for w in doc]
cdef int32_t length = len(doc) cdef int32_t length = len(doc)
bits.extend(length, 32) bits.extend(length, 32)
self.orth_codec.encode(orths, bits) self.orth_codec.encode(orths, bits)
@ -145,46 +148,47 @@ cdef class Packer:
return bits return bits
def _orth_decode(self, BitArray bits, n): def _orth_decode(self, BitArray bits, n):
orths = [0] * n orths = numpy.ndarray(shape=(n,), dtype=numpy.int32)
self.orth_codec.decode(bits, orths) self.orth_codec.decode(bits, orths)
orths_and_spaces = zip(orths, bits) orths_and_spaces = zip(orths, bits)
cdef Doc doc = Doc(self.vocab, orths_and_spaces) cdef Doc doc = Doc(self.vocab, orths_and_spaces)
return doc return doc
def _char_encode(self, Doc doc): def _char_encode(self, Doc doc, bytes utf8_str):
cdef BitArray bits = BitArray() cdef BitArray bits = BitArray()
cdef bytes utf8_str = doc.string.encode('utf8')
cdef int32_t length = len(utf8_str) cdef int32_t length = len(utf8_str)
# Signal chars with negative length # Signal chars with negative length
bits.extend(-length, 32) bits.extend(-length, 32)
self.char_codec.encode(utf8_str, bits) self.char_codec.encode(utf8_str, bits)
for token in doc: cdef int i, j
for i in range(len(token)-1): for i in range(doc.length):
for j in range(doc.data[i].lex.length-1):
bits.append(False) bits.append(False)
bits.append(True) bits.append(True)
if token.whitespace_: if doc.data[i].spacy:
bits.append(False) bits.append(False)
return bits return bits
def _char_decode(self, BitArray bits, n): def _char_decode(self, BitArray bits, n):
chars = [b''] * n cdef bytearray utf8_str = bytearray(n)
self.char_codec.decode(bits, chars) self.char_codec.decode(bits, utf8_str)
cdef bytes utf8_str = b''.join(chars)
cdef unicode string = utf8_str.decode('utf8') cdef unicode string = utf8_str.decode('utf8')
cdef Doc tokens = Doc(self.vocab) cdef Doc tokens = Doc(self.vocab)
cdef int i
cdef int start = 0 cdef int start = 0
cdef bint is_spacy cdef bint is_spacy
cdef UniStr span cdef UniStr span
cdef int length = len(string) cdef int length = len(string)
iter_bits = iter(bits) cdef int i = 0
for i in range(length): cdef bint is_end_token
is_end_token = iter_bits.next() for is_end_token in bits:
if is_end_token: if is_end_token:
slice_unicode(&span, string, start, i+1) slice_unicode(&span, string, start, i+1)
lex = self.vocab.get(tokens.mem, &span) lex = self.vocab.get(tokens.mem, &span)
is_spacy = (i+1) < length and string[i+1] == u' ' is_spacy = (i+1) < length and string[i+1] == u' '
tokens.push_back(lex, is_spacy) tokens.push_back(lex, is_spacy)
start = i + 1 + is_spacy start = i + 1 + is_spacy
i += 1
if i >= n:
break
return tokens return tokens