mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 07:57:35 +03:00 
			
		
		
		
	Work on serialization for models
This commit is contained in:
		
							parent
							
								
									b007b0e5a0
								
							
						
					
					
						commit
						6dad4117ad
					
				
							
								
								
									
										39
									
								
								spacy/_ml.py
									
									
									
									
									
								
							
							
						
						
									
										39
									
								
								spacy/_ml.py
									
									
									
									
									
								
							|  | @ -1,3 +1,4 @@ | |||
| import ujson | ||||
| from thinc.api import add, layerize, chain, clone, concatenate, with_flatten | ||||
| from thinc.neural import Model, Maxout, Softmax, Affine | ||||
| from thinc.neural._classes.hash_embed import HashEmbed | ||||
|  | @ -15,9 +16,47 @@ from thinc.neural._classes.affine import _set_dimensions_if_needed | |||
| from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP | ||||
| from .tokens.doc import Doc | ||||
| 
 | ||||
| import dill | ||||
| import numpy | ||||
| import io | ||||
| 
 | ||||
| 
 | ||||
| def model_to_bytes(model): | ||||
|     weights = [] | ||||
|     metas = [] | ||||
|     queue = [model] | ||||
|     i = 0 | ||||
|     for layer in queue: | ||||
|         if hasattr(layer, '_mem'): | ||||
|             weights.append(layer._mem.weights) | ||||
|             metas.append(layer._mem._offsets) | ||||
|             i += 1 | ||||
|         if hasattr(layer, '_layers'): | ||||
|             queue.extend(layer._layers) | ||||
|     data = {'metas': metas, 'weights': weights} | ||||
|     # TODO: Replace the pickle here with something else | ||||
|     return dill.dumps(data) | ||||
| 
 | ||||
| 
 | ||||
| def model_from_bytes(model, bytes_data): | ||||
|     # TODO: Replace the pickle here with something else | ||||
|     data = dill.loads(bytes_data) | ||||
|     metas = data['metas'] | ||||
|     weights = data['weights'] | ||||
|     queue = [model] | ||||
|     i = 0 | ||||
|     for layer in queue: | ||||
|         if hasattr(layer, '_mem'): | ||||
|             params = weights[i] | ||||
|             flat_mem = layer._mem._mem.ravel() | ||||
|             flat_params = params.ravel() | ||||
|             flat_mem[:flat_params.size] = flat_params | ||||
|             layer._mem._offsets.update(metas[i]) | ||||
|             i += 1 | ||||
|         if hasattr(layer, '_layers'): | ||||
|             queue.extend(layer._layers) | ||||
|   | ||||
| 
 | ||||
| def _init_for_precomputed(W, ops): | ||||
|     if (W**2).sum() != 0.: | ||||
|         return | ||||
|  |  | |||
|  | @ -9,6 +9,7 @@ import numpy | |||
| cimport numpy as np | ||||
| import cytoolz | ||||
| import util | ||||
| import ujson | ||||
| 
 | ||||
| from thinc.api import add, layerize, chain, clone, concatenate, with_flatten | ||||
| from thinc.neural import Model, Maxout, Softmax, Affine | ||||
|  | @ -35,6 +36,7 @@ from .syntax import nonproj | |||
| 
 | ||||
| from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS | ||||
| from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats | ||||
| from ._ml import model_to_bytes, model_from_bytes | ||||
| from .parts_of_speech import X | ||||
| 
 | ||||
| 
 | ||||
|  | @ -148,7 +150,6 @@ class TokenVectorEncoder(object): | |||
|         if self.model is True: | ||||
|             self.model = self.Model() | ||||
| 
 | ||||
| 
 | ||||
