Allow reasonably efficient pickling of Language class, using to_bytes() and from_bytes().

This commit is contained in:
Matthew Honnibal 2017-10-17 18:18:10 +02:00
parent 0d57b9748a
commit 1cc85a89ef

View File

@ -9,6 +9,7 @@ import ujson
from collections import OrderedDict
import itertools
import weakref
import functools
from .tokenizer import Tokenizer
from .vocab import Vocab
@ -19,14 +20,14 @@ from .syntax.parser import get_templates
from .pipeline import NeuralDependencyParser, TokenVectorEncoder, NeuralTagger
from .pipeline import NeuralEntityRecognizer, SimilarityHook, TextCategorizer
from .compat import json_dumps, izip
from .compat import json_dumps, izip, copy_reg
from .scorer import Scorer
from ._ml import link_vectors_to_models
from .attrs import IS_STOP
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
from .lang.tokenizer_exceptions import TOKEN_MATCH
from .lang.tag_map import TAG_MAP
from .lang.lex_attrs import LEX_ATTRS
from .lang.lex_attrs import LEX_ATTRS, is_stop
from . import util
from . import about
@ -42,7 +43,8 @@ class BaseDefaults(object):
lemmatizer = cls.create_lemmatizer(nlp)
lex_attr_getters = dict(cls.lex_attr_getters)
# This is messy, but it's the minimal working fix to Issue #639.
lex_attr_getters[IS_STOP] = lambda string: string.lower() in cls.stop_words
lex_attr_getters[IS_STOP] = functools.partial(is_stop,
stops=cls.stop_words)
vocab = Vocab(lex_attr_getters=lex_attr_getters, tag_map=cls.tag_map,
lemmatizer=lemmatizer)
for tag_str, exc in cls.morph_rules.items():
@ -135,6 +137,10 @@ class Language(object):
self.pipeline = []
self._optimizer = None
def __reduce__(self):
bytes_data = self.to_bytes(vocab=False)
return (unpickle_language, (self.vocab, self.meta, bytes_data))
@property
def meta(self):
self._meta.setdefault('lang', self.vocab.lang)
@ -608,7 +614,7 @@ class Language(object):
util.from_disk(path, deserializers, exclude)
return self
def to_bytes(self, disable=[]):
def to_bytes(self, disable=[], **exclude):
"""Serialize the current state to a binary string.
disable (list): Nameds of pipeline components to disable and prevent
@ -626,7 +632,7 @@ class Language(object):
if not hasattr(proc, 'to_bytes'):
continue
serializers[i] = lambda proc=proc: proc.to_bytes(vocab=False)
return util.to_bytes(serializers, {})
return util.to_bytes(serializers, exclude)
def from_bytes(self, bytes_data, disable=[]):
"""Load state from a binary string.
@ -650,6 +656,12 @@ class Language(object):
return self
def unpickle_language(vocab, meta, bytes_data):
lang = Language(vocab=vocab)
lang.from_bytes(bytes_data)
return lang
def _pipe(func, docs):
for doc in docs:
func(doc)