mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	* Move serialization into Serializer class, with __call__ and train() api
This commit is contained in:
		
							parent
							
								
									e2133d990e
								
							
						
					
					
						commit
						6c99e5f4aa
					
				| 
						 | 
					@ -16,8 +16,13 @@ cdef struct Code:
 | 
				
			||||||
    char length
 | 
					    char length
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef class Serializer:
 | 
				
			||||||
 | 
					    cdef list codecs
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class HuffmanCodec:
 | 
					cdef class HuffmanCodec:
 | 
				
			||||||
    cdef vector[Node] nodes
 | 
					    cdef vector[Node] nodes
 | 
				
			||||||
    cdef vector[Code] codes
 | 
					    cdef vector[Code] codes
 | 
				
			||||||
    cdef uint32_t eol
 | 
					    cdef uint32_t eol
 | 
				
			||||||
 | 
					    cdef int id
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -93,6 +93,84 @@ cdef class BitArray:
 | 
				
			||||||
                self.bit_of_byte = 0
 | 
					                self.bit_of_byte = 0
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					cdef class Serializer:
 | 
				
			||||||
 | 
					    # Manage codecs, maintain consistent format for io
 | 
				
			||||||
 | 
					    def __init__(self, Vocab vocab, model_dir):
 | 
				
			||||||
 | 
					        self.vocab = vocab
 | 
				
			||||||
 | 
					        self.lex = None
 | 
				
			||||||
 | 
					        self.codecs = []
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def __call__(self, doc_or_bits):
 | 
				
			||||||
 | 
					        if isinstance(doc_or_bits, Doc):
 | 
				
			||||||
 | 
					            return self.serialize(doc_or_bits)
 | 
				
			||||||
 | 
					        elif isinstance(doc_or_bits, BitArray):
 | 
				
			||||||
 | 
					            return self.deserialize(doc_or_bits)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            raise ValueError(doc_or_bits)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def train(self, doc):
 | 
				
			||||||
 | 
					        array = doc.to_array(self.attrs)
 | 
				
			||||||
 | 
					        for i, attr in enumerate(self.attrs):
 | 
				
			||||||
 | 
					            for j in range(doc.length):
 | 
				
			||||||
 | 
					                self.freqs[attr].inc(array[i, j], 1)
 | 
				
			||||||
 | 
					            self.freqs[attr].inc(self.eol, 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def serialize(self, doc):
 | 
				
			||||||
 | 
					        bits = BitArray()
 | 
				
			||||||
 | 
					        array = doc.to_array(self.attrs)
 | 
				
			||||||
 | 
					        for i, attr in enumerate(self.attrs, self.codecs):
 | 
				
			||||||
 | 
					            codec.encode(array[i,], bits)
 | 
				
			||||||
 | 
					        return bits
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    @cython.boundscheck(False)
 | 
				
			||||||
 | 
					    def deserialize(self, bits):
 | 
				
			||||||
 | 
					        cdef Doc doc = Doc(self.vocab)
 | 
				
			||||||
 | 
					        biterator = iter(bits)
 | 
				
			||||||
 | 
					        ids = self.codecs[0].decode(bits)
 | 
				
			||||||
 | 
					        cdef int id_
 | 
				
			||||||
 | 
					        cdef bint is_spacy
 | 
				
			||||||
 | 
					        for id_ in ids:
 | 
				
			||||||
 | 
					            is_spacy = biterator.next()
 | 
				
			||||||
 | 
					            doc.push_back(vocab.lexemes.at(id_), is_spacy)
 | 
				
			||||||
 | 
					       
 | 
				
			||||||
 | 
					        cdef int length = doc.length
 | 
				
			||||||
 | 
					        cdef int i
 | 
				
			||||||
 | 
					        cdef attr_t value
 | 
				
			||||||
 | 
					        cdef attr_id_t attr_id
 | 
				
			||||||
 | 
					        cdef attr_t[:] values
 | 
				
			||||||
 | 
					        cdef TokenC* tokens = doc.data
 | 
				
			||||||
 | 
					        for codec in vocab.codecs[1:]:
 | 
				
			||||||
 | 
					            values = codec.decode(biterator)
 | 
				
			||||||
 | 
					            attr_id = codec.id
 | 
				
			||||||
 | 
					            if attr_id == HEAD:
 | 
				
			||||||
 | 
					                for i in range(length):
 | 
				
			||||||
 | 
					                    tokens[i].head = values[i]
 | 
				
			||||||
 | 
					            elif attr_id == TAG:
 | 
				
			||||||
 | 
					                for i in range(length):
 | 
				
			||||||
 | 
					                    tokens[i].tag = values[i]
 | 
				
			||||||
 | 
					            elif attr_id == DEP:
 | 
				
			||||||
 | 
					                for i in range(length):
 | 
				
			||||||
 | 
					                    tokens[i].dep = values[i]
 | 
				
			||||||
 | 
					            elif attr_id == ENT_IOB:
 | 
				
			||||||
 | 
					                for i in range(length):
 | 
				
			||||||
 | 
					                    tokens[i].ent_iob = values[i]
 | 
				
			||||||
 | 
					            elif attr_id == ENT_TYPE:
 | 
				
			||||||
 | 
					                for i in range(length):
 | 
				
			||||||
 | 
					                    tokens[i].ent_type = values[i]
 | 
				
			||||||
 | 
					        return doc
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def lex_codec(self):
 | 
				
			||||||
 | 
					        cdef Address mem
 | 
				
			||||||
 | 
					        cdef int i
 | 
				
			||||||
 | 
					        cdef float[:] cv_probs
 | 
				
			||||||
 | 
					        mem = Address(len(self), sizeof(float))
 | 
				
			||||||
 | 
					        probs = <float*>mem.ptr
 | 
				
			||||||
 | 
					        for i in range(len(self.vocab)):
 | 
				
			||||||
 | 
					            probs[i] = <float>c_exp(self.lexemes[i].prob)
 | 
				
			||||||
 | 
					        cv_probs = <float[:len(self)]>probs
 | 
				
			||||||
 | 
					        return HuffmanCodec(cv_probs, 0, id=ID)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class HuffmanCodec:
 | 
					cdef class HuffmanCodec:
 | 
				
			||||||
    """Create a Huffman code table, and use it to pack and unpack sequences into
 | 
					    """Create a Huffman code table, and use it to pack and unpack sequences into
 | 
				
			||||||
    byte strings. Emphasis is on efficiency, so API is quite strict:
 | 
					    byte strings. Emphasis is on efficiency, so API is quite strict:
 | 
				
			||||||
| 
						 | 
					@ -109,7 +187,8 @@ cdef class HuffmanCodec:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        eol (uint32_t): The index of the weight of the EOL symbol.
 | 
					        eol (uint32_t): The index of the weight of the EOL symbol.
 | 
				
			||||||
    """
 | 
					    """
 | 
				
			||||||
    def __init__(self, float[:] probs, uint32_t eol):
 | 
					    def __init__(self, float[:] probs, uint32_t eol, id=0):
 | 
				
			||||||
 | 
					        self.id = id
 | 
				
			||||||
        self.eol = eol
 | 
					        self.eol = eol
 | 
				
			||||||
        self.codes.resize(len(probs))
 | 
					        self.codes.resize(len(probs))
 | 
				
			||||||
        for i in range(len(self.codes)):
 | 
					        for i in range(len(self.codes)):
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user