mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-28 06:31:12 +03:00 
			
		
		
		
	* Very scrappy, likely buggy first-cut pickle implementation, to work on Issue #125: allow pickle for Apache Spark. The current implementation sends stuff to temp files, and does almost nothing to ensure all modifiable state is actually preserved. The Language() instance is a deep tree of extension objects, and if pickling during training, some of the C-data state is hard to preserve.
This commit is contained in:
		
							parent
							
								
									f8de403483
								
							
						
					
					
						commit
						20fd36a0f7
					
				|  | @ -29,5 +29,6 @@ cdef class Model: | |||
|     cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1 | ||||
|      | ||||
|     cdef object model_loc | ||||
|     cdef object _templates | ||||
|     cdef Extractor _extractor | ||||
|     cdef LinearModel _model | ||||
|  |  | |||
|  | @ -3,6 +3,7 @@ from __future__ import unicode_literals | |||
| from __future__ import division | ||||
| 
 | ||||
| from os import path | ||||
| import tempfile | ||||
| import os | ||||
| import shutil | ||||
| import json | ||||
|  | @ -52,6 +53,7 @@ cdef class Model: | |||
|     def __init__(self, n_classes, templates, model_loc=None): | ||||
|         if model_loc is not None and path.isdir(model_loc): | ||||
|             model_loc = path.join(model_loc, 'model') | ||||
|         self._templates = templates | ||||
|         self.n_classes = n_classes | ||||
|         self._extractor = Extractor(templates) | ||||
|         self.n_feats = self._extractor.n_templ | ||||
|  | @ -60,6 +62,18 @@ cdef class Model: | |||
|         if self.model_loc and path.exists(self.model_loc): | ||||
|             self._model.load(self.model_loc, freq_thresh=0) | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         model_loc = tempfile.mkstemp() | ||||
|         # TODO: This is a potentially buggy implementation. We're not really | ||||
|         # given a good guarantee that all internal state is saved correctly here, | ||||
|         # since there are learning parameters for e.g. the model averaging in | ||||
|         # averaged perceptron, the gradient calculations in AdaGrad, etc | ||||
|         # that aren't necessarily saved. So, if we're part way through training | ||||
|         # the model, and then we pickle it, we won't recover the state correctly. | ||||
|         self._model.dump(model_loc) | ||||
|         return (Model, (self.n_classes, self.templates, model_loc), | ||||
|                 None, None) | ||||
| 
 | ||||
|     def predict(self, Example eg): | ||||
|         self.set_scores(eg.c.scores, eg.c.atoms) | ||||
|         eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) | ||||
|  |  | |||
|  | @ -207,6 +207,12 @@ class Language(object): | |||
|         self.entity = entity | ||||
|         self.matcher = matcher | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (self.__class__, | ||||
|                   (None, self.vocab, self.tokenizer, self.tagger, self.parser, | ||||
|                    self.entity, self.matcher, None), | ||||
|                 None, None) | ||||
| 
 | ||||
