mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	* Work on refactored parser, where TransitionSystem can be easily subclassed
This commit is contained in:
		
							parent
							
								
									1cc6329b18
								
							
						
					
					
						commit
						dc986dbc0b
					
				|  | @ -53,10 +53,10 @@ cdef unicode print_state(State* s, list words): | ||||||
| def get_templates(name): | def get_templates(name): | ||||||
|     pf = _parse_features |     pf = _parse_features | ||||||
|     if name == 'zhang': |     if name == 'zhang': | ||||||
|         return pf.unigrams, pf.arc_eager |         return pf.arc_eager | ||||||
|     else: |     else: | ||||||
|         return pf.unigrams, (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + \ |         return (pf.unigrams + pf.s0_n0 + pf.s1_n0 + pf.s0_n1 + pf.n0_n1 + \ | ||||||
|                              pf.tree_shape + pf.trigrams) |                 pf.tree_shape + pf.trigrams) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef class GreedyParser: | cdef class GreedyParser: | ||||||
|  | @ -64,84 +64,51 @@ cdef class GreedyParser: | ||||||
|         assert os.path.exists(model_dir) and os.path.isdir(model_dir) |         assert os.path.exists(model_dir) and os.path.isdir(model_dir) | ||||||
|         self.cfg = Config.read(model_dir, 'config') |         self.cfg = Config.read(model_dir, 'config') | ||||||
|         self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels) |         self.moves = TransitionSystem(self.cfg.left_labels, self.cfg.right_labels) | ||||||
|         hasty_templ, full_templ = get_templates(self.cfg.features) |         templates = get_templates(self.cfg.features) | ||||||
|         self.model = Model(self.moves.n_moves, full_templ, model_dir) |         self.model = Model(self.moves.n_moves, templates, model_dir) | ||||||
| 
 | 
 | ||||||
|     def __call__(self, Tokens tokens): |     def __call__(self, Tokens tokens): | ||||||
|         cdef: |  | ||||||
|             Transition guess |  | ||||||
|             uint64_t state_key |  | ||||||
| 
 |  | ||||||
|         if tokens.length == 0: |         if tokens.length == 0: | ||||||
|             return 0 |             return 0 | ||||||
| 
 | 
 | ||||||
|         cdef atom_t[CONTEXT_SIZE] context |         cdef atom_t[CONTEXT_SIZE] context | ||||||
|         cdef int n_feats |         cdef int n_feats | ||||||
|         cdef Pool mem = Pool() |         cdef Pool mem = Pool() | ||||||
|         cdef State* state = init_state(mem, tokens.data, tokens.length) |         cdef State* state = init_state(mem, tokens.data, tokens.length) # TODO | ||||||
|  |         cdef Transition guess | ||||||
|         while not is_final(state): |         while not is_final(state): | ||||||
|             fill_context(context, state) |             fill_context(context, state) | ||||||
|             scores = self.model.score(context) |             scores = self.model.score(context) | ||||||
|             guess = self.moves.best_valid(scores, state) |             guess = self.moves.best_valid(scores, state) | ||||||
|             guess.do(&guess, state) |             guess.do(&guess, state) | ||||||
|         # Messily tell Tokens object the string names of the dependency labels |         tokens.set_parse(state.sent, self.moves.label_ids) # TODO | ||||||
|         dep_strings = [None] * len(self.moves.label_ids) |  | ||||||
|         for label, id_ in self.moves.label_ids.items(): |  | ||||||
|             dep_strings[id_] = label |  | ||||||
|         tokens._dep_strings = tuple(dep_strings) |  | ||||||
|         tokens.is_parsed = True |  | ||||||
|         # TODO: Clean this up. |  | ||||||
|         tokens._py_tokens = [None] * tokens.length |  | ||||||
|         return 0 |         return 0 | ||||||
| 
 | 
 | ||||||
|     def train_sent(self, Tokens tokens, list gold_heads, list gold_labels, |     def train_sent(self, Tokens tokens, GoldParse gold, force_gold=False): | ||||||
|                    force_gold=False): |  | ||||||
|         cdef: |         cdef: | ||||||
|  |             int n_feats | ||||||
|             const Feature* feats |             const Feature* feats | ||||||
|             const weight_t* scores |             const weight_t* scores | ||||||
|             Transition guess |             Transition guess | ||||||
|             Transition gold |             Transition gold | ||||||
|              |              | ||||||
|         cdef int n_feats |             atom_t[CONTEXT_SIZE] context | ||||||
|         cdef atom_t[CONTEXT_SIZE] context |  | ||||||
|         cdef Pool mem = Pool() |  | ||||||
|         cdef int* heads_array = <int*>mem.alloc(tokens.length, sizeof(int)) |  | ||||||
|         cdef int* labels_array = <int*>mem.alloc(tokens.length, sizeof(int)) |  | ||||||
|         cdef int i |  | ||||||
|         for i in range(tokens.length): |  | ||||||
|             if gold_heads[i] is None: |  | ||||||
|                 heads_array[i] = -1 |  | ||||||
|                 labels_array[i] = -1 |  | ||||||
|             else: |  | ||||||
|                 heads_array[i] = gold_heads[i] |  | ||||||
|                 labels_array[i] = self.moves.label_ids[gold_labels[i]] |  | ||||||
| 
 | 
 | ||||||
|         py_words = [t.orth_ for t in tokens] |         cdef Pool mem = Pool() | ||||||
|         py_moves = ['S', 'D', 'L', 'R', 'BS', 'BR'] |  | ||||||
|         history = [] |  | ||||||
|         #print py_words |  | ||||||
|         cdef State* state = init_state(mem, tokens.data, tokens.length) |         cdef State* state = init_state(mem, tokens.data, tokens.length) | ||||||
|  | 
 | ||||||
|         while not is_final(state): |         while not is_final(state): | ||||||
|             fill_context(context, state) |             fill_context(context, state) | ||||||
|             scores = self.model.score(context) |             scores = self.model.score(context) | ||||||
|             guess = self.moves.best_valid(scores, state) |             guess = self.moves.best_valid(scores, state) | ||||||
|             best = self.moves.best_gold(&guess, scores, state, heads_array, labels_array) |             gold = self.moves.best_gold(scores, state, gold) | ||||||
|             history.append((py_moves[best.move], print_state(state, py_words))) |             cost = guess.get_cost(&guess, state, gold) | ||||||
|             self.model.update(context, guess.clas, best.clas, guess.cost) |             self.model.update(context, guess.clas, best.clas, cost) | ||||||
|             if force_gold: |             if force_gold: | ||||||
|                 best.do(&best, state) |                 best.do(&best, state) | ||||||
|             else: |             else: | ||||||
|                 guess.do(&guess, state) |                 guess.do(&guess, state) | ||||||
|         cdef int n_corr = 0 |         n_corr = gold.heads_correct(state.sent, score_punct=True) # TODO | ||||||
|         for i in range(tokens.length): |  | ||||||
|             if gold_heads[i] != -1: |  | ||||||
|                 n_corr += (i + state.sent[i].head) == gold_heads[i] |  | ||||||
|         if force_gold and n_corr != tokens.length: |         if force_gold and n_corr != tokens.length: | ||||||
|             #print py_words |  | ||||||
|             #print gold_heads |  | ||||||
|             #for move, state_str in history: |  | ||||||
|             #    print move, state_str |  | ||||||
|             #for i in range(tokens.length): |  | ||||||
|             #    print py_words[i], py_words[i + state.sent[i].head], py_words[gold_heads[i]] |  | ||||||
|             raise OracleError |             raise OracleError | ||||||
|         return n_corr |         return n_corr | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user