mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 13:41:21 +03:00 
			
		
		
		
	* Merge
This commit is contained in:
		
						commit
						f6d74b14de
					
				|  | @ -29,5 +29,6 @@ cdef class Model: | |||
|     cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1 | ||||
|      | ||||
|     cdef object model_loc | ||||
|     cdef object _templates | ||||
|     cdef Extractor _extractor | ||||
|     cdef LinearModel _model | ||||
|  |  | |||
|  | @ -3,6 +3,7 @@ from __future__ import unicode_literals | |||
| from __future__ import division | ||||
| 
 | ||||
| from os import path | ||||
| import tempfile | ||||
| import os | ||||
| import shutil | ||||
| import json | ||||
|  | @ -52,6 +53,7 @@ cdef class Model: | |||
|     def __init__(self, n_classes, templates, model_loc=None): | ||||
|         if model_loc is not None and path.isdir(model_loc): | ||||
|             model_loc = path.join(model_loc, 'model') | ||||
|         self._templates = templates | ||||
|         self.n_classes = n_classes | ||||
|         self._extractor = Extractor(templates) | ||||
|         self.n_feats = self._extractor.n_templ | ||||
|  | @ -60,6 +62,18 @@ cdef class Model: | |||
|         if self.model_loc and path.exists(self.model_loc): | ||||
|             self._model.load(self.model_loc, freq_thresh=0) | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         model_loc = tempfile.mkstemp() | ||||
|         # TODO: This is a potentially buggy implementation. We're not really | ||||
|         # given a good guarantee that all internal state is saved correctly here, | ||||
|         # since there are learning parameters for e.g. the model averaging in | ||||
|         # averaged perceptron, the gradient calculations in AdaGrad, etc | ||||
|         # that aren't necessarily saved. So, if we're part way through training | ||||
|         # the model, and then we pickle it, we won't recover the state correctly. | ||||
|         self._model.dump(model_loc) | ||||
|         return (Model, (self.n_classes, self.templates, model_loc), | ||||
|                 None, None) | ||||
| 
 | ||||
|     def predict(self, Example eg): | ||||
|         self.set_scores(eg.c.scores, eg.c.atoms) | ||||
|         eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) | ||||
|  |  | |||
|  | @ -207,6 +207,12 @@ class Language(object): | |||
|         self.entity = entity | ||||
|         self.matcher = matcher | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (self.__class__, | ||||
|                   (None, self.vocab, self.tokenizer, self.tagger, self.parser, | ||||
|                    self.entity, self.matcher, None), | ||||
|                 None, None) | ||||
| 
 | ||||
