prevent updating cfg if the Model was already defined (#5078)

This commit is contained in:
Sofie Van Landeghem 2020-03-03 13:58:56 +01:00 committed by GitHub
parent d307e9ca58
commit a0998868ff
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -606,7 +606,6 @@ cdef class Parser:
if not hasattr(get_gold_tuples, '__call__'): if not hasattr(get_gold_tuples, '__call__'):
gold_tuples = get_gold_tuples gold_tuples = get_gold_tuples
get_gold_tuples = lambda: gold_tuples get_gold_tuples = lambda: gold_tuples
cfg.setdefault('min_action_freq', 30)
actions = self.moves.get_actions(gold_parses=get_gold_tuples(), actions = self.moves.get_actions(gold_parses=get_gold_tuples(),
min_freq=cfg.get('min_action_freq', 30), min_freq=cfg.get('min_action_freq', 30),
learn_tokens=self.cfg.get("learn_tokens", False)) learn_tokens=self.cfg.get("learn_tokens", False))
@ -616,8 +615,9 @@ cdef class Parser:
if label not in actions[action]: if label not in actions[action]:
actions[action][label] = freq actions[action][label] = freq
self.moves.initialize_actions(actions) self.moves.initialize_actions(actions)
cfg.setdefault('token_vector_width', 96)
if self.model is True: if self.model is True:
cfg.setdefault('min_action_freq', 30)
cfg.setdefault('token_vector_width', 96)
self.model, cfg = self.Model(self.moves.n_moves, **cfg) self.model, cfg = self.Model(self.moves.n_moves, **cfg)
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
@ -633,11 +633,11 @@ cdef class Parser:
if pipeline is not None: if pipeline is not None:
self.init_multitask_objectives(get_gold_tuples, pipeline, sgd=sgd, **cfg) self.init_multitask_objectives(get_gold_tuples, pipeline, sgd=sgd, **cfg)
link_vectors_to_models(self.vocab) link_vectors_to_models(self.vocab)
self.cfg.update(cfg)
else: else:
if sgd is None: if sgd is None:
sgd = self.create_optimizer() sgd = self.create_optimizer()
self.model.begin_training([]) self.model.begin_training([])
self.cfg.update(cfg)
return sgd return sgd
def to_disk(self, path, exclude=tuple(), **kwargs): def to_disk(self, path, exclude=tuple(), **kwargs):