diff --git a/spacy/_ml.pxd b/spacy/_ml.pxd index c2c7ffded..b9a190b67 100644 --- a/spacy/_ml.pxd +++ b/spacy/_ml.pxd @@ -29,5 +29,6 @@ cdef class Model: cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1 cdef object model_loc + cdef object _templates cdef Extractor _extractor cdef LinearModel _model diff --git a/spacy/_ml.pyx b/spacy/_ml.pyx index 56c080fa6..bc789e7d6 100644 --- a/spacy/_ml.pyx +++ b/spacy/_ml.pyx @@ -3,6 +3,7 @@ from __future__ import unicode_literals from __future__ import division from os import path +import tempfile import os import shutil import json @@ -52,6 +53,7 @@ cdef class Model: def __init__(self, n_classes, templates, model_loc=None): if model_loc is not None and path.isdir(model_loc): model_loc = path.join(model_loc, 'model') + self._templates = templates self.n_classes = n_classes self._extractor = Extractor(templates) self.n_feats = self._extractor.n_templ @@ -60,6 +62,18 @@ cdef class Model: if self.model_loc and path.exists(self.model_loc): self._model.load(self.model_loc, freq_thresh=0) + def __reduce__(self): + model_loc = tempfile.mkstemp() + # TODO: This is a potentially buggy implementation. We're not really + # given a good guarantee that all internal state is saved correctly here, + # since there are learning parameters for e.g. the model averaging in + # averaged perceptron, the gradient calculations in AdaGrad, etc + # that aren't necessarily saved. So, if we're part way through training + # the model, and then we pickle it, we won't recover the state correctly. + self._model.dump(model_loc) + return (Model, (self.n_classes, self.templates, model_loc), + None, None) + def predict(self, Example eg): self.set_scores(eg.c.scores, eg.c.atoms) eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) diff --git a/spacy/language.py b/spacy/language.py index ba4c048d7..65425bc45 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -207,6 +207,12 @@ class Language(object): self.entity = entity self.matcher = matcher + def __reduce__(self): + return (self.__class__, + (None, self.vocab, self.tokenizer, self.tagger, self.parser, + self.entity, self.matcher, None), + None, None) + def __call__(self, text, tag=True, parse=True, entity=True): """Apply the pipeline to some text. The text can span multiple sentences, and can contain arbtrary whitespace. Alignment into the original string diff --git a/spacy/matcher.pyx b/spacy/matcher.pyx index 3ee825932..2bf8370b5 100644 --- a/spacy/matcher.pyx +++ b/spacy/matcher.pyx @@ -168,13 +168,7 @@ cdef class Matcher: cdef Pool mem cdef vector[Pattern*] patterns cdef readonly Vocab vocab - - def __init__(self, vocab, patterns): - self.vocab = vocab - self.mem = Pool() - self.vocab = vocab - for entity_key, (etype, attrs, specs) in sorted(patterns.items()): - self.add(entity_key, etype, attrs, specs) + cdef object _patterns @classmethod def from_dir(cls, data_dir, Vocab vocab): @@ -186,10 +180,22 @@ cdef class Matcher: else: return cls(vocab, {}) + def __init__(self, vocab, patterns): + self.vocab = vocab + self.mem = Pool() + self.vocab = vocab + self._patterns = dict(patterns) + for entity_key, (etype, attrs, specs) in sorted(patterns.items()): + self.add(entity_key, etype, attrs, specs) + + def __reduce__(self): + return (self.__class__, (self.vocab, self._patterns), None, None) + property n_patterns: def __get__(self): return self.patterns.size() def add(self, entity_key, etype, attrs, specs): + self._patterns[entity_key] = (etype, dict(attrs), list(specs)) if isinstance(entity_key, basestring): entity_key = self.vocab.strings[entity_key] if isinstance(etype, basestring): diff --git a/spacy/morphology.pxd b/spacy/morphology.pxd index 62d3fccc1..847626158 100644 --- a/spacy/morphology.pxd +++ b/spacy/morphology.pxd @@ -25,6 +25,7 @@ cdef class Morphology: cdef readonly Pool mem cdef readonly StringStore strings cdef public object lemmatizer + cdef readonly object tag_map cdef public object n_tags cdef public object reverse_index cdef public object tag_names diff --git a/spacy/morphology.pyx b/spacy/morphology.pyx index c53e5f478..e8b1f3520 100644 --- a/spacy/morphology.pyx +++ b/spacy/morphology.pyx @@ -14,6 +14,7 @@ cdef class Morphology: def __init__(self, StringStore string_store, tag_map, lemmatizer): self.mem = Pool() self.strings = string_store + self.tag_map = tag_map self.lemmatizer = lemmatizer self.n_tags = len(tag_map) + 1 self.tag_names = tuple(sorted(tag_map.keys())) @@ -28,6 +29,9 @@ cdef class Morphology: self.reverse_index[self.rich_tags[i].name] = i self._cache = PreshMapArray(self.n_tags) + def __reduce__(self): + return (Morphology, (self.strings, self.tag_map, self.lemmatizer), None, None) + cdef int assign_tag(self, TokenC* token, tag) except -1: cdef int tag_id if isinstance(tag, basestring): diff --git a/spacy/parts_of_speech.pyx b/spacy/parts_of_speech.pyx index 14933480c..006a1f006 100644 --- a/spacy/parts_of_speech.pyx +++ b/spacy/parts_of_speech.pyx @@ -25,4 +25,4 @@ IDS = { } -NAMES = [key for key, value in sorted(IDS.items(), key=lambda item: item[1])] +NAMES = {value: key for key, value in IDS.items()} diff --git a/spacy/strings.pyx b/spacy/strings.pyx index a4a470158..2208d3bdf 100644 --- a/spacy/strings.pyx +++ b/spacy/strings.pyx @@ -69,12 +69,15 @@ cdef Utf8Str _allocate(Pool mem, const unsigned char* chars, int length) except cdef class StringStore: '''Map strings to and from integer IDs.''' - def __init__(self): + def __init__(self, strings=None): self.mem = Pool() self._map = PreshMap() self._resize_at = 10000 self.c = self.mem.alloc(self._resize_at, sizeof(Utf8Str)) self.size = 1 + if strings is not None: + for string in strings: + _ = self[string] property size: def __get__(self): @@ -113,6 +116,14 @@ cdef class StringStore: for i in range(self.size): yield self[i] + def __reduce__(self): + strings = [""] + for i in range(1, self.size): + string = &self.c[i] + py_string = _decode(string) + strings.append(py_string) + return (StringStore, (strings,), None, None, None) + cdef const Utf8Str* intern(self, unsigned char* chars, int length) except NULL: # 0 means missing, but we don't bother offsetting the index. key = hash64(chars, length * sizeof(char), 0) diff --git a/spacy/syntax/parser.pyx b/spacy/syntax/parser.pyx index cf61647b9..25932a0a4 100644 --- a/spacy/syntax/parser.pyx +++ b/spacy/syntax/parser.pyx @@ -83,7 +83,6 @@ cdef class Parser: model = Model(moves.n_moves, templates, model_dir) return cls(strings, moves, model) - def __call__(self, Doc tokens): cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) self.moves.initialize_state(stcls) @@ -93,6 +92,9 @@ cdef class Parser: self.parse(stcls, eg.c) tokens.set_parse(stcls._sent) + def __reduce__(self): + return (Parser, (self.moves.strings, self.moves, self.model), None, None) + cdef void predict(self, StateClass stcls, ExampleC* eg) nogil: memset(eg.scores, 0, eg.nr_class * sizeof(weight_t)) self.moves.set_valid(eg.is_valid, stcls) diff --git a/spacy/syntax/transition_system.pxd b/spacy/syntax/transition_system.pxd index 4cf9aae7e..38bc91605 100644 --- a/spacy/syntax/transition_system.pxd +++ b/spacy/syntax/transition_system.pxd @@ -37,6 +37,8 @@ cdef class TransitionSystem: cdef public int root_label cdef public freqs + cdef object _labels_by_action + cdef int initialize_state(self, StateClass state) except -1 cdef int finalize_state(self, StateClass state) nogil diff --git a/spacy/syntax/transition_system.pyx b/spacy/syntax/transition_system.pyx index 86aef1fbc..5de3513e0 100644 --- a/spacy/syntax/transition_system.pyx +++ b/spacy/syntax/transition_system.pyx @@ -15,7 +15,8 @@ class OracleError(Exception): cdef class TransitionSystem: - def __init__(self, StringStore string_table, dict labels_by_action): + def __init__(self, StringStore string_table, dict labels_by_action, _freqs=None): + self._labels_by_action = labels_by_action self.mem = Pool() self.n_moves = sum(len(labels) for labels in labels_by_action.values()) self._is_valid = self.mem.alloc(self.n_moves, sizeof(bint)) @@ -30,7 +31,7 @@ cdef class TransitionSystem: i += 1 self.c = moves self.root_label = self.strings['ROOT'] - self.freqs = {} + self.freqs = {} if _freqs is None else _freqs for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB): self.freqs[attr] = defaultdict(int) self.freqs[attr][0] = 1 @@ -39,6 +40,11 @@ cdef class TransitionSystem: self.freqs[HEAD][i] = 1 self.freqs[HEAD][-i] = 1 + def __reduce__(self): + return (self.__class__, + (self.strings, self._labels_by_action, self.freqs), + None, None) + cdef int initialize_state(self, StateClass state) except -1: pass diff --git a/spacy/tagger.pyx b/spacy/tagger.pyx index 756bb7ea4..69925ff89 100644 --- a/spacy/tagger.pyx +++ b/spacy/tagger.pyx @@ -148,6 +148,9 @@ cdef class Tagger: tokens.is_tagged = True tokens._py_tokens = [None] * tokens.length + def __reduce__(self): + return (self.__class__, (self.vocab, self.model), None, None) + def tag_from_strings(self, Doc tokens, object tag_strs): cdef int i for i in range(tokens.length): diff --git a/spacy/vocab.pxd b/spacy/vocab.pxd index 929c7b345..d850bf929 100644 --- a/spacy/vocab.pxd +++ b/spacy/vocab.pxd @@ -25,7 +25,6 @@ cdef struct _Cached: cdef class Vocab: - cpdef public lexeme_props_getter cdef Pool mem cpdef readonly StringStore strings cpdef readonly Morphology morphology @@ -33,7 +32,6 @@ cdef class Vocab: cdef public object _serializer cdef public object data_dir cdef public object get_lex_attr - cdef public object pos_tags cdef public object serializer_freqs cdef const LexemeC* get(self, Pool mem, unicode string) except NULL diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index 0f43967bb..023d0bd89 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -10,6 +10,8 @@ from os import path import io import math import json +import tempfile +import copy_reg from .lexeme cimport EMPTY_LEXEME from .lexeme cimport Lexeme @@ -96,6 +98,20 @@ cdef class Vocab: """The current number of lexemes stored.""" return self.length + def __reduce__(self): + # TODO: Dump vectors + tmp_dir = tempfile.mkdtemp() + lex_loc = path.join(tmp_dir, 'lexemes.bin') + str_loc = path.join(tmp_dir, 'strings.txt') + vec_loc = path.join(self.data_dir, 'vec.bin') if self.data_dir is not None else None + + self.dump(lex_loc) + self.strings.dump(str_loc) + + state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr, + self.serializer_freqs, self.data_dir) + return (unpickle_vocab, state, None, None) + cdef const LexemeC* get(self, Pool mem, unicode string) except NULL: '''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme if necessary, using memory acquired from the given pool. If the pool @@ -271,17 +287,17 @@ cdef class Vocab: i += 1 fp.close() - def load_vectors(self, loc_or_file): + def load_vectors(self, file_): cdef LexemeC* lexeme cdef attr_t orth cdef int32_t vec_len = -1 - for line_num, line in enumerate(loc_or_file): + for line_num, line in enumerate(file_): pieces = line.split() word_str = pieces.pop(0) if vec_len == -1: vec_len = len(pieces) elif vec_len != len(pieces): - raise VectorReadError.mismatched_sizes(loc_or_file, line_num, + raise VectorReadError.mismatched_sizes(file_, line_num, vec_len, len(pieces)) orth = self.strings[word_str] lexeme = self.get_by_orth(self.mem, orth) @@ -339,6 +355,25 @@ cdef class Vocab: return vec_len +def unpickle_vocab(strings_loc, lex_loc, vec_loc, morphology, get_lex_attr, + serializer_freqs, data_dir): + cdef Vocab vocab = Vocab() + + vocab.get_lex_attr = get_lex_attr + vocab.morphology = morphology + vocab.strings = morphology.strings + vocab.data_dir = data_dir + vocab.serializer_freqs = serializer_freqs + + vocab.load_lexemes(strings_loc, lex_loc) + if vec_loc is not None: + vocab.load_vectors_from_bin_loc(vec_loc) + return vocab + + +copy_reg.constructor(unpickle_vocab) + + def write_binary_vectors(in_loc, out_loc): cdef CFile out_file = CFile(out_loc, 'wb') cdef Address mem diff --git a/tests/morphology/test_pickle.py b/tests/morphology/test_pickle.py new file mode 100644 index 000000000..f1b5bcd4c --- /dev/null +++ b/tests/morphology/test_pickle.py @@ -0,0 +1,17 @@ +import pytest + +import pickle +import StringIO + + +from spacy.morphology import Morphology +from spacy.lemmatizer import Lemmatizer +from spacy.strings import StringStore + + +def test_pickle(): + morphology = Morphology(StringStore(), {}, Lemmatizer({}, {}, {})) + + file_ = StringIO.StringIO() + pickle.dump(morphology, file_) + diff --git a/tests/parser/test_parse_navigate.py b/tests/parser/test_parse_navigate.py index 1771dbeba..eac57a5cd 100644 --- a/tests/parser/test_parse_navigate.py +++ b/tests/parser/test_parse_navigate.py @@ -7,7 +7,8 @@ import pytest @pytest.fixture def sun_text(): - with io.open(path.join(path.dirname(__file__), 'sun.txt'), 'r', encoding='utf8') as file_: + with io.open(path.join(path.dirname(__file__), '..', 'sun.txt'), 'r', + encoding='utf8') as file_: text = file_.read() return text diff --git a/tests/parser/test_pickle.py b/tests/parser/test_pickle.py new file mode 100644 index 000000000..b1b768650 --- /dev/null +++ b/tests/parser/test_pickle.py @@ -0,0 +1,16 @@ +import pytest + +import pickle +import cloudpickle +import StringIO + + +@pytest.mark.models +def test_pickle(EN): + file_ = StringIO.StringIO() + cloudpickle.dump(EN.parser, file_) + + file_.seek(0) + + loaded = pickle.load(file_) + diff --git a/tests/tagger/test_lemmatizer.py b/tests/tagger/test_lemmatizer.py index ff10b6573..5dfdaabb1 100644 --- a/tests/tagger/test_lemmatizer.py +++ b/tests/tagger/test_lemmatizer.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals +import StringIO +import pickle from spacy.lemmatizer import Lemmatizer, read_index, read_exc from spacy.en import LOCAL_DATA_DIR @@ -41,3 +43,12 @@ def test_smart_quotes(lemmatizer): do = lemmatizer.punct assert do('“') == set(['"']) assert do('“') == set(['"']) + + +def test_pickle_lemmatizer(lemmatizer): + file_ = StringIO.StringIO() + pickle.dump(lemmatizer, file_) + + file_.seek(0) + + loaded = pickle.load(file_) diff --git a/tests/test_pickle.py b/tests/test_pickle.py new file mode 100644 index 000000000..02d908b0d --- /dev/null +++ b/tests/test_pickle.py @@ -0,0 +1,15 @@ +import pytest +import StringIO +import cloudpickle +import pickle + + +@pytest.mark.models +def test_pickle_english(EN): + file_ = StringIO.StringIO() + cloudpickle.dump(EN, file_) + + file_.seek(0) + + loaded = pickle.load(file_) + diff --git a/tests/vocab/test_intern.py b/tests/vocab/test_intern.py index 6e007c645..256706c6f 100644 --- a/tests/vocab/test_intern.py +++ b/tests/vocab/test_intern.py @@ -1,5 +1,7 @@ # -*- coding: utf8 -*- from __future__ import unicode_literals +import pickle +import StringIO from spacy.strings import StringStore @@ -76,3 +78,18 @@ def test_massive_strings(sstore): s513 = '1' * 513 orth = sstore[s513] assert sstore[orth] == s513 + + +def test_pickle_string_store(sstore): + hello_id = sstore[u'Hi'] + string_file = StringIO.StringIO() + pickle.dump(sstore, string_file) + + string_file.seek(0) + + loaded = pickle.load(string_file) + + assert loaded[hello_id] == u'Hi' + + + diff --git a/tests/vocab/test_vocab.py b/tests/vocab/test_vocab.py index 153e0d546..5981f30e7 100644 --- a/tests/vocab/test_vocab.py +++ b/tests/vocab/test_vocab.py @@ -1,5 +1,11 @@ from __future__ import unicode_literals import pytest +import StringIO +import cloudpickle +import pickle + +from spacy.attrs import LEMMA, ORTH, PROB, IS_ALPHA +from spacy.parts_of_speech import NOUN, VERB from spacy.attrs import LEMMA, ORTH, PROB, IS_ALPHA from spacy.parts_of_speech import NOUN, VERB @@ -38,3 +44,11 @@ def test_symbols(en_vocab): assert en_vocab.strings['ORTH'] == ORTH assert en_vocab.strings['PROB'] == PROB + +def test_pickle_vocab(en_vocab): + file_ = StringIO.StringIO() + cloudpickle.dump(en_vocab, file_) + + file_.seek(0) + + loaded = pickle.load(file_)