From a318f0cae1dbca401d259b1cb059d82518c4beca Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 29 May 2017 12:24:41 +0200 Subject: [PATCH] Add to/from disk/bytes methods for tokenizer --- spacy/tokenizer.pyx | 49 +++++++++++++++++++++++++++++++++++++++------ 1 file changed, 43 insertions(+), 6 deletions(-) diff --git a/spacy/tokenizer.pyx b/spacy/tokenizer.pyx index 9aa897444..c2671d785 100644 --- a/spacy/tokenizer.pyx +++ b/spacy/tokenizer.pyx @@ -7,7 +7,9 @@ from cython.operator cimport preincrement as preinc from cymem.cymem cimport Pool from preshed.maps cimport PreshMap +import dill from .strings cimport hash_string +from . import util cimport cython from .tokens.doc cimport Doc @@ -325,15 +327,16 @@ cdef class Tokenizer: self._cache.set(key, cached) self._rules[string] = substrings - def to_disk(self, path): + def to_disk(self, path, **exclude): """Save the current state to a directory. path (unicode or Path): A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. """ - raise NotImplementedError() + with path.open('wb') as file_: + file_.write(self.to_bytes(**exclude)) - def from_disk(self, path): + def from_disk(self, path, **exclude): """Loads state from a directory. Modifies the object in place and returns it. @@ -341,7 +344,10 @@ cdef class Tokenizer: strings or `Path`-like objects. RETURNS (Tokenizer): The modified `Tokenizer` object. """ - raise NotImplementedError() + with path.open('wb') as file_: + bytes_data = file_.read(path) + self.from_bytes(bytes_data, **exclude) + return self def to_bytes(self, **exclude): """Serialize the current state to a binary string. @@ -349,7 +355,16 @@ cdef class Tokenizer: **exclude: Named attributes to prevent from being serialized. RETURNS (bytes): The serialized form of the `Tokenizer` object. """ - raise NotImplementedError() + # TODO: Improve this so it doesn't need pickle + serializers = { + 'vocab': lambda: self.vocab.to_bytes(), + 'prefix': lambda: dill.dumps(self.prefix_search), + 'suffix_search': lambda: dill.dumps(self.suffix_search), + 'infix_finditer': lambda: dill.dumps(self.infix_finditer), + 'token_match': lambda: dill.dumps(self.token_match), + 'exceptions': lambda: dill.dumps(self._rules) + } + return util.to_bytes(serializers, exclude) def from_bytes(self, bytes_data, **exclude): """Load state from a binary string. @@ -358,4 +373,26 @@ cdef class Tokenizer: **exclude: Named attributes to prevent from being loaded. RETURNS (Tokenizer): The `Tokenizer` object. """ - raise NotImplementedError() + # TODO: Improve this so it doesn't need pickle + data = {} + deserializers = { + 'vocab': lambda b: self.vocab.from_bytes(b), + 'prefix': lambda b: data.setdefault('prefix', dill.loads(b)), + 'suffix_search': lambda b: data.setdefault('suffix_search', dill.loads(b)), + 'infix_finditer': lambda b: data.setdefault('infix_finditer', dill.loads(b)), + 'token_match': lambda b: data.setdefault('token_match', dill.loads(b)), + 'exceptions': lambda b: data.setdefault('rules', dill.loads(b)) + } + msg = util.from_bytes(bytes_data, deserializers, exclude) + if 'prefix' in data: + self.prefix_search = data['prefix'] + if 'suffix' in data: + self.suffix_search = data['suffix'] + if 'infix' in data: + self.infix_finditer = data['infix'] + if 'token_match' in data: + self.token_match = data['token_match'] + for string, substrings in data.get('rules', {}).items(): + self.add_special_case(string, substrings) + +