* Switch between the orth and char codecs depending on which is shorter for that message. Mostly orth is shorter, except if there are OOV words.

This commit is contained in:
Matthew Honnibal 2015-07-20 01:36:22 +02:00
parent 5a042ee0d3
commit 5a7d060d9c

View File

@ -1,7 +1,7 @@
# cython: profile=True # cython: profile=True
from __future__ import unicode_literals from __future__ import unicode_literals
from libc.stdint cimport uint32_t from libc.stdint cimport uint32_t, int32_t
from libc.stdint cimport uint64_t from libc.stdint cimport uint64_t
from libc.math cimport exp as c_exp from libc.math cimport exp as c_exp
from libcpp.queue cimport priority_queue from libcpp.queue cimport priority_queue
@ -68,7 +68,7 @@ def _gen_orths(Vocab vocab):
def _gen_chars(Vocab vocab): def _gen_chars(Vocab vocab):
cdef attr_t orth cdef attr_t orth
cdef size_t addr cdef size_t addr
char_weights = {u' ': 0.0} char_weights = {chr(i): 1e-20 for i in range(256)}
cdef unicode string cdef unicode string
cdef bytes char cdef bytes char
cdef bytes utf8_str cdef bytes utf8_str
@ -79,15 +79,17 @@ def _gen_chars(Vocab vocab):
for char in utf8_str: for char in utf8_str:
char_weights.setdefault(char, 0.0) char_weights.setdefault(char, 0.0)
char_weights[char] += c_exp(lex.prob) char_weights[char] += c_exp(lex.prob)
char_weights[u' '] += c_exp(lex.prob) char_weights[b' '] += c_exp(lex.prob)
return char_weights.items() return char_weights.items()
cdef class Packer: cdef class Packer:
def __init__(self, Vocab vocab, attr_freqs): def __init__(self, Vocab vocab, attr_freqs, char_freqs=None):
if char_freqs is None:
char_freqs = _gen_chars(vocab)
self.vocab = vocab self.vocab = vocab
self.orth_codec = HuffmanCodec(_gen_orths(vocab)) self.orth_codec = HuffmanCodec(_gen_orths(vocab))
self.char_codec = HuffmanCodec(_gen_chars(vocab)) self.char_codec = HuffmanCodec(char_freqs)
codecs = [] codecs = []
attrs = [] attrs = []
@ -104,7 +106,14 @@ cdef class Packer:
return cls(vocab, util.read_encoding_freqs(data_dir)) return cls(vocab, util.read_encoding_freqs(data_dir))
def pack(self, Doc doc): def pack(self, Doc doc):
bits = self._orth_encode(doc) orths = [t.orth for t in doc]
chars = doc.string.encode('utf8')
# n_bits returns nan for oov words, i.e. can't encode message.
# So, it's important to write the conditional like this.
if self.orth_codec.n_bits(orths) < self.char_codec.n_bits(chars, overhead=1):
bits = self._orth_encode(doc)
else:
bits = self._char_encode(doc)
array = doc.to_array(self.attrs) array = doc.to_array(self.attrs)
for i, codec in enumerate(self._codecs): for i, codec in enumerate(self._codecs):
codec.encode(array[:, i], bits) codec.encode(array[:, i], bits)
@ -112,8 +121,11 @@ cdef class Packer:
def unpack(self, BitArray bits): def unpack(self, BitArray bits):
bits.seek(0) bits.seek(0)
cdef uint32_t length = bits.read32() cdef int32_t length = bits.read32()
doc = self._orth_decode(bits, length) if length >= 0:
doc = self._orth_decode(bits, length)
else:
doc = self._char_decode(bits, -length)
array = numpy.zeros(shape=(len(doc), len(self._codecs)), dtype=numpy.int32) array = numpy.zeros(shape=(len(doc), len(self._codecs)), dtype=numpy.int32)
for i, codec in enumerate(self._codecs): for i, codec in enumerate(self._codecs):
@ -125,7 +137,7 @@ cdef class Packer:
def _orth_encode(self, Doc doc): def _orth_encode(self, Doc doc):
cdef BitArray bits = BitArray() cdef BitArray bits = BitArray()
orths = [w.orth for w in doc] orths = [w.orth for w in doc]
cdef uint32_t length = len(doc) cdef int32_t length = len(doc)
bits.extend(length, 32) bits.extend(length, 32)
self.orth_codec.encode(orths, bits) self.orth_codec.encode(orths, bits)
for token in doc: for token in doc:
@ -142,11 +154,10 @@ cdef class Packer:
def _char_encode(self, Doc doc): def _char_encode(self, Doc doc):
cdef BitArray bits = BitArray() cdef BitArray bits = BitArray()
cdef bytes utf8_str = doc.string.encode('utf8') cdef bytes utf8_str = doc.string.encode('utf8')
cdef uint32_t length = len(utf8_str) cdef int32_t length = len(utf8_str)
bits.extend(length, 32) # Signal chars with negative length
bits.extend(-length, 32)
cdef bytes utf8_string = doc.string.encode('utf8') self.char_codec.encode(utf8_str, bits)
self.char_codec.encode(utf8_string, bits)
for token in doc: for token in doc:
for i in range(len(token)-1): for i in range(len(token)-1):
bits.append(False) bits.append(False)