mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Five save/load of textcat config
This commit is contained in:
parent
c27fdaef6f
commit
a88a7deffe
|
@ -109,7 +109,8 @@ class BaseThincComponent(object):
|
||||||
def to_disk(self, path, **exclude):
|
def to_disk(self, path, **exclude):
|
||||||
serialize = OrderedDict((
|
serialize = OrderedDict((
|
||||||
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
|
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
|
||||||
('vocab', lambda p: self.vocab.to_disk(p))
|
('vocab', lambda p: self.vocab.to_disk(p)),
|
||||||
|
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg)))
|
||||||
))
|
))
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
|
@ -118,7 +119,8 @@ class BaseThincComponent(object):
|
||||||
self.model = self.Model()
|
self.model = self.Model()
|
||||||
deserialize = OrderedDict((
|
deserialize = OrderedDict((
|
||||||
('model', lambda p: self.model.from_bytes(p.open('rb').read())),
|
('model', lambda p: self.model.from_bytes(p.open('rb').read())),
|
||||||
('vocab', lambda p: self.vocab.from_disk(p))
|
('vocab', lambda p: self.vocab.from_disk(p)),
|
||||||
|
('cfg', lambda p: self.cfg.update(ujson.load(p.open()))),
|
||||||
))
|
))
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
@ -383,6 +385,7 @@ class NeuralTagger(BaseThincComponent):
|
||||||
use_bin_type=True,
|
use_bin_type=True,
|
||||||
encoding='utf8'))),
|
encoding='utf8'))),
|
||||||
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
|
('model', lambda p: p.open('wb').write(self.model.to_bytes())),
|
||||||
|
('cfg', lambda p: p.open('w').write(json_dumps(self.cfg)))
|
||||||
))
|
))
|
||||||
util.to_disk(path, serialize, exclude)
|
util.to_disk(path, serialize, exclude)
|
||||||
|
|
||||||
|
@ -405,6 +408,7 @@ class NeuralTagger(BaseThincComponent):
|
||||||
('vocab', lambda p: self.vocab.from_disk(p)),
|
('vocab', lambda p: self.vocab.from_disk(p)),
|
||||||
('tag_map', load_tag_map),
|
('tag_map', load_tag_map),
|
||||||
('model', load_model),
|
('model', load_model),
|
||||||
|
('cfg', lambda p: self.cfg.update(ujson.load(p.open()))),
|
||||||
))
|
))
|
||||||
util.from_disk(path, deserialize, exclude)
|
util.from_disk(path, deserialize, exclude)
|
||||||
return self
|
return self
|
||||||
|
@ -523,7 +527,15 @@ class TextCategorizer(BaseThincComponent):
|
||||||
def __init__(self, vocab, model=True, **cfg):
|
def __init__(self, vocab, model=True, **cfg):
|
||||||
self.vocab = vocab
|
self.vocab = vocab
|
||||||
self.model = model
|
self.model = model
|
||||||
self.labels = cfg.get('labels', ['LABEL'])
|
self.cfg = cfg
|
||||||
|
|
||||||
|
@property
|
||||||
|
def labels(self):
|
||||||
|
return self.cfg.get('labels', ['LABEL'])
|
||||||
|
|
||||||
|
@labels.setter
|
||||||
|
def labels(self, value):
|
||||||
|
self.cfg['labels'] = value
|
||||||
|
|
||||||
def __call__(self, doc):
|
def __call__(self, doc):
|
||||||
scores = self.predict([doc])
|
scores = self.predict([doc])
|
||||||
|
|
Loading…
Reference in New Issue
Block a user