diff --git a/spacy/language.py b/spacy/language.py index 5a8c3e90c..283b19899 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -9,6 +9,7 @@ import ujson from collections import OrderedDict import itertools import weakref +import functools from .tokenizer import Tokenizer from .vocab import Vocab @@ -19,14 +20,14 @@ from .syntax.parser import get_templates from .pipeline import NeuralDependencyParser, TokenVectorEncoder, NeuralTagger from .pipeline import NeuralEntityRecognizer, SimilarityHook, TextCategorizer -from .compat import json_dumps, izip +from .compat import json_dumps, izip, copy_reg from .scorer import Scorer from ._ml import link_vectors_to_models from .attrs import IS_STOP from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES from .lang.tokenizer_exceptions import TOKEN_MATCH from .lang.tag_map import TAG_MAP -from .lang.lex_attrs import LEX_ATTRS +from .lang.lex_attrs import LEX_ATTRS, is_stop from . import util from . import about @@ -42,7 +43,8 @@ class BaseDefaults(object): lemmatizer = cls.create_lemmatizer(nlp) lex_attr_getters = dict(cls.lex_attr_getters) # This is messy, but it's the minimal working fix to Issue #639. - lex_attr_getters[IS_STOP] = lambda string: string.lower() in cls.stop_words + lex_attr_getters[IS_STOP] = functools.partial(is_stop, + stops=cls.stop_words) vocab = Vocab(lex_attr_getters=lex_attr_getters, tag_map=cls.tag_map, lemmatizer=lemmatizer) for tag_str, exc in cls.morph_rules.items(): @@ -135,6 +137,10 @@ class Language(object): self.pipeline = [] self._optimizer = None + def __reduce__(self): + bytes_data = self.to_bytes(vocab=False) + return (unpickle_language, (self.vocab, self.meta, bytes_data)) + @property def meta(self): self._meta.setdefault('lang', self.vocab.lang) @@ -608,7 +614,7 @@ class Language(object): util.from_disk(path, deserializers, exclude) return self - def to_bytes(self, disable=[]): + def to_bytes(self, disable=[], **exclude): """Serialize the current state to a binary string. disable (list): Nameds of pipeline components to disable and prevent @@ -626,7 +632,7 @@ class Language(object): if not hasattr(proc, 'to_bytes'): continue serializers[i] = lambda proc=proc: proc.to_bytes(vocab=False) - return util.to_bytes(serializers, {}) + return util.to_bytes(serializers, exclude) def from_bytes(self, bytes_data, disable=[]): """Load state from a binary string. @@ -650,6 +656,12 @@ class Language(object): return self +def unpickle_language(vocab, meta, bytes_data): + lang = Language(vocab=vocab) + lang.from_bytes(bytes_data) + return lang + + def _pipe(func, docs): for doc in docs: func(doc)