Fix serialization for tagger when tag_map has changed

This commit is contained in:
Matthew Honnibal 2017-06-01 12:18:36 -05:00
parent c6dc2fafc0
commit 307d615c5f
2 changed files with 27 additions and 14 deletions

View File

@ -10,6 +10,7 @@ cimport numpy as np
import cytoolz import cytoolz
import util import util
from collections import OrderedDict from collections import OrderedDict
import ujson
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
from thinc.neural import Model, Maxout, Softmax, Affine from thinc.neural import Model, Maxout, Softmax, Affine
@ -33,6 +34,7 @@ from .gold cimport GoldParse
from .morphology cimport Morphology from .morphology cimport Morphology
from .vocab cimport Vocab from .vocab cimport Vocab
from .syntax import nonproj from .syntax import nonproj
from .compat import json_dumps
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS
from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats
@ -308,7 +310,7 @@ class NeuralTagger(object):
if self.model is True: if self.model is True:
token_vector_width = util.env_opt('token_vector_width', 128) token_vector_width = util.env_opt('token_vector_width', 128)
self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width) self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width)
self.model.from_bytes(b) self.model.from_bytes(b)
deserialize = OrderedDict(( deserialize = OrderedDict((
('vocab', lambda b: self.vocab.from_bytes(b)), ('vocab', lambda b: self.vocab.from_bytes(b)),
('model', lambda b: load_model(b)), ('model', lambda b: load_model(b)),
@ -317,17 +319,33 @@ class NeuralTagger(object):
return self return self
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
serialize = { serialize = OrderedDict((
'model': lambda p: p.open('wb').write(self.model.to_bytes()), ('vocab', lambda p: self.vocab.to_disk(p)),
'vocab': lambda p: self.vocab.to_disk(p) ('tag_map', lambda p: p.open('w').write(json_dumps(
} self.vocab.morphology.tag_map))),
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
))
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
def from_disk(self, path, **exclude): def from_disk(self, path, **exclude):
deserialize = { def load_model(p):
'model': lambda p: self.model.from_bytes(p.open('rb').read()), if self.model is True:
'vocab': lambda p: self.vocab.from_disk(p) token_vector_width = util.env_opt('token_vector_width', 128)
} self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width)
self.model.from_bytes(p.open('rb').read())
def load_tag_map(p):
with p.open() as file_:
tag_map = ujson.loads(file_.read())
self.vocab.morphology = Morphology(
self.vocab.strings, tag_map=tag_map,
lemmatizer=self.vocab.morphology.lemmatizer)
deserialize = OrderedDict((
('vocab', lambda p: self.vocab.from_disk(p)),
('tag_map', load_tag_map),
('model', load_model),
))
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)
return self return self

View File

@ -315,7 +315,6 @@ cdef class Vocab:
getters = OrderedDict(( getters = OrderedDict((
('strings', lambda: self.strings.to_bytes()), ('strings', lambda: self.strings.to_bytes()),
('lexemes', lambda: self.lexemes_to_bytes()), ('lexemes', lambda: self.lexemes_to_bytes()),
('tag_map', lambda: self.morphology.tag_map),
)) ))
return util.to_bytes(getters, exclude) return util.to_bytes(getters, exclude)
@ -326,13 +325,9 @@ cdef class Vocab:
**exclude: Named attributes to prevent from being loaded. **exclude: Named attributes to prevent from being loaded.
RETURNS (Vocab): The `Vocab` object. RETURNS (Vocab): The `Vocab` object.
""" """
def set_tag_map(tag_map):
self.morphology = Morphology(self.strings, tag_map,
self.morphology.lemmatizer)
setters = OrderedDict(( setters = OrderedDict((
('strings', lambda b: self.strings.from_bytes(b)), ('strings', lambda b: self.strings.from_bytes(b)),
('lexemes', lambda b: self.lexemes_from_bytes(b)), ('lexemes', lambda b: self.lexemes_from_bytes(b)),
('tag_map', lambda b: set_tag_map(b))
)) ))
return util.from_bytes(bytes_data, setters, exclude) return util.from_bytes(bytes_data, setters, exclude)