|     def __call__(self, text, tag=True, parse=True, entity=True): | ||||
|         """Apply the pipeline to some text.  The text can span multiple sentences, | ||||
|         and can contain arbtrary whitespace.  Alignment into the original string | ||||
|  |  | |||
|  | @ -168,13 +168,7 @@ cdef class Matcher: | |||
|     cdef Pool mem | ||||
|     cdef vector[Pattern*] patterns | ||||
|     cdef readonly Vocab vocab | ||||
| 
 | ||||
|     def __init__(self, vocab, patterns): | ||||
|         self.vocab = vocab | ||||
|         self.mem = Pool() | ||||
|         self.vocab = vocab | ||||
|         for entity_key, (etype, attrs, specs) in sorted(patterns.items()): | ||||
|             self.add(entity_key, etype, attrs, specs) | ||||
|     cdef object _patterns | ||||
| 
 | ||||
|     @classmethod | ||||
|     def from_dir(cls, data_dir, Vocab vocab): | ||||
|  | @ -186,10 +180,22 @@ cdef class Matcher: | |||
|         else: | ||||
|             return cls(vocab, {}) | ||||
| 
 | ||||
|     def __init__(self, vocab, patterns): | ||||
|         self.vocab = vocab | ||||
|         self.mem = Pool() | ||||
|         self.vocab = vocab | ||||
|         self._patterns = dict(patterns) | ||||
|         for entity_key, (etype, attrs, specs) in sorted(patterns.items()): | ||||
|             self.add(entity_key, etype, attrs, specs) | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (self.__class__, (self.vocab, self._patterns), None, None) | ||||
|      | ||||
|     property n_patterns: | ||||
|         def __get__(self): return self.patterns.size() | ||||
| 
 | ||||
|     def add(self, entity_key, etype, attrs, specs): | ||||
|         self._patterns[entity_key] = (etype, dict(attrs), list(specs)) | ||||
|         if isinstance(entity_key, basestring): | ||||
|             entity_key = self.vocab.strings[entity_key] | ||||
|         if isinstance(etype, basestring): | ||||
|  |  | |||
|  | @ -25,6 +25,7 @@ cdef class Morphology: | |||
|     cdef readonly Pool mem | ||||
|     cdef readonly StringStore strings | ||||
|     cdef public object lemmatizer | ||||
|     cdef readonly object tag_map | ||||
|     cdef public object n_tags | ||||
|     cdef public object reverse_index | ||||
|     cdef public object tag_names | ||||
|  |  | |||
|  | @ -14,6 +14,7 @@ cdef class Morphology: | |||
|     def __init__(self, StringStore string_store, tag_map, lemmatizer): | ||||
|         self.mem = Pool() | ||||
|         self.strings = string_store | ||||
|         self.tag_map = tag_map | ||||
|         self.lemmatizer = lemmatizer | ||||
|         self.n_tags = len(tag_map) + 1 | ||||
|         self.tag_names = tuple(sorted(tag_map.keys())) | ||||
|  | @ -28,6 +29,9 @@ cdef class Morphology: | |||
|             self.reverse_index[self.rich_tags[i].name] = i | ||||
|         self._cache = PreshMapArray(self.n_tags) | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (Morphology, (self.strings, self.tag_map, self.lemmatizer), None, None) | ||||
| 
 | ||||
|     cdef int assign_tag(self, TokenC* token, tag) except -1: | ||||
|         cdef int tag_id | ||||
|         if isinstance(tag, basestring): | ||||
|  |  | |||
|  | @ -25,4 +25,4 @@ IDS = { | |||
| } | ||||
| 
 | ||||
| 
 | ||||
| NAMES = [key for key, value in sorted(IDS.items(), key=lambda item: item[1])] | ||||
| NAMES = {value: key for key, value in IDS.items()} | ||||
|  |  | |||
|  | @ -69,12 +69,15 @@ cdef Utf8Str _allocate(Pool mem, const unsigned char* chars, int length) except | |||
| 
 | ||||
| cdef class StringStore: | ||||
|     '''Map strings to and from integer IDs.''' | ||||
|     def __init__(self): | ||||
|     def __init__(self, strings=None): | ||||
|         self.mem = Pool() | ||||
|         self._map = PreshMap() | ||||
|         self._resize_at = 10000 | ||||
|         self.c = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str)) | ||||
|         self.size = 1 | ||||
|         if strings is not None: | ||||
|             for string in strings: | ||||
|                 _ = self[string] | ||||
| 
 | ||||
|     property size: | ||||
|         def __get__(self): | ||||
|  | @ -113,6 +116,14 @@ cdef class StringStore: | |||
|         for i in range(self.size): | ||||
|             yield self[i] | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         strings = [""] | ||||
|         for i in range(1, self.size): | ||||
|             string = &self.c[i] | ||||
|             py_string = _decode(string) | ||||
|             strings.append(py_string) | ||||
|         return (StringStore, (strings,), None, None, None) | ||||
| 
 | ||||
|     cdef const Utf8Str* intern(self, unsigned char* chars, int length) except NULL: | ||||
|         # 0 means missing, but we don't bother offsetting the index. | ||||
|         key = hash64(chars, length * sizeof(char), 0) | ||||
|  |  | |||
|  | @ -83,7 +83,6 @@ cdef class Parser: | |||
|         model = Model(moves.n_moves, templates, model_dir) | ||||
|         return cls(strings, moves, model) | ||||
| 
 | ||||
| 
 | ||||
|     def __call__(self, Doc tokens): | ||||
|         cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) | ||||
|         self.moves.initialize_state(stcls) | ||||
|  | @ -93,6 +92,9 @@ cdef class Parser: | |||
|         self.parse(stcls, eg.c) | ||||
|         tokens.set_parse(stcls._sent) | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (Parser, (self.moves.strings, self.moves, self.model), None, None) | ||||
| 
 | ||||
|     cdef void predict(self, StateClass stcls, ExampleC* eg) nogil: | ||||
|         memset(eg.scores, 0, eg.nr_class * sizeof(weight_t)) | ||||
|         self.moves.set_valid(eg.is_valid, stcls) | ||||
|  |  | |||
|  | @ -37,6 +37,8 @@ cdef class TransitionSystem: | |||
|     cdef public int root_label | ||||
|     cdef public freqs | ||||
| 
 | ||||
|     cdef object _labels_by_action | ||||
| 
 | ||||
|     cdef int initialize_state(self, StateClass state) except -1 | ||||
|     cdef int finalize_state(self, StateClass state) nogil | ||||
| 
 | ||||
|  |  | |||
|  | @ -15,7 +15,8 @@ class OracleError(Exception): | |||
| 
 | ||||
| 
 | ||||
| cdef class TransitionSystem: | ||||
|     def __init__(self, StringStore string_table, dict labels_by_action): | ||||
|     def __init__(self, StringStore string_table, dict labels_by_action, _freqs=None): | ||||
|         self._labels_by_action = labels_by_action | ||||
|         self.mem = Pool() | ||||
|         self.n_moves = sum(len(labels) for labels in labels_by_action.values()) | ||||
|         self._is_valid = <bint*>self.mem.alloc(self.n_moves, sizeof(bint)) | ||||
|  | @ -30,7 +31,7 @@ cdef class TransitionSystem: | |||
|                 i += 1 | ||||
|         self.c = moves | ||||
|         self.root_label = self.strings['ROOT'] | ||||
|         self.freqs = {} | ||||
|         self.freqs = {} if _freqs is None else _freqs | ||||
|         for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB): | ||||
|             self.freqs[attr] = defaultdict(int) | ||||
|             self.freqs[attr][0] = 1 | ||||
|  | @ -39,6 +40,11 @@ cdef class TransitionSystem: | |||
|             self.freqs[HEAD][i] = 1 | ||||
|             self.freqs[HEAD][-i] = 1 | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (self.__class__, | ||||
|                 (self.strings, self._labels_by_action, self.freqs), | ||||
|                 None, None) | ||||
| 
 | ||||
|     cdef int initialize_state(self, StateClass state) except -1: | ||||
|         pass | ||||
| 
 | ||||
|  |  | |||
|  | @ -148,6 +148,9 @@ cdef class Tagger: | |||
|         tokens.is_tagged = True | ||||
|         tokens._py_tokens = [None] * tokens.length | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         return (self.__class__, (self.vocab, self.model), None, None) | ||||
| 
 | ||||
|     def tag_from_strings(self, Doc tokens, object tag_strs): | ||||
|         cdef int i | ||||
|         for i in range(tokens.length): | ||||
|  |  | |||
|  | @ -25,7 +25,6 @@ cdef struct _Cached: | |||
| 
 | ||||
| 
 | ||||
| cdef class Vocab: | ||||
|     cpdef public lexeme_props_getter | ||||
|     cdef Pool mem | ||||
|     cpdef readonly StringStore strings | ||||
|     cpdef readonly Morphology morphology | ||||
|  | @ -33,7 +32,6 @@ cdef class Vocab: | |||
|     cdef public object _serializer | ||||
|     cdef public object data_dir | ||||
|     cdef public object get_lex_attr | ||||
|     cdef public object pos_tags | ||||
|     cdef public object serializer_freqs | ||||
| 
 | ||||
|     cdef const LexemeC* get(self, Pool mem, unicode string) except NULL | ||||
|  |  | |||
|  | @ -10,6 +10,8 @@ from os import path | |||
| import io | ||||
| import math | ||||
| import json | ||||
| import tempfile | ||||
| import copy_reg | ||||
| 
 | ||||
| from .lexeme cimport EMPTY_LEXEME | ||||
| from .lexeme cimport Lexeme | ||||
|  | @ -96,6 +98,20 @@ cdef class Vocab: | |||
|         """The current number of lexemes stored.""" | ||||
|         return self.length | ||||
| 
 | ||||
|     def __reduce__(self): | ||||
|         # TODO: Dump vectors | ||||
|         tmp_dir = tempfile.mkdtemp() | ||||
|         lex_loc = path.join(tmp_dir, 'lexemes.bin') | ||||
|         str_loc = path.join(tmp_dir, 'strings.txt') | ||||
|         vec_loc = path.join(self.data_dir, 'vec.bin') if self.data_dir is not None else None | ||||
| 
 | ||||
|         self.dump(lex_loc) | ||||
|         self.strings.dump(str_loc) | ||||
|          | ||||
|         state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr, | ||||
|                  self.serializer_freqs, self.data_dir) | ||||
|         return (unpickle_vocab, state, None, None) | ||||
| 
 | ||||
|     cdef const LexemeC* get(self, Pool mem, unicode string) except NULL: | ||||
|         '''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme | ||||
|         if necessary, using memory acquired from the given pool.  If the pool | ||||
|  | @ -271,17 +287,17 @@ cdef class Vocab: | |||
|             i += 1 | ||||
|         fp.close() | ||||
| 
 | ||||
|     def load_vectors(self, loc_or_file): | ||||
|     def load_vectors(self, file_): | ||||
|         cdef LexemeC* lexeme | ||||
|         cdef attr_t orth | ||||
|         cdef int32_t vec_len = -1 | ||||
|         for line_num, line in enumerate(loc_or_file): | ||||
|         for line_num, line in enumerate(file_): | ||||
|             pieces = line.split() | ||||
|             word_str = pieces.pop(0) | ||||
|             if vec_len == -1: | ||||
|                 vec_len = len(pieces) | ||||
|             elif vec_len != len(pieces): | ||||
|                 raise VectorReadError.mismatched_sizes(loc_or_file, line_num, | ||||
|                 raise VectorReadError.mismatched_sizes(file_, line_num, | ||||
|                                                         vec_len, len(pieces)) | ||||
|             orth = self.strings[word_str] | ||||
|             lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth) | ||||
|  | @ -339,6 +355,25 @@ cdef class Vocab: | |||
|         return vec_len | ||||
| 
 | ||||
| 
 | ||||
| def unpickle_vocab(strings_loc, lex_loc, vec_loc, morphology, get_lex_attr, | ||||
|                    serializer_freqs, data_dir): | ||||
|     cdef Vocab vocab = Vocab() | ||||
| 
 | ||||
|     vocab.get_lex_attr = get_lex_attr | ||||
|     vocab.morphology = morphology | ||||
|     vocab.strings = morphology.strings | ||||
|     vocab.data_dir = data_dir | ||||
|     vocab.serializer_freqs = serializer_freqs | ||||
| 
 | ||||
|     vocab.load_lexemes(strings_loc, lex_loc) | ||||
|     if vec_loc is not None: | ||||
|         vocab.load_vectors_from_bin_loc(vec_loc) | ||||
|     return vocab | ||||
|   | ||||
| 
 | ||||
| copy_reg.constructor(unpickle_vocab) | ||||
| 
 | ||||
| 
 | ||||
| def write_binary_vectors(in_loc, out_loc): | ||||
|     cdef CFile out_file = CFile(out_loc, 'wb') | ||||
|     cdef Address mem | ||||
|  |  | |||
							
								
								
									
										17
									
								
								tests/morphology/test_pickle.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										17
									
								
								tests/morphology/test_pickle.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,17 @@ | |||
| import pytest | ||||
| 
 | ||||
| import pickle | ||||
| import StringIO | ||||
| 
 | ||||
| 
 | ||||
| from spacy.morphology import Morphology | ||||
| from spacy.lemmatizer import Lemmatizer | ||||
| from spacy.strings import StringStore | ||||
| 
 | ||||
| 
 | ||||
| def test_pickle(): | ||||
|     morphology = Morphology(StringStore(), {}, Lemmatizer({}, {}, {}))  | ||||
| 
 | ||||
|     file_ = StringIO.StringIO() | ||||
|     pickle.dump(morphology, file_) | ||||
| 
 | ||||
|  | @ -7,7 +7,8 @@ import pytest | |||
| 
 | ||||
| @pytest.fixture | ||||
| def sun_text(): | ||||
|     with io.open(path.join(path.dirname(__file__), 'sun.txt'), 'r', encoding='utf8') as file_: | ||||
|     with io.open(path.join(path.dirname(__file__), '..', 'sun.txt'), 'r', | ||||
|                  encoding='utf8') as file_: | ||||
|         text = file_.read() | ||||
|     return text | ||||
| 
 | ||||
|  |  | |||
							
								
								
									
										16
									
								
								tests/parser/test_pickle.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								tests/parser/test_pickle.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,16 @@ | |||
| import pytest | ||||
| 
 | ||||
| import pickle | ||||
| import cloudpickle | ||||
| import StringIO | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.models | ||||
| def test_pickle(EN): | ||||
|     file_ = StringIO.StringIO() | ||||
|     cloudpickle.dump(EN.parser, file_) | ||||
| 
 | ||||
|     file_.seek(0) | ||||
| 
 | ||||
|     loaded = pickle.load(file_) | ||||
| 
 | ||||
|  | @ -1,5 +1,7 @@ | |||
| # -*- coding: utf-8 -*- | ||||
| from __future__ import unicode_literals | ||||
| import StringIO | ||||
| import pickle | ||||
| 
 | ||||
| from spacy.lemmatizer import Lemmatizer, read_index, read_exc | ||||
| from spacy.en import LOCAL_DATA_DIR | ||||
|  | @ -41,3 +43,12 @@ def test_smart_quotes(lemmatizer): | |||
|     do = lemmatizer.punct | ||||
|     assert do('“') == set(['"']) | ||||
|     assert do('“') == set(['"']) | ||||
| 
 | ||||
| 
 | ||||
| def test_pickle_lemmatizer(lemmatizer): | ||||
|     file_ = StringIO.StringIO() | ||||
|     pickle.dump(lemmatizer, file_) | ||||
| 
 | ||||
|     file_.seek(0) | ||||
|      | ||||
|     loaded = pickle.load(file_) | ||||
|  |  | |||
							
								
								
									
										15
									
								
								tests/test_pickle.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								tests/test_pickle.py
									
									
									
									
									
										Normal file
									
								
							|  | @ -0,0 +1,15 @@ | |||
| import pytest | ||||
| import StringIO | ||||
| import cloudpickle | ||||
| import pickle | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.models | ||||
| def test_pickle_english(EN): | ||||
|     file_ = StringIO.StringIO() | ||||
|     cloudpickle.dump(EN, file_) | ||||
| 
 | ||||
|     file_.seek(0) | ||||
| 
 | ||||
|     loaded = pickle.load(file_) | ||||
| 
 | ||||
|  | @ -1,5 +1,7 @@ | |||
| # -*- coding: utf8 -*- | ||||
| from __future__ import unicode_literals | ||||
| import pickle | ||||
| import StringIO | ||||
| 
 | ||||
| from spacy.strings import StringStore | ||||
| 
 | ||||
|  | @ -76,3 +78,18 @@ def test_massive_strings(sstore): | |||
|     s513 = '1' * 513 | ||||
|     orth = sstore[s513] | ||||
|     assert sstore[orth] == s513 | ||||
| 
 | ||||
| 
 | ||||
| def test_pickle_string_store(sstore): | ||||
|     hello_id = sstore[u'Hi'] | ||||
|     string_file = StringIO.StringIO() | ||||
|     pickle.dump(sstore, string_file) | ||||
| 
 | ||||
|     string_file.seek(0) | ||||
|      | ||||
|     loaded = pickle.load(string_file) | ||||
| 
 | ||||
|     assert loaded[hello_id] == u'Hi' | ||||
| 
 | ||||
| 
 | ||||
| 
 | ||||
|  |  | |||
|  | @ -1,5 +1,11 @@ | |||
| from __future__ import unicode_literals | ||||
| import pytest | ||||
| import StringIO | ||||
| import cloudpickle | ||||
| import pickle | ||||
| 
 | ||||
| from spacy.attrs import LEMMA, ORTH, PROB, IS_ALPHA | ||||
| from spacy.parts_of_speech import NOUN, VERB | ||||
| 
 | ||||
| from spacy.attrs import LEMMA, ORTH, PROB, IS_ALPHA | ||||
| from spacy.parts_of_speech import NOUN, VERB | ||||
|  | @ -38,3 +44,11 @@ def test_symbols(en_vocab): | |||
|     assert en_vocab.strings['ORTH'] == ORTH | ||||
|     assert en_vocab.strings['PROB'] == PROB | ||||
|      | ||||
| 
 | ||||
| def test_pickle_vocab(en_vocab): | ||||
|     file_ = StringIO.StringIO() | ||||
|     cloudpickle.dump(en_vocab, file_) | ||||
| 
 | ||||
|     file_.seek(0) | ||||
| 
 | ||||
|     loaded = pickle.load(file_) | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user