mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-10-31 16:07:41 +03:00 
			
		
		
		
	* Begin adding stateclass to ArcEager
This commit is contained in:
		
							parent
							
								
									ba10fd8af5
								
							
						
					
					
						commit
						2b9629ed62
					
				|  | @ -1,6 +1,9 @@ | ||||||
| # cython: profile=True | # cython: profile=True | ||||||
| from __future__ import unicode_literals | from __future__ import unicode_literals | ||||||
| 
 | 
 | ||||||
|  | import ctypes | ||||||
|  | import os | ||||||
|  | 
 | ||||||
| from ._state cimport State | from ._state cimport State | ||||||
| from ._state cimport has_head, get_idx, get_s0, get_n0, get_left, get_right | from ._state cimport has_head, get_idx, get_s0, get_n0, get_left, get_right | ||||||
| from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep | from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep | ||||||
|  | @ -15,6 +18,12 @@ from .transition_system cimport move_cost_func_t, label_cost_func_t | ||||||
| from ..gold cimport GoldParse | from ..gold cimport GoldParse | ||||||
| from ..gold cimport GoldParseC | from ..gold cimport GoldParseC | ||||||
| 
 | 
 | ||||||
|  | from libc.stdint cimport uint32_t | ||||||
|  | from libc.string cimport memcpy | ||||||
|  | 
 | ||||||
|  | from cymem.cymem cimport Pool | ||||||
|  | from ..stateclass cimport StateClass | ||||||
|  | 
 | ||||||
| 
 | 
 | ||||||
| DEF NON_MONOTONIC = True | DEF NON_MONOTONIC = True | ||||||
| DEF USE_BREAK = True | DEF USE_BREAK = True | ||||||
|  | @ -78,7 +87,6 @@ cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child) | ||||||
|         return 0 |         return 0 | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| 
 |  | ||||||
| cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1: | cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1: | ||||||
|     if gold.labels[child] == -1: |     if gold.labels[child] == -1: | ||||||
|         return True |         return True | ||||||
|  | @ -110,6 +118,11 @@ cdef class Shift: | ||||||
|     cdef bint is_valid(const State* s, int label) except -1: |     cdef bint is_valid(const State* s, int label) except -1: | ||||||
|         return not at_eol(s) |         return not at_eol(s) | ||||||
| 
 | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     cdef bint _new_is_valid(StateClass st, int label) except -1: | ||||||
|  |         return not st.eol() | ||||||
|  | 
 | ||||||
|  | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef int transition(State* state, int label) except -1: |     cdef int transition(State* state, int label) except -1: | ||||||
|         # Set the dep label, in case we need it after we reduce |         # Set the dep label, in case we need it after we reduce | ||||||
|  | @ -133,6 +146,13 @@ cdef class Shift: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef class Reduce: | cdef class Reduce: | ||||||
|  |     @staticmethod | ||||||
|  |     cdef bint _new_is_valid(StateClass st, int label) except -1: | ||||||
|  |         if NON_MONOTONIC: | ||||||
|  |             return st.stack_depth() >= 2 #and not missing_brackets(s) | ||||||
|  |         else: | ||||||
|  |             return st.stack_depth() >= 2 and st.has_head(st.S(0)) | ||||||
|  | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef bint is_valid(const State* s, int label) except -1: |     cdef bint is_valid(const State* s, int label) except -1: | ||||||
|         if NON_MONOTONIC: |         if NON_MONOTONIC: | ||||||
|  | @ -165,6 +185,13 @@ cdef class Reduce: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef class LeftArc: | cdef class LeftArc: | ||||||
|  |     @staticmethod | ||||||
|  |     cdef bint _new_is_valid(StateClass st, int label) except -1: | ||||||
|  |         if NON_MONOTONIC: | ||||||
|  |             return st.stack_depth() >= 1 #and not missing_brackets(s) | ||||||
|  |         else: | ||||||
|  |             return st.stack_depth() >= 1 and not st.has_head(st.S(0)) | ||||||
|  | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef bint is_valid(const State* s, int label) except -1: |     cdef bint is_valid(const State* s, int label) except -1: | ||||||
|         if NON_MONOTONIC: |         if NON_MONOTONIC: | ||||||
|  | @ -206,6 +233,10 @@ cdef class RightArc: | ||||||
|     cdef bint is_valid(const State* s, int label) except -1: |     cdef bint is_valid(const State* s, int label) except -1: | ||||||
|         return s.stack_len >= 1 and not at_eol(s) |         return s.stack_len >= 1 and not at_eol(s) | ||||||
| 
 | 
 | ||||||
