mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 01:48:04 +03:00 
			
		
		
		
	Fix arc eager label costs for uint64
This commit is contained in:
		
							parent
							
								
									b127645afc
								
							
						
					
					
						commit
						be4a640f0c
					
				| 
						 | 
					@ -8,6 +8,7 @@ from .syntax.transition_system cimport Transition
 | 
				
			||||||
cdef struct GoldParseC:
 | 
					cdef struct GoldParseC:
 | 
				
			||||||
    int* tags
 | 
					    int* tags
 | 
				
			||||||
    int* heads
 | 
					    int* heads
 | 
				
			||||||
 | 
					    int* has_dep
 | 
				
			||||||
    attr_t* labels
 | 
					    attr_t* labels
 | 
				
			||||||
    int** brackets
 | 
					    int** brackets
 | 
				
			||||||
    Transition* ner
 | 
					    Transition* ner
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -385,6 +385,7 @@ cdef class GoldParse:
 | 
				
			||||||
        self.c.tags = <int*>self.mem.alloc(len(doc), sizeof(int))
 | 
					        self.c.tags = <int*>self.mem.alloc(len(doc), sizeof(int))
 | 
				
			||||||
        self.c.heads = <int*>self.mem.alloc(len(doc), sizeof(int))
 | 
					        self.c.heads = <int*>self.mem.alloc(len(doc), sizeof(int))
 | 
				
			||||||
        self.c.labels = <attr_t*>self.mem.alloc(len(doc), sizeof(attr_t))
 | 
					        self.c.labels = <attr_t*>self.mem.alloc(len(doc), sizeof(attr_t))
 | 
				
			||||||
 | 
					        self.c.has_dep = <int*>self.mem.alloc(len(doc), sizeof(int))
 | 
				
			||||||
        self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
 | 
					        self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
        self.words = [None] * len(doc)
 | 
					        self.words = [None] * len(doc)
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
| 
						 | 
					@ -60,7 +60,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
 | 
				
			||||||
            cost += 1
 | 
					            cost += 1
 | 
				
			||||||
        if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
 | 
					        if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
 | 
				
			||||||
            cost += 1
 | 
					            cost += 1
 | 
				
			||||||
    cost += Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0
 | 
					    cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
 | 
				
			||||||
    return cost
 | 
					    return cost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -73,7 +73,7 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog
 | 
				
			||||||
        cost += gold.heads[target] == B_i
 | 
					        cost += gold.heads[target] == B_i
 | 
				
			||||||
        if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
 | 
					        if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
 | 
				
			||||||
            break
 | 
					            break
 | 
				
			||||||
    if Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0:
 | 
					    if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
 | 
				
			||||||
        cost += 1
 | 
					        cost += 1
 | 
				
			||||||
    return cost
 | 
					    return cost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -84,14 +84,14 @@ cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int c
 | 
				
			||||||
    elif stcls.H(child) == gold.heads[child]:
 | 
					    elif stcls.H(child) == gold.heads[child]:
 | 
				
			||||||
        return 1
 | 
					        return 1
 | 
				
			||||||
    # Head in buffer
 | 
					    # Head in buffer
 | 
				
			||||||
    elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != -1:
 | 
					    elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != 0:
 | 
				
			||||||
        return 1
 | 
					        return 1
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        return 0
 | 
					        return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
 | 
					cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
 | 
				
			||||||
    if gold.labels[child] == -1:
 | 
					    if not gold.has_dep[child]:
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
    elif gold.heads[child] == head:
 | 
					    elif gold.heads[child] == head:
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
| 
						 | 
					@ -100,9 +100,9 @@ cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t label) nogil:
 | 
					cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t label) nogil:
 | 
				
			||||||
    if gold.labels[child] == -1:
 | 
					    if not gold.has_dep[child]:
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
    elif label == -1:
 | 
					    elif label == 0:
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
    elif gold.labels[child] == label:
 | 
					    elif gold.labels[child] == label:
 | 
				
			||||||
        return True
 | 
					        return True
 | 
				
			||||||
| 
						 | 
					@ -111,8 +111,7 @@ cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t labe
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
 | 
					cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
 | 
				
			||||||
    return gold.labels[word] == -1 or gold.heads[word] == word
 | 
					    return gold.heads[word] == word or not gold.has_dep[word]
 | 
				
			||||||
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef class Shift:
 | 
					cdef class Shift:
 | 
				
			||||||
    @staticmethod
 | 
					    @staticmethod
 | 
				
			||||||
| 
						 | 
					@ -165,7 +164,7 @@ cdef class Reduce:
 | 
				
			||||||
                    cost -= 1
 | 
					                    cost -= 1
 | 
				
			||||||
                if gold.heads[S_i] == st.S(0):
 | 
					                if gold.heads[S_i] == st.S(0):
 | 
				
			||||||
                    cost -= 1
 | 
					                    cost -= 1
 | 
				
			||||||
            if Break.is_valid(st.c, -1) and Break.move_cost(st, gold) == 0:
 | 
					            if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
 | 
				
			||||||
                cost -= 1
 | 
					                cost -= 1
 | 
				
			||||||
        return cost
 | 
					        return cost
 | 
				
			||||||
 | 
					
 | 
				
			||||||
