* Very scrappy, likely buggy first-cut pickle implementation, to work on Issue #125: allow pickle for Apache Spark. The current implementation sends stuff to temp files, and does almost nothing to ensure all modifiable state is actually preserved. The Language() instance is a deep tree of extension objects, and if pickling during training, some of the C-data state is hard to preserve.

This commit is contained in:
Matthew Honnibal 2015-10-12 19:33:11 +11:00
parent f8de403483
commit 20fd36a0f7
12 changed files with 104 additions and 21 deletions

View File

@ -29,5 +29,6 @@ cdef class Model:
cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1 cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1
cdef object model_loc cdef object model_loc
cdef object _templates
cdef Extractor _extractor cdef Extractor _extractor
cdef LinearModel _model cdef LinearModel _model

View File

@ -3,6 +3,7 @@ from __future__ import unicode_literals
from __future__ import division from __future__ import division
from os import path from os import path
import tempfile
import os import os
import shutil import shutil
import json import json
@ -52,6 +53,7 @@ cdef class Model:
def __init__(self, n_classes, templates, model_loc=None): def __init__(self, n_classes, templates, model_loc=None):
if model_loc is not None and path.isdir(model_loc): if model_loc is not None and path.isdir(model_loc):
model_loc = path.join(model_loc, 'model') model_loc = path.join(model_loc, 'model')
self._templates = templates
self.n_classes = n_classes self.n_classes = n_classes
self._extractor = Extractor(templates) self._extractor = Extractor(templates)
self.n_feats = self._extractor.n_templ self.n_feats = self._extractor.n_templ
@ -60,6 +62,18 @@ cdef class Model:
if self.model_loc and path.exists(self.model_loc): if self.model_loc and path.exists(self.model_loc):
self._model.load(self.model_loc, freq_thresh=0) self._model.load(self.model_loc, freq_thresh=0)
def __reduce__(self):
model_loc = tempfile.mkstemp()
# TODO: This is a potentially buggy implementation. We're not really
# given a good guarantee that all internal state is saved correctly here,
# since there are learning parameters for e.g. the model averaging in
# averaged perceptron, the gradient calculations in AdaGrad, etc
# that aren't necessarily saved. So, if we're part way through training
# the model, and then we pickle it, we won't recover the state correctly.
self._model.dump(model_loc)
return (Model, (self.n_classes, self.templates, model_loc),
None, None)
def predict(self, Example eg): def predict(self, Example eg):
self.set_scores(eg.c.scores, eg.c.atoms) self.set_scores(eg.c.scores, eg.c.atoms)
eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes)

View File

@ -207,6 +207,12 @@ class Language(object):
self.entity = entity self.entity = entity
self.matcher = matcher self.matcher = matcher
def __reduce__(self):
return (self.__class__,
(None, self.vocab, self.tokenizer, self.tagger, self.parser,
self.entity, self.matcher, None),
None, None)
def __call__(self, text, tag=True, parse=True, entity=True): def __call__(self, text, tag=True, parse=True, entity=True):
"""Apply the pipeline to some text. The text can span multiple sentences, """Apply the pipeline to some text. The text can span multiple sentences,
and can contain arbtrary whitespace. Alignment into the original string and can contain arbtrary whitespace. Alignment into the original string

View File

@ -168,13 +168,7 @@ cdef class Matcher:
cdef Pool mem cdef Pool mem
cdef vector[Pattern*] patterns cdef vector[Pattern*] patterns
cdef readonly Vocab vocab cdef readonly Vocab vocab
cdef object _patterns
def __init__(self, vocab, patterns):
self.vocab = vocab
self.mem = Pool()
self.vocab = vocab
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
self.add(entity_key, etype, attrs, specs)
@classmethod @classmethod
def from_dir(cls, data_dir, Vocab vocab): def from_dir(cls, data_dir, Vocab vocab):
@ -186,10 +180,22 @@ cdef class Matcher:
else: else:
return cls(vocab, {}) return cls(vocab, {})
def __init__(self, vocab, patterns):
self.vocab = vocab
self.mem = Pool()
self.vocab = vocab
self._patterns = dict(patterns)
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
self.add(entity_key, etype, attrs, specs)
def __reduce__(self):
return (self.__class__, (self.vocab, self._patterns), None, None)
property n_patterns: property n_patterns:
def __get__(self): return self.patterns.size() def __get__(self): return self.patterns.size()
def add(self, entity_key, etype, attrs, specs): def add(self, entity_key, etype, attrs, specs):
self._patterns[entity_key] = (etype, dict(attrs), list(specs))
if isinstance(entity_key, basestring): if isinstance(entity_key, basestring):
entity_key = self.vocab.strings[entity_key] entity_key = self.vocab.strings[entity_key]
if isinstance(etype, basestring): if isinstance(etype, basestring):

