Fix serialization of tokenizer

This commit is contained in:
Matthew Honnibal 2017-05-31 11:43:40 +02:00
parent e98eff275d
commit 66af019d5d

View File

@ -355,14 +355,13 @@ cdef class Tokenizer:
**exclude: Named attributes to prevent from being serialized. **exclude: Named attributes to prevent from being serialized.
RETURNS (bytes): The serialized form of the `Tokenizer` object. RETURNS (bytes): The serialized form of the `Tokenizer` object.
""" """
# TODO: Improve this so it doesn't need pickle
serializers = { serializers = {
'vocab': lambda: self.vocab.to_bytes(), 'vocab': lambda: self.vocab.to_bytes(),
'prefix': lambda: dill.dumps(self.prefix_search), 'prefix': lambda: self.prefix_search.__self__.pattern,
'suffix_search': lambda: dill.dumps(self.suffix_search), 'suffix_search': lambda: self.suffix_search.__self__.pattern,
'infix_finditer': lambda: dill.dumps(self.infix_finditer), 'infix_finditer': lambda: self.infix_finditer.__self__.pattern,
'token_match': lambda: dill.dumps(self.token_match), 'token_match': lambda: self.token_match.__self__.pattern,
'exceptions': lambda: dill.dumps(self._rules) 'exceptions': lambda: self._rules
} }
return util.to_bytes(serializers, exclude) return util.to_bytes(serializers, exclude)
@ -373,26 +372,23 @@ cdef class Tokenizer:
**exclude: Named attributes to prevent from being loaded. **exclude: Named attributes to prevent from being loaded.
RETURNS (Tokenizer): The `Tokenizer` object. RETURNS (Tokenizer): The `Tokenizer` object.
""" """
# TODO: Improve this so it doesn't need pickle
data = {} data = {}
deserializers = { deserializers = {
'vocab': lambda b: self.vocab.from_bytes(b), 'vocab': lambda b: self.vocab.from_bytes(b),
'prefix': lambda b: data.setdefault('prefix', dill.loads(b)), 'prefix': lambda b: data.setdefault('prefix', b),
'suffix_search': lambda b: data.setdefault('suffix_search', dill.loads(b)), 'suffix_search': lambda b: data.setdefault('suffix_search', b),
'infix_finditer': lambda b: data.setdefault('infix_finditer', dill.loads(b)), 'infix_finditer': lambda b: data.setdefault('infix_finditer', b),
'token_match': lambda b: data.setdefault('token_match', dill.loads(b)), 'token_match': lambda b: data.setdefault('token_match', b),
'exceptions': lambda b: data.setdefault('rules', dill.loads(b)) 'exceptions': lambda b: data.setdefault('rules', b)
} }
msg = util.from_bytes(bytes_data, deserializers, exclude) msg = util.from_bytes(bytes_data, deserializers, exclude)
if 'prefix' in data: if 'prefix' in data:
self.prefix_search = data['prefix'] self.prefix_search = re.compile(data['prefix'])
if 'suffix' in data: if 'suffix' in data:
self.suffix_search = data['suffix'] self.suffix_search = re.compile(data['suffix'])
if 'infix' in data: if 'infix' in data:
self.infix_finditer = data['infix'] self.infix_finditer = re.compile(data['infix'])
if 'token_match' in data: if 'token_match' in data:
self.token_match = data['token_match'] self.token_match = re.compile(data['token_match'])
for string, substrings in data.get('rules', {}).items(): for string, substrings in data.get('rules', {}).items():
self.add_special_case(string, substrings) self.add_special_case(string, substrings)