Allow reasonably efficient pickling of Language class, using to_bytes() and from_bytes().

This commit is contained in:
Matthew Honnibal 2017-10-17 18:18:10 +02:00
parent 0d57b9748a
commit 1cc85a89ef

View File

@ -9,6 +9,7 @@ import ujson
from collections import OrderedDict from collections import OrderedDict
import itertools import itertools
import weakref import weakref
import functools
from .tokenizer import Tokenizer from .tokenizer import Tokenizer
from .vocab import Vocab from .vocab import Vocab
@ -19,14 +20,14 @@ from .syntax.parser import get_templates
from .pipeline import NeuralDependencyParser, TokenVectorEncoder, NeuralTagger from .pipeline import NeuralDependencyParser, TokenVectorEncoder, NeuralTagger
from .pipeline import NeuralEntityRecognizer, SimilarityHook, TextCategorizer from .pipeline import NeuralEntityRecognizer, SimilarityHook, TextCategorizer
from .compat import json_dumps, izip from .compat import json_dumps, izip, copy_reg
from .scorer import Scorer from .scorer import Scorer
from ._ml import link_vectors_to_models from ._ml import link_vectors_to_models
from .attrs import IS_STOP from .attrs import IS_STOP
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
from .lang.tokenizer_exceptions import TOKEN_MATCH from .lang.tokenizer_exceptions import TOKEN_MATCH
from .lang.tag_map import TAG_MAP from .lang.tag_map import TAG_MAP
from .lang.lex_attrs import LEX_ATTRS from .lang.lex_attrs import LEX_ATTRS, is_stop
from . import util from . import util
from . import about from . import about
@ -42,7 +43,8 @@ class BaseDefaults(object):
lemmatizer = cls.create_lemmatizer(nlp) lemmatizer = cls.create_lemmatizer(nlp)
lex_attr_getters = dict(cls.lex_attr_getters) lex_attr_getters = dict(cls.lex_attr_getters)
# This is messy, but it's the minimal working fix to Issue #639. # This is messy, but it's the minimal working fix to Issue #639.
lex_attr_getters[IS_STOP] = lambda string: string.lower() in cls.stop_words lex_attr_getters[IS_STOP] = functools.partial(is_stop,
stops=cls.stop_words)
vocab = Vocab(lex_attr_getters=lex_attr_getters, tag_map=cls.tag_map, vocab = Vocab(lex_attr_getters=lex_attr_getters, tag_map=cls.tag_map,
lemmatizer=lemmatizer) lemmatizer=lemmatizer)
for tag_str, exc in cls.morph_rules.items(): for tag_str, exc in cls.morph_rules.items():
@ -135,6 +137,10 @@ class Language(object):
self.pipeline = [] self.pipeline = []
self._optimizer = None self._optimizer = None
def __reduce__(self):
bytes_data = self.to_bytes(vocab=False)
return (unpickle_language, (self.vocab, self.meta, bytes_data))
@property @property
def meta(self): def meta(self):
self._meta.setdefault('lang', self.vocab.lang) self._meta.setdefault('lang', self.vocab.lang)
@ -608,7 +614,7 @@ class Language(object):
util.from_disk(path, deserializers, exclude) util.from_disk(path, deserializers, exclude)
return self return self
def to_bytes(self, disable=[]): def to_bytes(self, disable=[], **exclude):
"""Serialize the current state to a binary string. """Serialize the current state to a binary string.
disable (list): Nameds of pipeline components to disable and prevent disable (list): Nameds of pipeline components to disable and prevent
@ -626,7 +632,7 @@ class Language(object):
if not hasattr(proc, 'to_bytes'): if not hasattr(proc, 'to_bytes'):
continue continue
serializers[i] = lambda proc=proc: proc.to_bytes(vocab=False) serializers[i] = lambda proc=proc: proc.to_bytes(vocab=False)
return util.to_bytes(serializers, {}) return util.to_bytes(serializers, exclude)
def from_bytes(self, bytes_data, disable=[]): def from_bytes(self, bytes_data, disable=[]):
"""Load state from a binary string. """Load state from a binary string.
@ -650,6 +656,12 @@ class Language(object):
return self return self
def unpickle_language(vocab, meta, bytes_data):
lang = Language(vocab=vocab)
lang.from_bytes(bytes_data)
return lang
def _pipe(func, docs): def _pipe(func, docs):
for doc in docs: for doc in docs:
func(doc) func(doc)