mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
* Very scrappy, likely buggy first-cut pickle implementation, to work on Issue #125: allow pickle for Apache Spark. The current implementation sends stuff to temp files, and does almost nothing to ensure all modifiable state is actually preserved. The Language() instance is a deep tree of extension objects, and if pickling during training, some of the C-data state is hard to preserve.
This commit is contained in:
parent
f8de403483
commit
20fd36a0f7
|
@ -29,5 +29,6 @@ cdef class Model:
|
|||
cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1
|
||||
|
||||
cdef object model_loc
|
||||
cdef object _templates
|
||||
cdef Extractor _extractor
|
||||
cdef LinearModel _model
|
||||
|
|
|
@ -3,6 +3,7 @@ from __future__ import unicode_literals
|
|||
from __future__ import division
|
||||
|
||||
from os import path
|
||||
import tempfile
|
||||
import os
|
||||
import shutil
|
||||
import json
|
||||
|
@ -52,6 +53,7 @@ cdef class Model:
|
|||
def __init__(self, n_classes, templates, model_loc=None):
|
||||
if model_loc is not None and path.isdir(model_loc):
|
||||
model_loc = path.join(model_loc, 'model')
|
||||
self._templates = templates
|
||||
self.n_classes = n_classes
|
||||
self._extractor = Extractor(templates)
|
||||
self.n_feats = self._extractor.n_templ
|
||||
|
@ -60,6 +62,18 @@ cdef class Model:
|
|||
if self.model_loc and path.exists(self.model_loc):
|
||||
self._model.load(self.model_loc, freq_thresh=0)
|
||||
|
||||
def __reduce__(self):
|
||||
model_loc = tempfile.mkstemp()
|
||||
# TODO: This is a potentially buggy implementation. We're not really
|
||||
# given a good guarantee that all internal state is saved correctly here,
|
||||
# since there are learning parameters for e.g. the model averaging in
|
||||
# averaged perceptron, the gradient calculations in AdaGrad, etc
|
||||
# that aren't necessarily saved. So, if we're part way through training
|
||||
# the model, and then we pickle it, we won't recover the state correctly.
|
||||
self._model.dump(model_loc)
|
||||
return (Model, (self.n_classes, self.templates, model_loc),
|
||||
None, None)
|
||||
|
||||
def predict(self, Example eg):
|
||||
self.set_scores(eg.c.scores, eg.c.atoms)
|
||||
eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes)
|
||||
|
|
|
@ -207,6 +207,12 @@ class Language(object):
|
|||
self.entity = entity
|
||||
self.matcher = matcher
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__,
|
||||
(None, self.vocab, self.tokenizer, self.tagger, self.parser,
|
||||
self.entity, self.matcher, None),
|
||||
None, None)
|
||||
|
||||
def __call__(self, text, tag=True, parse=True, entity=True):
|
||||
"""Apply the pipeline to some text. The text can span multiple sentences,
|
||||
and can contain arbtrary whitespace. Alignment into the original string
|
||||
|
|
|
@ -168,13 +168,7 @@ cdef class Matcher:
|
|||
cdef Pool mem
|
||||
cdef vector[Pattern*] patterns
|
||||
cdef readonly Vocab vocab
|
||||
|
||||
def __init__(self, vocab, patterns):
|
||||
self.vocab = vocab
|
||||
self.mem = Pool()
|
||||
self.vocab = vocab
|
||||
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
|
||||
self.add(entity_key, etype, attrs, specs)
|
||||
cdef object _patterns
|
||||
|
||||
@classmethod
|
||||
def from_dir(cls, data_dir, Vocab vocab):
|
||||
|
@ -186,10 +180,22 @@ cdef class Matcher:
|
|||
else:
|
||||
return cls(vocab, {})
|
||||
|
||||
def __init__(self, vocab, patterns):
|
||||
self.vocab = vocab
|
||||
self.mem = Pool()
|
||||
self.vocab = vocab
|
||||
self._patterns = dict(patterns)
|
||||
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
|
||||
self.add(entity_key, etype, attrs, specs)
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self.vocab, self._patterns), None, None)
|
||||
|
||||
property n_patterns:
|
||||
def __get__(self): return self.patterns.size()
|
||||
|
||||
def add(self, entity_key, etype, attrs, specs):
|
||||
self._patterns[entity_key] = (etype, dict(attrs), list(specs))
|
||||
if isinstance(entity_key, basestring):
|
||||
entity_key = self.vocab.strings[entity_key]
|
||||
if isinstance(etype, basestring):
|
||||
|
|
|
@ -83,7 +83,6 @@ cdef class Parser:
|
|||
model = Model(moves.n_moves, templates, model_dir)
|
||||
return cls(strings, moves, model)
|
||||
|
||||
|
||||
def __call__(self, Doc tokens):
|
||||
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
|
||||
self.moves.initialize_state(stcls)
|
||||
|
@ -93,6 +92,9 @@ cdef class Parser:
|
|||
self.parse(stcls, eg.c)
|
||||
tokens.set_parse(stcls._sent)
|
||||
|
||||
def __reduce__(self):
|
||||
return (Parser, (self.moves.strings, self.moves, self.model), None, None)
|
||||
|
||||
cdef void predict(self, StateClass stcls, ExampleC* eg) nogil:
|
||||
memset(eg.scores, 0, eg.nr_class * sizeof(weight_t))
|
||||
self.moves.set_valid(eg.is_valid, stcls)
|
||||
|
|
|
@ -37,6 +37,8 @@ cdef class TransitionSystem:
|
|||
cdef public int root_label
|
||||
cdef public freqs
|
||||
|
||||
cdef object _labels_by_action
|
||||
|
||||
cdef int initialize_state(self, StateClass state) except -1
|
||||
cdef int finalize_state(self, StateClass state) nogil
|
||||
|
||||
|
|
|
@ -15,7 +15,8 @@ class OracleError(Exception):
|
|||
|
||||
|
||||
cdef class TransitionSystem:
|
||||
def __init__(self, StringStore string_table, dict labels_by_action):
|
||||
def __init__(self, StringStore string_table, dict labels_by_action, _freqs=None):
|
||||
self._labels_by_action = labels_by_action
|
||||
self.mem = Pool()
|
||||
self.n_moves = sum(len(labels) for labels in labels_by_action.values())
|
||||
self._is_valid = <bint*>self.mem.alloc(self.n_moves, sizeof(bint))
|
||||
|
@ -30,7 +31,7 @@ cdef class TransitionSystem:
|
|||
i += 1
|
||||
self.c = moves
|
||||
self.root_label = self.strings['ROOT']
|
||||
self.freqs = {}
|
||||
self.freqs = {} if _freqs is None else _freqs
|
||||
for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB):
|
||||
self.freqs[attr] = defaultdict(int)
|
||||
self.freqs[attr][0] = 1
|
||||
|
@ -39,6 +40,11 @@ cdef class TransitionSystem:
|
|||
self.freqs[HEAD][i] = 1
|
||||
self.freqs[HEAD][-i] = 1
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__,
|
||||
(self.strings, self._labels_by_action, self.freqs),
|
||||
None, None)
|
||||
|
||||
cdef int initialize_state(self, StateClass state) except -1:
|
||||
pass
|
||||
|
||||
|
|
|
@ -148,6 +148,9 @@ cdef class Tagger:
|
|||
tokens.is_tagged = True
|
||||
tokens._py_tokens = [None] * tokens.length
|
||||
|
||||
def __reduce__(self):
|
||||
return (self.__class__, (self.vocab, self.model), None, None)
|
||||
|
||||
def tag_from_strings(self, Doc tokens, object tag_strs):
|
||||
cdef int i
|
||||
for i in range(tokens.length):
|
||||
|
|
|
@ -99,16 +99,18 @@ cdef class Vocab:
|
|||
return self.length
|
||||
|
||||
def __reduce__(self):
|
||||
# TODO: Dump vectors
|
||||
tmp_dir = tempfile.mkdtemp()
|
||||
lex_loc = path.join(tmp_dir, 'lexemes.bin')
|
||||
str_loc = path.join(tmp_dir, 'strings.txt')
|
||||
map_loc = path.join(tmp_dir, 'tag_map.json')
|
||||
vec_loc = path.join(self.data_dir, 'vec.bin') if self.data_dir is not None else None
|
||||
|
||||
self.dump(lex_loc)
|
||||
self.strings.dump(str_loc)
|
||||
json.dump(self.morphology.tag_map, open(map_loc, 'w'))
|
||||
|
||||
return (unpickle_vocab, (tmp_dir,), None, None)
|
||||
state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr,
|
||||
self.serializer_freqs, self.data_dir)
|
||||
return (unpickle_vocab, state, None, None)
|
||||
|
||||
cdef const LexemeC* get(self, Pool mem, unicode string) except NULL:
|
||||
'''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme
|
||||
|
@ -353,11 +355,21 @@ cdef class Vocab:
|
|||
return vec_len
|
||||
|
||||
|
||||
def unpickle_vocab(data_dir):
|
||||
# TODO: This needs fixing --- the trouble is, we can't pickle staticmethods,
|
||||
# so we need to fiddle with the design of Language a little bit.
|
||||
from .language import Language
|
||||
return Vocab.from_dir(data_dir, Language.default_lex_attrs())
|
||||
def unpickle_vocab(strings_loc, lex_loc, vec_loc, morphology, get_lex_attr,
|
||||
serializer_freqs, data_dir):
|
||||
cdef Vocab vocab = Vocab()
|
||||
|
||||
vocab.get_lex_attr = get_lex_attr
|
||||
vocab.morphology = morphology
|
||||
vocab.strings = morphology.strings
|
||||
vocab.data_dir = data_dir
|
||||
vocab.serializer_freqs = serializer_freqs
|
||||
|
||||
vocab.load_lexemes(strings_loc, lex_loc)
|
||||
if vec_loc is not None:
|
||||
vocab.load_vectors_from_bin_loc(vec_loc)
|
||||
return vocab
|
||||
|
||||
|
||||
copy_reg.constructor(unpickle_vocab)
|
||||
|
||||
|
|
16
tests/parser/test_pickle.py
Normal file
16
tests/parser/test_pickle.py
Normal file
|
@ -0,0 +1,16 @@
|
|||
import pytest
|
||||
|
||||
import pickle
|
||||
import cloudpickle
|
||||
import StringIO
|
||||
|
||||
|
||||
@pytest.mark.models
|
||||
def test_pickle(EN):
|
||||
file_ = StringIO.StringIO()
|
||||
cloudpickle.dump(EN.parser, file_)
|
||||
|
||||
file_.seek(0)
|
||||
|
||||
loaded = pickle.load(file_)
|
||||
|
15
tests/test_pickle.py
Normal file
15
tests/test_pickle.py
Normal file
|
@ -0,0 +1,15 @@
|
|||
import pytest
|
||||
import StringIO
|
||||
import cloudpickle
|
||||
import pickle
|
||||
|
||||
|
||||
@pytest.mark.models
|
||||
def test_pickle_english(EN):
|
||||
file_ = StringIO.StringIO()
|
||||
cloudpickle.dump(EN, file_)
|
||||
|
||||
file_.seek(0)
|
||||
|
||||
loaded = pickle.load(file_)
|
||||
|
|
@ -1,13 +1,13 @@
|
|||
from __future__ import unicode_literals
|
||||
import pytest
|
||||
import StringIO
|
||||
import cloudpickle
|
||||
import pickle
|
||||
|
||||
from spacy.attrs import LEMMA, ORTH, PROB, IS_ALPHA
|
||||
from spacy.parts_of_speech import NOUN, VERB
|
||||
|
||||
|
||||
|
||||
def test_neq(en_vocab):
|
||||
addr = en_vocab['Hello']
|
||||
assert en_vocab['bye'].orth != addr.orth
|
||||
|
@ -44,7 +44,7 @@ def test_symbols(en_vocab):
|
|||
|
||||
def test_pickle_vocab(en_vocab):
|
||||
file_ = StringIO.StringIO()
|
||||
pickle.dump(en_vocab, file_)
|
||||
cloudpickle.dump(en_vocab, file_)
|
||||
|
||||
file_.seek(0)
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user