|     def use_params(self, params): | ||||
|         """Replace weights of models in the pipeline with those provided in the | ||||
|         params dictionary. | ||||
|  | @ -158,6 +159,39 @@ class TokenVectorEncoder(object): | |||
|         with self.model.use_params(params): | ||||
|             yield | ||||
| 
 | ||||
|     def to_bytes(self, **exclude): | ||||
|         data = { | ||||
|             'model': self.model, | ||||
|             'vocab': self.vocab | ||||
|         } | ||||
|         return util.to_bytes(data, exclude) | ||||
| 
 | ||||
|     def from_bytes(self, bytes_data, **exclude): | ||||
|         data = ujson.loads(bytes_data) | ||||
|         if 'model' not in exclude: | ||||
|             util.model_from_bytes(self.model, data['model']) | ||||
|         if 'vocab' not in exclude: | ||||
|             self.vocab.from_bytes(data['vocab']) | ||||
|         return self | ||||
| 
 | ||||
|     def to_disk(self, path, **exclude): | ||||
|         path = util.ensure_path(path) | ||||
|         if not path.exists(): | ||||
|             path.mkdir() | ||||
|         if 'vocab' not in exclude: | ||||
|             self.vocab.to_disk(path / 'vocab') | ||||
|         if 'model' not in exclude: | ||||
|             with (path / 'model.bin').open('wb') as file_: | ||||
|                 file_.write(util.model_to_bytes(self.model)) | ||||
| 
 | ||||
|     def from_disk(self, path, **exclude): | ||||
|         path = util.ensure_path(path) | ||||
|         if 'vocab' not in exclude: | ||||
|             self.vocab.from_disk(path / 'vocab') | ||||
|         if 'model.bin' not in exclude: | ||||
|             with (path / 'model.bin').open('rb') as file_: | ||||
|                 util.model_from_bytes(self.model, file_.read()) | ||||
| 
 | ||||
| 
 | ||||
| class NeuralTagger(object): | ||||
|     name = 'nn_tagger' | ||||
|  |  | |||
|  | @ -2,12 +2,38 @@ | |||
| from __future__ import unicode_literals | ||||
| 
 | ||||
| from ..util import ensure_path | ||||
| from .._ml import model_to_bytes, model_from_bytes | ||||
| 
 | ||||
| from pathlib import Path | ||||
| import pytest | ||||
| from thinc.neural import Maxout, Softmax | ||||
| from thinc.api import chain | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize('text', ['hello/world', 'hello world']) | ||||
| def test_util_ensure_path_succeeds(text): | ||||
|     path = ensure_path(text) | ||||
|     assert isinstance(path, Path) | ||||
| 
 | ||||
| 
 | ||||
| def test_simple_model_roundtrip_bytes(): | ||||
|     model = Maxout(5, 10, pieces=2) | ||||
|     model.b += 1 | ||||
|     data = model_to_bytes(model) | ||||
|     model.b -= 1 | ||||
|     model_from_bytes(model, data) | ||||
|     assert model.b[0, 0] == 1 | ||||
| 
 | ||||
| 
 | ||||
| def test_multi_model_roundtrip_bytes(): | ||||
|     model = chain(Maxout(5, 10, pieces=2), Maxout(2, 3)) | ||||
|     model._layers[0].b += 1 | ||||
|     model._layers[1].b += 2 | ||||
|     data = model_to_bytes(model) | ||||
|     model._layers[0].b -= 1 | ||||
|     model._layers[1].b -= 2 | ||||
|     model_from_bytes(model, data) | ||||
|     assert model._layers[0].b[0, 0] == 1 | ||||
|     assert model._layers[1].b[0, 0] == 2 | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -408,6 +408,18 @@ def get_raw_input(description, default=False): | |||
|     return user_input | ||||
| 
 | ||||
| 
 | ||||
| def to_bytes(unserialized, exclude): | ||||
|     serialized = {} | ||||
|     for key, value in unserialized.items(): | ||||
|         if key in exclude: | ||||
|             continue | ||||
|         elif hasattr(value, 'to_bytes'): | ||||
|             serialized[key] = value.to_bytes() | ||||
|         else: | ||||
|             serialized[key] = ujson.dumps(value) | ||||
|     return ujson.dumps(serialized) | ||||
| 
 | ||||
| 
 | ||||
| def print_table(data, title=None): | ||||
|     """Print data in table format. | ||||
| 
 | ||||
|  |  | |||
|  | @ -56,15 +56,7 @@ cdef class Vocab: | |||
|         if strings: | ||||
|             for string in strings: | ||||
|                 self.strings.add(string) | ||||
|         # Load strings in a special order, so that we have an onset number for | ||||
|         # the vocabulary. This way, when words are added in order, the orth ID | ||||
|         # is the frequency rank of the word, plus a certain offset. The structural | ||||
|         # strings are loaded first, because the vocab is open-class, and these | ||||
|         # symbols are closed class. | ||||
|         # TODO: Actually this has turned out to be a pain in the ass... | ||||
|         # It means the data is invalidated when we add a symbol :( | ||||
|         # Need to rethink this. | ||||
|         for name in symbols.NAMES + list(sorted(tag_map.keys())): | ||||
|         for name in tag_map.keys(): | ||||
|             if name: | ||||
|                 self.strings.add(name) | ||||
|         self.lex_attr_getters = lex_attr_getters | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user