mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Improve efficiency of get_oracle_sequences
This commit is contained in:
		
							parent
							
								
									233945bfe0
								
							
						
					
					
						commit
						57e09747dc
					
				| 
						 | 
					@ -742,21 +742,14 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
        if n_gold < 1:
 | 
					        if n_gold < 1:
 | 
				
			||||||
            raise ValueError
 | 
					            raise ValueError
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    def get_oracle_sequence(self, Example example):
 | 
					    def get_oracle_sequence_from_state(self, StateClass state, ArcEagerGold gold, _debug=None):
 | 
				
			||||||
        cdef StateClass state
 | 
					        cdef int i
 | 
				
			||||||
        cdef ArcEagerGold gold
 | 
					 | 
				
			||||||
        states, golds, n_steps = self.init_gold_batch([example])
 | 
					 | 
				
			||||||
        if not golds:
 | 
					 | 
				
			||||||
            return []
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
        cdef Pool mem = Pool()
 | 
					        cdef Pool mem = Pool()
 | 
				
			||||||
        # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
 | 
					        # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
 | 
				
			||||||
        assert self.n_moves > 0
 | 
					        assert self.n_moves > 0
 | 
				
			||||||
        costs = <float*>mem.alloc(self.n_moves, sizeof(float))
 | 
					        costs = <float*>mem.alloc(self.n_moves, sizeof(float))
 | 
				
			||||||
        is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
 | 
					        is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        state = states[0]
 | 
					 | 
				
			||||||
        gold = golds[0]
 | 
					 | 
				
			||||||
        history = []
 | 
					        history = []
 | 
				
			||||||
        debug_log = []
 | 
					        debug_log = []
 | 
				
			||||||
        failed = False
 | 
					        failed = False
 | 
				
			||||||
