This commit is contained in:
Matthew Honnibal 2015-10-13 05:25:49 +02:00
commit f6d74b14de
21 changed files with 198 additions and 18 deletions

View File

@ -29,5 +29,6 @@ cdef class Model:
cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1 cdef int update(self, atom_t* context, class_t guess, class_t gold, int cost) except -1
cdef object model_loc cdef object model_loc
cdef object _templates
cdef Extractor _extractor cdef Extractor _extractor
cdef LinearModel _model cdef LinearModel _model

View File

@ -3,6 +3,7 @@ from __future__ import unicode_literals
from __future__ import division from __future__ import division
from os import path from os import path
import tempfile
import os import os
import shutil import shutil
import json import json
@ -52,6 +53,7 @@ cdef class Model:
def __init__(self, n_classes, templates, model_loc=None): def __init__(self, n_classes, templates, model_loc=None):
if model_loc is not None and path.isdir(model_loc): if model_loc is not None and path.isdir(model_loc):
model_loc = path.join(model_loc, 'model') model_loc = path.join(model_loc, 'model')
self._templates = templates
self.n_classes = n_classes self.n_classes = n_classes
self._extractor = Extractor(templates) self._extractor = Extractor(templates)
self.n_feats = self._extractor.n_templ self.n_feats = self._extractor.n_templ
@ -60,6 +62,18 @@ cdef class Model:
if self.model_loc and path.exists(self.model_loc): if self.model_loc and path.exists(self.model_loc):
self._model.load(self.model_loc, freq_thresh=0) self._model.load(self.model_loc, freq_thresh=0)
def __reduce__(self):
model_loc = tempfile.mkstemp()
# TODO: This is a potentially buggy implementation. We're not really
# given a good guarantee that all internal state is saved correctly here,
# since there are learning parameters for e.g. the model averaging in
# averaged perceptron, the gradient calculations in AdaGrad, etc
# that aren't necessarily saved. So, if we're part way through training
# the model, and then we pickle it, we won't recover the state correctly.
self._model.dump(model_loc)
return (Model, (self.n_classes, self.templates, model_loc),
None, None)
def predict(self, Example eg): def predict(self, Example eg):
self.set_scores(eg.c.scores, eg.c.atoms) self.set_scores(eg.c.scores, eg.c.atoms)
eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes) eg.c.guess = arg_max_if_true(eg.c.scores, eg.c.is_valid, self.n_classes)

View File

@ -207,6 +207,12 @@ class Language(object):
self.entity = entity self.entity = entity
self.matcher = matcher self.matcher = matcher
def __reduce__(self):
return (self.__class__,
(None, self.vocab, self.tokenizer, self.tagger, self.parser,
self.entity, self.matcher, None),
None, None)
def __call__(self, text, tag=True, parse=True, entity=True): def __call__(self, text, tag=True, parse=True, entity=True):
"""Apply the pipeline to some text. The text can span multiple sentences, """Apply the pipeline to some text. The text can span multiple sentences,
and can contain arbtrary whitespace. Alignment into the original string and can contain arbtrary whitespace. Alignment into the original string

View File

@ -168,13 +168,7 @@ cdef class Matcher:
cdef Pool mem cdef Pool mem
cdef vector[Pattern*] patterns cdef vector[Pattern*] patterns
cdef readonly Vocab vocab cdef readonly Vocab vocab
cdef object _patterns
def __init__(self, vocab, patterns):
self.vocab = vocab
self.mem = Pool()
self.vocab = vocab
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
self.add(entity_key, etype, attrs, specs)
@classmethod @classmethod
def from_dir(cls, data_dir, Vocab vocab): def from_dir(cls, data_dir, Vocab vocab):
@ -186,10 +180,22 @@ cdef class Matcher:
else: else:
return cls(vocab, {}) return cls(vocab, {})
def __init__(self, vocab, patterns):
self.vocab = vocab
self.mem = Pool()
self.vocab = vocab
self._patterns = dict(patterns)
for entity_key, (etype, attrs, specs) in sorted(patterns.items()):
self.add(entity_key, etype, attrs, specs)
def __reduce__(self):
return (self.__class__, (self.vocab, self._patterns), None, None)
property n_patterns: property n_patterns:
def __get__(self): return self.patterns.size() def __get__(self): return self.patterns.size()
def add(self, entity_key, etype, attrs, specs): def add(self, entity_key, etype, attrs, specs):
self._patterns[entity_key] = (etype, dict(attrs), list(specs))
if isinstance(entity_key, basestring): if isinstance(entity_key, basestring):
entity_key = self.vocab.strings[entity_key] entity_key = self.vocab.strings[entity_key]
if isinstance(etype, basestring): if isinstance(etype, basestring):

