mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
* Work on serializer design
This commit is contained in:
parent
a6f401580d
commit
fb54052ae0
|
@ -67,7 +67,6 @@ cdef class BitArray:
|
|||
|
||||
def append(self, bint bit):
|
||||
cdef uint64_t one = 1
|
||||
print 'append', bit
|
||||
if bit:
|
||||
self.byte |= one << self.bit_of_byte
|
||||
else:
|
||||
|
@ -95,10 +94,11 @@ cdef class BitArray:
|
|||
|
||||
cdef class Serializer:
|
||||
# Manage codecs, maintain consistent format for io
|
||||
def __init__(self, Vocab vocab, model_dir):
|
||||
self.vocab = vocab
|
||||
self.lex = None
|
||||
self.codecs = []
|
||||
def __init__(self, Vocab vocab, data_dir):
|
||||
model_dir = path.join(data_dir, 'bitter')
|
||||
self.vocab = vocab # Vocab owns the word codec, the big one
|
||||
self.cfg = Config.read(model_dir, 'config')
|
||||
self.codecs = tuple([CodecWrapper(attr) for attr in self.cfg.attrs])
|
||||
|
||||
def __call__(self, doc_or_bits):
|
||||
if isinstance(doc_or_bits, Doc):
|
||||
|
@ -109,24 +109,22 @@ cdef class Serializer:
|
|||
raise ValueError(doc_or_bits)
|
||||
|
||||
def train(self, doc):
|
||||
array = doc.to_array(self.attrs)
|
||||
for i, attr in enumerate(self.attrs):
|
||||
for j in range(doc.length):
|
||||
self.freqs[attr].inc(array[i, j], 1)
|
||||
self.freqs[attr].inc(self.eol, 1)
|
||||
array = doc.to_array([codec.id for codec in self.codecs])
|
||||
for i, codec in enumerate(self.codecs):
|
||||
codec.count(array[i])
|
||||
|
||||
def serialize(self, doc):
|
||||
bits = BitArray()
|
||||
array = doc.to_array(self.attrs)
|
||||
for i, attr in enumerate(self.attrs, self.codecs):
|
||||
for i, codec in enumerate(self.codecs):
|
||||
codec.encode(array[i,], bits)
|
||||
return bits
|
||||
|
||||
@cython.boundscheck(False)
|
||||
def deserialize(self, bits):
|
||||
cdef Doc doc = Doc(self.vocab)
|
||||
biterator = iter(bits)
|
||||
ids = self.codecs[0].decode(bits)
|
||||
cdef Doc doc = Doc(self.vocab)
|
||||
ids = self.vocab.codec.decode(biterator)
|
||||
cdef int id_
|
||||
cdef bint is_spacy
|
||||
for id_ in ids:
|
||||
|
@ -134,41 +132,44 @@ cdef class Serializer:
|
|||
doc.push_back(vocab.lexemes.at(id_), is_spacy)
|
||||
|
||||
cdef int length = doc.length
|
||||
cdef int i
|
||||
cdef attr_t value
|
||||
cdef attr_id_t attr_id
|
||||
cdef attr_t[:] values
|
||||
cdef TokenC* tokens = doc.data
|
||||
for codec in vocab.codecs[1:]:
|
||||
values = codec.decode(biterator)
|
||||
attr_id = codec.id
|
||||
if attr_id == HEAD:
|
||||
for i in range(length):
|
||||
tokens[i].head = values[i]
|
||||
elif attr_id == TAG:
|
||||
for i in range(length):
|
||||
tokens[i].tag = values[i]
|
||||
elif attr_id == DEP:
|
||||
for i in range(length):
|
||||
tokens[i].dep = values[i]
|
||||
elif attr_id == ENT_IOB:
|
||||
for i in range(length):
|
||||
tokens[i].ent_iob = values[i]
|
||||
elif attr_id == ENT_TYPE:
|
||||
for i in range(length):
|
||||
tokens[i].ent_type = values[i]
|
||||
array = numpy.zeros(shape=(length, len(self.codecs)), dtype=numpy.int)
|
||||
for i, codec in enumerate(self.codecs):
|
||||
array[i] = codec.decode(biterator)
|
||||
doc.from_array([c.id for c in self.codecs], array)
|
||||
return doc
|
||||
|
||||
def lex_codec(self):
|
||||
cdef Address mem
|
||||
cdef int i
|
||||
cdef float[:] cv_probs
|
||||
mem = Address(len(self), sizeof(float))
|
||||
probs = <float*>mem.ptr
|
||||
for i in range(len(self.vocab)):
|
||||
probs[i] = <float>c_exp(self.lexemes[i].prob)
|
||||
cv_probs = <float[:len(self)]>probs
|
||||
return HuffmanCodec(cv_probs, 0, id=ID)
|
||||
|
||||
cdef class AttributeEncoder:
|
||||
"""Wrapper around HuffmanCodec"""
|
||||
def __init__(self, freqs, id=0):
|
||||
cdef uint64_t key
|
||||
cdef uint64_t count
|
||||
cdef pair[uint64_t] item
|
||||
cdef priority_queue[pair[uint64_t]] items
|
||||
for key, count in freqs:
|
||||
item.first = count
|
||||
item.second = key
|
||||
items.push(item)
|
||||
|
||||
weights = array('f')
|
||||
keys = array('i')
|
||||
key_to_i = PreshMap()
|
||||
i = 0
|
||||
while not items.empty():
|
||||
item = items.top()
|
||||
weights.append(item.first)
|
||||
keys.append(item.second)
|
||||
key_to_i[item.second] = i
|
||||
i += 1
|
||||
items.pop()
|
||||
|
||||
def encode(self, symbols):
|
||||
indices = [self.table[symbol] for symbol in symbols]
|
||||
return self._codec.encode(indices)
|
||||
|
||||
def decode(self, bits):
|
||||
indices = self._codec.decode(bits)
|
||||
return [self.symbols[i] for i in indices]
|
||||
|
||||
|
||||
cdef class HuffmanCodec:
|
||||
|
@ -182,19 +183,17 @@ cdef class HuffmanCodec:
|
|||
the EOL symbol in your message.
|
||||
|
||||
Arguments:
|
||||
probs (float[:]): A descending-sorted sequence of probabilities/weights.
|
||||
weights (float[:]): A descending-sorted sequence of probabilities/weights.
|
||||
Must include a weight for an EOL symbol.
|
||||
|
||||
eol (uint32_t): The index of the weight of the EOL symbol.
|
||||
"""
|
||||
def __init__(self, float[:] probs, uint32_t eol, id=0):
|
||||
self.id = id
|
||||
self.eol = eol
|
||||
def __init__(self, float[:] weights, unt32_t eol):
|
||||
self.codes.resize(len(probs))
|
||||
for i in range(len(self.codes)):
|
||||
self.codes[i].bits = 0
|
||||
self.codes[i].length = 0
|
||||
populate_nodes(self.nodes, probs)
|
||||
populate_nodes(self.nodes, weights)
|
||||
cdef Code path
|
||||
path.bits = 0
|
||||
path.length = 0
|
||||
|
@ -270,6 +269,7 @@ cdef int populate_nodes(vector[Node]& nodes, float[:] probs) except -1:
|
|||
return 0
|
||||
|
||||
cdef int _cover_two_nodes(vector[Node]& nodes, int j) nogil:
|
||||
"""Introduce a new non-terminal, over two non-terminals)"""
|
||||
cdef Node node
|
||||
node.left = j
|
||||
node.right = j+1
|
||||
|
@ -278,6 +278,7 @@ cdef int _cover_two_nodes(vector[Node]& nodes, int j) nogil:
|
|||
|
||||
|
||||
cdef int _cover_one_word_one_node(vector[Node]& nodes, int j, int id_, float prob) nogil:
|
||||
"""Introduce a new non-terminal, over one terminal and one non-terminal."""
|
||||
cdef Node node
|
||||
# Encode leaves as negative integers, where the integer is the index of the
|
||||
# word in the vocabulary.
|
||||
|
@ -295,6 +296,7 @@ cdef int _cover_one_word_one_node(vector[Node]& nodes, int j, int id_, float pro
|
|||
|
||||
|
||||
cdef int _cover_two_words(vector[Node]& nodes, int id1, int id2, float prob) nogil:
|
||||
"""Introduce a new node, over two non-terminals."""
|
||||
cdef Node node
|
||||
node.left = -(id1+1)
|
||||
node.right = -(id2+1)
|
||||
|
@ -303,6 +305,11 @@ cdef int _cover_two_words(vector[Node]& nodes, int id1, int id2, float prob) nog
|
|||
|
||||
|
||||
cdef int assign_codes(vector[Node]& nodes, vector[Code]& codes, int i, Code path) except -1:
|
||||
"""Recursively assign paths, from the top down. At the end, the entry codes[i]
|
||||
knows the bit-address of the node[j] that points to entry i in the vocabulary.
|
||||
So, to encode i, we go to codes[i] and read its bit-string. To decode, we
|
||||
navigate nodes recursively.
|
||||
"""
|
||||
cdef Code left_path = bit_append(path, 0)
|
||||
cdef Code right_path = bit_append(path, 1)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user