Expose early_update option

Also fix a small bug -- C fields of Beam could not be accessed because
it was passed as a Python object.
This commit is contained in:
Daniël de Kok 2022-07-20 10:52:31 +02:00
parent c04ae74268
commit 968e6c7bf5
3 changed files with 13 additions and 3 deletions

View File

@ -207,7 +207,7 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
return loss
def _lowest_score_has_cost(beam: Beam) -> bool:
def _lowest_score_has_cost(Beam beam) -> bool:
"""Check whether the lowest-scoring candidate
in a parse is marked as non-gold, i.e. it has a cost
> 0."""

View File

@ -130,6 +130,7 @@ def make_parser(
"beam_width": 8,
"beam_density": 0.01,
"beam_update_prob": 0.5,
"early_update": False,
"moves": None,
"update_with_oracle_cut_size": 100,
"learn_tokens": False,
@ -157,6 +158,7 @@ def make_beam_parser(
beam_width: int,
beam_density: float,
beam_update_prob: float,
early_update: bool,
scorer: Optional[Callable],
):
"""Create a transition-based DependencyParser component that uses beam-search.
@ -209,6 +211,7 @@ def make_beam_parser(
beam_width=beam_width,
beam_density=beam_density,
beam_update_prob=beam_update_prob,
early_update=early_update,
multitasks=[],
learn_tokens=learn_tokens,
min_action_freq=min_action_freq,
@ -269,6 +272,7 @@ cdef class DependencyParser(Parser):
beam_width=1,
beam_density=0.0,
beam_update_prob=0.0,
early_update=False,
multitasks=tuple(),
incorrect_spans_key=None,
scorer=parser_score,
@ -286,6 +290,7 @@ cdef class DependencyParser(Parser):
beam_width=beam_width,
beam_density=beam_density,
beam_update_prob=beam_update_prob,
early_update=early_update,
multitasks=multitasks,
incorrect_spans_key=incorrect_spans_key,
scorer=scorer,

View File

@ -51,6 +51,7 @@ cdef class Parser(TrainablePipe):
beam_width=1,
beam_density=0.0,
beam_update_prob=0.0,
early_update=False,
multitasks=tuple(),
incorrect_spans_key=None,
scorer=None,
@ -103,6 +104,7 @@ cdef class Parser(TrainablePipe):
"beam_width": beam_width,
"beam_density": beam_density,
"beam_update_prob": beam_update_prob,
"early_update": early_update,
"incorrect_spans_key": incorrect_spans_key
}
if moves is None:
@ -387,7 +389,8 @@ cdef class Parser(TrainablePipe):
beam_width=self.cfg["beam_width"],
sgd=sgd,
losses=losses,
beam_density=self.cfg["beam_density"]
beam_density=self.cfg["beam_density"],
early_update=self.cfg["early_update"]
)
max_moves = self.cfg["update_with_oracle_cut_size"]
if max_moves >= 1:
@ -485,7 +488,8 @@ cdef class Parser(TrainablePipe):
return losses
def update_beam(self, examples, *, beam_width,
drop=0., sgd=None, losses=None, beam_density=0.0):
drop=0., sgd=None, losses=None, beam_density=0.0,
early_update=False):
states, golds, _ = self.moves.init_gold_batch(examples)
if not states:
return losses
@ -499,6 +503,7 @@ cdef class Parser(TrainablePipe):
model,
beam_width,
beam_density=beam_density,
early_update=early_update,
)
losses[self.name] += loss
backprop_tok2vec(golds)