View File

@ -25,6 +25,7 @@ cdef class Morphology:
cdef readonly Pool mem cdef readonly Pool mem
cdef readonly StringStore strings cdef readonly StringStore strings
cdef public object lemmatizer cdef public object lemmatizer
cdef readonly object tag_map
cdef public object n_tags cdef public object n_tags
cdef public object reverse_index cdef public object reverse_index
cdef public object tag_names cdef public object tag_names

View File

@ -14,6 +14,7 @@ cdef class Morphology:
def __init__(self, StringStore string_store, tag_map, lemmatizer): def __init__(self, StringStore string_store, tag_map, lemmatizer):
self.mem = Pool() self.mem = Pool()
self.strings = string_store self.strings = string_store
self.tag_map = tag_map
self.lemmatizer = lemmatizer self.lemmatizer = lemmatizer
self.n_tags = len(tag_map) + 1 self.n_tags = len(tag_map) + 1
self.tag_names = tuple(sorted(tag_map.keys())) self.tag_names = tuple(sorted(tag_map.keys()))
@ -28,6 +29,9 @@ cdef class Morphology:
self.reverse_index[self.rich_tags[i].name] = i self.reverse_index[self.rich_tags[i].name] = i
self._cache = PreshMapArray(self.n_tags) self._cache = PreshMapArray(self.n_tags)
def __reduce__(self):
return (Morphology, (self.strings, self.tag_map, self.lemmatizer), None, None)
cdef int assign_tag(self, TokenC* token, tag) except -1: cdef int assign_tag(self, TokenC* token, tag) except -1:
cdef int tag_id cdef int tag_id
if isinstance(tag, basestring): if isinstance(tag, basestring):

View File

@ -25,4 +25,4 @@ IDS = {
} }
NAMES = [key for key, value in sorted(IDS.items(), key=lambda item: item[1])] NAMES = {value: key for key, value in IDS.items()}

View File

@ -69,12 +69,15 @@ cdef Utf8Str _allocate(Pool mem, const unsigned char* chars, int length) except
cdef class StringStore: cdef class StringStore:
'''Map strings to and from integer IDs.''' '''Map strings to and from integer IDs.'''
def __init__(self): def __init__(self, strings=None):
self.mem = Pool() self.mem = Pool()
self._map = PreshMap() self._map = PreshMap()
self._resize_at = 10000 self._resize_at = 10000
self.c = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str)) self.c = <Utf8Str*>self.mem.alloc(self._resize_at, sizeof(Utf8Str))
self.size = 1 self.size = 1
if strings is not None:
for string in strings:
_ = self[string]
property size: property size:
def __get__(self): def __get__(self):
@ -113,6 +116,14 @@ cdef class StringStore:
for i in range(self.size): for i in range(self.size):
yield self[i] yield self[i]
def __reduce__(self):
strings = [""]
for i in range(1, self.size):
string = &self.c[i]
py_string = _decode(string)
strings.append(py_string)
return (StringStore, (strings,), None, None, None)
cdef const Utf8Str* intern(self, unsigned char* chars, int length) except NULL: cdef const Utf8Str* intern(self, unsigned char* chars, int length) except NULL:
# 0 means missing, but we don't bother offsetting the index. # 0 means missing, but we don't bother offsetting the index.
key = hash64(chars, length * sizeof(char), 0) key = hash64(chars, length * sizeof(char), 0)

View File

