mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 18:06:29 +03:00
* Improve serialization speed
This commit is contained in:
parent
f13d5dae91
commit
bb0ba1f0cd
|
@ -1,3 +1,4 @@
|
|||
# cython: profile=True
|
||||
cimport cython
|
||||
from libcpp.queue cimport priority_queue
|
||||
from libcpp.pair cimport pair
|
||||
|
@ -74,6 +75,8 @@ cdef class HuffmanCodec:
|
|||
node = self.root
|
||||
cdef int i = 0
|
||||
cdef int n = len(msg)
|
||||
cdef int branch
|
||||
cdef bint bit
|
||||
for bit in bits:
|
||||
branch = node.right if bit else node.left
|
||||
if branch >= 0:
|
||||
|
|
|
@ -106,17 +106,21 @@ cdef class Packer:
|
|||
return cls(vocab, util.read_encoding_freqs(data_dir))
|
||||
|
||||
def pack(self, Doc doc):
|
||||
orths = [t.orth for t in doc]
|
||||
chars = doc.string.encode('utf8')
|
||||
orths = doc.to_array([ORTH])
|
||||
orths = orths[:, 0]
|
||||
cdef bytes chars = doc.string.encode('utf8')
|
||||
# n_bits returns nan for oov words, i.e. can't encode message.
|
||||
# So, it's important to write the conditional like this.
|
||||
if self.orth_codec.n_bits(orths) < self.char_codec.n_bits(chars, overhead=1):
|
||||
bits = self._orth_encode(doc)
|
||||
bits = self._orth_encode(doc, orths)
|
||||
else:
|
||||
bits = self._char_encode(doc)
|
||||
array = doc.to_array(self.attrs)
|
||||
for i, codec in enumerate(self._codecs):
|
||||
codec.encode(array[:, i], bits)
|
||||
bits = self._char_encode(doc, chars)
|
||||
|
||||
cdef int i
|
||||
if self.attrs:
|
||||
array = doc.to_array(self.attrs)
|
||||
for i, codec in enumerate(self._codecs):
|
||||
codec.encode(array[:, i], bits)
|
||||
return bits
|
||||
|
||||
def unpack(self, BitArray bits):
|
||||
|
@ -134,9 +138,8 @@ cdef class Packer:
|
|||
doc.from_array(self.attrs, array)
|
||||
return doc
|
||||
|
||||
def _orth_encode(self, Doc doc):
|
||||
def _orth_encode(self, Doc doc, attr_t[:] orths):
|
||||
cdef BitArray bits = BitArray()
|
||||
orths = [w.orth for w in doc]
|
||||
cdef int32_t length = len(doc)
|
||||
bits.extend(length, 32)
|
||||
self.orth_codec.encode(orths, bits)
|
||||
|
@ -145,46 +148,47 @@ cdef class Packer:
|
|||
return bits
|
||||
|
||||
def _orth_decode(self, BitArray bits, n):
|
||||
orths = [0] * n
|
||||
orths = numpy.ndarray(shape=(n,), dtype=numpy.int32)
|
||||
self.orth_codec.decode(bits, orths)
|
||||
orths_and_spaces = zip(orths, bits)
|
||||
cdef Doc doc = Doc(self.vocab, orths_and_spaces)
|
||||
return doc
|
||||
|
||||
def _char_encode(self, Doc doc):
|
||||
def _char_encode(self, Doc doc, bytes utf8_str):
|
||||
cdef BitArray bits = BitArray()
|
||||
cdef bytes utf8_str = doc.string.encode('utf8')
|
||||
cdef int32_t length = len(utf8_str)
|
||||
# Signal chars with negative length
|
||||
bits.extend(-length, 32)
|
||||
self.char_codec.encode(utf8_str, bits)
|
||||
for token in doc:
|
||||
for i in range(len(token)-1):
|
||||
cdef int i, j
|
||||
for i in range(doc.length):
|
||||
for j in range(doc.data[i].lex.length-1):
|
||||
bits.append(False)
|
||||
bits.append(True)
|
||||
if token.whitespace_:
|
||||
if doc.data[i].spacy:
|
||||
bits.append(False)
|
||||
return bits
|
||||
|
||||
def _char_decode(self, BitArray bits, n):
|
||||
chars = [b''] * n
|
||||
self.char_codec.decode(bits, chars)
|
||||
cdef bytes utf8_str = b''.join(chars)
|
||||
cdef bytearray utf8_str = bytearray(n)
|
||||
self.char_codec.decode(bits, utf8_str)
|
||||
|
||||
cdef unicode string = utf8_str.decode('utf8')
|
||||
cdef Doc tokens = Doc(self.vocab)
|
||||
cdef int i
|
||||
cdef int start = 0
|
||||
cdef bint is_spacy
|
||||
cdef UniStr span
|
||||
cdef int length = len(string)
|
||||
iter_bits = iter(bits)
|
||||
for i in range(length):
|
||||
is_end_token = iter_bits.next()
|
||||
cdef int i = 0
|
||||
cdef bint is_end_token
|
||||
for is_end_token in bits:
|
||||
if is_end_token:
|
||||
slice_unicode(&span, string, start, i+1)
|
||||
lex = self.vocab.get(tokens.mem, &span)
|
||||
is_spacy = (i+1) < length and string[i+1] == u' '
|
||||
tokens.push_back(lex, is_spacy)
|
||||
start = i + 1 + is_spacy
|
||||
i += 1
|
||||
if i >= n:
|
||||
break
|
||||
return tokens
|
||||
|
|
Loading…
Reference in New Issue
Block a user