| 
						 | 
					@ -285,9 +284,9 @@ cdef class Break:
 | 
				
			||||||
        return 0
 | 
					        return 0
 | 
				
			||||||
 | 
					
 | 
				
			||||||
cdef int _get_root(int word, const GoldParseC* gold) nogil:
 | 
					cdef int _get_root(int word, const GoldParseC* gold) nogil:
 | 
				
			||||||
    while gold.heads[word] != word and gold.labels[word] != -1 and word >= 0:
 | 
					    while gold.heads[word] != word and not gold.has_dep[word] and word >= 0:
 | 
				
			||||||
        word = gold.heads[word]
 | 
					        word = gold.heads[word]
 | 
				
			||||||
    if gold.labels[word] == -1:
 | 
					    if not gold.has_dep[word]:
 | 
				
			||||||
        return -1
 | 
					        return -1
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
        return word
 | 
					        return word
 | 
				
			||||||
| 
						 | 
					@ -363,9 +362,10 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
        for i in range(gold.length):
 | 
					        for i in range(gold.length):
 | 
				
			||||||
            if gold.heads[i] is None: # Missing values
 | 
					            if gold.heads[i] is None: # Missing values
 | 
				
			||||||
                gold.c.heads[i] = i
 | 
					                gold.c.heads[i] = i
 | 
				
			||||||
                gold.c.labels[i] = -1
 | 
					                gold.c.has_dep[i] = False
 | 
				
			||||||
            else:
 | 
					            else:
 | 
				
			||||||
                label = gold.labels[i]
 | 
					                label = gold.labels[i]
 | 
				
			||||||
 | 
					                gold.c.has_dep[i] = True
 | 
				
			||||||
                if label.upper() == 'ROOT':
 | 
					                if label.upper() == 'ROOT':
 | 
				
			||||||
                    label = 'ROOT'
 | 
					                    label = 'ROOT'
 | 
				
			||||||
                gold.c.heads[i] = gold.heads[i]
 | 
					                gold.c.heads[i] = gold.heads[i]
 | 
				
			||||||
| 
						 | 
					@ -440,18 +440,19 @@ cdef class ArcEager(TransitionSystem):
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef int set_valid(self, int* output, const StateC* st) nogil:
 | 
					    cdef int set_valid(self, int* output, const StateC* st) nogil:
 | 
				
			||||||
        cdef bint[N_MOVES] is_valid
 | 
					        cdef bint[N_MOVES] is_valid
 | 
				
			||||||
        is_valid[SHIFT] = Shift.is_valid(st, -1)
 | 
					        is_valid[SHIFT] = Shift.is_valid(st, 0)
 | 
				
			||||||
        is_valid[REDUCE] = Reduce.is_valid(st, -1)
 | 
					        is_valid[REDUCE] = Reduce.is_valid(st, 0)
 | 
				
			||||||
        is_valid[LEFT] = LeftArc.is_valid(st, -1)
 | 
					        is_valid[LEFT] = LeftArc.is_valid(st, 0)
 | 
				
			||||||
        is_valid[RIGHT] = RightArc.is_valid(st, -1)
 | 
					        is_valid[RIGHT] = RightArc.is_valid(st, 0)
 | 
				
			||||||
        is_valid[BREAK] = Break.is_valid(st, -1)
 | 
					        is_valid[BREAK] = Break.is_valid(st, 0)
 | 
				
			||||||
        cdef int i
 | 
					        cdef int i
 | 
				
			||||||
        for i in range(self.n_moves):
 | 
					        for i in range(self.n_moves):
 | 
				
			||||||
            output[i] = is_valid[self.c[i].move]
 | 
					            output[i] = is_valid[self.c[i].move]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    cdef int set_costs(self, int* is_valid, weight_t* costs,
 | 
					    cdef int set_costs(self, int* is_valid, weight_t* costs,
 | 
				
			||||||
                       StateClass stcls, GoldParse gold) except -1:
 | 
					                       StateClass stcls, GoldParse gold) except -1:
 | 
				
			||||||
        cdef int i, move, label
 | 
					        cdef int i, move
 | 
				
			||||||
 | 
					        cdef attr_t label
 | 
				
			||||||
        cdef label_cost_func_t[N_MOVES] label_cost_funcs
 | 
					        cdef label_cost_func_t[N_MOVES] label_cost_funcs
 | 
				
			||||||
        cdef move_cost_func_t[N_MOVES] move_cost_funcs
 | 
					        cdef move_cost_func_t[N_MOVES] move_cost_funcs
 | 
				
			||||||
        cdef weight_t[N_MOVES] move_costs
 | 
					        cdef weight_t[N_MOVES] move_costs
 | 
				
			||||||
| 
						 | 
					
 | 
				
			||||||
		Loading…
	
		Reference in New Issue
	
	Block a user