Add to/from disk/bytes methods for tokenizer

This commit is contained in:
Matthew Honnibal 2017-05-29 12:24:41 +02:00
parent ff26aa6c37
commit a318f0cae1

View File

@ -7,7 +7,9 @@ from cython.operator cimport preincrement as preinc
from cymem.cymem cimport Pool
from preshed.maps cimport PreshMap
import dill
from .strings cimport hash_string
from . import util
cimport cython
from .tokens.doc cimport Doc
@ -325,15 +327,16 @@ cdef class Tokenizer:
self._cache.set(key, cached)
self._rules[string] = substrings
def to_disk(self, path):
def to_disk(self, path, **exclude):
"""Save the current state to a directory.
path (unicode or Path): A path to a directory, which will be created if
it doesn't exist. Paths may be either strings or `Path`-like objects.
"""
raise NotImplementedError()
with path.open('wb') as file_:
file_.write(self.to_bytes(**exclude))
def from_disk(self, path):
def from_disk(self, path, **exclude):
"""Loads state from a directory. Modifies the object in place and
returns it.
@ -341,7 +344,10 @@ cdef class Tokenizer:
strings or `Path`-like objects.
RETURNS (Tokenizer): The modified `Tokenizer` object.
"""
raise NotImplementedError()
with path.open('wb') as file_:
bytes_data = file_.read(path)
self.from_bytes(bytes_data, **exclude)
return self
def to_bytes(self, **exclude):
"""Serialize the current state to a binary string.
@ -349,7 +355,16 @@ cdef class Tokenizer:
**exclude: Named attributes to prevent from being serialized.
RETURNS (bytes): The serialized form of the `Tokenizer` object.
"""
raise NotImplementedError()
# TODO: Improve this so it doesn't need pickle
serializers = {
'vocab': lambda: self.vocab.to_bytes(),
'prefix': lambda: dill.dumps(self.prefix_search),
'suffix_search': lambda: dill.dumps(self.suffix_search),
'infix_finditer': lambda: dill.dumps(self.infix_finditer),
'token_match': lambda: dill.dumps(self.token_match),
'exceptions': lambda: dill.dumps(self._rules)
}
return util.to_bytes(serializers, exclude)
def from_bytes(self, bytes_data, **exclude):
"""Load state from a binary string.
@ -358,4 +373,26 @@ cdef class Tokenizer:
**exclude: Named attributes to prevent from being loaded.
RETURNS (Tokenizer): The `Tokenizer` object.
"""
raise NotImplementedError()
# TODO: Improve this so it doesn't need pickle
data = {}
deserializers = {
'vocab': lambda b: self.vocab.from_bytes(b),
'prefix': lambda b: data.setdefault('prefix', dill.loads(b)),
'suffix_search': lambda b: data.setdefault('suffix_search', dill.loads(b)),
'infix_finditer': lambda b: data.setdefault('infix_finditer', dill.loads(b)),
'token_match': lambda b: data.setdefault('token_match', dill.loads(b)),
'exceptions': lambda b: data.setdefault('rules', dill.loads(b))
}
msg = util.from_bytes(bytes_data, deserializers, exclude)
if 'prefix' in data:
self.prefix_search = data['prefix']
if 'suffix' in data:
self.suffix_search = data['suffix']
if 'infix' in data:
self.infix_finditer = data['infix']
if 'token_match' in data:
self.token_match = data['token_match']
for string, substrings in data.get('rules', {}).items():
self.add_special_case(string, substrings)