@ -83,7 +83,6 @@ cdef class Parser:
model = Model(moves.n_moves, templates, model_dir) model = Model(moves.n_moves, templates, model_dir)
return cls(strings, moves, model) return cls(strings, moves, model)
def __call__(self, Doc tokens): def __call__(self, Doc tokens):
cdef StateClass stcls = StateClass.init(tokens.data, tokens.length) cdef StateClass stcls = StateClass.init(tokens.data, tokens.length)
self.moves.initialize_state(stcls) self.moves.initialize_state(stcls)
@ -93,6 +92,9 @@ cdef class Parser:
self.parse(stcls, eg.c) self.parse(stcls, eg.c)
tokens.set_parse(stcls._sent) tokens.set_parse(stcls._sent)
def __reduce__(self):
return (Parser, (self.moves.strings, self.moves, self.model), None, None)
cdef void predict(self, StateClass stcls, ExampleC* eg) nogil: cdef void predict(self, StateClass stcls, ExampleC* eg) nogil:
memset(eg.scores, 0, eg.nr_class * sizeof(weight_t)) memset(eg.scores, 0, eg.nr_class * sizeof(weight_t))
self.moves.set_valid(eg.is_valid, stcls) self.moves.set_valid(eg.is_valid, stcls)

View File

@ -37,6 +37,8 @@ cdef class TransitionSystem:
cdef public int root_label cdef public int root_label
cdef public freqs cdef public freqs
cdef object _labels_by_action
cdef int initialize_state(self, StateClass state) except -1 cdef int initialize_state(self, StateClass state) except -1
cdef int finalize_state(self, StateClass state) nogil cdef int finalize_state(self, StateClass state) nogil

View File

@ -15,7 +15,8 @@ class OracleError(Exception):
cdef class TransitionSystem: cdef class TransitionSystem:
def __init__(self, StringStore string_table, dict labels_by_action): def __init__(self, StringStore string_table, dict labels_by_action, _freqs=None):
self._labels_by_action = labels_by_action
self.mem = Pool() self.mem = Pool()
self.n_moves = sum(len(labels) for labels in labels_by_action.values()) self.n_moves = sum(len(labels) for labels in labels_by_action.values())
self._is_valid = <bint*>self.mem.alloc(self.n_moves, sizeof(bint)) self._is_valid = <bint*>self.mem.alloc(self.n_moves, sizeof(bint))
@ -30,7 +31,7 @@ cdef class TransitionSystem:
i += 1 i += 1
self.c = moves self.c = moves
self.root_label = self.strings['ROOT'] self.root_label = self.strings['ROOT']
self.freqs = {} self.freqs = {} if _freqs is None else _freqs
for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB): for attr in (TAG, HEAD, DEP, ENT_TYPE, ENT_IOB):
self.freqs[attr] = defaultdict(int) self.freqs[attr] = defaultdict(int)
self.freqs[attr][0] = 1 self.freqs[attr][0] = 1
@ -39,6 +40,11 @@ cdef class TransitionSystem:
self.freqs[HEAD][i] = 1 self.freqs[HEAD][i] = 1
self.freqs[HEAD][-i] = 1 self.freqs[HEAD][-i] = 1
def __reduce__(self):
return (self.__class__,
(self.strings, self._labels_by_action, self.freqs),
None, None)
cdef int initialize_state(self, StateClass state) except -1: cdef int initialize_state(self, StateClass state) except -1:
pass pass

View File

@ -148,6 +148,9 @@ cdef class Tagger:
tokens.is_tagged = True tokens.is_tagged = True
tokens._py_tokens = [None] * tokens.length tokens._py_tokens = [None] * tokens.length
def __reduce__(self):
return (self.__class__, (self.vocab, self.model), None, None)
def tag_from_strings(self, Doc tokens, object tag_strs): def tag_from_strings(self, Doc tokens, object tag_strs):
cdef int i cdef int i
for i in range(tokens.length): for i in range(tokens.length):

View File

