mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	More GoldParse excise
This commit is contained in:
		
							parent
							
								
									60d4e5a9e0
								
							
						
					
					
						commit
						9296d71a54
					
				| 
						 | 
				
			
			@ -9,7 +9,6 @@ import numpy
 | 
			
		|||
 | 
			
		||||
from ..typedefs cimport hash_t, class_t
 | 
			
		||||
from .transition_system cimport TransitionSystem, Transition
 | 
			
		||||
from ..gold cimport GoldParse
 | 
			
		||||
from .stateclass cimport StateC, StateClass
 | 
			
		||||
 | 
			
		||||
from ..errors import Errors
 | 
			
		||||
| 
						 | 
				
			
			@ -126,12 +125,12 @@ cdef class ParserBeam(object):
 | 
			
		|||
                    beam.scores[i][j] = 0
 | 
			
		||||
                    beam.costs[i][j] = 0
 | 
			
		||||
 | 
			
		||||
    def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
 | 
			
		||||
    def _set_costs(self, Beam beam, NewExample example, int follow_gold=False):
 | 
			
		||||
        for i in range(beam.size):
 | 
			
		||||
            state = StateClass.borrow(<StateC*>beam.at(i))
 | 
			
		||||
            if not state.is_final():
 | 
			
		||||
                self.moves.set_costs(beam.is_valid[i], beam.costs[i],
 | 
			
		||||
                                     state, gold)
 | 
			
		||||
                                     state, example)
 | 
			
		||||
                if follow_gold:
 | 
			
		||||
                    min_cost = 0
 | 
			
		||||
                    for j in range(beam.nr_class):
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -20,7 +20,6 @@ import numpy
 | 
			
		|||
import warnings
 | 
			
		||||
 | 
			
		||||
from ..tokens.doc cimport Doc
 | 
			
		||||
from ..gold cimport GoldParse
 | 
			
		||||
from ..typedefs cimport weight_t, class_t, hash_t
 | 
			
		||||
from ._parser_model cimport alloc_activations, free_activations
 | 
			
		||||
from ._parser_model cimport predict_states, arg_max_if_valid
 | 
			
		||||
| 
						 | 
				
			
			@ -567,9 +566,9 @@ cdef class Parser:
 | 
			
		|||
            max_moves = max(max_moves, len(oracle_actions))
 | 
			
		||||
        return states, golds, max_moves
 | 
			
		||||
 | 
			
		||||
    def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
 | 
			
		||||
    def get_batch_loss(self, states, examples, float[:, ::1] scores, losses):
 | 
			
		||||
        cdef StateClass state
 | 
			
		||||
        cdef GoldParse gold
 | 
			
		||||
        cdef NewExample example
 | 
			
		||||
        cdef Pool mem = Pool()
 | 
			
		||||
        cdef int i
 | 
			
		||||
 | 
			
		||||
| 
						 | 
				
			
			@ -582,10 +581,10 @@ cdef class Parser:
 | 
			
		|||
                                        dtype='f', order='C')
 | 
			
		||||
        c_d_scores = <float*>d_scores.data
 | 
			
		||||
        unseen_classes = self.model.attrs["unseen_classes"]
 | 
			
		||||
        for i, (state, gold) in enumerate(zip(states, golds)):
 | 
			
		||||
        for i, (state, eg) in enumerate(zip(states, examples)):
 | 
			
		||||
            memset(is_valid, 0, self.moves.n_moves * sizeof(int))
 | 
			
		||||
            memset(costs, 0, self.moves.n_moves * sizeof(float))
 | 
			
		||||
            self.moves.set_costs(is_valid, costs, state, gold)
 | 
			
		||||
            self.moves.set_costs(is_valid, costs, state, eg)
 | 
			
		||||
            for j in range(self.moves.n_moves):
 | 
			
		||||
                if costs[j] <= 0.0 and j in unseen_classes:
 | 
			
		||||
                    unseen_classes.remove(j)
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user