mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-25 00:34:20 +03:00
Fix serialization of tag_map in NeuralTagger
This commit is contained in:
parent
acc47e2673
commit
5f4d328e2c
|
@ -11,6 +11,7 @@ import cytoolz
|
|||
import util
|
||||
from collections import OrderedDict
|
||||
import ujson
|
||||
import msgpack
|
||||
|
||||
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
||||
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||
|
@ -301,7 +302,8 @@ class NeuralTagger(object):
|
|||
def to_bytes(self, **exclude):
|
||||
serialize = OrderedDict((
|
||||
('model', lambda: self.model.to_bytes()),
|
||||
('vocab', lambda: self.vocab.to_bytes())
|
||||
('vocab', lambda: self.vocab.to_bytes()),
|
||||
('tag_map', lambda: msgpack.dumps(self.vocab.morphology.tag_map))
|
||||
))
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
|
@ -311,8 +313,15 @@ class NeuralTagger(object):
|
|||
token_vector_width = util.env_opt('token_vector_width', 128)
|
||||
self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width)
|
||||
self.model.from_bytes(b)
|
||||
|
||||
def load_tag_map(b):
|
||||
tag_map = msgpack.loads(b)
|
||||
self.vocab.morphology = Morphology(
|
||||
self.vocab.strings, tag_map=tag_map,
|
||||
lemmatizer=self.vocab.morphology.lemmatizer)
|
||||
deserialize = OrderedDict((
|
||||
('vocab', lambda b: self.vocab.from_bytes(b)),
|
||||
('tag_map', load_tag_map),
|
||||
('model', lambda b: load_model(b)),
|
||||
))
|
||||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
|
@ -321,7 +330,7 @@ class NeuralTagger(object):
|
|||
def to_disk(self, path, **exclude):
|
||||
serialize = OrderedDict((
|
||||
('vocab', lambda p: self.vocab.to_disk(p)),
|
||||
('tag_map', lambda p: p.open('w').write(json_dumps(
|
||||
('tag_map', lambda p: p.open('w').write(msgpack.dumps(
|
||||
self.vocab.morphology.tag_map))),
|
||||
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
|
||||
))
|
||||
|
@ -336,7 +345,7 @@ class NeuralTagger(object):
|
|||
|
||||
def load_tag_map(p):
|
||||
with p.open() as file_:
|
||||
tag_map = ujson.loads(file_.read())
|
||||
tag_map = msgpack.loads(file_.read())
|
||||
self.vocab.morphology = Morphology(
|
||||
self.vocab.strings, tag_map=tag_map,
|
||||
lemmatizer=self.vocab.morphology.lemmatizer)
|
||||
|
|
Loading…
Reference in New Issue
Block a user