mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 01:04:34 +03:00
* Move serialization into Serializer class, with __call__ and train() api
This commit is contained in:
parent
e2133d990e
commit
6c99e5f4aa
|
@ -16,8 +16,13 @@ cdef struct Code:
|
||||||
char length
|
char length
|
||||||
|
|
||||||
|
|
||||||
|
cdef class Serializer:
|
||||||
|
cdef list codecs
|
||||||
|
|
||||||
|
|
||||||
cdef class HuffmanCodec:
|
cdef class HuffmanCodec:
|
||||||
cdef vector[Node] nodes
|
cdef vector[Node] nodes
|
||||||
cdef vector[Code] codes
|
cdef vector[Code] codes
|
||||||
cdef uint32_t eol
|
cdef uint32_t eol
|
||||||
|
cdef int id
|
||||||
|
|
||||||
|
|
|
@ -93,6 +93,84 @@ cdef class BitArray:
|
||||||
self.bit_of_byte = 0
|
self.bit_of_byte = 0
|
||||||
|
|
||||||
|
|
||||||
|
cdef class Serializer:
|
||||||
|
# Manage codecs, maintain consistent format for io
|
||||||
|
def __init__(self, Vocab vocab, model_dir):
|
||||||
|
self.vocab = vocab
|
||||||
|
self.lex = None
|
||||||
|
self.codecs = []
|
||||||
|
|
||||||
|
def __call__(self, doc_or_bits):
|
||||||
|
if isinstance(doc_or_bits, Doc):
|
||||||
|
return self.serialize(doc_or_bits)
|
||||||
|
elif isinstance(doc_or_bits, BitArray):
|
||||||
|
return self.deserialize(doc_or_bits)
|
||||||
|
else:
|
||||||
|
raise ValueError(doc_or_bits)
|
||||||
|
|
||||||
|
def train(self, doc):
|
||||||
|
array = doc.to_array(self.attrs)
|
||||||
|
for i, attr in enumerate(self.attrs):
|
||||||
|
for j in range(doc.length):
|
||||||
|
self.freqs[attr].inc(array[i, j], 1)
|
||||||
|
self.freqs[attr].inc(self.eol, 1)
|
||||||
|
|
||||||
|
def serialize(self, doc):
|
||||||
|
bits = BitArray()
|
||||||
|
array = doc.to_array(self.attrs)
|
||||||
|
for i, attr in enumerate(self.attrs, self.codecs):
|
||||||
|
codec.encode(array[i,], bits)
|
||||||
|
return bits
|
||||||
|
|
||||||
|
@cython.boundscheck(False)
|
||||||
|
def deserialize(self, bits):
|
||||||
|
cdef Doc doc = Doc(self.vocab)
|
||||||
|
biterator = iter(bits)
|
||||||
|
ids = self.codecs[0].decode(bits)
|
||||||
|
cdef int id_
|
||||||
|
cdef bint is_spacy
|
||||||
|
for id_ in ids:
|
||||||
|
is_spacy = biterator.next()
|
||||||
|
doc.push_back(vocab.lexemes.at(id_), is_spacy)
|
||||||
|
|
||||||
|
cdef int length = doc.length
|
||||||
|
cdef int i
|
||||||
|
cdef attr_t value
|
||||||
|
cdef attr_id_t attr_id
|
||||||
|
cdef attr_t[:] values
|
||||||
|
cdef TokenC* tokens = doc.data
|
||||||
|
for codec in vocab.codecs[1:]:
|
||||||
|
values = codec.decode(biterator)
|
||||||
|
attr_id = codec.id
|
||||||
|
if attr_id == HEAD:
|
||||||
|
for i in range(length):
|
||||||
|
tokens[i].head = values[i]
|
||||||
|
elif attr_id == TAG:
|
||||||
|
for i in range(length):
|
||||||
|
tokens[i].tag = values[i]
|
||||||
|
elif attr_id == DEP:
|
||||||
|
for i in range(length):
|
||||||
|
tokens[i].dep = values[i]
|
||||||
|
elif attr_id == ENT_IOB:
|
||||||
|
for i in range(length):
|
||||||
|
tokens[i].ent_iob = values[i]
|
||||||
|
elif attr_id == ENT_TYPE:
|
||||||
|
for i in range(length):
|
||||||
|
tokens[i].ent_type = values[i]
|
||||||
|
return doc
|
||||||
|
|
||||||
|
def lex_codec(self):
|
||||||
|
cdef Address mem
|
||||||
|
cdef int i
|
||||||
|
cdef float[:] cv_probs
|
||||||
|
mem = Address(len(self), sizeof(float))
|
||||||
|
probs = <float*>mem.ptr
|
||||||
|
for i in range(len(self.vocab)):
|
||||||
|
probs[i] = <float>c_exp(self.lexemes[i].prob)
|
||||||
|
cv_probs = <float[:len(self)]>probs
|
||||||
|
return HuffmanCodec(cv_probs, 0, id=ID)
|
||||||
|
|
||||||
|
|
||||||
cdef class HuffmanCodec:
|
cdef class HuffmanCodec:
|
||||||
"""Create a Huffman code table, and use it to pack and unpack sequences into
|
"""Create a Huffman code table, and use it to pack and unpack sequences into
|
||||||
byte strings. Emphasis is on efficiency, so API is quite strict:
|
byte strings. Emphasis is on efficiency, so API is quite strict:
|
||||||
|
@ -109,7 +187,8 @@ cdef class HuffmanCodec:
|
||||||
|
|
||||||
eol (uint32_t): The index of the weight of the EOL symbol.
|
eol (uint32_t): The index of the weight of the EOL symbol.
|
||||||
"""
|
"""
|
||||||
def __init__(self, float[:] probs, uint32_t eol):
|
def __init__(self, float[:] probs, uint32_t eol, id=0):
|
||||||
|
self.id = id
|
||||||
self.eol = eol
|
self.eol = eol
|
||||||
self.codes.resize(len(probs))
|
self.codes.resize(len(probs))
|
||||||
for i in range(len(self.codes)):
|
for i in range(len(self.codes)):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user