mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	* Begin adding stateclass to ArcEager
This commit is contained in:
		
							parent
							
								
									ba10fd8af5
								
							
						
					
					
						commit
						2b9629ed62
					
				| 
						 | 
				
			
			@ -1,6 +1,9 @@
 | 
			
		|||
# cython: profile=True
 | 
			
		||||
from __future__ import unicode_literals
 | 
			
		||||
 | 
			
		||||
import ctypes
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from ._state cimport State
 | 
			
		||||
from ._state cimport has_head, get_idx, get_s0, get_n0, get_left, get_right
 | 
			
		||||
from ._state cimport is_final, at_eol, pop_stack, push_stack, add_dep
 | 
			
		||||
| 
						 | 
				
			
			@ -15,6 +18,12 @@ from .transition_system cimport move_cost_func_t, label_cost_func_t
 | 
			
		|||
from ..gold cimport GoldParse
 | 
			
		||||
from ..gold cimport GoldParseC
 | 
			
		||||
 | 
			
		||||
from libc.stdint cimport uint32_t
 | 
			
		||||
from libc.string cimport memcpy
 | 
			
		||||
 | 
			
		||||
from cymem.cymem cimport Pool
 | 
			
		||||
from ..stateclass cimport StateClass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
DEF NON_MONOTONIC = True
 | 
			
		||||
DEF USE_BREAK = True
 | 
			
		||||
| 
						 | 
				
			
			@ -78,7 +87,6 @@ cdef int arc_cost(const State* st, const GoldParseC* gold, int head, int child)
 | 
			
		|||
        return 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) except -1:
 | 
			
		||||
    if gold.labels[child] == -1:
 | 
			
		||||
        return True
 | 
			
		||||
| 
						 | 
				
			
			@ -110,6 +118,11 @@ cdef class Shift:
 | 
			
		|||
    cdef bint is_valid(const State* s, int label) except -1:
 | 
			
		||||
        return not at_eol(s)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    cdef bint _new_is_valid(StateClass st, int label) except -1:
 | 
			
		||||
        return not st.eol()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    cdef int transition(State* state, int label) except -1:
 | 
			
		||||
        # Set the dep label, in case we need it after we reduce
 | 
			
		||||
| 
						 | 
				
			
			@ -133,6 +146,13 @@ cdef class Shift:
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
cdef class Reduce:
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    cdef bint _new_is_valid(StateClass st, int label) except -1:
 | 
			
		||||
        if NON_MONOTONIC:
 | 
			
		||||
            return st.stack_depth() >= 2 #and not missing_brackets(s)
 | 
			
		||||
        else:
 | 
			
		||||
            return st.stack_depth() >= 2 and st.has_head(st.S(0))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    cdef bint is_valid(const State* s, int label) except -1:
 | 
			
		||||
        if NON_MONOTONIC:
 | 
			
		||||
| 
						 | 
				
			
			@ -165,6 +185,13 @@ cdef class Reduce:
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
cdef class LeftArc:
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    cdef bint _new_is_valid(StateClass st, int label) except -1:
 | 
			
		||||
        if NON_MONOTONIC:
 | 
			
		||||
            return st.stack_depth() >= 1 #and not missing_brackets(s)
 | 
			
		||||
        else:
 | 
			
		||||
            return st.stack_depth() >= 1 and not st.has_head(st.S(0))
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    cdef bint is_valid(const State* s, int label) except -1:
 | 
			
		||||
        if NON_MONOTONIC:
 | 
			
		||||
| 
						 | 
				
			
			@ -206,6 +233,10 @@ cdef class RightArc:
 | 
			
		|||
    cdef bint is_valid(const State* s, int label) except -1:
 | 
			
		||||
        return s.stack_len >= 1 and not at_eol(s)
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    cdef bint _new_is_valid(StateClass st, int label) except -1:
 | 
			
		||||
        return st.stack_depth() >= 1 and not st.eol()
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    cdef int transition(State* state, int label) except -1:
 | 
			
		||||
        add_dep(state, state.stack[0], state.i, label)
 | 
			
		||||
