mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-14 11:36:24 +03:00
* Switch between the orth and char codecs depending on which is shorter for that message. Mostly orth is shorter, except if there are OOV words.
This commit is contained in:
parent
5a042ee0d3
commit
5a7d060d9c
|
@ -1,7 +1,7 @@
|
||||||
# cython: profile=True
|
# cython: profile=True
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
from libc.stdint cimport uint32_t
|
from libc.stdint cimport uint32_t, int32_t
|
||||||
from libc.stdint cimport uint64_t
|
from libc.stdint cimport uint64_t
|
||||||
from libc.math cimport exp as c_exp
|
from libc.math cimport exp as c_exp
|
||||||
from libcpp.queue cimport priority_queue
|
from libcpp.queue cimport priority_queue
|
||||||
|
@ -68,7 +68,7 @@ def _gen_orths(Vocab vocab):
|
||||||
def _gen_chars(Vocab vocab):
|
def _gen_chars(Vocab vocab):
|
||||||
cdef attr_t orth
|
cdef attr_t orth
|
||||||
cdef size_t addr
|
cdef size_t addr
|
||||||
char_weights = {u' ': 0.0}
|
char_weights = {chr(i): 1e-20 for i in range(256)}
|
||||||
cdef unicode string
|
cdef unicode string
|
||||||
cdef bytes char
|
cdef bytes char
|
||||||
cdef bytes utf8_str
|
cdef bytes utf8_str
|
||||||
|
@ -79,15 +79,17 @@ def _gen_chars(Vocab vocab):
|
||||||
for char in utf8_str:
|
for char in utf8_str:
|
||||||
char_weights.setdefault(char, 0.0)
|
char_weights.setdefault(char, 0.0)
|
||||||
char_weights[char] += c_exp(lex.prob)
|
char_weights[char] += c_exp(lex.prob)
|
||||||
char_weights[u' '] += c_exp(lex.prob)
|
char_weights[b' '] += c_exp(lex.prob)
|
||||||
return char_weights.items()
|
return char_weights.items()
|
||||||
|
|
||||||
|
|
||||||
cdef class Packer:
|
cdef class Packer:
|
||||||
def __init__(self, Vocab vocab, attr_freqs):
|
def __init__(self, Vocab vocab, attr_freqs, char_freqs=None):
|
||||||
|
if char_freqs is None:
|
||||||
|
char_freqs = _gen_chars(vocab)
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.orth_codec = HuffmanCodec(_gen_orths(vocab))
|
self.orth_codec = HuffmanCodec(_gen_orths(vocab))
|
||||||
self.char_codec = HuffmanCodec(_gen_chars(vocab))
|
self.char_codec = HuffmanCodec(char_freqs)
|
||||||
|
|
||||||
codecs = []
|
codecs = []
|
||||||
attrs = []
|
attrs = []
|
||||||
|
@ -104,7 +106,14 @@ cdef class Packer:
|
||||||
return cls(vocab, util.read_encoding_freqs(data_dir))
|
return cls(vocab, util.read_encoding_freqs(data_dir))
|
||||||
|
|
||||||
def pack(self, Doc doc):
|
def pack(self, Doc doc):
|
||||||
bits = self._orth_encode(doc)
|
orths = [t.orth for t in doc]
|
||||||
|
chars = doc.string.encode('utf8')
|
||||||
|
# n_bits returns nan for oov words, i.e. can't encode message.
|
||||||
|
# So, it's important to write the conditional like this.
|
||||||
|
if self.orth_codec.n_bits(orths) < self.char_codec.n_bits(chars, overhead=1):
|
||||||
|
bits = self._orth_encode(doc)
|
||||||
|
else:
|
||||||
|
bits = self._char_encode(doc)
|
||||||
array = doc.to_array(self.attrs)
|
array = doc.to_array(self.attrs)
|
||||||
for i, codec in enumerate(self._codecs):
|
for i, codec in enumerate(self._codecs):
|
||||||
codec.encode(array[:, i], bits)
|
codec.encode(array[:, i], bits)
|
||||||
|
@ -112,8 +121,11 @@ cdef class Packer:
|
||||||
|
|
||||||
def unpack(self, BitArray bits):
|
def unpack(self, BitArray bits):
|
||||||
bits.seek(0)
|
bits.seek(0)
|
||||||
cdef uint32_t length = bits.read32()
|
cdef int32_t length = bits.read32()
|
||||||
doc = self._orth_decode(bits, length)
|
if length >= 0:
|
||||||
|
doc = self._orth_decode(bits, length)
|
||||||
|
else:
|
||||||
|
doc = self._char_decode(bits, -length)
|
||||||
|
|
||||||
array = numpy.zeros(shape=(len(doc), len(self._codecs)), dtype=numpy.int32)
|
array = numpy.zeros(shape=(len(doc), len(self._codecs)), dtype=numpy.int32)
|
||||||
for i, codec in enumerate(self._codecs):
|
for i, codec in enumerate(self._codecs):
|
||||||
|
@ -125,7 +137,7 @@ cdef class Packer:
|
||||||
def _orth_encode(self, Doc doc):
|
def _orth_encode(self, Doc doc):
|
||||||
cdef BitArray bits = BitArray()
|
cdef BitArray bits = BitArray()
|
||||||
orths = [w.orth for w in doc]
|
orths = [w.orth for w in doc]
|
||||||
cdef uint32_t length = len(doc)
|
cdef int32_t length = len(doc)
|
||||||
bits.extend(length, 32)
|
bits.extend(length, 32)
|
||||||
self.orth_codec.encode(orths, bits)
|
self.orth_codec.encode(orths, bits)
|
||||||
for token in doc:
|
for token in doc:
|
||||||
|
@ -142,11 +154,10 @@ cdef class Packer:
|
||||||
def _char_encode(self, Doc doc):
|
def _char_encode(self, Doc doc):
|
||||||
cdef BitArray bits = BitArray()
|
cdef BitArray bits = BitArray()
|
||||||
cdef bytes utf8_str = doc.string.encode('utf8')
|
cdef bytes utf8_str = doc.string.encode('utf8')
|
||||||
cdef uint32_t length = len(utf8_str)
|
cdef int32_t length = len(utf8_str)
|
||||||
bits.extend(length, 32)
|
# Signal chars with negative length
|
||||||
|
bits.extend(-length, 32)
|
||||||
cdef bytes utf8_string = doc.string.encode('utf8')
|
self.char_codec.encode(utf8_str, bits)
|
||||||
self.char_codec.encode(utf8_string, bits)
|
|
||||||
for token in doc:
|
for token in doc:
|
||||||
for i in range(len(token)-1):
|
for i in range(len(token)-1):
|
||||||
bits.append(False)
|
bits.append(False)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user