mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-30 23:47:31 +03:00 
			
		
		
		
	* Refactor TransitionSystem, adding set_valid method
This commit is contained in:
		
							parent
							
								
									bd82a49994
								
							
						
					
					
						commit
						0786d9b3c7
					
				|  | @ -44,10 +44,6 @@ MOVE_NAMES[CONSTITUENT] = 'C' | ||||||
| MOVE_NAMES[ADJUST] = 'A' | MOVE_NAMES[ADJUST] = 'A' | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef do_func_t[N_MOVES] do_funcs |  | ||||||
| cdef get_cost_func_t[N_MOVES] get_cost_funcs |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cdef class ArcEager(TransitionSystem): | cdef class ArcEager(TransitionSystem): | ||||||
|     @classmethod |     @classmethod | ||||||
|     def get_labels(cls, gold_parses): |     def get_labels(cls, gold_parses): | ||||||
|  | @ -107,8 +103,27 @@ cdef class ArcEager(TransitionSystem): | ||||||
|         t.clas = clas |         t.clas = clas | ||||||
|         t.move = move |         t.move = move | ||||||
|         t.label = label |         t.label = label | ||||||
|         t.do = do_funcs[move] |         if move == SHIFT: | ||||||
|         t.get_cost = get_cost_funcs[move] |             t.do = _do_shift | ||||||
|  |             t.get_cost = _shift_cost | ||||||
|  |         elif move == REDUCE: | ||||||
|  |             t.do = _do_reduce | ||||||
|  |             t.get_cost = _reduce_cost | ||||||
|  |         elif move == LEFT: | ||||||
|  |             t.do = _do_left | ||||||
|  |             t.get_cost = _left_cost | ||||||
|  |         elif move == RIGHT: | ||||||
|  |             t.do = _do_right | ||||||
|  |             t.get_cost = _right_cost | ||||||
|  |         elif move == BREAK: | ||||||
|  |             t.get_cost = _break_cost | ||||||
|  |         elif move == CONSTITUENT: | ||||||
|  |             t.get_cost = _constituent_cost | ||||||
|  |         elif move == ADJUST: | ||||||
|  |             t.do = _do_adjust | ||||||
|  |             t.get_cost = _adjust_cost | ||||||
|  |         else: | ||||||
|  |             raise Exception(move) | ||||||
|         return t |         return t | ||||||
| 
 | 
 | ||||||
|     cdef int initialize_state(self, State* state) except -1: |     cdef int initialize_state(self, State* state) except -1: | ||||||
|  | @ -120,7 +135,7 @@ cdef class ArcEager(TransitionSystem): | ||||||
|             if state.sent[i].head == 0 and state.sent[i].dep == 0: |             if state.sent[i].head == 0 and state.sent[i].dep == 0: | ||||||
|                 state.sent[i].dep = root_label |                 state.sent[i].dep = root_label | ||||||
| 
 | 
 | ||||||
|     cdef bint* get_valid(self, const State* s) except NULL: |     cdef int set_valid(self, bint* output, const State* s) except -1: | ||||||
|         cdef bint[N_MOVES] is_valid |         cdef bint[N_MOVES] is_valid | ||||||
|         is_valid[SHIFT] = _can_shift(s) |         is_valid[SHIFT] = _can_shift(s) | ||||||
|         is_valid[REDUCE] = _can_reduce(s) |         is_valid[REDUCE] = _can_reduce(s) | ||||||
|  | @ -131,8 +146,7 @@ cdef class ArcEager(TransitionSystem): | ||||||
|         is_valid[ADJUST] = _can_adjust(s) |         is_valid[ADJUST] = _can_adjust(s) | ||||||
|         cdef int i |         cdef int i | ||||||
|         for i in range(self.n_moves): |         for i in range(self.n_moves): | ||||||
|             self._is_valid[i] = is_valid[self.c[i].move] |             output[i] = is_valid[self.c[i].move] | ||||||
|         return self._is_valid |  | ||||||
| 
 | 
 | ||||||
