mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-10 16:22:29 +03:00
Add config for beam density
This commit is contained in:
parent
532fa36c13
commit
de82552a13
|
@ -70,20 +70,23 @@ def get_templates(name):
|
||||||
pf.tree_shape + pf.trigrams)
|
pf.tree_shape + pf.trigrams)
|
||||||
|
|
||||||
|
|
||||||
cdef int BEAM_WIDTH = 8
|
cdef int BEAM_WIDTH = 16
|
||||||
|
cdef weight_t BEAM_DENSITY = 0.01
|
||||||
|
|
||||||
cdef class BeamParser(Parser):
|
cdef class BeamParser(Parser):
|
||||||
cdef public int beam_width
|
cdef public int beam_width
|
||||||
|
cdef public weight_t beam_density
|
||||||
|
|
||||||
def __init__(self, *args, **kwargs):
|
def __init__(self, *args, **kwargs):
|
||||||
self.beam_width = kwargs.get('beam_width', BEAM_WIDTH)
|
self.beam_width = kwargs.get('beam_width', BEAM_WIDTH)
|
||||||
|
self.beam_density = kwargs.get('beam_density', BEAM_DENSITY)
|
||||||
Parser.__init__(self, *args, **kwargs)
|
Parser.__init__(self, *args, **kwargs)
|
||||||
|
|
||||||
cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) with gil:
|
cdef int parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) with gil:
|
||||||
self._parseC(tokens, length, nr_feat, nr_class)
|
self._parseC(tokens, length, nr_feat, nr_class)
|
||||||
|
|
||||||
cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1:
|
cdef int _parseC(self, TokenC* tokens, int length, int nr_feat, int nr_class) except -1:
|
||||||
cdef Beam beam = Beam(self.moves.n_moves, self.beam_width)
|
cdef Beam beam = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density)
|
||||||
beam.initialize(_init_state, length, tokens)
|
beam.initialize(_init_state, length, tokens)
|
||||||
beam.check_done(_check_final_state, NULL)
|
beam.check_done(_check_final_state, NULL)
|
||||||
while not beam.is_done:
|
while not beam.is_done:
|
||||||
|
@ -96,11 +99,11 @@ cdef class BeamParser(Parser):
|
||||||
|
|
||||||
def train(self, Doc tokens, GoldParse gold_parse, itn=0):
|
def train(self, Doc tokens, GoldParse gold_parse, itn=0):
|
||||||
self.moves.preprocess_gold(gold_parse)
|
self.moves.preprocess_gold(gold_parse)
|
||||||
cdef Beam pred = Beam(self.moves.n_moves, self.beam_width)
|
cdef Beam pred = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density)
|
||||||
pred.initialize(_init_state, tokens.length, tokens.c)
|
pred.initialize(_init_state, tokens.length, tokens.c)
|
||||||
pred.check_done(_check_final_state, NULL)
|
pred.check_done(_check_final_state, NULL)
|
||||||
|
|
||||||
cdef Beam gold = Beam(self.moves.n_moves, self.beam_width)
|
cdef Beam gold = Beam(self.moves.n_moves, self.beam_width, min_density=self.beam_density)
|
||||||
gold.initialize(_init_state, tokens.length, tokens.c)
|
gold.initialize(_init_state, tokens.length, tokens.c)
|
||||||
gold.check_done(_check_final_state, NULL)
|
gold.check_done(_check_final_state, NULL)
|
||||||
violn = MaxViolation()
|
violn = MaxViolation()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user