@ -25,7 +25,6 @@ cdef struct _Cached:
cdef class Vocab: cdef class Vocab:
cpdef public lexeme_props_getter
cdef Pool mem cdef Pool mem
cpdef readonly StringStore strings cpdef readonly StringStore strings
cpdef readonly Morphology morphology cpdef readonly Morphology morphology
@ -33,7 +32,6 @@ cdef class Vocab:
cdef public object _serializer cdef public object _serializer
cdef public object data_dir cdef public object data_dir
cdef public object get_lex_attr cdef public object get_lex_attr
cdef public object pos_tags
cdef public object serializer_freqs cdef public object serializer_freqs
cdef const LexemeC* get(self, Pool mem, unicode string) except NULL cdef const LexemeC* get(self, Pool mem, unicode string) except NULL

View File

@ -10,6 +10,8 @@ from os import path
import io import io
import math import math
import json import json
import tempfile
import copy_reg
from .lexeme cimport EMPTY_LEXEME from .lexeme cimport EMPTY_LEXEME
from .lexeme cimport Lexeme from .lexeme cimport Lexeme
@ -96,6 +98,20 @@ cdef class Vocab:
"""The current number of lexemes stored.""" """The current number of lexemes stored."""
return self.length return self.length
def __reduce__(self):
# TODO: Dump vectors
tmp_dir = tempfile.mkdtemp()
lex_loc = path.join(tmp_dir, 'lexemes.bin')
str_loc = path.join(tmp_dir, 'strings.txt')
vec_loc = path.join(self.data_dir, 'vec.bin') if self.data_dir is not None else None
self.dump(lex_loc)
self.strings.dump(str_loc)
state = (str_loc, lex_loc, vec_loc, self.morphology, self.get_lex_attr,
self.serializer_freqs, self.data_dir)
return (unpickle_vocab, state, None, None)
cdef const LexemeC* get(self, Pool mem, unicode string) except NULL: cdef const LexemeC* get(self, Pool mem, unicode string) except NULL:
'''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme '''Get a pointer to a LexemeC from the lexicon, creating a new Lexeme
if necessary, using memory acquired from the given pool. If the pool if necessary, using memory acquired from the given pool. If the pool
@ -271,17 +287,17 @@ cdef class Vocab:
i += 1 i += 1
fp.close() fp.close()
def load_vectors(self, loc_or_file): def load_vectors(self, file_):
cdef LexemeC* lexeme cdef LexemeC* lexeme
cdef attr_t orth cdef attr_t orth
cdef int32_t vec_len = -1 cdef int32_t vec_len = -1
for line_num, line in enumerate(loc_or_file): for line_num, line in enumerate(file_):
pieces = line.split() pieces = line.split()
word_str = pieces.pop(0) word_str = pieces.pop(0)
if vec_len == -1: if vec_len == -1:
vec_len = len(pieces) vec_len = len(pieces)
elif vec_len != len(pieces): elif vec_len != len(pieces):
raise VectorReadError.mismatched_sizes(loc_or_file, line_num, raise VectorReadError.mismatched_sizes(file_, line_num,
vec_len, len(pieces)) vec_len, len(pieces))
orth = self.strings[word_str] orth = self.strings[word_str]
lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth) lexeme = <LexemeC*><void*>self.get_by_orth(self.mem, orth)
@ -339,6 +355,25 @@ cdef class Vocab:
return vec_len return vec_len
def unpickle_vocab(strings_loc, lex_loc, vec_loc, morphology, get_lex_attr,
serializer_freqs, data_dir):
cdef Vocab vocab = Vocab()
vocab.get_lex_attr = get_lex_attr
vocab.morphology = morphology
vocab.strings = morphology.strings
vocab.data_dir = data_dir
vocab.serializer_freqs = serializer_freqs
vocab.load_lexemes(strings_loc, lex_loc)
if vec_loc is not None:
vocab.load_vectors_from_bin_loc(vec_loc)
return vocab
copy_reg.constructor(unpickle_vocab)
def write_binary_vectors(in_loc, out_loc): def write_binary_vectors(in_loc, out_loc):
cdef CFile out_file = CFile(out_loc, 'wb') cdef CFile out_file = CFile(out_loc, 'wb')
cdef Address mem cdef Address mem

View File

