mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-01 00:17:44 +03:00 
			
		
		
		
	Update to/from bytes methods
This commit is contained in:
		
							parent
							
								
									c91b121aeb
								
							
						
					
					
						commit
						6b019b0540
					
				|  | @ -9,7 +9,6 @@ import numpy | |||
| cimport numpy as np | ||||
| import cytoolz | ||||
| import util | ||||
| import ujson | ||||
| 
 | ||||
| from thinc.api import add, layerize, chain, clone, concatenate, with_flatten | ||||
| from thinc.neural import Model, Maxout, Softmax, Affine | ||||
|  | @ -160,18 +159,18 @@ class TokenVectorEncoder(object): | |||
|             yield | ||||
| 
 | ||||
|     def to_bytes(self, **exclude): | ||||
|         data = { | ||||
|             'model': self.model, | ||||
|             'vocab': self.vocab | ||||
|         serialize = { | ||||
|             'model': lambda: model_to_bytes(self.model), | ||||
|             'vocab': lambda: self.vocab.to_bytes() | ||||
|         } | ||||
|         return util.to_bytes(data, exclude) | ||||
|         return util.to_bytes(serialize, exclude) | ||||
| 
 | ||||
|     def from_bytes(self, bytes_data, **exclude): | ||||
|         data = ujson.loads(bytes_data) | ||||
|         if 'model' not in exclude: | ||||
|             util.model_from_bytes(self.model, data['model']) | ||||
|         if 'vocab' not in exclude: | ||||
|             self.vocab.from_bytes(data['vocab']) | ||||
|         deserialize = { | ||||
|             'model': lambda b: model_from_bytes(self.model, b), | ||||
|             'vocab': lambda b: self.vocab.from_bytes(b) | ||||
|         } | ||||
|         util.from_bytes(deserialize, exclude) | ||||
|         return self | ||||
| 
 | ||||
|     def to_disk(self, path, **exclude): | ||||
|  | @ -290,6 +289,23 @@ class NeuralTagger(object): | |||
|         with self.model.use_params(params): | ||||
|             yield | ||||
| 
 | ||||
|     def to_bytes(self, **exclude): | ||||
|         serialize = { | ||||
|             'model': lambda: model_to_bytes(self.model), | ||||
|             'vocab': lambda: self.vocab.to_bytes() | ||||
|         } | ||||
|         return util.to_bytes(serialize, exclude) | ||||
| 
 | ||||
|     def from_bytes(self, bytes_data, **exclude): | ||||
|         deserialize = { | ||||
|             'model': lambda b: model_from_bytes(self.model, b), | ||||
|             'vocab': lambda b: self.vocab.from_bytes(b) | ||||
|         } | ||||
|         util.from_bytes(deserialize, exclude) | ||||
|         return self | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| class NeuralLabeller(NeuralTagger): | ||||
|     name = 'nn_labeller' | ||||
|     def __init__(self, vocab, model=True): | ||||
|  |  | |||
|  | @ -260,7 +260,14 @@ cdef class Parser: | |||
|         # Used to set input dimensions in network. | ||||
|         lower.begin_training(lower.ops.allocate((500, token_vector_width))) | ||||
|         upper.begin_training(upper.ops.allocate((500, hidden_width))) | ||||
|         return lower, upper | ||||
|         cfg = { | ||||
|             'nr_class': nr_class, | ||||
|             'depth': depth, | ||||
|             'token_vector_width': token_vector_width, | ||||
|             'hidden_width': hidden_width, | ||||
|             'maxout_pieces': parser_maxout_pieces | ||||
|         } | ||||
|         return (lower, upper), cfg | ||||
| 
 | ||||
|     def __init__(self, Vocab vocab, moves=True, model=True, **cfg): | ||||
|         """ | ||||
|  | @ -611,7 +618,8 @@ cdef class Parser: | |||
|             for label in labels: | ||||
|                 self.moves.add_action(action, label) | ||||
|         if self.model is True: | ||||
|             self.model = self.Model(self.moves.n_moves, **cfg) | ||||
|             self.model, cfg = self.Model(self.moves.n_moves, **cfg) | ||||
|             self.cfg.update(cfg) | ||||
| 
 | ||||
|     def preprocess_gold(self, docs_golds): | ||||
|         for doc, gold in docs_golds: | ||||
|  | @ -633,11 +641,28 @@ cdef class Parser: | |||
|         with (path / 'model.bin').open('wb') as file_: | ||||
|             self.model = dill.load(file_) | ||||
| 
 | ||||
|     def to_bytes(self): | ||||
|         dill.dumps(self.model) | ||||
|     def to_bytes(self, **exclude): | ||||
|         serialize = { | ||||
|             'model': lambda: util.model_to_bytes(self.model), | ||||
|             'vocab': lambda: self.vocab.to_bytes(), | ||||
|             'moves': lambda: self.moves.to_bytes(), | ||||
|             'cfg': lambda: ujson.dumps(self.cfg) | ||||
|         } | ||||
|         return util.to_bytes(serialize, exclude) | ||||
| 
 | ||||
|     def from_bytes(self, data): | ||||
|         self.model = dill.loads(data) | ||||
|     def from_bytes(self, bytes_data, **exclude): | ||||
|         deserialize = { | ||||
|             'vocab': lambda b: self.vocab.from_bytes(b), | ||||
|             'moves': lambda b: self.moves.from_bytes(b), | ||||
|             'cfg': lambda b: self.cfg.update(ujson.loads(b)), | ||||
|             'model': lambda b: None | ||||
|         } | ||||
|         msg = util.from_bytes(deserialize, exclude) | ||||
|         if 'model' not in exclude: | ||||
|             if self.model is True: | ||||
|                 self.model = self.Model(**msg['cfg']) | ||||
|             util.model_from_disk(self.model, msg['model']) | ||||
|         return self | ||||
| 
 | ||||
| 
 | ||||
| class ParserStateError(ValueError): | ||||
|  |  | |||
|  | @ -291,12 +291,11 @@ cdef class Vocab: | |||
|         **exclude: Named attributes to prevent from being serialized. | ||||
|         RETURNS (bytes): The serialized form of the `Vocab` object. | ||||
|         """ | ||||
|         data = {} | ||||
|         if 'strings' not in exclude: | ||||
|             data['strings'] = self.strings.to_bytes() | ||||
|         if 'lexemes' not in exclude: | ||||
|             data['lexemes'] = self.lexemes_to_bytes | ||||
|         return ujson.dumps(data) | ||||
|         getters = { | ||||
|             'strings': lambda: self.strings.to_bytes(), | ||||
|             'lexemes': lambda: self.lexemes_to_bytes() | ||||
|         } | ||||
|         return util.to_bytes(getters, exclude) | ||||
| 
 | ||||
|     def from_bytes(self, bytes_data, **exclude): | ||||
|         """Load state from a binary string. | ||||
|  | @ -305,12 +304,11 @@ cdef class Vocab: | |||
|         **exclude: Named attributes to prevent from being loaded. | ||||
|         RETURNS (Vocab): The `Vocab` object. | ||||
|         """ | ||||
|         data = ujson.loads(bytes_data) | ||||
|         if 'strings' not in exclude: | ||||
|             self.strings.from_bytes(data['strings']) | ||||
|         if 'lexemes' not in exclude: | ||||
|             self.lexemes_from_bytes(data['lexemes']) | ||||
|         return self | ||||
|         setters = { | ||||
|             'strings': lambda b: self.strings.from_bytes(b), | ||||
|             'lexemes': lambda b: self.lexemes_from_bytes(b) | ||||
|         } | ||||
|         return util.from_bytes(bytes_data, setters, exclude) | ||||
| 
 | ||||
|     def lexemes_to_bytes(self): | ||||
|         cdef hash_t key | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user