|     cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: |     cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: | ||||||
|         cdef bint[N_MOVES] is_valid |         cdef bint[N_MOVES] is_valid | ||||||
|  | @ -200,52 +214,6 @@ cdef int _do_break(const Transition* self, State* state) except -1: | ||||||
|     if not at_eol(state): |     if not at_eol(state): | ||||||
|         push_stack(state) |         push_stack(state) | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| cdef int _do_constituent(const Transition* self, State* state) except -1: |  | ||||||
|     return False |  | ||||||
|     #cdef Constituent* bracket = new_bracket(state.ctnts) |  | ||||||
| 
 |  | ||||||
|     #bracket.parent = NULL |  | ||||||
|     #bracket.label = self.label |  | ||||||
|     #bracket.head = get_s0(state) |  | ||||||
|     #bracket.length = 0 |  | ||||||
| 
 |  | ||||||
|     #attach(bracket, state.ctnts.stack) |  | ||||||
|     # Attach rightward children. They're in the brackets array somewhere |  | ||||||
|     # between here and B0. |  | ||||||
|     #cdef Constituent* node |  | ||||||
|     #cdef const TokenC* node_gov |  | ||||||
|     #for i in range(1, bracket - state.ctnts.stack): |  | ||||||
|     #    node = bracket - i |  | ||||||
|     #    node_gov = node.head + node.head.head |  | ||||||
|     #    if node_gov == bracket.head: |  | ||||||
|     #        attach(bracket, node) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cdef int _do_adjust(const Transition* self, State* state) except -1: |  | ||||||
|     return False |  | ||||||
|     #cdef Constituent* b0 = state.ctnts.stack[0] |  | ||||||
|     #cdef Constituent* b1 = state.ctnts.stack[1] |  | ||||||
| 
 |  | ||||||
|     #assert (b1.head + b1.head.head) == b0.head |  | ||||||
|     #assert b0.head < b1.head |  | ||||||
|     #assert b0 < b1 |  | ||||||
| 
 |  | ||||||
|     #attach(b0, b1) |  | ||||||
|     ## Pop B1 from stack, but keep B0 on top |  | ||||||
|     #state.ctnts.stack -= 1 |  | ||||||
|     #state.ctnts.stack[0] = b0 |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| do_funcs[SHIFT] = _do_shift |  | ||||||
| do_funcs[REDUCE] = _do_reduce |  | ||||||
| do_funcs[LEFT] = _do_left |  | ||||||
| do_funcs[RIGHT] = _do_right |  | ||||||
| do_funcs[BREAK] = _do_break |  | ||||||
| do_funcs[CONSTITUENT] = _do_constituent |  | ||||||
| do_funcs[ADJUST] = _do_adjust |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1: | cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) except -1: | ||||||
|     if not _can_shift(s): |     if not _can_shift(s): | ||||||
|         return 9000 |         return 9000 | ||||||
|  | @ -257,7 +225,6 @@ cdef int _shift_cost(const Transition* self, const State* s, GoldParse gold) exc | ||||||
|         cost += 1 |         cost += 1 | ||||||
|     return cost |     return cost | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1: | cdef int _right_cost(const Transition* self, const State* s, GoldParse gold) except -1: | ||||||
|     if not _can_right(s): |     if not _can_right(s): | ||||||
|         return 9000 |         return 9000 | ||||||
|  | @ -322,6 +289,77 @@ cdef int _break_cost(const Transition* self, const State* s, GoldParse gold) exc | ||||||
|     return cost |     return cost | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
|  | cdef inline bint _can_shift(const State* s) nogil: | ||||||
|  |     return not at_eol(s) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef inline bint _can_right(const State* s) nogil: | ||||||
|  |     return s.stack_len >= 1 and not at_eol(s) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef inline bint _can_left(const State* s) nogil: | ||||||
|  |     if NON_MONOTONIC: | ||||||
|  |         return s.stack_len >= 1 #and not missing_brackets(s) | ||||||
|  |     else: | ||||||
|  |         return s.stack_len >= 1 and not has_head(get_s0(s)) | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|  | cdef inline bint _can_reduce(const State* s) nogil: | ||||||
|  |     if NON_MONOTONIC: | ||||||
|  |         return s.stack_len >= 2 #and not missing_brackets(s) | ||||||
|  |     else: | ||||||
|  |         return s.stack_len >= 2 and has_head(get_s0(s)) | ||||||
|  | 
 | ||||||
