mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-10 00:20:35 +03:00
Fix arc eager label costs for uint64
This commit is contained in:
parent
b127645afc
commit
be4a640f0c
|
@ -8,6 +8,7 @@ from .syntax.transition_system cimport Transition
|
||||||
cdef struct GoldParseC:
|
cdef struct GoldParseC:
|
||||||
int* tags
|
int* tags
|
||||||
int* heads
|
int* heads
|
||||||
|
int* has_dep
|
||||||
attr_t* labels
|
attr_t* labels
|
||||||
int** brackets
|
int** brackets
|
||||||
Transition* ner
|
Transition* ner
|
||||||
|
|
|
@ -385,6 +385,7 @@ cdef class GoldParse:
|
||||||
self.c.tags = <int*>self.mem.alloc(len(doc), sizeof(int))
|
self.c.tags = <int*>self.mem.alloc(len(doc), sizeof(int))
|
||||||
self.c.heads = <int*>self.mem.alloc(len(doc), sizeof(int))
|
self.c.heads = <int*>self.mem.alloc(len(doc), sizeof(int))
|
||||||
self.c.labels = <attr_t*>self.mem.alloc(len(doc), sizeof(attr_t))
|
self.c.labels = <attr_t*>self.mem.alloc(len(doc), sizeof(attr_t))
|
||||||
|
self.c.has_dep = <int*>self.mem.alloc(len(doc), sizeof(int))
|
||||||
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
|
self.c.ner = <Transition*>self.mem.alloc(len(doc), sizeof(Transition))
|
||||||
|
|
||||||
self.words = [None] * len(doc)
|
self.words = [None] * len(doc)
|
||||||
|
|
|
@ -60,7 +60,7 @@ cdef weight_t push_cost(StateClass stcls, const GoldParseC* gold, int target) no
|
||||||
cost += 1
|
cost += 1
|
||||||
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
|
if gold.heads[S_i] == target and (NON_MONOTONIC or not stcls.has_head(S_i)):
|
||||||
cost += 1
|
cost += 1
|
||||||
cost += Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0
|
cost += Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
|
||||||
|
@ -73,7 +73,7 @@ cdef weight_t pop_cost(StateClass stcls, const GoldParseC* gold, int target) nog
|
||||||
cost += gold.heads[target] == B_i
|
cost += gold.heads[target] == B_i
|
||||||
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
|
if gold.heads[B_i] == B_i or gold.heads[B_i] < target:
|
||||||
break
|
break
|
||||||
if Break.is_valid(stcls.c, -1) and Break.move_cost(stcls, gold) == 0:
|
if Break.is_valid(stcls.c, 0) and Break.move_cost(stcls, gold) == 0:
|
||||||
cost += 1
|
cost += 1
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
@ -84,14 +84,14 @@ cdef weight_t arc_cost(StateClass stcls, const GoldParseC* gold, int head, int c
|
||||||
elif stcls.H(child) == gold.heads[child]:
|
elif stcls.H(child) == gold.heads[child]:
|
||||||
return 1
|
return 1
|
||||||
# Head in buffer
|
# Head in buffer
|
||||||
elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != -1:
|
elif gold.heads[child] >= stcls.B(0) and stcls.B(1) != 0:
|
||||||
return 1
|
return 1
|
||||||
else:
|
else:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
|
cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
|
||||||
if gold.labels[child] == -1:
|
if not gold.has_dep[child]:
|
||||||
return True
|
return True
|
||||||
elif gold.heads[child] == head:
|
elif gold.heads[child] == head:
|
||||||
return True
|
return True
|
||||||
|
@ -100,9 +100,9 @@ cdef bint arc_is_gold(const GoldParseC* gold, int head, int child) nogil:
|
||||||
|
|
||||||
|
|
||||||
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t label) nogil:
|
cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t label) nogil:
|
||||||
if gold.labels[child] == -1:
|
if not gold.has_dep[child]:
|
||||||
return True
|
return True
|
||||||
elif label == -1:
|
elif label == 0:
|
||||||
return True
|
return True
|
||||||
elif gold.labels[child] == label:
|
elif gold.labels[child] == label:
|
||||||
return True
|
return True
|
||||||
|
@ -111,8 +111,7 @@ cdef bint label_is_gold(const GoldParseC* gold, int head, int child, attr_t labe
|
||||||
|
|
||||||
|
|
||||||
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
cdef bint _is_gold_root(const GoldParseC* gold, int word) nogil:
|
||||||
return gold.labels[word] == -1 or gold.heads[word] == word
|
return gold.heads[word] == word or not gold.has_dep[word]
|
||||||
|
|
||||||
|
|
||||||
cdef class Shift:
|
cdef class Shift:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -165,7 +164,7 @@ cdef class Reduce:
|
||||||
cost -= 1
|
cost -= 1
|
||||||
if gold.heads[S_i] == st.S(0):
|
if gold.heads[S_i] == st.S(0):
|
||||||
cost -= 1
|
cost -= 1
|
||||||
if Break.is_valid(st.c, -1) and Break.move_cost(st, gold) == 0:
|
if Break.is_valid(st.c, 0) and Break.move_cost(st, gold) == 0:
|
||||||
cost -= 1
|
cost -= 1
|
||||||
return cost
|
return cost
|
||||||
|
|
||||||
|
@ -285,9 +284,9 @@ cdef class Break:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
cdef int _get_root(int word, const GoldParseC* gold) nogil:
|
cdef int _get_root(int word, const GoldParseC* gold) nogil:
|
||||||
while gold.heads[word] != word and gold.labels[word] != -1 and word >= 0:
|
while gold.heads[word] != word and not gold.has_dep[word] and word >= 0:
|
||||||
word = gold.heads[word]
|
word = gold.heads[word]
|
||||||
if gold.labels[word] == -1:
|
if not gold.has_dep[word]:
|
||||||
return -1
|
return -1
|
||||||
else:
|
else:
|
||||||
return word
|
return word
|
||||||
|
@ -363,9 +362,10 @@ cdef class ArcEager(TransitionSystem):
|
||||||
for i in range(gold.length):
|
for i in range(gold.length):
|
||||||
if gold.heads[i] is None: # Missing values
|
if gold.heads[i] is None: # Missing values
|
||||||
gold.c.heads[i] = i
|
gold.c.heads[i] = i
|
||||||
gold.c.labels[i] = -1
|
gold.c.has_dep[i] = False
|
||||||
else:
|
else:
|
||||||
label = gold.labels[i]
|
label = gold.labels[i]
|
||||||
|
gold.c.has_dep[i] = True
|
||||||
if label.upper() == 'ROOT':
|
if label.upper() == 'ROOT':
|
||||||
label = 'ROOT'
|
label = 'ROOT'
|
||||||
gold.c.heads[i] = gold.heads[i]
|
gold.c.heads[i] = gold.heads[i]
|
||||||
|
@ -440,18 +440,19 @@ cdef class ArcEager(TransitionSystem):
|
||||||
|
|
||||||
cdef int set_valid(self, int* output, const StateC* st) nogil:
|
cdef int set_valid(self, int* output, const StateC* st) nogil:
|
||||||
cdef bint[N_MOVES] is_valid
|
cdef bint[N_MOVES] is_valid
|
||||||
is_valid[SHIFT] = Shift.is_valid(st, -1)
|
is_valid[SHIFT] = Shift.is_valid(st, 0)
|
||||||
is_valid[REDUCE] = Reduce.is_valid(st, -1)
|
is_valid[REDUCE] = Reduce.is_valid(st, 0)
|
||||||
is_valid[LEFT] = LeftArc.is_valid(st, -1)
|
is_valid[LEFT] = LeftArc.is_valid(st, 0)
|
||||||
is_valid[RIGHT] = RightArc.is_valid(st, -1)
|
is_valid[RIGHT] = RightArc.is_valid(st, 0)
|
||||||
is_valid[BREAK] = Break.is_valid(st, -1)
|
is_valid[BREAK] = Break.is_valid(st, 0)
|
||||||
cdef int i
|
cdef int i
|
||||||
for i in range(self.n_moves):
|
for i in range(self.n_moves):
|
||||||
output[i] = is_valid[self.c[i].move]
|
output[i] = is_valid[self.c[i].move]
|
||||||
|
|
||||||
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||||
StateClass stcls, GoldParse gold) except -1:
|
StateClass stcls, GoldParse gold) except -1:
|
||||||
cdef int i, move, label
|
cdef int i, move
|
||||||
|
cdef attr_t label
|
||||||
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
cdef label_cost_func_t[N_MOVES] label_cost_funcs
|
||||||
cdef move_cost_func_t[N_MOVES] move_cost_funcs
|
cdef move_cost_func_t[N_MOVES] move_cost_funcs
|
||||||
cdef weight_t[N_MOVES] move_costs
|
cdef weight_t[N_MOVES] move_costs
|
||||||
|
|
Loading…
Reference in New Issue
Block a user