mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
* Revise greedy_parse/beam_parse ownership goof
This commit is contained in:
parent
70a7ad89ca
commit
66dfa95847
|
@ -14,6 +14,5 @@ cdef class Parser:
|
|||
cdef readonly Model model
|
||||
cdef readonly TransitionSystem moves
|
||||
|
||||
|
||||
cdef State* _greedy_parse(self, Tokens tokens) except NULL
|
||||
cdef State* _beam_parse(self, Tokens tokens) except NULL
|
||||
cdef int _greedy_parse(self, Tokens tokens) except -1
|
||||
cdef int _beam_parse(self, Tokens tokens) except -1
|
||||
|
|
|
@ -81,15 +81,19 @@ cdef class Parser:
|
|||
def __call__(self, Tokens tokens):
|
||||
if tokens.length == 0:
|
||||
return 0
|
||||
cdef State* state
|
||||
if self.cfg.beam_width == 1:
|
||||
state = self._greedy_parse(tokens)
|
||||
self._greedy_parse(tokens)
|
||||
else:
|
||||
state = self._beam_parse(tokens)
|
||||
self.moves.finalize_state(state)
|
||||
tokens.set_parse(state.sent)
|
||||
self._beam_parse(tokens)
|
||||
|
||||
cdef State* _greedy_parse(self, Tokens tokens) except NULL:
|
||||
def train(self, Tokens tokens, GoldParse gold):
|
||||
self.moves.preprocess_gold(gold)
|
||||
if self.cfg.beam_width == 1:
|
||||
return self._greedy_train(tokens, gold)
|
||||
else:
|
||||
return self._beam_train(tokens, gold)
|
||||
|
||||
cdef int _greedy_parse(self, Tokens tokens) except -1:
|
||||
cdef atom_t[CONTEXT_SIZE] context
|
||||
cdef int n_feats
|
||||
cdef Pool mem = Pool()
|
||||
|
@ -101,21 +105,17 @@ cdef class Parser:
|
|||
scores = self.model.score(context)
|
||||
guess = self.moves.best_valid(scores, state)
|
||||
guess.do(&guess, state)
|
||||
return state
|
||||
self.moves.finalize_state(state)
|
||||
tokens.set_parse(state.sent)
|
||||
|
||||
cdef State* _beam_parse(self, Tokens tokens) except NULL:
|
||||
cdef int _beam_parse(self, Tokens tokens) except -1:
|
||||
cdef Beam beam = Beam(self.model.n_classes, self.cfg.beam_width)
|
||||
beam.initialize(_init_state, tokens.length, tokens.data)
|
||||
while not beam.is_done:
|
||||
self._advance_beam(beam, None, False)
|
||||
return <State*>beam.at(0)
|
||||
|
||||
def train(self, Tokens tokens, GoldParse gold):
|
||||
self.moves.preprocess_gold(gold)
|
||||
if self.beam_width == 1:
|
||||
return self._greedy_train(tokens, gold)
|
||||
else:
|
||||
return self._beam_train(tokens, gold)
|
||||
state = <State*>beam.at(0)
|
||||
self.moves.finalize_state(state)
|
||||
tokens.set_parse(state.sent)
|
||||
|
||||
def _greedy_train(self, Tokens tokens, GoldParse gold):
|
||||
cdef Pool mem = Pool()
|
||||
|
|
Loading…
Reference in New Issue
Block a user