Five save/load of textcat config

This commit is contained in:
Matthew Honnibal 2017-07-23 00:33:43 +02:00
parent c27fdaef6f
commit a88a7deffe

View File

@ -109,7 +109,8 @@ class BaseThincComponent(object):
def to_disk(self, path, **exclude):
serialize = OrderedDict((
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
('vocab', lambda p: self.vocab.to_disk(p))
('vocab', lambda p: self.vocab.to_disk(p)),
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg)))
))
util.to_disk(path, serialize, exclude)
@ -118,7 +119,8 @@ class BaseThincComponent(object):
self.model = self.Model()
deserialize = OrderedDict((
('model', lambda p: self.model.from_bytes(p.open('rb').read())),
('vocab', lambda p: self.vocab.from_disk(p))
('vocab', lambda p: self.vocab.from_disk(p)),
('cfg', lambda p: self.cfg.update(ujson.load(p.open()))),
))
util.from_disk(path, deserialize, exclude)
return self
@ -383,6 +385,7 @@ class NeuralTagger(BaseThincComponent):
use_bin_type=True,
encoding='utf8'))),
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg)))
))
util.to_disk(path, serialize, exclude)
@ -405,6 +408,7 @@ class NeuralTagger(BaseThincComponent):
('vocab', lambda p: self.vocab.from_disk(p)),
('tag_map', load_tag_map),
('model', load_model),
('cfg', lambda p: self.cfg.update(ujson.load(p.open()))),
))
util.from_disk(path, deserialize, exclude)
return self
@ -523,7 +527,15 @@ class TextCategorizer(BaseThincComponent):
def __init__(self, vocab, model=True, **cfg):
self.vocab = vocab
self.model = model
self.labels = cfg.get('labels', ['LABEL'])
self.cfg = cfg
@property
def labels(self):
return self.cfg.get('labels', ['LABEL'])
@labels.setter
def labels(self, value):
self.cfg['labels'] = value
def __call__(self, doc):
scores = self.predict([doc])