Five save/load of textcat config

This commit is contained in:
Matthew Honnibal 2017-07-23 00:33:43 +02:00
parent c27fdaef6f
commit a88a7deffe

View File

@ -109,7 +109,8 @@ class BaseThincComponent(object):
def to_disk(self, path, **exclude): def to_disk(self, path, **exclude):
serialize = OrderedDict(( serialize = OrderedDict((
('model', lambda p: p.open('wb').write(self.model.to_bytes())), ('model', lambda p: p.open('wb').write(self.model.to_bytes())),
('vocab', lambda p: self.vocab.to_disk(p)) ('vocab', lambda p: self.vocab.to_disk(p)),
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg)))
)) ))
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
@ -118,7 +119,8 @@ class BaseThincComponent(object):
self.model = self.Model() self.model = self.Model()
deserialize = OrderedDict(( deserialize = OrderedDict((
('model', lambda p: self.model.from_bytes(p.open('rb').read())), ('model', lambda p: self.model.from_bytes(p.open('rb').read())),
('vocab', lambda p: self.vocab.from_disk(p)) ('vocab', lambda p: self.vocab.from_disk(p)),
('cfg', lambda p: self.cfg.update(ujson.load(p.open()))),
)) ))
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)
return self return self
@ -383,6 +385,7 @@ class NeuralTagger(BaseThincComponent):
use_bin_type=True, use_bin_type=True,
encoding='utf8'))), encoding='utf8'))),
('model', lambda p: p.open('wb').write(self.model.to_bytes())), ('model', lambda p: p.open('wb').write(self.model.to_bytes())),
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg)))
)) ))
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
@ -405,6 +408,7 @@ class NeuralTagger(BaseThincComponent):
('vocab', lambda p: self.vocab.from_disk(p)), ('vocab', lambda p: self.vocab.from_disk(p)),
('tag_map', load_tag_map), ('tag_map', load_tag_map),
('model', load_model), ('model', load_model),
('cfg', lambda p: self.cfg.update(ujson.load(p.open()))),
)) ))
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)
return self return self
@ -523,7 +527,15 @@ class TextCategorizer(BaseThincComponent):
def __init__(self, vocab, model=True, **cfg): def __init__(self, vocab, model=True, **cfg):
self.vocab = vocab self.vocab = vocab
self.model = model self.model = model
self.labels = cfg.get('labels', ['LABEL']) self.cfg = cfg
@property
def labels(self):
return self.cfg.get('labels', ['LABEL'])
@labels.setter
def labels(self, value):
self.cfg['labels'] = value
def __call__(self, doc): def __call__(self, doc):
scores = self.predict([doc]) scores = self.predict([doc])