mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	Fix history features
This commit is contained in:
		
							parent
							
								
									b770f4e108
								
							
						
					
					
						commit
						278a4c17c6
					
				| 
						 | 
				
			
			@ -70,6 +70,8 @@ from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
 | 
			
		|||
from . import _beam_utils
 | 
			
		||||
 | 
			
		||||
USE_HISTORY = True
 | 
			
		||||
HIST_SIZE = 2
 | 
			
		||||
HIST_DIMS = 16
 | 
			
		||||
 | 
			
		||||
def get_templates(*args, **kwargs):
 | 
			
		||||
    return []
 | 
			
		||||
| 
						 | 
				
			
			@ -262,13 +264,11 @@ cdef class Parser:
 | 
			
		|||
 | 
			
		||||
        with Model.use_device('cpu'):
 | 
			
		||||
            if depth == 0:
 | 
			
		||||
                hist_size = 8
 | 
			
		||||
                nr_dim = 8
 | 
			
		||||
                if USE_HISTORY:
 | 
			
		||||
                    upper = chain(
 | 
			
		||||
                        HistoryFeatures(nr_class=nr_class, hist_size=hist_size,
 | 
			
		||||
                                        nr_dim=nr_dim),
 | 
			
		||||
                        zero_init(Affine(nr_class, nr_class+hist_size*nr_dim,
 | 
			
		||||
                        HistoryFeatures(nr_class=nr_class, hist_size=HIST_SIZE,
 | 
			
		||||
                                        nr_dim=HIST_DIMS),
 | 
			
		||||
                        zero_init(Affine(nr_class, nr_class+HIST_SIZE*HIST_DIMS,
 | 
			
		||||
                                          drop_factor=0.0)))
 | 
			
		||||
                    upper.is_noop = False
 | 
			
		||||
                else:
 | 
			
		||||
| 
						 | 
				
			
			@ -736,15 +736,13 @@ cdef class Parser:
 | 
			
		|||
        cdef StateClass state
 | 
			
		||||
        cdef int[500] is_valid # TODO: Unhack
 | 
			
		||||
        cdef float* c_scores = &scores[0, 0]
 | 
			
		||||
        hists = []
 | 
			
		||||
        for state in states:
 | 
			
		||||
            self.moves.set_valid(is_valid, state.c)
 | 
			
		||||
            guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1])
 | 
			
		||||
            action = self.moves.c[guess]
 | 
			
		||||
            action.do(state.c, action.label)
 | 
			
		||||
            c_scores += scores.shape[1]
 | 
			
		||||
            hists.append(guess)
 | 
			
		||||
        return hists
 | 
			
		||||
            state.c.push_hist(guess)
 | 
			
		||||
 | 
			
		||||
    def get_batch_loss(self, states, golds, float[:, ::1] scores):
 | 
			
		||||
        cdef StateClass state
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user