* Fix Packer API, so that it reads and writes bytes strings, instead of BitArray. Docs are always byte aligned anyway.

This commit is contained in:
Matthew Honnibal 2015-07-23 01:12:00 +02:00
parent 38ef986b29
commit 1f31d96bf9

View File

@ -10,6 +10,7 @@ from libcpp.pair cimport pair
from cymem.cymem cimport Address, Pool from cymem.cymem cimport Address, Pool
from preshed.maps cimport PreshMap from preshed.maps cimport PreshMap
from preshed.counter cimport PreshCounter from preshed.counter cimport PreshCounter
import json
from ..attrs cimport ORTH, ID, SPACY, TAG, HEAD, DEP, ENT_IOB, ENT_TYPE from ..attrs cimport ORTH, ID, SPACY, TAG, HEAD, DEP, ENT_IOB, ENT_TYPE
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
@ -98,33 +99,34 @@ cdef class Packer:
self._codecs = tuple(codecs) self._codecs = tuple(codecs)
self.attrs = tuple(attrs) self.attrs = tuple(attrs)
@classmethod
def from_dir(cls, Vocab vocab, data_dir):
return cls(vocab, util.read_encoding_freqs(data_dir))
def pack(self, Doc doc): def pack(self, Doc doc):
bits = self._orth_encode(doc) bits = self._orth_encode(doc)
if bits is None: if bits is None:
bits = self._char_encode(doc) bits = self._char_encode(doc)
cdef int i cdef int i
if self.attrs: if self.attrs:
array = doc.to_array(self.attrs) array = doc.to_array(self.attrs)
for i, codec in enumerate(self._codecs): for i, codec in enumerate(self._codecs):
codec.encode_int32(array[:, i], bits) codec.encode(array[:, i], bits)
return bits return bits.as_bytes()
def unpack(self, BitArray bits): def unpack(self, bytes data):
doc = Doc(self.vocab)
self.unpack_into(data, doc)
return doc
def unpack_into(self, bytes byte_string, Doc doc):
bits = BitArray(byte_string)
bits.seek(0) bits.seek(0)
cdef int32_t length = bits.read32() cdef int32_t length = bits.read32()
if length >= 0: if length >= 0:
doc = self._orth_decode(bits, length) self._orth_decode(bits, length, doc)
else: else:
doc = self._char_decode(bits, -length) self._char_decode(bits, -length, doc)
array = numpy.zeros(shape=(len(doc), len(self._codecs)), dtype=numpy.int32) array = numpy.zeros(shape=(len(doc), len(self._codecs)), dtype=numpy.int32)
for i, codec in enumerate(self._codecs): for i, codec in enumerate(self._codecs):
codec.decode_int32(bits, array[:, i]) codec.decode(bits, array[:, i])
doc.from_array(self.attrs, array) doc.from_array(self.attrs, array)
return doc return doc
@ -141,13 +143,6 @@ cdef class Packer:
bits.append(bool(token.whitespace_)) bits.append(bool(token.whitespace_))
return bits return bits
def _orth_decode(self, BitArray bits, n):
orths = numpy.ndarray(shape=(n,), dtype=numpy.int32)
self.orth_codec.decode_int32(bits, orths)
orths_and_spaces = zip(orths, bits)
cdef Doc doc = Doc(self.vocab, orths_and_spaces)
return doc
def _char_encode(self, Doc doc): def _char_encode(self, Doc doc):
cdef bytes utf8_str = doc.string.encode('utf8') cdef bytes utf8_str = doc.string.encode('utf8')
cdef BitArray bits = BitArray() cdef BitArray bits = BitArray()
@ -164,12 +159,24 @@ cdef class Packer:
bits.append(False) bits.append(False)
return bits return bits
def _char_decode(self, BitArray bits, n): def _orth_decode(self, BitArray bits, int32_t n, Doc doc):
cdef attr_t[:] orths = numpy.ndarray(shape=(n,), dtype=numpy.int32)
self.orth_codec.decode_int32(bits, orths)
cdef int i
cdef bint space
spaces = iter(bits)
for i in range(n):
orth = orths[i]
space = spaces.next()
lex = self.vocab.get_by_orth(doc.mem, orth)
doc.push_back(lex, space)
return doc
def _char_decode(self, BitArray bits, int32_t n, Doc doc):
cdef bytearray utf8_str = bytearray(n) cdef bytearray utf8_str = bytearray(n)
self.char_codec.decode(bits, utf8_str) self.char_codec.decode(bits, utf8_str)
cdef unicode string = utf8_str.decode('utf8') cdef unicode string = utf8_str.decode('utf8')
cdef Doc tokens = Doc(self.vocab)
cdef int start = 0 cdef int start = 0
cdef bint is_spacy cdef bint is_spacy
cdef int length = len(string) cdef int length = len(string)
@ -178,11 +185,11 @@ cdef class Packer:
for is_end_token in bits: for is_end_token in bits:
if is_end_token: if is_end_token:
span = string[start:i+1] span = string[start:i+1]
lex = self.vocab.get(tokens.mem, span) lex = self.vocab.get(doc.mem, span)
is_spacy = (i+1) < length and string[i+1] == u' ' is_spacy = (i+1) < length and string[i+1] == u' '
tokens.push_back(lex, is_spacy) doc.push_back(lex, is_spacy)
start = i + 1 + is_spacy start = i + 1 + is_spacy
i += 1 i += 1
if i >= n: if i >= n:
break break
return tokens return doc