|  |     @staticmethod | ||||||
|  |     cdef bint _new_is_valid(StateClass st, int label) except -1: | ||||||
|  |         return st.stack_depth() >= 1 and not st.eol() | ||||||
|  | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef int transition(State* state, int label) except -1: |     cdef int transition(State* state, int label) except -1: | ||||||
|         add_dep(state, state.stack[0], state.i, label) |         add_dep(state, state.stack[0], state.i, label) | ||||||
|  | @ -230,6 +261,32 @@ cdef class RightArc: | ||||||
| 
 | 
 | ||||||
| 
 | 
 | ||||||
| cdef class Break: | cdef class Break: | ||||||
|  |     @staticmethod | ||||||
|  |     cdef bint _new_is_valid(StateClass st, int label) except -1: | ||||||
|  |         cdef int i | ||||||
|  |         if not USE_BREAK: | ||||||
|  |             return False | ||||||
|  |         elif st.eol(): | ||||||
|  |             return False | ||||||
|  |         elif st.stack_depth() < 1: | ||||||
|  |             return False | ||||||
|  |         elif NON_MONOTONIC: | ||||||
|  |             return True | ||||||
|  |         else: | ||||||
|  |             # In the Break transition paper, they have this constraint that prevents | ||||||
|  |             # Break if stack is disconnected. But, if we're doing non-monotonic parsing, | ||||||
|  |             # we prefer to relax this constraint. This is helpful in parsing whole | ||||||
|  |             # documents, because then we don't get stuck with words on the stack. | ||||||
|  |             seen_headless = False | ||||||
|  |             for i in range(st.stack_depth()): | ||||||
|  |                 if not st.has_head(st.S(i)): | ||||||
|  |                     if seen_headless: | ||||||
|  |                         return False | ||||||
|  |                     else: | ||||||
|  |                         seen_headless = True | ||||||
|  |             # TODO: Constituency constraints | ||||||
|  |             return True | ||||||
|  | 
 | ||||||
|     @staticmethod |     @staticmethod | ||||||
|     cdef bint is_valid(const State* s, int label) except -1: |     cdef bint is_valid(const State* s, int label) except -1: | ||||||
|         cdef int i |         cdef int i | ||||||
|  | @ -584,14 +641,17 @@ cdef class ArcEager(TransitionSystem): | ||||||
|             output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label) |             output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label) | ||||||
| 
 | 
 | ||||||
|     cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: |     cdef Transition best_valid(self, const weight_t* scores, const State* s) except *: | ||||||
|  |         cdef Pool mem = Pool() | ||||||
|  |         cdef StateClass stcls = StateClass.from_struct(mem, s) | ||||||
|         cdef bint[N_MOVES] is_valid |         cdef bint[N_MOVES] is_valid | ||||||
|         is_valid[SHIFT] = Shift.is_valid(s, -1) |         #is_valid[SHIFT] = Shift.is_valid(s, -1) | ||||||
|         is_valid[REDUCE] = Reduce.is_valid(s, -1) |         is_valid[SHIFT] = Shift._new_is_valid(stcls, -1) | ||||||
|         is_valid[LEFT] = LeftArc.is_valid(s, -1) |         is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1) | ||||||
|         is_valid[RIGHT] = RightArc.is_valid(s, -1) |         is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1) | ||||||
|         is_valid[BREAK] = Break.is_valid(s, -1) |         is_valid[RIGHT] = RightArc._new_is_valid(stcls, -1) | ||||||
|         is_valid[CONSTITUENT] = Constituent.is_valid(s, -1) |         is_valid[BREAK] = Break._new_is_valid(stcls, -1) | ||||||
|         is_valid[ADJUST] = Adjust.is_valid(s, -1) |         is_valid[CONSTITUENT] = False # Constituent._new_is_valid(s, -1) | ||||||
|  |         is_valid[ADJUST] = False # Adjust._new_is_valid(s, -1) | ||||||
|         cdef Transition best |         cdef Transition best | ||||||
|         cdef weight_t score = MIN_SCORE |         cdef weight_t score = MIN_SCORE | ||||||
|         cdef int i |         cdef int i | ||||||
|  |  | ||||||
		Loading…
	
		Reference in New Issue
	
	Block a user