* Fix Packer API, so that it reads and writes bytes strings, instead of BitArray. Docs are always byte aligned anyway.

This commit is contained in:
Matthew Honnibal 2015-07-23 01:12:00 +02:00
parent 38ef986b29
commit 1f31d96bf9

View File

@ -10,6 +10,7 @@ from libcpp.pair cimport pair
from cymem.cymem cimport Address, Pool
from preshed.maps cimport PreshMap
from preshed.counter cimport PreshCounter
import json
from ..attrs cimport ORTH, ID, SPACY, TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
from ..tokens.doc cimport Doc
@ -98,33 +99,34 @@ cdef class Packer:
self._codecs = tuple(codecs)
self.attrs = tuple(attrs)
@classmethod
def from_dir(cls, Vocab vocab, data_dir):
return cls(vocab, util.read_encoding_freqs(data_dir))
def pack(self, Doc doc):
bits = self._orth_encode(doc)
if bits is None:
bits = self._char_encode(doc)
cdef int i
if self.attrs:
array = doc.to_array(self.attrs)
for i, codec in enumerate(self._codecs):
codec.encode_int32(array[:, i], bits)
return bits
codec.encode(array[:, i], bits)
return bits.as_bytes()
def unpack(self, BitArray bits):
def unpack(self, bytes data):
doc = Doc(self.vocab)
self.unpack_into(data, doc)
return doc
def unpack_into(self, bytes byte_string, Doc doc):
bits = BitArray(byte_string)
bits.seek(0)
cdef int32_t length = bits.read32()
if length >= 0:
doc = self._orth_decode(bits, length)
self._orth_decode(bits, length, doc)
else:
doc = self._char_decode(bits, -length)
self._char_decode(bits, -length, doc)
array = numpy.zeros(shape=(len(doc), len(self._codecs)), dtype=numpy.int32)
for i, codec in enumerate(self._codecs):
codec.decode_int32(bits, array[:, i])
codec.decode(bits, array[:, i])
doc.from_array(self.attrs, array)
return doc
@ -141,13 +143,6 @@ cdef class Packer:
bits.append(bool(token.whitespace_))
return bits
def _orth_decode(self, BitArray bits, n):
orths = numpy.ndarray(shape=(n,), dtype=numpy.int32)
self.orth_codec.decode_int32(bits, orths)
orths_and_spaces = zip(orths, bits)
cdef Doc doc = Doc(self.vocab, orths_and_spaces)
return doc
def _char_encode(self, Doc doc):
cdef bytes utf8_str = doc.string.encode('utf8')
cdef BitArray bits = BitArray()
@ -164,12 +159,24 @@ cdef class Packer:
bits.append(False)
return bits
def _char_decode(self, BitArray bits, n):
def _orth_decode(self, BitArray bits, int32_t n, Doc doc):
cdef attr_t[:] orths = numpy.ndarray(shape=(n,), dtype=numpy.int32)
self.orth_codec.decode_int32(bits, orths)
cdef int i
cdef bint space
spaces = iter(bits)
for i in range(n):
orth = orths[i]
space = spaces.next()
lex = self.vocab.get_by_orth(doc.mem, orth)
doc.push_back(lex, space)
return doc
def _char_decode(self, BitArray bits, int32_t n, Doc doc):
cdef bytearray utf8_str = bytearray(n)
self.char_codec.decode(bits, utf8_str)
cdef unicode string = utf8_str.decode('utf8')
cdef Doc tokens = Doc(self.vocab)
cdef int start = 0
cdef bint is_spacy
cdef int length = len(string)
@ -178,11 +185,11 @@ cdef class Packer:
for is_end_token in bits:
if is_end_token:
span = string[start:i+1]
lex = self.vocab.get(tokens.mem, span)
lex = self.vocab.get(doc.mem, span)
is_spacy = (i+1) < length and string[i+1] == u' '
tokens.push_back(lex, is_spacy)
doc.push_back(lex, is_spacy)
start = i + 1 + is_spacy
i += 1
if i >= n:
break
return tokens
return doc