View File

@ -83,7 +83,6 @@ cdef class Parser:
model = Model(moves.n_moves, templates, model_dir) model = Model(moves.n_moves, templates, model_dir)
return cls(strings, moves, model) return cls(strings, moves, model)
def __call__(self, Doc tokens): def __call__(self, Doc tokens):
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
self.moves.initialize_state(stcls) self.moves.initialize_state(stcls)
@ -93,6 +92,9 @@ cdef class Parser:
self.parse(stcls, eg.c) self.parse(stcls, eg.c)
tokens.set_parse(stcls._sent) tokens.set_parse(stcls._sent)
def __reduce__(self):
return (Parser, (self.moves.strings, self.moves, self.model), None, None)
cdef void predict(self, StateClass stcls, ExampleC* eg) nogil: cdef void predict(self, StateClass stcls, ExampleC* eg) nogil:
memset(eg.scores, 0, eg.nr_class * sizeof(weight_t)) memset(eg.scores, 0, eg.nr_class * sizeof(weight_t))
self.moves.set_valid(eg.is_valid, stcls) self.moves.set_valid(eg.is_valid, stcls)

View File

@ -37,6 +37,8 @@ cdef class TransitionSystem:
cdef public int root_label cdef public int root_label
cdef public freqs cdef public freqs
cdef object _labels_by_action
cdef int initialize_state(self, StateClass state) except -1 cdef int initialize_state(self, StateClass state) except -1
cdef int finalize_state(self, StateClass state) nogil cdef int finalize_state(self, StateClass state) nogil

View File

@ -15,7 +15,8 @@ class OracleError(Exception):
cdef class TransitionSystem: cdef class TransitionSystem:
def __init__(self, StringStore string_table, dict labels_by_action): def __init__(self, StringStore string_table, dict labels_by_action, _freqs=None):
self._labels_by_action = labels_by_action
self.mem = Pool() self.mem = Pool()
self.n_moves = sum(len(labels) for labels in labels_by_action.values()) self.n_moves = sum(len(labels) for labels in labels_by_action.values())
self._is_valid = <bint*>self.mem.alloc(self.n_moves, sizeof(bint)) self._is_valid = <bint*>self.mem.alloc(self.n_moves, sizeof(bint))
@ -30,7 +31,7 @@ cdef class TransitionSystem:
i += 1 i += 1
self.c = moves self.c = moves
self.root_label = self.strings['ROOT'] self.root_label = self.strings['ROOT']
self.freqs = {} self.freqs = {} if _freqs is None else _freqs
for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB): for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB):
self.freqs[attr] = defaultdict(int) self.freqs[attr] = defaultdict(int)
self.freqs[attr][0] = 1 self.freqs[attr][0] = 1
@ -39,6 +40,11 @@ cdef class TransitionSystem:
self.freqs[HEAD][i] = 1 self.freqs[HEAD][i] = 1
self.freqs[HEAD][-i] = 1 self.freqs[HEAD][-i] = 1
def __reduce__(self):
return (self.__class__,
(self.strings, self._labels_by_action, self.freqs),
None, None)
cdef int initialize_state(self, StateClass state) except -1: cdef int initialize_state(self, StateClass state) except -1:
pass pass

View File

