* Round-trip for serialization finally working. Needs a lot of optimization.

This commit is contained in:
Matthew Honnibal 2015-07-13 18:39:38 +02:00
parent edd371246c
commit 5b0a7190c9
4 changed files with 86 additions and 19 deletions

View File

@ -51,13 +51,13 @@ cdef class BitArray:
start_byte = self.i // 8
if (self.i % 8) != 0:
for i in range(self.i % 8):
yield (self.data[start_byte] & (one << i))
yield 1 if (self.data[start_byte] & (one << i)) else 0
start_byte += 1
for byte in self.data[start_byte:]:
for i in range(8):
yield byte & (one << i)
yield 1 if byte & (one << i) else 0
for i in range(self.bit_of_byte):
yield self.byte & (one << i)
yield 1 if self.byte & (one << i) else 0
def as_bytes(self):
if self.bit_of_byte != 0:
@ -67,6 +67,7 @@ cdef class BitArray:
def append(self, bint bit):
cdef uint64_t one = 1
print 'append', bit
if bit:
self.byte |= one << self.bit_of_byte
else:
@ -128,9 +129,9 @@ cdef class HuffmanCodec:
bits.extend(self.codes[self.eol].bits, self.codes[self.eol].length)
return bits
def decode(self, BitArray bits):
def decode(self, bits):
node = self.nodes.back()
symbols = []
symbols = []
for bit in bits:
branch = node.right if bit else node.left
if branch >= 0:

View File

@ -16,6 +16,8 @@ from .lexeme cimport check_flag
from .spans import Span
from .structs cimport UniStr
from .serialize import BitArray
from unidecode import unidecode
# Compiler crashes on memory view coercion without this. Should report bug.
from cython.view cimport array as cvarray
@ -373,12 +375,55 @@ cdef class Doc:
# Return the merged Python object
return self[start]
def _has_trailing_space(self, int i):
cdef int end_idx = self.data[i].idx + self.data[i].lex.length
if end_idx >= len(self._string):
return False
else:
return self._string[end_idx] == u' '
def serialize(self, bits=None):
if bits is None:
bits = BitArray()
codec = self.vocab.codec
ids = numpy.zeros(shape=(len(self),), dtype=numpy.uint32)
cdef int i
for i in range(self.length):
ids[i] = self.data[i].lex.id
bits = codec.encode(ids, bits=bits)
for i in range(self.length):
bits.append(self._has_trailing_space(i))
return bits
@staticmethod
def deserialize(Vocab vocab, bits):
biterator = iter(bits)
ids = vocab.codec.decode(biterator)
spaces = []
for bit in biterator:
spaces.append(bit)
if len(spaces) == len(ids):
break
string = u''
cdef const LexemeC* lex
for id_, space in zip(ids, spaces):
lex = vocab.lexemes[id_]
string += vocab.strings[lex.orth]
if space:
string += u' '
cdef Doc doc = Doc(vocab, string)
cdef int idx = 0
for i, id_ in enumerate(ids):
doc.push_back(idx, vocab.lexemes[id_])
idx += vocab.lexemes[id_].length
if spaces[i]:
idx += 1
return doc
# Enhance backwards compatibility by aliasing Doc to Tokens, for now
Tokens = Doc
cdef class Token:
"""An individual token --- i.e. a word, a punctuation symbol, etc. Created
via Doc.__getitem__ and Doc.__iter__.
@ -412,6 +457,10 @@ cdef class Token:
self.c, self.i, self.array_len,
self._seq)
property lex_id:
def __get__(self):
return self.c.lex.id
property string:
def __get__(self):
if (self.i+1) == self._seq.length:

View File

@ -1,5 +1,7 @@
from __future__ import unicode_literals
from spacy.tokens import Doc
import pytest
@ -9,3 +11,26 @@ def test_getitem(EN):
assert tokens[-1].orth_ == '.'
with pytest.raises(IndexError):
tokens[len(tokens)]
def test_trailing_spaces(EN):
tokens = EN(u' Give it back! He pleaded. ')
assert tokens[0].orth_ == ' '
assert not tokens._has_trailing_space(0)
assert tokens._has_trailing_space(1)
assert tokens._has_trailing_space(2)
assert not tokens._has_trailing_space(3)
assert tokens._has_trailing_space(4)
assert tokens._has_trailing_space(5)
assert not tokens._has_trailing_space(6)
assert tokens._has_trailing_space(7)
def test_serialize(EN):
tokens = EN(u' Give it back! He pleaded. ')
packed = tokens.serialize()
new_tokens = Doc.deserialize(EN.vocab, packed)
assert tokens.string == new_tokens.string
assert [t.orth_ for t in tokens] == [t.orth_ for t in new_tokens]
assert [t.orth for t in tokens] == [t.orth for t in new_tokens]
assert [tokens._has_trailing_space(t.i) for t in tokens] == [new_tokens._has_trailing_space(t.i) for t in new_tokens]

View File

@ -26,7 +26,8 @@ class Vocab(object):
return self.codec.encode(numpy.array(seq, dtype=numpy.uint32))
def unpack(self, packed):
return [self.symbols[i] for i in self.codec.decode(packed)]
ids = self.codec.decode(packed)
return [self.symbols[i] for i in ids]
def py_encode(symb2freq):
@ -75,12 +76,9 @@ def test_round_trip():
message = ['the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the',
'the', 'lazy', 'dog', '.']
strings = list(vocab.codec.strings)
for i in range(len(vocab.symbols)):
print vocab.symbols[i], strings[i]
codes = {vocab.symbols[i]: strings[i] for i in range(len(vocab.symbols))}
packed = vocab.pack(message)
string = b''.join(b'{0:b}'.format(ord(c)).rjust(8, b'0')[::-1] for c in packed)
print string
string = b''.join(b'{0:b}'.format(ord(c)).rjust(8, b'0')[::-1] for c in packed.as_bytes())
for word in message:
code = codes[word]
assert string[:len(code)] == code
@ -115,16 +113,10 @@ def test_rosetta():
def test_vocab(EN):
probs = numpy.ndarray(shape=(len(EN.vocab), 2), dtype=numpy.float32)
for word in EN.vocab:
probs[word.id, 0] = numpy.exp(word.prob)
probs[word.id, 1] = word.id
probs.sort()
probs[:,::-1]
codec = HuffmanCodec(probs[:, 0], 0)
codec = EN.vocab.codec
expected_length = 0
for i, code in enumerate(codec.strings):
expected_length += len(code) * probs[i, 0]
expected_length += len(code) * numpy.exp(EN.vocab[i].prob)
assert 8 < expected_length < 15