mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-03 05:04:09 +03:00
Allow reasonably efficient pickling of Language class, using to_bytes() and from_bytes().
This commit is contained in:
parent
0d57b9748a
commit
1cc85a89ef
|
@ -9,6 +9,7 @@ import ujson
|
|||
from collections import OrderedDict
|
||||
import itertools
|
||||
import weakref
|
||||
import functools
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .vocab import Vocab
|
||||
|
@ -19,14 +20,14 @@ from .syntax.parser import get_templates
|
|||
from .pipeline import NeuralDependencyParser, TokenVectorEncoder, NeuralTagger
|
||||
from .pipeline import NeuralEntityRecognizer, SimilarityHook, TextCategorizer
|
||||
|
||||
from .compat import json_dumps, izip
|
||||
from .compat import json_dumps, izip, copy_reg
|
||||
from .scorer import Scorer
|
||||
from ._ml import link_vectors_to_models
|
||||
from .attrs import IS_STOP
|
||||
from .lang.punctuation import TOKENIZER_PREFIXES, TOKENIZER_SUFFIXES, TOKENIZER_INFIXES
|
||||
from .lang.tokenizer_exceptions import TOKEN_MATCH
|
||||
from .lang.tag_map import TAG_MAP
|
||||
from .lang.lex_attrs import LEX_ATTRS
|
||||
from .lang.lex_attrs import LEX_ATTRS, is_stop
|
||||
from . import util
|
||||
from . import about
|
||||
|
||||
|
@ -42,7 +43,8 @@ class BaseDefaults(object):
|
|||
lemmatizer = cls.create_lemmatizer(nlp)
|
||||
lex_attr_getters = dict(cls.lex_attr_getters)
|
||||
# This is messy, but it's the minimal working fix to Issue #639.
|
||||
lex_attr_getters[IS_STOP] = lambda string: string.lower() in cls.stop_words
|
||||
lex_attr_getters[IS_STOP] = functools.partial(is_stop,
|
||||
stops=cls.stop_words)
|
||||
vocab = Vocab(lex_attr_getters=lex_attr_getters, tag_map=cls.tag_map,
|
||||
lemmatizer=lemmatizer)
|
||||
for tag_str, exc in cls.morph_rules.items():
|
||||
|
@ -135,6 +137,10 @@ class Language(object):
|
|||
self.pipeline = []
|
||||
self._optimizer = None
|
||||
|
||||
def __reduce__(self):
|
||||
bytes_data = self.to_bytes(vocab=False)
|
||||
return (unpickle_language, (self.vocab, self.meta, bytes_data))
|
||||
|
||||
@property
|
||||
def meta(self):
|
||||
self._meta.setdefault('lang', self.vocab.lang)
|
||||
|
@ -608,7 +614,7 @@ class Language(object):
|
|||
util.from_disk(path, deserializers, exclude)
|
||||
return self
|
||||
|
||||
def to_bytes(self, disable=[]):
|
||||
def to_bytes(self, disable=[], **exclude):
|
||||
"""Serialize the current state to a binary string.
|
||||
|
||||
disable (list): Nameds of pipeline components to disable and prevent
|
||||
|
@ -626,7 +632,7 @@ class Language(object):
|
|||
if not hasattr(proc, 'to_bytes'):
|
||||
continue
|
||||
serializers[i] = lambda proc=proc: proc.to_bytes(vocab=False)
|
||||
return util.to_bytes(serializers, {})
|
||||
return util.to_bytes(serializers, exclude)
|
||||
|
||||
def from_bytes(self, bytes_data, disable=[]):
|
||||
"""Load state from a binary string.
|
||||
|
@ -650,6 +656,12 @@ class Language(object):
|
|||
return self
|
||||
|
||||
|
||||
def unpickle_language(vocab, meta, bytes_data):
|
||||
lang = Language(vocab=vocab)
|
||||
lang.from_bytes(bytes_data)
|
||||
return lang
|
||||
|
||||
|
||||
def _pipe(func, docs):
|
||||
for doc in docs:
|
||||
func(doc)
|
||||
|
|
Loading…
Reference in New Issue
Block a user