@ -148,6 +148,9 @@ cdef class Tagger:
tokens.is_tagged = True tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length tokens._py_tokens = [None] * tokens.length
def __reduce__(self):
return (self.__class__, (self.vocab, self.model), None, None)
def tag_from_strings(self, Doc tokens, object tag_strs): def tag_from_strings(self, Doc tokens, object tag_strs):
cdef int i cdef int i
for i in range(tokens.length): for i in range(tokens.length):

View File

@ -99,16 +99,18 @@ cdef class Vocab:
return self.length return self.length
def __reduce__(self): def __reduce__(self):
# TODO: Dump vectors
tmp_dir = tempfile.mkdtemp() tmp_dir = tempfile.mkdtemp()
lex_loc = path.join(tmp_dir, 'lexemes.bin') lex_loc = path.join(tmp_dir, 'lexemes.bin')
str_loc = path.join(tmp_dir, 'strings.txt') str_loc = path.join(tmp_dir, 'strings.txt')
map_loc = path.join(tmp_dir, 'tag_map.json') vec_loc = path.join(self.data_dir, 'vec.bin') if self.data_dir is not None else None
self.dump(lex_loc) self.dump(lex_loc)
self.strings.dump(str_loc) self.strings.dump(str_loc)
json.dump(self.morphology.tag_map, open(map_loc, 'w'))
return (unpickle_vocab, (tmp_dir,), None, None) state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr,
self.serializer_freqs, self.data_dir)
return (unpickle_vocab, state, None, None)
cdef const LexemeC* get(self, Pool mem, unicode string) except NULL: cdef const LexemeC* get(self, Pool mem, unicode string) except NULL:
'''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme '''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme
@ -353,11 +355,21 @@ cdef class Vocab:
return vec_len return vec_len
def unpickle_vocab(data_dir): def unpickle_vocab(strings_loc, lex_loc, vec_loc, morphology, get_lex_attr,
# TODO: This needs fixing --- the trouble is, we can't pickle staticmethods, serializer_freqs, data_dir):
# so we need to fiddle with the design of Language a little bit. cdef Vocab vocab = Vocab()
from .language import Language
return Vocab.from_dir(data_dir, Language.default_lex_attrs()) vocab.get_lex_attr = get_lex_attr
vocab.morphology = morphology
vocab.strings = morphology.strings
vocab.data_dir = data_dir
vocab.serializer_freqs = serializer_freqs
vocab.load_lexemes(strings_loc, lex_loc)
if vec_loc is not None:
vocab.load_vectors_from_bin_loc(vec_loc)
return vocab
copy_reg.constructor(unpickle_vocab) copy_reg.constructor(unpickle_vocab)

View File

@ -0,0 +1,16 @@
import pytest
import pickle
import cloudpickle
import StringIO
@pytest.mark.models
def test_pickle(EN):
file_ = StringIO.StringIO()
cloudpickle.dump(EN.parser, file_)
file_.seek(0)
loaded = pickle.load(file_)

15
tests/test_pickle.py Normal file
View File

@ -0,0 +1,15 @@
import pytest
import StringIO
import cloudpickle
import pickle
@pytest.mark.models
def test_pickle_english(EN):
file_ = StringIO.StringIO()
cloudpickle.dump(EN, file_)
file_.seek(0)
loaded = pickle.load(file_)

View File

@ -1,13 +1,13 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import pytest import pytest
import StringIO import StringIO
import cloudpickle
import pickle import pickle
from spacy.attrs import LEMMA, ORTH, PROB, IS_ALPHA from spacy.attrs import LEMMA, ORTH, PROB, IS_ALPHA
from spacy.parts_of_speech import NOUN, VERB from spacy.parts_of_speech import NOUN, VERB
def test_neq(en_vocab): def test_neq(en_vocab):
addr = en_vocab['Hello'] addr = en_vocab['Hello']
assert en_vocab['bye'].orth != addr.orth assert en_vocab['bye'].orth != addr.orth
@ -44,7 +44,7 @@ def test_symbols(en_vocab):
def test_pickle_vocab(en_vocab): def test_pickle_vocab(en_vocab):
file_ = StringIO.StringIO() file_ = StringIO.StringIO()
pickle.dump(en_vocab, file_) cloudpickle.dump(en_vocab, file_)
file_.seek(0) file_.seek(0)