mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 02:16:32 +03:00
* Improve serialization speed
This commit is contained in:
parent
f13d5dae91
commit
bb0ba1f0cd
|
@ -1,3 +1,4 @@
|
||||||
|
# cython: profile=True
|
||||||
cimport cython
|
cimport cython
|
||||||
from libcpp.queue cimport priority_queue
|
from libcpp.queue cimport priority_queue
|
||||||
from libcpp.pair cimport pair
|
from libcpp.pair cimport pair
|
||||||
|
@ -74,6 +75,8 @@ cdef class HuffmanCodec:
|
||||||
node = self.root
|
node = self.root
|
||||||
cdef int i = 0
|
cdef int i = 0
|
||||||
cdef int n = len(msg)
|
cdef int n = len(msg)
|
||||||
|
cdef int branch
|
||||||
|
cdef bint bit
|
||||||
for bit in bits:
|
for bit in bits:
|
||||||
branch = node.right if bit else node.left
|
branch = node.right if bit else node.left
|
||||||
if branch >= 0:
|
if branch >= 0:
|
||||||
|
|
|
@ -106,14 +106,18 @@ cdef class Packer:
|
||||||
return cls(vocab, util.read_encoding_freqs(data_dir))
|
return cls(vocab, util.read_encoding_freqs(data_dir))
|
||||||
|
|
||||||
def pack(self, Doc doc):
|
def pack(self, Doc doc):
|
||||||
orths = [t.orth for t in doc]
|
orths = doc.to_array([ORTH])
|
||||||
chars = doc.string.encode('utf8')
|
orths = orths[:, 0]
|
||||||
|
cdef bytes chars = doc.string.encode('utf8')
|
||||||
# n_bits returns nan for oov words, i.e. can't encode message.
|
# n_bits returns nan for oov words, i.e. can't encode message.
|
||||||
# So, it's important to write the conditional like this.
|
# So, it's important to write the conditional like this.
|
||||||
if self.orth_codec.n_bits(orths) < self.char_codec.n_bits(chars, overhead=1):
|
if self.orth_codec.n_bits(orths) < self.char_codec.n_bits(chars, overhead=1):
|
||||||
bits = self._orth_encode(doc)
|
bits = self._orth_encode(doc, orths)
|
||||||
else:
|
else:
|
||||||
bits = self._char_encode(doc)
|
bits = self._char_encode(doc, chars)
|
||||||
|
|
||||||
|
cdef int i
|
||||||
|
if self.attrs:
|
||||||
array = doc.to_array(self.attrs)
|
array = doc.to_array(self.attrs)
|
||||||
for i, codec in enumerate(self._codecs):
|
for i, codec in enumerate(self._codecs):
|
||||||
codec.encode(array[:, i], bits)
|
codec.encode(array[:, i], bits)
|
||||||
|
@ -134,9 +138,8 @@ cdef class Packer:
|
||||||
doc.from_array(self.attrs, array)
|
doc.from_array(self.attrs, array)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def _orth_encode(self, Doc doc):
|
def _orth_encode(self, Doc doc, attr_t[:] orths):
|
||||||
cdef BitArray bits = BitArray()
|
cdef BitArray bits = BitArray()
|
||||||
orths = [w.orth for w in doc]
|
|
||||||
cdef int32_t length = len(doc)
|
cdef int32_t length = len(doc)
|
||||||
bits.extend(length, 32)
|
bits.extend(length, 32)
|
||||||
self.orth_codec.encode(orths, bits)
|
self.orth_codec.encode(orths, bits)
|
||||||
|
@ -145,46 +148,47 @@ cdef class Packer:
|
||||||
return bits
|
return bits
|
||||||
|
|
||||||
def _orth_decode(self, BitArray bits, n):
|
def _orth_decode(self, BitArray bits, n):
|
||||||
orths = [0] * n
|
orths = numpy.ndarray(shape=(n,), dtype=numpy.int32)
|
||||||
self.orth_codec.decode(bits, orths)
|
self.orth_codec.decode(bits, orths)
|
||||||
orths_and_spaces = zip(orths, bits)
|
orths_and_spaces = zip(orths, bits)
|
||||||
cdef Doc doc = Doc(self.vocab, orths_and_spaces)
|
cdef Doc doc = Doc(self.vocab, orths_and_spaces)
|
||||||
return doc
|
return doc
|
||||||
|
|
||||||
def _char_encode(self, Doc doc):
|
def _char_encode(self, Doc doc, bytes utf8_str):
|
||||||
cdef BitArray bits = BitArray()
|
cdef BitArray bits = BitArray()
|
||||||
cdef bytes utf8_str = doc.string.encode('utf8')
|
|
||||||
cdef int32_t length = len(utf8_str)
|
cdef int32_t length = len(utf8_str)
|
||||||
# Signal chars with negative length
|
# Signal chars with negative length
|
||||||
bits.extend(-length, 32)
|
bits.extend(-length, 32)
|
||||||
self.char_codec.encode(utf8_str, bits)
|
self.char_codec.encode(utf8_str, bits)
|
||||||
for token in doc:
|
cdef int i, j
|
||||||
for i in range(len(token)-1):
|
for i in range(doc.length):
|
||||||
|
for j in range(doc.data[i].lex.length-1):
|
||||||
bits.append(False)
|
bits.append(False)
|
||||||
bits.append(True)
|
bits.append(True)
|
||||||
if token.whitespace_:
|
if doc.data[i].spacy:
|
||||||
bits.append(False)
|
bits.append(False)
|
||||||
return bits
|
return bits
|
||||||
|
|
||||||
def _char_decode(self, BitArray bits, n):
|
def _char_decode(self, BitArray bits, n):
|
||||||
chars = [b''] * n
|
cdef bytearray utf8_str = bytearray(n)
|
||||||
self.char_codec.decode(bits, chars)
|
self.char_codec.decode(bits, utf8_str)
|
||||||
cdef bytes utf8_str = b''.join(chars)
|
|
||||||
|
|
||||||
cdef unicode string = utf8_str.decode('utf8')
|
cdef unicode string = utf8_str.decode('utf8')
|
||||||
cdef Doc tokens = Doc(self.vocab)
|
cdef Doc tokens = Doc(self.vocab)
|
||||||
cdef int i
|
|
||||||
cdef int start = 0
|
cdef int start = 0
|
||||||
cdef bint is_spacy
|
cdef bint is_spacy
|
||||||
cdef UniStr span
|
cdef UniStr span
|
||||||
cdef int length = len(string)
|
cdef int length = len(string)
|
||||||
iter_bits = iter(bits)
|
cdef int i = 0
|
||||||
for i in range(length):
|
cdef bint is_end_token
|
||||||
is_end_token = iter_bits.next()
|
for is_end_token in bits:
|
||||||
if is_end_token:
|
if is_end_token:
|
||||||
slice_unicode(&span, string, start, i+1)
|
slice_unicode(&span, string, start, i+1)
|
||||||
lex = self.vocab.get(tokens.mem, &span)
|
lex = self.vocab.get(tokens.mem, &span)
|
||||||
is_spacy = (i+1) < length and string[i+1] == u' '
|
is_spacy = (i+1) < length and string[i+1] == u' '
|
||||||
tokens.push_back(lex, is_spacy)
|
tokens.push_back(lex, is_spacy)
|
||||||
start = i + 1 + is_spacy
|
start = i + 1 + is_spacy
|
||||||
|
i += 1
|
||||||
|
if i >= n:
|
||||||
|
break
|
||||||
return tokens
|
return tokens
|
||||||
|
|
Loading…
Reference in New Issue
Block a user