| 
						 | 
				
			
			@ -230,6 +261,32 @@ cdef class RightArc:
 | 
			
		|||
 | 
			
		||||
 | 
			
		||||
cdef class Break:
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    cdef bint _new_is_valid(StateClass st, int label) except -1:
 | 
			
		||||
        cdef int i
 | 
			
		||||
        if not USE_BREAK:
 | 
			
		||||
            return False
 | 
			
		||||
        elif st.eol():
 | 
			
		||||
            return False
 | 
			
		||||
        elif st.stack_depth() < 1:
 | 
			
		||||
            return False
 | 
			
		||||
        elif NON_MONOTONIC:
 | 
			
		||||
            return True
 | 
			
		||||
        else:
 | 
			
		||||
            # In the Break transition paper, they have this constraint that prevents
 | 
			
		||||
            # Break if stack is disconnected. But, if we're doing non-monotonic parsing,
 | 
			
		||||
            # we prefer to relax this constraint. This is helpful in parsing whole
 | 
			
		||||
            # documents, because then we don't get stuck with words on the stack.
 | 
			
		||||
            seen_headless = False
 | 
			
		||||
            for i in range(st.stack_depth()):
 | 
			
		||||
                if not st.has_head(st.S(i)):
 | 
			
		||||
                    if seen_headless:
 | 
			
		||||
                        return False
 | 
			
		||||
                    else:
 | 
			
		||||
                        seen_headless = True
 | 
			
		||||
            # TODO: Constituency constraints
 | 
			
		||||
            return True
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    cdef bint is_valid(const State* s, int label) except -1:
 | 
			
		||||
        cdef int i
 | 
			
		||||
| 
						 | 
				
			
			@ -584,14 +641,17 @@ cdef class ArcEager(TransitionSystem):
 | 
			
		|||
            output[i] = move_costs[move] + label_cost_funcs[move](s, &gold.c, label)
 | 
			
		||||
 | 
			
		||||
    cdef Transition best_valid(self, const weight_t* scores, const State* s) except *:
 | 
			
		||||
        cdef Pool mem = Pool()
 | 
			
		||||
        cdef StateClass stcls = StateClass.from_struct(mem, s)
 | 
			
		||||
        cdef bint[N_MOVES] is_valid
 | 
			
		||||
        is_valid[SHIFT] = Shift.is_valid(s, -1)
 | 
			
		||||
        is_valid[REDUCE] = Reduce.is_valid(s, -1)
 | 
			
		||||
        is_valid[LEFT] = LeftArc.is_valid(s, -1)
 | 
			
		||||
        is_valid[RIGHT] = RightArc.is_valid(s, -1)
 | 
			
		||||
        is_valid[BREAK] = Break.is_valid(s, -1)
 | 
			
		||||
        is_valid[CONSTITUENT] = Constituent.is_valid(s, -1)
 | 
			
		||||
        is_valid[ADJUST] = Adjust.is_valid(s, -1)
 | 
			
		||||
        #is_valid[SHIFT] = Shift.is_valid(s, -1)
 | 
			
		||||
        is_valid[SHIFT] = Shift._new_is_valid(stcls, -1)
 | 
			
		||||
        is_valid[REDUCE] = Reduce._new_is_valid(stcls, -1)
 | 
			
		||||
        is_valid[LEFT] = LeftArc._new_is_valid(stcls, -1)
 | 
			
		||||
        is_valid[RIGHT] = RightArc._new_is_valid(stcls, -1)
 | 
			
		||||
        is_valid[BREAK] = Break._new_is_valid(stcls, -1)
 | 
			
		||||
        is_valid[CONSTITUENT] = False # Constituent._new_is_valid(s, -1)
 | 
			
		||||
        is_valid[ADJUST] = False # Adjust._new_is_valid(s, -1)
 | 
			
		||||
        cdef Transition best
 | 
			
		||||
        cdef weight_t score = MIN_SCORE
 | 
			
		||||
        cdef int i
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue
	
	Block a user