mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-11 17:56:30 +03:00
Expose early_update option
Also fix a small bug -- C fields of Beam could not be accessed because it was passed as a Python object.
This commit is contained in:
parent
c04ae74268
commit
968e6c7bf5
|
@ -207,7 +207,7 @@ def update_beam(TransitionSystem moves, states, golds, model, int width, beam_de
|
|||
return loss
|
||||
|
||||
|
||||
def _lowest_score_has_cost(beam: Beam) -> bool:
|
||||
def _lowest_score_has_cost(Beam beam) -> bool:
|
||||
"""Check whether the lowest-scoring candidate
|
||||
in a parse is marked as non-gold, i.e. it has a cost
|
||||
> 0."""
|
||||
|
|
|
@ -130,6 +130,7 @@ def make_parser(
|
|||
"beam_width": 8,
|
||||
"beam_density": 0.01,
|
||||
"beam_update_prob": 0.5,
|
||||
"early_update": False,
|
||||
"moves": None,
|
||||
"update_with_oracle_cut_size": 100,
|
||||
"learn_tokens": False,
|
||||
|
@ -157,6 +158,7 @@ def make_beam_parser(
|
|||
beam_width: int,
|
||||
beam_density: float,
|
||||
beam_update_prob: float,
|
||||
early_update: bool,
|
||||
scorer: Optional[Callable],
|
||||
):
|
||||
"""Create a transition-based DependencyParser component that uses beam-search.
|
||||
|
@ -209,6 +211,7 @@ def make_beam_parser(
|
|||
beam_width=beam_width,
|
||||
beam_density=beam_density,
|
||||
beam_update_prob=beam_update_prob,
|
||||
early_update=early_update,
|
||||
multitasks=[],
|
||||
learn_tokens=learn_tokens,
|
||||
min_action_freq=min_action_freq,
|
||||
|
@ -269,6 +272,7 @@ cdef class DependencyParser(Parser):
|
|||
beam_width=1,
|
||||
beam_density=0.0,
|
||||
beam_update_prob=0.0,
|
||||
early_update=False,
|
||||
multitasks=tuple(),
|
||||
incorrect_spans_key=None,
|
||||
scorer=parser_score,
|
||||
|
@ -286,6 +290,7 @@ cdef class DependencyParser(Parser):
|
|||
beam_width=beam_width,
|
||||
beam_density=beam_density,
|
||||
beam_update_prob=beam_update_prob,
|
||||
early_update=early_update,
|
||||
multitasks=multitasks,
|
||||
incorrect_spans_key=incorrect_spans_key,
|
||||
scorer=scorer,
|
||||
|
|
|
@ -51,6 +51,7 @@ cdef class Parser(TrainablePipe):
|
|||
beam_width=1,
|
||||
beam_density=0.0,
|
||||
beam_update_prob=0.0,
|
||||
early_update=False,
|
||||
multitasks=tuple(),
|
||||
incorrect_spans_key=None,
|
||||
scorer=None,
|
||||
|
@ -103,6 +104,7 @@ cdef class Parser(TrainablePipe):
|
|||
"beam_width": beam_width,
|
||||
"beam_density": beam_density,
|
||||
"beam_update_prob": beam_update_prob,
|
||||
"early_update": early_update,
|
||||
"incorrect_spans_key": incorrect_spans_key
|
||||
}
|
||||
if moves is None:
|
||||
|
@ -387,7 +389,8 @@ cdef class Parser(TrainablePipe):
|
|||
beam_width=self.cfg["beam_width"],
|
||||
sgd=sgd,
|
||||
losses=losses,
|
||||
beam_density=self.cfg["beam_density"]
|
||||
beam_density=self.cfg["beam_density"],
|
||||
early_update=self.cfg["early_update"]
|
||||
)
|
||||
max_moves = self.cfg["update_with_oracle_cut_size"]
|
||||
if max_moves >= 1:
|
||||
|
@ -485,7 +488,8 @@ cdef class Parser(TrainablePipe):
|
|||
return losses
|
||||
|
||||
def update_beam(self, examples, *, beam_width,
|
||||
drop=0., sgd=None, losses=None, beam_density=0.0):
|
||||
drop=0., sgd=None, losses=None, beam_density=0.0,
|
||||
early_update=False):
|
||||
states, golds, _ = self.moves.init_gold_batch(examples)
|
||||
if not states:
|
||||
return losses
|
||||
|
@ -499,6 +503,7 @@ cdef class Parser(TrainablePipe):
|
|||
model,
|
||||
beam_width,
|
||||
beam_density=beam_density,
|
||||
early_update=early_update,
|
||||
)
|
||||
losses[self.name] += loss
|
||||
backprop_tok2vec(golds)
|
||||
|
|
Loading…
Reference in New Issue
Block a user