|     def __call__(self, text, tag=True, parse=True, entity=True): | ||||
|         """Apply the pipeline to some text.  The text can span multiple sentences, | ||||
|         and can contain arbtrary whitespace.  Alignment into the original string | ||||
|  |  | |||
|  | @ -168,13 +168,7 @@ cdef class Matcher: | |||
|     cdef Pool mem | ||||
|     cdef vector[Pattern*] patterns | ||||
|     cdef readonly Vocab vocab | ||||
| 
 | ||||
|     def __init__(self, vocab, patterns): | ||||
|         self.vocab = vocab | ||||
|         self.mem = Pool() | ||||
|         self.vocab = vocab | ||||
|         for entity_key, (etype, attrs, specs) in sorted(patterns.items()): | ||||
|             self.add(entity_key, etype, attrs, specs) | ||||
|     cdef object _patterns | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_dir(cls, data_dir, Vocab vocab): | ||||
|  | @ -186,10 +180,22 @@ cdef class Matcher: | |||
|         else: | ||||
|             return cls(vocab, {}) | ||||
| 
 | ||||
|     def __init__(self, vocab, patterns): | ||||
|         self.vocab = vocab | ||||
|         self.mem = Pool() | ||||
|         self.vocab = vocab | ||||
|         self._patterns = dict(patterns) | ||||
|         for entity_key, (etype, attrs, specs) in sorted(patterns.items()): | ||||
|             self.add(entity_key, etype, attrs, specs) | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (self.__class__, (self.vocab, self._patterns), None, None) | ||||
|      | ||||
|     property n_patterns: | ||||
|         def __get__(self): return self.patterns.size() | ||||
| 
 | ||||
|     def add(self, entity_key, etype, attrs, specs): | ||||
|         self._patterns[entity_key] = (etype, dict(attrs), list(specs)) | ||||
|         if isinstance(entity_key, basestring): | ||||
|             entity_key = self.vocab.strings[entity_key] | ||||
|         if isinstance(etype, basestring): | ||||
|  |  | |||
|  | @ -83,7 +83,6 @@ cdef class Parser: | |||
|         model = Model(moves.n_moves, templates, model_dir) | ||||
|         return cls(strings, moves, model) | ||||
| 
 | ||||
| 
 | ||||
|     def __call__(self, Doc tokens): | ||||
|         cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) | ||||
|         self.moves.initialize_state(stcls) | ||||
|  | @ -93,6 +92,9 @@ cdef class Parser: | |||
|         self.parse(stcls, eg.c) | ||||
|         tokens.set_parse(stcls._sent) | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (Parser, (self.moves.strings, self.moves, self.model), None, None) | ||||
| 
 | ||||
|     cdef void predict(self, StateClass stcls, ExampleC* eg) nogil: | ||||
|         memset(eg.scores, 0, eg.nr_class * sizeof(weight_t)) | ||||
|         self.moves.set_valid(eg.is_valid, stcls) | ||||
|  |  | |||
|  | @ -37,6 +37,8 @@ cdef class TransitionSystem: | |||
|     cdef public int root_label | ||||
|     cdef public freqs | ||||
| 
 | ||||
|     cdef object _labels_by_action | ||||
| 
 | ||||
|     cdef int initialize_state(self, StateClass state) except -1 | ||||
|     cdef int finalize_state(self, StateClass state) nogil | ||||
| 
 | ||||
|  |  | |||
|  | @ -15,7 +15,8 @@ class OracleError(Exception): | |||
| 
 | ||||
| 
 | ||||
| cdef class TransitionSystem: | ||||
|     def __init__(self, StringStore string_table, dict labels_by_action): | ||||
|     def __init__(self, StringStore string_table, dict labels_by_action, _freqs=None): | ||||
|         self._labels_by_action = labels_by_action | ||||
|         self.mem = Pool() | ||||
|         self.n_moves = sum(len(labels) for labels in labels_by_action.values()) | ||||
|         self._is_valid = <bint*>self.mem.alloc(self.n_moves, sizeof(bint)) | ||||
|  | @ -30,7 +31,7 @@ cdef class TransitionSystem: | |||
|                 i += 1 | ||||
|         self.c = moves | ||||
|         self.root_label = self.strings['ROOT'] | ||||
|         self.freqs = {} | ||||
|         self.freqs = {} if _freqs is None else _freqs | ||||
|         for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB): | ||||
|             self.freqs[attr] = defaultdict(int) | ||||
|             self.freqs[attr][0] = 1 | ||||
|  | @ -39,6 +40,11 @@ cdef class TransitionSystem: | |||
|             self.freqs[HEAD][i] = 1 | ||||
|             self.freqs[HEAD][-i] = 1 | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (self.__class__, | ||||
|                 (self.strings, self._labels_by_action, self.freqs), | ||||
|                 None, None) | ||||
| 
 | ||||
|     cdef int initialize_state(self, StateClass state) except -1: | ||||
|         pass | ||||
| 
 | ||||
|  |  | |||
|  | @ -148,6 +148,9 @@ cdef class Tagger: | |||
|         tokens.is_tagged = True | ||||
|         tokens._py_tokens = [None] * tokens.length | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (self.__class__, (self.vocab, self.model), None, None) | ||||
| 
 | ||||
|     def tag_from_strings(self, Doc tokens, object tag_strs): | ||||
|         cdef int i | ||||
|         for i in range(tokens.length): | ||||
|  |  | |||
|  | @ -99,16 +99,18 @@ cdef class Vocab: | |||
|         return self.length | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         # TODO: Dump vectors | ||||
|         tmp_dir = tempfile.mkdtemp() | ||||
|         lex_loc = path.join(tmp_dir, 'lexemes.bin') | ||||
|         str_loc = path.join(tmp_dir, 'strings.txt') | ||||
|         map_loc = path.join(tmp_dir, 'tag_map.json') | ||||
|         vec_loc = path.join(self.data_dir, 'vec.bin') if self.data_dir is not None else None | ||||
| 
 | ||||
|         self.dump(lex_loc) | ||||
|         self.strings.dump(str_loc) | ||||
|         json.dump(self.morphology.tag_map, open(map_loc, 'w')) | ||||
| 
 | ||||
|         return (unpickle_vocab, (tmp_dir,), None, None) | ||||
|          | ||||
|         state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr, | ||||
|                  self.serializer_freqs, self.data_dir) | ||||
|         return (unpickle_vocab, state, None, None) | ||||
| 
 | ||||
|     cdef const LexemeC* get(self, Pool mem, unicode string) except NULL: | ||||
|         '''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme | ||||
|  | @ -353,11 +355,21 @@ cdef class Vocab: | |||
|         return vec_len | ||||
| 
 | ||||
| 
 | ||||
| def unpickle_vocab(data_dir): | ||||
|     # TODO: This needs fixing --- the trouble is, we can't pickle staticmethods, | ||||
|     # so we need to fiddle with the design of Language a little bit. | ||||
|     from .language import Language | ||||
|     return Vocab.from_dir(data_dir, Language.default_lex_attrs()) | ||||
| def unpickle_vocab(strings_loc, lex_loc, vec_loc, morphology, get_lex_attr, | ||||
|                    serializer_freqs, data_dir): | ||||
|     cdef Vocab vocab = Vocab() | ||||
| 
 | ||||
|     vocab.get_lex_attr = get_lex_attr | ||||
|     vocab.morphology = morphology | ||||
|     vocab.strings = morphology.strings | ||||
|     vocab.data_dir = data_dir | ||||
|     vocab.serializer_freqs = serializer_freqs | ||||
| 
 | ||||
|     vocab.load_lexemes(strings_loc, lex_loc) | ||||
|     if vec_loc is not None: | ||||
|         vocab.load_vectors_from_bin_loc(vec_loc) | ||||
|     return vocab | ||||
|   | ||||
| 
 | ||||
| copy_reg.constructor(unpickle_vocab) | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										16
									
								
								tests/parser/test_pickle.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								tests/parser/test_pickle.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,16 @@ | |||
| import pytest | ||||
| 
 | ||||
| import pickle | ||||
| import cloudpickle | ||||
| import StringIO | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.models | ||||
| def test_pickle(EN): | ||||
|     file_ = StringIO.StringIO() | ||||
|     cloudpickle.dump(EN.parser, file_) | ||||
| 
 | ||||
|     file_.seek(0) | ||||
| 
 | ||||
|     loaded = pickle.load(file_) | ||||
| 
 | ||||
							
								
								
									
										15
									
								
								tests/test_pickle.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								tests/test_pickle.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,15 @@ | |||
| import pytest | ||||
| import StringIO | ||||
| import cloudpickle | ||||
| import pickle | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.models | ||||
| def test_pickle_english(EN): | ||||
|     file_ = StringIO.StringIO() | ||||
|     cloudpickle.dump(EN, file_) | ||||
| 
 | ||||
|     file_.seek(0) | ||||
| 
 | ||||
|     loaded = pickle.load(file_) | ||||
| 
 | ||||
|  | @ -1,13 +1,13 @@ | |||
| from __future__ import unicode_literals | ||||
| import pytest | ||||
| import StringIO | ||||
| import cloudpickle | ||||
| import pickle | ||||
| 
 | ||||
| from spacy.attrs import LEMMA, ORTH, PROB, IS_ALPHA | ||||
| from spacy.parts_of_speech import NOUN, VERB | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
| def test_neq(en_vocab): | ||||
|     addr = en_vocab['Hello'] | ||||
|     assert en_vocab['bye'].orth != addr.orth | ||||
|  | @ -44,7 +44,7 @@ def test_symbols(en_vocab): | |||
| 
 | ||||
| def test_pickle_vocab(en_vocab): | ||||
|     file_ = StringIO.StringIO() | ||||
|     pickle.dump(en_vocab, file_) | ||||
|     cloudpickle.dump(en_vocab, file_) | ||||
| 
 | ||||
|     file_.seek(0) | ||||
| 
 | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user