@ -0,0 +1,17 @@
import pytest
import pickle
import StringIO
from spacy.morphology import Morphology
from spacy.lemmatizer import Lemmatizer
from spacy.strings import StringStore
def test_pickle():
morphology = Morphology(StringStore(), {}, Lemmatizer({}, {}, {}))
file_ = StringIO.StringIO()
pickle.dump(morphology, file_)

View File

@ -7,7 +7,8 @@ import pytest
@pytest.fixture @pytest.fixture
def sun_text(): def sun_text():
with io.open(path.join(path.dirname(__file__), 'sun.txt'), 'r', encoding='utf8') as file_: with io.open(path.join(path.dirname(__file__), '..', 'sun.txt'), 'r',
encoding='utf8') as file_:
text = file_.read() text = file_.read()
return text return text

View File

@ -0,0 +1,16 @@
import pytest
import pickle
import cloudpickle
import StringIO
@pytest.mark.models
def test_pickle(EN):
file_ = StringIO.StringIO()
cloudpickle.dump(EN.parser, file_)
file_.seek(0)
loaded = pickle.load(file_)

View File

@ -1,5 +1,7 @@
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
from __future__ import unicode_literals from __future__ import unicode_literals
import StringIO
import pickle
from spacy.lemmatizer import Lemmatizer, read_index, read_exc from spacy.lemmatizer import Lemmatizer, read_index, read_exc
from spacy.en import LOCAL_DATA_DIR from spacy.en import LOCAL_DATA_DIR
@ -41,3 +43,12 @@ def test_smart_quotes(lemmatizer):
do = lemmatizer.punct do = lemmatizer.punct
assert do('') == set(['"']) assert do('') == set(['"'])
assert do('') == set(['"']) assert do('') == set(['"'])
def test_pickle_lemmatizer(lemmatizer):
file_ = StringIO.StringIO()
pickle.dump(lemmatizer, file_)
file_.seek(0)
loaded = pickle.load(file_)

15
tests/test_pickle.py Normal file
View File

@ -0,0 +1,15 @@
import pytest
import StringIO
import cloudpickle
import pickle
@pytest.mark.models
def test_pickle_english(EN):
file_ = StringIO.StringIO()
cloudpickle.dump(EN, file_)
file_.seek(0)
loaded = pickle.load(file_)

View File

@ -1,5 +1,7 @@
# -*- coding: utf8 -*- # -*- coding: utf8 -*-
from __future__ import unicode_literals from __future__ import unicode_literals
import pickle
import StringIO
from spacy.strings import StringStore from spacy.strings import StringStore
@ -76,3 +78,18 @@ def test_massive_strings(sstore):
s513 = '1' * 513 s513 = '1' * 513
orth = sstore[s513] orth = sstore[s513]
assert sstore[orth] == s513 assert sstore[orth] == s513
def test_pickle_string_store(sstore):
hello_id = sstore[u'Hi']
string_file = StringIO.StringIO()
pickle.dump(sstore, string_file)
string_file.seek(0)
loaded = pickle.load(string_file)
assert loaded[hello_id] == u'Hi'

View File

@ -1,5 +1,11 @@
from __future__ import unicode_literals from __future__ import unicode_literals
import pytest import pytest
import StringIO
import cloudpickle
import pickle
from spacy.attrs import LEMMA, ORTH, PROB, IS_ALPHA
from spacy.parts_of_speech import NOUN, VERB
from spacy.attrs import LEMMA, ORTH, PROB, IS_ALPHA from spacy.attrs import LEMMA, ORTH, PROB, IS_ALPHA
from spacy.parts_of_speech import NOUN, VERB from spacy.parts_of_speech import NOUN, VERB
@ -38,3 +44,11 @@ def test_symbols(en_vocab):
assert en_vocab.strings['ORTH'] == ORTH assert en_vocab.strings['ORTH'] == ORTH
assert en_vocab.strings['PROB'] == PROB assert en_vocab.strings['PROB'] == PROB
def test_pickle_vocab(en_vocab):
file_ = StringIO.StringIO()
cloudpickle.dump(en_vocab, file_)
file_.seek(0)
loaded = pickle.load(file_)