mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
* Move serialization into Serializer class, with __call__ and train() api
This commit is contained in:
parent
e2133d990e
commit
6c99e5f4aa
|
@ -16,8 +16,13 @@ cdef struct Code:
|
|||
char length
|
||||
|
||||
|
||||
cdef class Serializer:
|
||||
cdef list codecs
|
||||
|
||||
|
||||
cdef class HuffmanCodec:
|
||||
cdef vector[Node] nodes
|
||||
cdef vector[Code] codes
|
||||
cdef uint32_t eol
|
||||
cdef int id
|
||||
|
||||
|
|
|
@ -93,6 +93,84 @@ cdef class BitArray:
|
|||
self.bit_of_byte = 0
|
||||
|
||||
|
||||
cdef class Serializer:
|
||||
# Manage codecs, maintain consistent format for io
|
||||
def __init__(self, Vocab vocab, model_dir):
|
||||
self.vocab = vocab
|
||||
self.lex = None
|
||||
self.codecs = []
|
||||
|
||||
def __call__(self, doc_or_bits):
|
||||
if isinstance(doc_or_bits, Doc):
|
||||
return self.serialize(doc_or_bits)
|
||||
elif isinstance(doc_or_bits, BitArray):
|
||||
return self.deserialize(doc_or_bits)
|
||||
else:
|
||||
raise ValueError(doc_or_bits)
|
||||
|
||||
def train(self, doc):
|
||||
array = doc.to_array(self.attrs)
|
||||
for i, attr in enumerate(self.attrs):
|
||||
for j in range(doc.length):
|
||||
self.freqs[attr].inc(array[i, j], 1)
|
||||
self.freqs[attr].inc(self.eol, 1)
|
||||
|
||||
def serialize(self, doc):
|
||||
bits = BitArray()
|
||||
array = doc.to_array(self.attrs)
|
||||
for i, attr in enumerate(self.attrs, self.codecs):
|
||||
codec.encode(array[i,], bits)
|
||||
return bits
|
||||
|
||||
@cython.boundscheck(False)
|
||||
def deserialize(self, bits):
|
||||
cdef Doc doc = Doc(self.vocab)
|
||||
biterator = iter(bits)
|
||||
ids = self.codecs[0].decode(bits)
|
||||
cdef int id_
|
||||
cdef bint is_spacy
|
||||
for id_ in ids:
|
||||
is_spacy = biterator.next()
|
||||
doc.push_back(vocab.lexemes.at(id_), is_spacy)
|
||||
|
||||
cdef int length = doc.length
|
||||
cdef int i
|
||||
cdef attr_t value
|
||||
cdef attr_id_t attr_id
|
||||
cdef attr_t[:] values
|
||||
cdef TokenC* tokens = doc.data
|
||||
for codec in vocab.codecs[1:]:
|
||||
values = codec.decode(biterator)
|
||||
attr_id = codec.id
|
||||
if attr_id == HEAD:
|
||||
for i in range(length):
|
||||
tokens[i].head = values[i]
|
||||
elif attr_id == TAG:
|
||||
for i in range(length):
|
||||
tokens[i].tag = values[i]
|
||||
elif attr_id == DEP:
|
||||
for i in range(length):
|
||||
tokens[i].dep = values[i]
|
||||
elif attr_id == ENT_IOB:
|
||||
for i in range(length):
|
||||
tokens[i].ent_iob = values[i]
|
||||
elif attr_id == ENT_TYPE:
|
||||
for i in range(length):
|
||||
tokens[i].ent_type = values[i]
|
||||
return doc
|
||||
|
||||
def lex_codec(self):
|
||||
cdef Address mem
|
||||
cdef int i
|
||||
cdef float[:] cv_probs
|
||||
mem = Address(len(self), sizeof(float))
|
||||
probs = <float*>mem.ptr
|
||||
for i in range(len(self.vocab)):
|
||||
probs[i] = <float>c_exp(self.lexemes[i].prob)
|
||||
cv_probs = <float[:len(self)]>probs
|
||||
return HuffmanCodec(cv_probs, 0, id=ID)
|
||||
|
||||
|
||||
cdef class HuffmanCodec:
|
||||
"""Create a Huffman code table, and use it to pack and unpack sequences into
|
||||
byte strings. Emphasis is on efficiency, so API is quite strict:
|
||||
|
@ -109,7 +187,8 @@ cdef class HuffmanCodec:
|
|||
|
||||
eol (uint32_t): The index of the weight of the EOL symbol.
|
||||
"""
|
||||
def __init__(self, float[:] probs, uint32_t eol):
|
||||
def __init__(self, float[:] probs, uint32_t eol, id=0):
|
||||
self.id = id
|
||||
self.eol = eol
|
||||
self.codes.resize(len(probs))
|
||||
for i in range(len(self.codes)):
|
||||
|
|
Loading…
Reference in New Issue
Block a user