Expose early_update option

Also fix a small bug -- C fields of Beam could not be accessed because
it was passed as a Python object.
This commit is contained in:
Daniël de Kok 2022-07-20 10:52:31 +02:00
parent c04ae74268
commit 968e6c7bf5
3 changed files with 13 additions and 3 deletions

View File

@ -207,7 +207,7 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
return loss return loss
def _lowest_score_has_cost(beam: Beam) -> bool: def _lowest_score_has_cost(Beam beam) -> bool:
"""Check whether the lowest-scoring candidate """Check whether the lowest-scoring candidate
in a parse is marked as non-gold, i.e. it has a cost in a parse is marked as non-gold, i.e. it has a cost
> 0.""" > 0."""

View File

@ -130,6 +130,7 @@ def make_parser(
"beam_width": 8, "beam_width": 8,
"beam_density": 0.01, "beam_density": 0.01,
"beam_update_prob": 0.5, "beam_update_prob": 0.5,
"early_update": False,
"moves": None, "moves": None,
"update_with_oracle_cut_size": 100, "update_with_oracle_cut_size": 100,
"learn_tokens": False, "learn_tokens": False,
@ -157,6 +158,7 @@ def make_beam_parser(
beam_width: int, beam_width: int,
beam_density: float, beam_density: float,
beam_update_prob: float, beam_update_prob: float,
early_update: bool,
scorer: Optional[Callable], scorer: Optional[Callable],
): ):
"""Create a transition-based DependencyParser component that uses beam-search. """Create a transition-based DependencyParser component that uses beam-search.
@ -209,6 +211,7 @@ def make_beam_parser(
beam_width=beam_width, beam_width=beam_width,
beam_density=beam_density, beam_density=beam_density,
beam_update_prob=beam_update_prob, beam_update_prob=beam_update_prob,
early_update=early_update,
multitasks=[], multitasks=[],
learn_tokens=learn_tokens, learn_tokens=learn_tokens,
min_action_freq=min_action_freq, min_action_freq=min_action_freq,
@ -269,6 +272,7 @@ cdef class DependencyParser(Parser):
beam_width=1, beam_width=1,
beam_density=0.0, beam_density=0.0,
beam_update_prob=0.0, beam_update_prob=0.0,
early_update=False,
multitasks=tuple(), multitasks=tuple(),
incorrect_spans_key=None, incorrect_spans_key=None,
scorer=parser_score, scorer=parser_score,
@ -286,6 +290,7 @@ cdef class DependencyParser(Parser):
beam_width=beam_width, beam_width=beam_width,
beam_density=beam_density, beam_density=beam_density,
beam_update_prob=beam_update_prob, beam_update_prob=beam_update_prob,
early_update=early_update,
multitasks=multitasks, multitasks=multitasks,
incorrect_spans_key=incorrect_spans_key, incorrect_spans_key=incorrect_spans_key,
scorer=scorer, scorer=scorer,

View File

@ -51,6 +51,7 @@ cdef class Parser(TrainablePipe):
beam_width=1, beam_width=1,
beam_density=0.0, beam_density=0.0,
beam_update_prob=0.0, beam_update_prob=0.0,
early_update=False,
multitasks=tuple(), multitasks=tuple(),
incorrect_spans_key=None, incorrect_spans_key=None,
scorer=None, scorer=None,
@ -103,6 +104,7 @@ cdef class Parser(TrainablePipe):
"beam_width": beam_width, "beam_width": beam_width,
"beam_density": beam_density, "beam_density": beam_density,
"beam_update_prob": beam_update_prob, "beam_update_prob": beam_update_prob,
"early_update": early_update,
"incorrect_spans_key": incorrect_spans_key "incorrect_spans_key": incorrect_spans_key
} }
if moves is None: if moves is None:
@ -387,7 +389,8 @@ cdef class Parser(TrainablePipe):
beam_width=self.cfg["beam_width"], beam_width=self.cfg["beam_width"],
sgd=sgd, sgd=sgd,
losses=losses, losses=losses,
beam_density=self.cfg["beam_density"] beam_density=self.cfg["beam_density"],
early_update=self.cfg["early_update"]
) )
max_moves = self.cfg["update_with_oracle_cut_size"] max_moves = self.cfg["update_with_oracle_cut_size"]
if max_moves >= 1: if max_moves >= 1:
@ -485,7 +488,8 @@ cdef class Parser(TrainablePipe):
return losses return losses
def update_beam(self, examples, *, beam_width, def update_beam(self, examples, *, beam_width,
drop=0., sgd=None, losses=None, beam_density=0.0): drop=0., sgd=None, losses=None, beam_density=0.0,
early_update=False):
states, golds, _ = self.moves.init_gold_batch(examples) states, golds, _ = self.moves.init_gold_batch(examples)
if not states: if not states:
return losses return losses
@ -499,6 +503,7 @@ cdef class Parser(TrainablePipe):
model, model,
beam_width, beam_width,
beam_density=beam_density, beam_density=beam_density,
early_update=early_update,
) )
losses[self.name] += loss losses[self.name] += loss
backprop_tok2vec(golds) backprop_tok2vec(golds)