diff --git a/spacy/serialize.pyx b/spacy/serialize.pyx index 15053b4d4..3e78bb59e 100644 --- a/spacy/serialize.pyx +++ b/spacy/serialize.pyx @@ -51,13 +51,13 @@ cdef class BitArray: start_byte = self.i // 8 if (self.i % 8) != 0: for i in range(self.i % 8): - yield (self.data[start_byte] & (one << i)) + yield 1 if (self.data[start_byte] & (one << i)) else 0 start_byte += 1 for byte in self.data[start_byte:]: for i in range(8): - yield byte & (one << i) + yield 1 if byte & (one << i) else 0 for i in range(self.bit_of_byte): - yield self.byte & (one << i) + yield 1 if self.byte & (one << i) else 0 def as_bytes(self): if self.bit_of_byte != 0: @@ -67,6 +67,7 @@ cdef class BitArray: def append(self, bint bit): cdef uint64_t one = 1 + print 'append', bit if bit: self.byte |= one << self.bit_of_byte else: @@ -128,9 +129,9 @@ cdef class HuffmanCodec: bits.extend(self.codes[self.eol].bits, self.codes[self.eol].length) return bits - def decode(self, BitArray bits): + def decode(self, bits): node = self.nodes.back() - symbols = [] + symbols = [] for bit in bits: branch = node.right if bit else node.left if branch >= 0: diff --git a/spacy/tokens.pyx b/spacy/tokens.pyx index c15b92366..9e18d058f 100644 --- a/spacy/tokens.pyx +++ b/spacy/tokens.pyx @@ -16,6 +16,8 @@ from .lexeme cimport check_flag from .spans import Span from .structs cimport UniStr +from .serialize import BitArray + from unidecode import unidecode # Compiler crashes on memory view coercion without this. Should report bug. from cython.view cimport array as cvarray @@ -373,12 +375,55 @@ cdef class Doc: # Return the merged Python object return self[start] + def _has_trailing_space(self, int i): + cdef int end_idx = self.data[i].idx + self.data[i].lex.length + if end_idx >= len(self._string): + return False + else: + return self._string[end_idx] == u' ' + + def serialize(self, bits=None): + if bits is None: + bits = BitArray() + codec = self.vocab.codec + ids = numpy.zeros(shape=(len(self),), dtype=numpy.uint32) + cdef int i + for i in range(self.length): + ids[i] = self.data[i].lex.id + bits = codec.encode(ids, bits=bits) + for i in range(self.length): + bits.append(self._has_trailing_space(i)) + return bits + + @staticmethod + def deserialize(Vocab vocab, bits): + biterator = iter(bits) + ids = vocab.codec.decode(biterator) + spaces = [] + for bit in biterator: + spaces.append(bit) + if len(spaces) == len(ids): + break + string = u'' + cdef const LexemeC* lex + for id_, space in zip(ids, spaces): + lex = vocab.lexemes[id_] + string += vocab.strings[lex.orth] + if space: + string += u' ' + cdef Doc doc = Doc(vocab, string) + cdef int idx = 0 + for i, id_ in enumerate(ids): + doc.push_back(idx, vocab.lexemes[id_]) + idx += vocab.lexemes[id_].length + if spaces[i]: + idx += 1 + return doc # Enhance backwards compatibility by aliasing Doc to Tokens, for now Tokens = Doc - cdef class Token: """An individual token --- i.e. a word, a punctuation symbol, etc. Created via Doc.__getitem__ and Doc.__iter__. @@ -412,6 +457,10 @@ cdef class Token: self.c, self.i, self.array_len, self._seq) + property lex_id: + def __get__(self): + return self.c.lex.id + property string: def __get__(self): if (self.i+1) == self._seq.length: diff --git a/tests/tokens/test_tokens_api.py b/tests/tokens/test_tokens_api.py index 48fba66fc..85a3d93d6 100644 --- a/tests/tokens/test_tokens_api.py +++ b/tests/tokens/test_tokens_api.py @@ -1,5 +1,7 @@ from __future__ import unicode_literals +from spacy.tokens import Doc + import pytest @@ -9,3 +11,26 @@ def test_getitem(EN): assert tokens[-1].orth_ == '.' with pytest.raises(IndexError): tokens[len(tokens)] + + +def test_trailing_spaces(EN): + tokens = EN(u' Give it back! He pleaded. ') + assert tokens[0].orth_ == ' ' + assert not tokens._has_trailing_space(0) + assert tokens._has_trailing_space(1) + assert tokens._has_trailing_space(2) + assert not tokens._has_trailing_space(3) + assert tokens._has_trailing_space(4) + assert tokens._has_trailing_space(5) + assert not tokens._has_trailing_space(6) + assert tokens._has_trailing_space(7) + + +def test_serialize(EN): + tokens = EN(u' Give it back! He pleaded. ') + packed = tokens.serialize() + new_tokens = Doc.deserialize(EN.vocab, packed) + assert tokens.string == new_tokens.string + assert [t.orth_ for t in tokens] == [t.orth_ for t in new_tokens] + assert [t.orth for t in tokens] == [t.orth for t in new_tokens] + assert [tokens._has_trailing_space(t.i) for t in tokens] == [new_tokens._has_trailing_space(t.i) for t in new_tokens] diff --git a/tests/vocab/test_huffman.py b/tests/vocab/test_huffman.py index 4c386af71..124431a66 100644 --- a/tests/vocab/test_huffman.py +++ b/tests/vocab/test_huffman.py @@ -26,7 +26,8 @@ class Vocab(object): return self.codec.encode(numpy.array(seq, dtype=numpy.uint32)) def unpack(self, packed): - return [self.symbols[i] for i in self.codec.decode(packed)] + ids = self.codec.decode(packed) + return [self.symbols[i] for i in ids] def py_encode(symb2freq): @@ -75,12 +76,9 @@ def test_round_trip(): message = ['the', 'quick', 'brown', 'fox', 'jumped', 'over', 'the', 'the', 'lazy', 'dog', '.'] strings = list(vocab.codec.strings) - for i in range(len(vocab.symbols)): - print vocab.symbols[i], strings[i] codes = {vocab.symbols[i]: strings[i] for i in range(len(vocab.symbols))} packed = vocab.pack(message) - string = b''.join(b'{0:b}'.format(ord(c)).rjust(8, b'0')[::-1] for c in packed) - print string + string = b''.join(b'{0:b}'.format(ord(c)).rjust(8, b'0')[::-1] for c in packed.as_bytes()) for word in message: code = codes[word] assert string[:len(code)] == code @@ -115,16 +113,10 @@ def test_rosetta(): def test_vocab(EN): - probs = numpy.ndarray(shape=(len(EN.vocab), 2), dtype=numpy.float32) - for word in EN.vocab: - probs[word.id, 0] = numpy.exp(word.prob) - probs[word.id, 1] = word.id - probs.sort() - probs[:,::-1] - codec = HuffmanCodec(probs[:, 0], 0) + codec = EN.vocab.codec expected_length = 0 for i, code in enumerate(codec.strings): - expected_length += len(code) * probs[i, 0] + expected_length += len(code) * numpy.exp(EN.vocab[i].prob) assert 8 < expected_length < 15