| 
						 | 
					@ -772,6 +765,8 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
                    history.append(i)
 | 
					                    history.append(i)
 | 
				
			||||||
                    s0 = state.S(0)
 | 
					                    s0 = state.S(0)
 | 
				
			||||||
                    b0 = state.B(0)
 | 
					                    b0 = state.B(0)
 | 
				
			||||||
 | 
					                    if _debug:
 | 
				
			||||||
 | 
					                        example = _debug
 | 
				
			||||||
                        debug_log.append(" ".join((
 | 
					                        debug_log.append(" ".join((
 | 
				
			||||||
                            self.get_class_name(i),
 | 
					                            self.get_class_name(i),
 | 
				
			||||||
                            "S0=", (example.x[s0].text if s0 >= 0 else "__"),
 | 
					                            "S0=", (example.x[s0].text if s0 >= 0 else "__"),
 | 
				
			||||||
| 
						 | 
					@ -784,6 +779,7 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
                failed = False
 | 
					                failed = False
 | 
				
			||||||
                break
 | 
					                break
 | 
				
			||||||
        if failed:
 | 
					        if failed:
 | 
				
			||||||
 | 
					            example = _debug
 | 
				
			||||||
            print("Actions")
 | 
					            print("Actions")
 | 
				
			||||||
            for i in range(self.n_moves):
 | 
					            for i in range(self.n_moves):
 | 
				
			||||||
                print(self.get_class_name(i))
 | 
					                print(self.get_class_name(i))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -63,7 +63,9 @@ cdef class Parser:
 | 
				
			||||||
        self.model = model
 | 
					        self.model = model
 | 
				
			||||||
        if self.moves.n_moves != 0:
 | 
					        if self.moves.n_moves != 0:
 | 
				
			||||||
            self.set_output(self.moves.n_moves)
 | 
					            self.set_output(self.moves.n_moves)
 | 
				
			||||||
        self.cfg = cfg
 | 
					        self.cfg = dict(cfg)
 | 
				
			||||||
 | 
					        self.cfg.setdefault("update_with_oracle_cut_size", 100)
 | 
				
			||||||
 | 
					        self.cfg.setdefault("normalize_gradients_with_batch_size", True)
 | 
				
			||||||
        self._multitasks = []
 | 
					        self._multitasks = []
 | 
				
			||||||
        for multitask in cfg.get("multitasks", []):
 | 
					        for multitask in cfg.get("multitasks", []):
 | 
				
			||||||
            self.add_multitask_objective(multitask)
 | 
					            self.add_multitask_objective(multitask)
 | 
				
			||||||
| 
						 | 
					@ -272,13 +274,16 @@ cdef class Parser:
 | 
				
			||||||
        # Prepare the stepwise model, and get the callback for finishing the batch
 | 
					        # Prepare the stepwise model, and get the callback for finishing the batch
 | 
				
			||||||
        model, backprop_tok2vec = self.model.begin_update(
 | 
					        model, backprop_tok2vec = self.model.begin_update(
 | 
				
			||||||
            [eg.predicted for eg in examples])
 | 
					            [eg.predicted for eg in examples])
 | 
				
			||||||
 | 
					        if self.cfg["update_with_oracle_cut_size"] >= 1:
 | 
				
			||||||
            # Chop sequences into lengths of this many transitions, to make the
 | 
					            # Chop sequences into lengths of this many transitions, to make the
 | 
				
			||||||
            # batch uniform length. We randomize this to overfit less.
 | 
					            # batch uniform length. We randomize this to overfit less.
 | 
				
			||||||
        cut_gold = numpy.random.choice(range(20, 100))
 | 
					            cut_size = self.cfg["update_with_oracle_cut_size"]
 | 
				
			||||||
            states, golds, max_steps = self._init_gold_batch(
 | 
					            states, golds, max_steps = self._init_gold_batch(
 | 
				
			||||||
                examples,
 | 
					                examples,
 | 
				
			||||||
            max_length=cut_gold
 | 
					                max_length=numpy.random.choice(range(20, cut_size))
 | 
				
			||||||
            )
 | 
					            )
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            states, golds, max_steps = self.moves.init_gold_batch(examples)
 | 
				
			||||||
        all_states = list(states)
 | 
					        all_states = list(states)
 | 
				
			||||||
        states_golds = zip(states, golds)
 | 
					        states_golds = zip(states, golds)
 | 
				
			||||||
        for _ in range(max_steps):
 | 
					        for _ in range(max_steps):
 | 
				
			||||||
| 
						 | 
					@ -384,7 +389,7 @@ cdef class Parser:
 | 
				
			||||||
            cpu_log_loss(c_d_scores,
 | 
					            cpu_log_loss(c_d_scores,
 | 
				
			||||||
                costs, is_valid, &scores[i, 0], d_scores.shape[1])
 | 
					                costs, is_valid, &scores[i, 0], d_scores.shape[1])
 | 
				
			||||||
            c_d_scores += d_scores.shape[1]
 | 
					            c_d_scores += d_scores.shape[1]
 | 
				
			||||||
        if len(states):
 | 
					        if len(states) and self.cfg["normalize_gradients_with_batch_size"]:
 | 
				
			||||||
            d_scores /= len(states)
 | 
					            d_scores /= len(states)
 | 
				
			||||||
        if losses is not None:
 | 
					        if losses is not None:
 | 
				
			||||||
            losses.setdefault(self.name, 0.)
 | 
					            losses.setdefault(self.name, 0.)
 | 
				
			||||||
| 
						 | 
					@ -516,7 +521,8 @@ cdef class Parser:
 | 
				
			||||||
        states = []
 | 
					        states = []
 | 
				
			||||||
        golds = []
 | 
					        golds = []
 | 
				
			||||||
        for eg, state, gold in kept:
 | 
					        for eg, state, gold in kept:
 | 
				
			||||||
            oracle_actions = self.moves.get_oracle_sequence(eg)
 | 
					            oracle_actions = self.moves.get_oracle_sequence_from_state(
 | 
				
			||||||
 | 
					                state, gold)
 | 
				
			||||||
            start = 0
 | 
					            start = 0
 | 
				
			||||||
            while start < len(eg.predicted):
 | 
					            while start < len(eg.predicted):
 | 
				
			||||||
                state = state.copy()
 | 
					                state = state.copy()
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -62,18 +62,23 @@ cdef class TransitionSystem:
 | 
				
			||||||
        return states
 | 
					        return states
 | 
				
			||||||
    
 | 
					    
 | 
				
			||||||
    def get_oracle_sequence(self, Example example, _debug=False):
 | 
					    def get_oracle_sequence(self, Example example, _debug=False):
 | 
				
			||||||
 | 
					        states, golds, _ = self.init_gold_batch([example])
 | 
				
			||||||
 | 
					        if not states:
 | 
				
			||||||
 | 
					            return []
 | 
				
			||||||
 | 
					        state = states[0]
 | 
				
			||||||
 | 
					        gold = golds[0]
 | 
				
			||||||
 | 
					        if _debug:
 | 
				
			||||||
 | 
					            return self.get_oracle_sequence_from_state(state, gold, _debug=example)
 | 
				
			||||||
 | 
					        else:
 | 
				
			||||||
 | 
					            return self.get_oracle_sequence_from_state(state, gold)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
 | 
				
			||||||
        cdef Pool mem = Pool()
 | 
					        cdef Pool mem = Pool()
 | 
				
			||||||
        # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
 | 
					        # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
 | 
				
			||||||
        assert self.n_moves > 0
 | 
					        assert self.n_moves > 0
 | 
				
			||||||
        costs = <float*>mem.alloc(self.n_moves, sizeof(float))
 | 
					        costs = <float*>mem.alloc(self.n_moves, sizeof(float))
 | 
				
			||||||
        is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
 | 
					        is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        cdef StateClass state
 | 
					 | 
				
			||||||
        states, golds, n_steps = self.init_gold_batch([example])
 | 
					 | 
				
			||||||
        if not states:
 | 
					 | 
				
			||||||
            return []
 | 
					 | 
				
			||||||
        state = states[0]
 | 
					 | 
				
			||||||
        gold = golds[0]
 | 
					 | 
				
			||||||
        history = []
 | 
					        history = []
 | 
				
			||||||
        debug_log = []
 | 
					        debug_log = []
 | 
				
			||||||
        while not state.is_final():
 | 
					        while not state.is_final():
 | 
				
			||||||
| 
						 | 
					@ -85,6 +90,7 @@ cdef class TransitionSystem:
 | 
				
			||||||
                    s0 = state.S(0)
 | 
					                    s0 = state.S(0)
 | 
				
			||||||
                    b0 = state.B(0)
 | 
					                    b0 = state.B(0)
 | 
				
			||||||
                    if _debug:
 | 
					                    if _debug:
 | 
				
			||||||
 | 
					                        example = _debug
 | 
				
			||||||
                        debug_log.append(" ".join((
 | 
					                        debug_log.append(" ".join((
 | 
				
			||||||
                            self.get_class_name(i),
 | 
					                            self.get_class_name(i),
 | 
				
			||||||
                            "S0=", (example.x[s0].text if s0 >= 0 else "__"),
 | 
					                            "S0=", (example.x[s0].text if s0 >= 0 else "__"),
 | 
				
			||||||
| 
						 | 
					@ -95,6 +101,7 @@ cdef class TransitionSystem:
 | 
				
			||||||
                    break
 | 
					                    break
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                if _debug:
 | 
					                if _debug:
 | 
				
			||||||
 | 
					                    example = _debug
 | 
				
			||||||
                    print("Actions")
 | 
					                    print("Actions")
 | 
				
			||||||
                    for i in range(self.n_moves):
 | 
					                    for i in range(self.n_moves):
 | 
				
			||||||
                        print(self.get_class_name(i))
 | 
					                        print(self.get_class_name(i))
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user