mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-26 13:41:21 +03:00 
			
		
		
		
	* Work on refactored parser, where TransitionSystem can be easily subclassed
This commit is contained in:
		
							parent
							
								
									1cc6329b18
								
							
						
					
					
						commit
						dc986dbc0b
					
				|  | @ -53,9 +53,9 @@ cdef unicode print_state(State* s, list words): | |||
| def get_templates(name): | ||||
|     pf = _parse_features | ||||
|     if name == 'zhang': | ||||
|         return pf.unigrams, pf.arc_eager | ||||
|         return pf.arc_eager | ||||
|     else: | ||||
|         return pf.unigrams, (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + \ | ||||
|         return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + \ | ||||
|                 pf.tree_shape + pf.trigrams) | ||||
| 
 | ||||
| 
 | ||||
|  | @ -64,84 +64,51 @@ cdef class GreedyParser: | |||
|         assert os.path.exists(model_dir) and os.path.isdir(model_dir) | ||||
|         self.cfg = Config.read(model_dir, 'config') | ||||
|         self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels) | ||||
|         hasty_templ, full_templ = get_templates(self.cfg.features) | ||||
|         self.model = Model(self.moves.n_moves, full_templ, model_dir) | ||||
|         templates = get_templates(self.cfg.features) | ||||
|         self.model = Model(self.moves.n_moves, templates, model_dir) | ||||
| 
 | ||||
|     def __call__(self, Tokens tokens): | ||||
|         cdef: | ||||
|             Transition guess | ||||
|             uint64_t state_key | ||||
| 
 | ||||
|         if tokens.length == 0: | ||||
|             return 0 | ||||
| 
 | ||||
|         cdef atom_t[CONTEXT_SIZE] context | ||||
|         cdef int n_feats | ||||
|         cdef Pool mem = Pool() | ||||
|         cdef State* state = init_state(mem, tokens.data, tokens.length) | ||||
|         cdef State* state = init_state(mem, tokens.data, tokens.length) # TODO | ||||
|         cdef Transition guess | ||||
|         while not is_final(state): | ||||
|             fill_context(context, state) | ||||
|             scores = self.model.score(context) | ||||
|             guess = self.moves.best_valid(scores, state) | ||||
|             guess.do(&guess, state) | ||||
|         # Messily tell Tokens object the string names of the dependency labels | ||||
|         dep_strings = [None] * len(self.moves.label_ids) | ||||
|         for label, id_ in self.moves.label_ids.items(): | ||||
|             dep_strings[id_] = label | ||||
|         tokens._dep_strings = tuple(dep_strings) | ||||
|         tokens.is_parsed = True | ||||
|         # TODO: Clean this up. | ||||
|         tokens._py_tokens = [None] * tokens.length | ||||
|         tokens.set_parse(state.sent, self.moves.label_ids) # TODO | ||||
|         return 0 | ||||
| 
 | ||||
|     def train_sent(self, Tokens tokens, list gold_heads, list gold_labels, | ||||
|                    force_gold=False): | ||||
|     def train_sent(self, Tokens tokens, GoldParse gold, force_gold=False): | ||||
|         cdef: | ||||
|             int n_feats | ||||
|             const Feature* feats | ||||
|             const weight_t* scores | ||||
|             Transition guess | ||||
|             Transition gold | ||||
|              | ||||
|         cdef int n_feats | ||||
|         cdef atom_t[CONTEXT_SIZE] context | ||||
|         cdef Pool mem = Pool() | ||||
|         cdef int* heads_array = <int*>mem.alloc(tokens.length, sizeof(int)) | ||||
|         cdef int* labels_array = <int*>mem.alloc(tokens.length, sizeof(int)) | ||||
|         cdef int i | ||||
|         for i in range(tokens.length): | ||||
|             if gold_heads[i] is None: | ||||
|                 heads_array[i] = -1 | ||||
|                 labels_array[i] = -1 | ||||
|             else: | ||||
|                 heads_array[i] = gold_heads[i] | ||||
|                 labels_array[i] = self.moves.label_ids[gold_labels[i]] | ||||
|             atom_t[CONTEXT_SIZE] context | ||||
| 
 | ||||
|         py_words = [t.orth_ for t in tokens] | ||||
|         py_moves = ['S', 'D', 'L', 'R', 'BS', 'BR'] | ||||
|         history = [] | ||||
|         #print py_words | ||||
|         cdef Pool mem = Pool() | ||||
|         cdef State* state = init_state(mem, tokens.data, tokens.length) | ||||
| 
 | ||||
|         while not is_final(state): | ||||
|             fill_context(context, state) | ||||
|             scores = self.model.score(context) | ||||
|             guess = self.moves.best_valid(scores, state) | ||||
|             best = self.moves.best_gold(&guess, scores, state, heads_array, labels_array) | ||||
|             history.append((py_moves[best.move], print_state(state, py_words))) | ||||
|             self.model.update(context, guess.clas, best.clas, guess.cost) | ||||
|             gold = self.moves.best_gold(scores, state, gold) | ||||
|             cost = guess.get_cost(&guess, state, gold) | ||||
|             self.model.update(context, guess.clas, best.clas, cost) | ||||
|             if force_gold: | ||||
|                 best.do(&best, state) | ||||
|             else: | ||||
|                 guess.do(&guess, state) | ||||
|         cdef int n_corr = 0 | ||||
|         for i in range(tokens.length): | ||||
|             if gold_heads[i] != -1: | ||||
|                 n_corr += (i + state.sent[i].head) == gold_heads[i] | ||||
|         n_corr = gold.heads_correct(state.sent, score_punct=True) # TODO | ||||
|         if force_gold and n_corr != tokens.length: | ||||
|             #print py_words | ||||
|             #print gold_heads | ||||
|             #for move, state_str in history: | ||||
|             #    print move, state_str | ||||
|             #for i in range(tokens.length): | ||||
|             #    print py_words[i], py_words[i + state.sent[i].head], py_words[gold_heads[i]] | ||||
|             raise OracleError | ||||
|         return n_corr | ||||
|  |  | |||
		Loading…
	
		Reference in New Issue
	
	Block a user