mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
Fix serialization for tagger when tag_map has changed
This commit is contained in:
parent
c6dc2fafc0
commit
307d615c5f
|
@ -10,6 +10,7 @@ cimport numpy as np
|
||||||
import cytoolz
|
import cytoolz
|
||||||
import util
|
import util
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
import ujson
|
||||||
|
|
||||||
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
from thinc.api import add, layerize, chain, clone, concatenate, with_flatten
|
||||||
from thinc.neural import Model, Maxout, Softmax, Affine
|
from thinc.neural import Model, Maxout, Softmax, Affine
|
||||||
|
@ -33,6 +34,7 @@ from .gold cimport GoldParse
|
||||||
from .morphology cimport Morphology
|
from .morphology cimport Morphology
|
||||||
from .vocab cimport Vocab
|
from .vocab cimport Vocab
|
||||||
from .syntax import nonproj
|
from .syntax import nonproj
|
||||||
|
from .compat import json_dumps
|
||||||
|
|
||||||
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS
|
from .attrs import ID, LOWER, PREFIX, SUFFIX, SHAPE, TAG, DEP, POS
|
||||||
from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats
|
from ._ml import rebatch, Tok2Vec, flatten, get_col, doc2feats
|
||||||
|
@ -308,7 +310,7 @@ class NeuralTagger(object):
|
||||||
if self.model is True:
|
if self.model is True:
|
||||||
token_vector_width = util.env_opt('token_vector_width', 128)
|
token_vector_width = util.env_opt('token_vector_width', 128)
|
||||||
self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width)
|
self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width)
|
||||||
self.model.from_bytes(b)
|
self.model.from_bytes(b)
|
||||||
deserialize = OrderedDict((
|
deserialize = OrderedDict((
|
||||||
('vocab', lambda b: self.vocab.from_bytes(b)),
|
('vocab', lambda b: self.vocab.from_bytes(b)),
|
||||||
('model', lambda b: load_model(b)),
|
('model', lambda b: load_model(b)),
|
||||||
|
@ -317,17 +319,33 @@ class NeuralTagger(object):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
serialize = {
|
serialize = OrderedDict((
|
||||||
'model': lambda p: p.open('wb').write(self.model.to_bytes()),
|
('vocab', lambda p: self.vocab.to_disk(p)),
|
||||||
'vocab': lambda p: self.vocab.to_disk(p)
|
('tag_map', lambda p: p.open('w').write(json_dumps(
|
||||||
}
|
self.vocab.morphology.tag_map))),
|
||||||
|
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
|
||||||
|
))
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
def from_disk(self, path, **exclude):
|
def from_disk(self, path, **exclude):
|
||||||
deserialize = {
|
def load_model(p):
|
||||||
'model': lambda p: self.model.from_bytes(p.open('rb').read()),
|
if self.model is True:
|
||||||
'vocab': lambda p: self.vocab.from_disk(p)
|
token_vector_width = util.env_opt('token_vector_width', 128)
|
||||||
}
|
self.model = self.Model(self.vocab.morphology.n_tags, token_vector_width)
|
||||||
|
self.model.from_bytes(p.open('rb').read())
|
||||||
|
|
||||||
|
def load_tag_map(p):
|
||||||
|
with p.open() as file_:
|
||||||
|
tag_map = ujson.loads(file_.read())
|
||||||
|
self.vocab.morphology = Morphology(
|
||||||
|
self.vocab.strings, tag_map=tag_map,
|
||||||
|
lemmatizer=self.vocab.morphology.lemmatizer)
|
||||||
|
|
||||||
|
deserialize = OrderedDict((
|
||||||
|
('vocab', lambda p: self.vocab.from_disk(p)),
|
||||||
|
('tag_map', load_tag_map),
|
||||||
|
('model', load_model),
|
||||||
|
))
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
|
@ -315,7 +315,6 @@ cdef class Vocab:
|
||||||
getters = OrderedDict((
|
getters = OrderedDict((
|
||||||
('strings', lambda: self.strings.to_bytes()),
|
('strings', lambda: self.strings.to_bytes()),
|
||||||
('lexemes', lambda: self.lexemes_to_bytes()),
|
('lexemes', lambda: self.lexemes_to_bytes()),
|
||||||
('tag_map', lambda: self.morphology.tag_map),
|
|
||||||
))
|
))
|
||||||
return util.to_bytes(getters, exclude)
|
return util.to_bytes(getters, exclude)
|
||||||
|
|
||||||
|
@ -326,13 +325,9 @@ cdef class Vocab:
|
||||||
**exclude: Named attributes to prevent from being loaded.
|
**exclude: Named attributes to prevent from being loaded.
|
||||||
RETURNS (Vocab): The `Vocab` object.
|
RETURNS (Vocab): The `Vocab` object.
|
||||||
"""
|
"""
|
||||||
def set_tag_map(tag_map):
|
|
||||||
self.morphology = Morphology(self.strings, tag_map,
|
|
||||||
self.morphology.lemmatizer)
|
|
||||||
setters = OrderedDict((
|
setters = OrderedDict((
|
||||||
('strings', lambda b: self.strings.from_bytes(b)),
|
('strings', lambda b: self.strings.from_bytes(b)),
|
||||||
('lexemes', lambda b: self.lexemes_from_bytes(b)),
|
('lexemes', lambda b: self.lexemes_from_bytes(b)),
|
||||||
('tag_map', lambda b: set_tag_map(b))
|
|
||||||
))
|
))
|
||||||
return util.from_bytes(bytes_data, setters, exclude)
|
return util.from_bytes(bytes_data, setters, exclude)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user