mirror of
https://github.com/explosion/spaCy.git
synced 2025-05-31 03:03:17 +03:00
Fix training for new labels
This commit is contained in:
parent
c729d72fc6
commit
c76cb8af35
|
@ -151,7 +151,6 @@ cdef class Parser:
|
||||||
if isinstance(labels, dict):
|
if isinstance(labels, dict):
|
||||||
labels = list(sorted(labels.keys()))
|
labels = list(sorted(labels.keys()))
|
||||||
cfg['actions'][action_name] = labels
|
cfg['actions'][action_name] = labels
|
||||||
print(cfg['actions'])
|
|
||||||
self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
|
self = cls(vocab, TransitionSystem=TransitionSystem, model=None, **cfg)
|
||||||
if (path / 'model').exists():
|
if (path / 'model').exists():
|
||||||
self.model.load(str(path / 'model'))
|
self.model.load(str(path / 'model'))
|
||||||
|
@ -187,6 +186,11 @@ cdef class Parser:
|
||||||
self.model.learn_rate = cfg.get('learn_rate', 0.001)
|
self.model.learn_rate = cfg.get('learn_rate', 0.001)
|
||||||
|
|
||||||
self.cfg = cfg
|
self.cfg = cfg
|
||||||
|
# TODO: This is a pretty hacky fix to the problem of adding more
|
||||||
|
# labels. The issue is they come in out of order, if labels are
|
||||||
|
# added during training
|
||||||
|
for label in cfg.get('extra_labels', []):
|
||||||
|
self.add_label(label)
|
||||||
|
|
||||||
def __reduce__(self):
|
def __reduce__(self):
|
||||||
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
return (Parser, (self.vocab, self.moves, self.model), None, None)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user