mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	* Revise greedy_parse/beam_parse ownership goof
This commit is contained in:
		
							parent
							
								
									70a7ad89ca
								
							
						
					
					
						commit
						66dfa95847
					
				| 
						 | 
					@ -14,6 +14,5 @@ cdef class Parser:
 | 
				
			||||||
    cdef readonly Model model
 | 
					    cdef readonly Model model
 | 
				
			||||||
    cdef readonly TransitionSystem moves
 | 
					    cdef readonly TransitionSystem moves
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cdef int _greedy_parse(self, Tokens tokens) except -1
 | 
				
			||||||
    cdef State* _greedy_parse(self, Tokens tokens) except NULL
 | 
					    cdef int _beam_parse(self, Tokens tokens) except -1
 | 
				
			||||||
    cdef State* _beam_parse(self, Tokens tokens) except NULL
 | 
					 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -81,15 +81,19 @@ cdef class Parser:
 | 
				
			||||||
    def __call__(self, Tokens tokens):
 | 
					    def __call__(self, Tokens tokens):
 | 
				
			||||||
        if tokens.length == 0:
 | 
					        if tokens.length == 0:
 | 
				
			||||||
            return 0
 | 
					            return 0
 | 
				
			||||||
        cdef State* state
 | 
					 | 
				
			||||||
        if self.cfg.beam_width == 1:
 | 
					        if self.cfg.beam_width == 1:
 | 
				
			||||||
            state = self._greedy_parse(tokens)
 | 
					            self._greedy_parse(tokens)
 | 
				
			||||||
        else:
 | 
					        else:
 | 
				
			||||||
            state = self._beam_parse(tokens)
 | 
					            self._beam_parse(tokens)
 | 
				
			||||||
        self.moves.finalize_state(state)
 | 
					 | 
				
			||||||
        tokens.set_parse(state.sent)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef State* _greedy_parse(self, Tokens tokens) except NULL:
 | 
					    def train(self, Tokens tokens, GoldParse gold):
 | 
				
			||||||
 | 
					        self.moves.preprocess_gold(gold)
 | 
				
			||||||
 | 
					        if self.cfg.beam_width == 1:
 | 
				
			||||||
 | 
					            return self._greedy_train(tokens, gold)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return self._beam_train(tokens, gold)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    cdef int _greedy_parse(self, Tokens tokens) except -1:
 | 
				
			||||||
        cdef atom_t[CONTEXT_SIZE] context
 | 
					        cdef atom_t[CONTEXT_SIZE] context
 | 
				
			||||||
        cdef int n_feats
 | 
					        cdef int n_feats
 | 
				
			||||||
        cdef Pool mem = Pool()
 | 
					        cdef Pool mem = Pool()
 | 
				
			||||||
| 
						 | 
					@ -101,21 +105,17 @@ cdef class Parser:
 | 
				
			||||||
            scores = self.model.score(context)
 | 
					            scores = self.model.score(context)
 | 
				
			||||||
            guess = self.moves.best_valid(scores, state)
 | 
					            guess = self.moves.best_valid(scores, state)
 | 
				
			||||||
            guess.do(&guess, state)
 | 
					            guess.do(&guess, state)
 | 
				
			||||||
        return state
 | 
					        self.moves.finalize_state(state)
 | 
				
			||||||
 | 
					        tokens.set_parse(state.sent)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef State* _beam_parse(self, Tokens tokens) except NULL:
 | 
					    cdef int _beam_parse(self, Tokens tokens) except -1:
 | 
				
			||||||
        cdef Beam beam = Beam(self.model.n_classes, self.cfg.beam_width)
 | 
					        cdef Beam beam = Beam(self.model.n_classes, self.cfg.beam_width)
 | 
				
			||||||
        beam.initialize(_init_state, tokens.length, tokens.data)
 | 
					        beam.initialize(_init_state, tokens.length, tokens.data)
 | 
				
			||||||
        while not beam.is_done:
 | 
					        while not beam.is_done:
 | 
				
			||||||
            self._advance_beam(beam, None, False)
 | 
					            self._advance_beam(beam, None, False)
 | 
				
			||||||
        return <State*>beam.at(0)
 | 
					        state = <State*>beam.at(0)
 | 
				
			||||||
 | 
					        self.moves.finalize_state(state)
 | 
				
			||||||
    def train(self, Tokens tokens, GoldParse gold):
 | 
					        tokens.set_parse(state.sent)
 | 
				
			||||||
        self.moves.preprocess_gold(gold)
 | 
					 | 
				
			||||||
        if self.beam_width == 1:
 | 
					 | 
				
			||||||
            return self._greedy_train(tokens, gold)
 | 
					 | 
				
			||||||
        else:
 | 
					 | 
				
			||||||
            return self._beam_train(tokens, gold)
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def _greedy_train(self, Tokens tokens, GoldParse gold):
 | 
					    def _greedy_train(self, Tokens tokens, GoldParse gold):
 | 
				
			||||||
        cdef Pool mem = Pool()
 | 
					        cdef Pool mem = Pool()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user