mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Add config for beam density
This commit is contained in:
parent
532fa36c13
commit
de82552a13
|
@ -70,20 +70,23 @@ def get_templates(name):
|
|||
pf.tree_shape + pf.trigrams)
|
||||
|
||||
|
||||
cdef int BEAM_WIDTH = 8
|
||||
cdef int BEAM_WIDTH = 16
|
||||
cdef weight_t BEAM_DENSITY = 0.01
|
||||
|
||||
cdef class BeamParser(Parser):
|
||||
cdef public int beam_width
|
||||
cdef public weight_t beam_density
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.beam_width = kwargs.get('beam_width', BEAM_WIDTH)
|
||||
self.beam_density = kwargs.get('beam_density', BEAM_DENSITY)
|
||||
Parser.__init__(self, *args, **kwargs)
|
||||
|
||||
cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) with gil:
|
||||
self._parseC(tokens, length, nr_feat, nr_class)
|
||||
|
||||
cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1:
|
||||
cdef Beam beam = Beam(self.moves.n_moves, self.beam_width)
|
||||
cdef Beam beam = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density)
|
||||
beam.initialize(_init_state, length, tokens)
|
||||
beam.check_done(_check_final_state, NULL)
|
||||
while not beam.is_done:
|
||||
|
@ -96,11 +99,11 @@ cdef class BeamParser(Parser):
|
|||
|
||||
def train(self, Doc tokens, GoldParse gold_parse, itn=0):
|
||||
self.moves.preprocess_gold(gold_parse)
|
||||
cdef Beam pred = Beam(self.moves.n_moves, self.beam_width)
|
||||
cdef Beam pred = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density)
|
||||
pred.initialize(_init_state, tokens.length, tokens.c)
|
||||
pred.check_done(_check_final_state, NULL)
|
||||
|
||||
cdef Beam gold = Beam(self.moves.n_moves, self.beam_width)
|
||||
cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density)
|
||||
gold.initialize(_init_state, tokens.length, tokens.c)
|
||||
gold.check_done(_check_final_state, NULL)
|
||||
violn = MaxViolation()
|
||||
|
|
Loading…
Reference in New Issue
Block a user