|  | cdef inline bint _can_break(const State* s) nogil: | ||||||
|  |     cdef int i | ||||||
|  |     if not USE_BREAK: | ||||||
|  |         return False | ||||||
|  |     elif at_eol(s): | ||||||
|  |         return False | ||||||
|  |     #elif NON_MONOTONIC: | ||||||
|  |     #    return True | ||||||
|  |     else: | ||||||
|  |         # In the Break transition paper, they have this constraint that prevents | ||||||
|  |         # Break if stack is disconnected. But, if we're doing non-monotonic parsing, | ||||||
|  |         # we prefer to relax this constraint. This is helpful in parsing whole | ||||||
|  |         # documents, because then we don't get stuck with words on the stack. | ||||||
|  |         seen_headless = False | ||||||
|  |         for i in range(s.stack_len): | ||||||
|  |             if s.sent[s.stack[-i]].head == 0: | ||||||
|  |                 if seen_headless: | ||||||
|  |                     return False | ||||||
|  |                 else: | ||||||
|  |                     seen_headless = True | ||||||
|  |         # TODO: Constituency constraints | ||||||
|  |         return True | ||||||
|  | 
 | ||||||
|  | cdef inline bint _can_constituent(const State* s) nogil: | ||||||
|  |     if s.stack_len < 1: | ||||||
|  |         return False | ||||||
|  |     return False | ||||||
|  |     #else: | ||||||
|  |     #    # If all stack elements are popped, can't constituent | ||||||
|  |     #    for i in range(s.ctnts.stack_len): | ||||||
|  |     #        if not s.ctnts.is_popped[-i]: | ||||||
|  |     #            return True | ||||||
|  |     #    else: | ||||||
|  |     #        return False | ||||||
|  | 
 | ||||||
|  | cdef inline bint _can_adjust(const State* s) nogil: | ||||||
|  |     return False | ||||||
|  |     #if s.ctnts.stack_len < 2: | ||||||
|  |     #    return False | ||||||
|  | 
 | ||||||
|  |     #cdef const Constituent* b1 = s.ctnts.stack[-1] | ||||||
|  |     #cdef const Constituent* b0 = s.ctnts.stack[0] | ||||||
|  | 
 | ||||||
|  |     #if (b1.head + b1.head.head) != b0.head: | ||||||
|  |     #    return False | ||||||
|  |     #elif b0.head >= b1.head: | ||||||
|  |     #    return False | ||||||
|  |     #elif b0 >= b1: | ||||||
|  |     #    return False | ||||||
|  | 
 | ||||||
| cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gold) except -1: | cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gold) except -1: | ||||||
|     if not _can_constituent(s): |     if not _can_constituent(s): | ||||||
|         return 9000 |         return 9000 | ||||||
|  | @ -350,7 +388,6 @@ cdef int _constituent_cost(const Transition* self, const State* s, GoldParse gol | ||||||
|     #            loss = 1 # If we see the start position, set loss to 1 |     #            loss = 1 # If we see the start position, set loss to 1 | ||||||
|     #return loss |     #return loss | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| cdef int _adjust_cost(const Transition* self, const State* s, GoldParse gold) except -1: | cdef int _adjust_cost(const Transition* self, const State* s, GoldParse gold) except -1: | ||||||
|     if not _can_adjust(s): |     if not _can_adjust(s): | ||||||
|         return 9000 |         return 9000 | ||||||
|  | @ -383,85 +420,37 @@ cdef int _adjust_cost(const Transition* self, const State* s, GoldParse gold) ex | ||||||
|     #return loss |     #return loss | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| get_cost_funcs[SHIFT] = _shift_cost | cdef int _do_constituent(const Transition* self, State* state) except -1: | ||||||
| get_cost_funcs[REDUCE] = _reduce_cost |  | ||||||
| get_cost_funcs[LEFT] = _left_cost |  | ||||||
| get_cost_funcs[RIGHT] = _right_cost |  | ||||||
| get_cost_funcs[BREAK] = _break_cost |  | ||||||
| get_cost_funcs[CONSTITUENT] = _constituent_cost |  | ||||||
| get_cost_funcs[ADJUST] = _adjust_cost |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cdef inline bint _can_shift(const State* s) nogil: |  | ||||||
|     return not at_eol(s) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cdef inline bint _can_right(const State* s) nogil: |  | ||||||
|     return s.stack_len >= 1 and not at_eol(s) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cdef inline bint _can_left(const State* s) nogil: |  | ||||||
|     if NON_MONOTONIC: |  | ||||||
|         return s.stack_len >= 1 #and not missing_brackets(s) |  | ||||||
|     else: |  | ||||||
|         return s.stack_len >= 1 and not has_head(get_s0(s)) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cdef inline bint _can_reduce(const State* s) nogil: |  | ||||||
|     if NON_MONOTONIC: |  | ||||||
|         return s.stack_len >= 2 #and not missing_brackets(s) |  | ||||||
|     else: |  | ||||||
|         return s.stack_len >= 2 and has_head(get_s0(s)) |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cdef inline bint _can_break(const State* s) nogil: |  | ||||||
|     cdef int i |  | ||||||
|     if not USE_BREAK: |  | ||||||
|         return False |  | ||||||
|     elif at_eol(s): |  | ||||||
|         return False |  | ||||||
|     #elif NON_MONOTONIC: |  | ||||||
|     #    return True |  | ||||||
|     else: |  | ||||||
|         # In the Break transition paper, they have this constraint that prevents |  | ||||||
|         # Break if stack is disconnected. But, if we're doing non-monotonic parsing, |  | ||||||
|         # we prefer to relax this constraint. This is helpful in parsing whole |  | ||||||
|         # documents, because then we don't get stuck with words on the stack. |  | ||||||
|         seen_headless = False |  | ||||||
|         for i in range(s.stack_len): |  | ||||||
|             if s.sent[s.stack[-i]].head == 0: |  | ||||||
|                 if seen_headless: |  | ||||||
|                     return False |  | ||||||
|                 else: |  | ||||||
|                     seen_headless = True |  | ||||||
|         # TODO: Constituency constraints |  | ||||||
|         return True |  | ||||||
| 
 |  | ||||||
| 
 |  | ||||||
| cdef inline bint _can_constituent(const State* s) nogil: |  | ||||||
|     if s.stack_len < 1: |  | ||||||
|         return False |  | ||||||
|     return False |     return False | ||||||
|     #else: |     #cdef Constituent* bracket = new_bracket(state.ctnts) | ||||||
|     #    # If all stack elements are popped, can't constituent | 
 | ||||||
|     #    for i in range(s.ctnts.stack_len): |     #bracket.parent = NULL | ||||||
|     #        if not s.ctnts.is_popped[-i]: |     #bracket.label = self.label | ||||||
|     #            return True |     #bracket.head = get_s0(state) | ||||||
|     #    else: |     #bracket.length = 0 | ||||||
|     #        return False | 
 | ||||||
|  |     #attach(bracket, state.ctnts.stack) | ||||||
|  |     # Attach rightward children. They're in the brackets array somewhere | ||||||
|  |     # between here and B0. | ||||||
|  |     #cdef Constituent* node | ||||||
|  |     #cdef const TokenC* node_gov | ||||||
|  |     #for i in range(1, bracket - state.ctnts.stack): | ||||||
|  |     #    node = bracket - i | ||||||
|  |     #    node_gov = node.head + node.head.head | ||||||
|  |     #    if node_gov == bracket.head: | ||||||
|  |     #        attach(bracket, node) | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef inline bint _can_adjust(const State* s) nogil: | cdef int _do_adjust(const Transition* self, State* state) except -1: | ||||||
|     return False |     return False | ||||||
|     #if s.ctnts.stack_len < 2: |     #cdef Constituent* b0 = state.ctnts.stack[0] | ||||||
|     #    return False |     #cdef Constituent* b1 = state.ctnts.stack[1] | ||||||
| 
 | 
 | ||||||
|     #cdef const Constituent* b1 = s.ctnts.stack[-1] |     #assert (b1.head + b1.head.head) == b0.head | ||||||
|     #cdef const Constituent* b0 = s.ctnts.stack[0] |     #assert b0.head < b1.head | ||||||
|  |     #assert b0 < b1 | ||||||
| 
 | 
 | ||||||
|     #if (b1.head + b1.head.head) != b0.head: |     #attach(b0, b1) | ||||||
|     #    return False |     ## Pop B1 from stack, but keep B0 on top | ||||||
|     #elif b0.head >= b1.head: |     #state.ctnts.stack -= 1 | ||||||
|     #    return False |     #state.ctnts.stack[0] = b0 | ||||||
|     #elif b0 >= b1: |  | ||||||
|     #    return False |  | ||||||
|  |  | ||||||
|  | @ -140,12 +140,11 @@ cdef class BiluoPushDown(TransitionSystem): | ||||||
|         t.score = score |         t.score = score | ||||||
|         return t |         return t | ||||||
| 
 | 
 | ||||||
|     cdef bint* get_valid(self, const State* s) except NULL: |     cdef int set_valid(self, bint* output, const State* s) except -1: | ||||||
|         cdef int i |         cdef int i | ||||||
|         for i in range(self.n_moves): |         for i in range(self.n_moves): | ||||||
|             m = &self.c[i] |             m = &self.c[i] | ||||||
|             self._is_valid[i] = _is_valid(m.move, m.label, s) |             output[i] = _is_valid(m.move, m.label, s) | ||||||
|         return self._is_valid |  | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef int _get_cost(const Transition* self, const State* s, GoldParse gold) except -1: | cdef int _get_cost(const Transition* self, const State* s, GoldParse gold) except -1: | ||||||
|  |  | ||||||
|  | @ -40,7 +40,7 @@ cdef class TransitionSystem: | ||||||
| 
 | 
 | ||||||
|     cdef Transition init_transition(self, int clas, int move, int label) except * |     cdef Transition init_transition(self, int clas, int move, int label) except * | ||||||
| 
 | 
 | ||||||
|     cdef bint* get_valid(self, const State* state) except NULL |     cdef int set_valid(self, bint* output, const State* state) except -1 | ||||||
| 
 | 
 | ||||||
|     cdef Transition best_valid(self, const weight_t* scores, const State* state) except * |     cdef Transition best_valid(self, const weight_t* scores, const State* state) except * | ||||||
| 
 | 
 | ||||||
|  |  | ||||||
|  | @ -45,7 +45,7 @@ cdef class TransitionSystem: | ||||||
|     cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: |     cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
|      |      | ||||||
|     cdef bint* get_valid(self, const State* state) except NULL: |     cdef int set_valid(self, bint* output, const State* state) except -1: | ||||||
|         raise NotImplementedError |         raise NotImplementedError | ||||||
| 
 | 
 | ||||||
|     cdef Transition best_gold(self, const weight_t* scores, const State* s, |     cdef Transition best_gold(self, const weight_t* scores, const State* s, | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user