mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 10:16:27 +03:00
* Fix bug in label assignment: ensure null-label transitions receive the label 0
This commit is contained in:
parent
ee927fbbb4
commit
f729164c01
|
@ -42,8 +42,8 @@ cdef get_cost_func_t[N_MOVES] get_cost_funcs
|
||||||
cdef class ArcEager(TransitionSystem):
|
cdef class ArcEager(TransitionSystem):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_labels(cls, gold_parses):
|
def get_labels(cls, gold_parses):
|
||||||
move_labels = {SHIFT: {'ROOT': True}, REDUCE: {'ROOT': True}, RIGHT: {},
|
move_labels = {SHIFT: {'': True}, REDUCE: {'': True}, RIGHT: {},
|
||||||
LEFT: {}, BREAK: {'ROOT': True}}
|
LEFT: {}, BREAK: {'': True}}
|
||||||
for raw_text, segmented, (ids, tags, heads, labels, iob) in gold_parses:
|
for raw_text, segmented, (ids, tags, heads, labels, iob) in gold_parses:
|
||||||
for i, (head, label) in enumerate(zip(heads, labels)):
|
for i, (head, label) in enumerate(zip(heads, labels)):
|
||||||
if label != 'ROOT':
|
if label != 'ROOT':
|
||||||
|
|
|
@ -70,8 +70,8 @@ cdef int _is_valid(int act, int label, const State* s) except -1:
|
||||||
cdef class BiluoPushDown(TransitionSystem):
|
cdef class BiluoPushDown(TransitionSystem):
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_labels(cls, gold_tuples):
|
def get_labels(cls, gold_tuples):
|
||||||
move_labels = {MISSING: {'ROOT': True}, BEGIN: {}, IN: {}, LAST: {}, UNIT: {},
|
move_labels = {MISSING: {'': True}, BEGIN: {}, IN: {}, LAST: {}, UNIT: {},
|
||||||
OUT: {'ROOT': True}}
|
OUT: {'': True}}
|
||||||
moves = ('M', 'B', 'I', 'L', 'U')
|
moves = ('M', 'B', 'I', 'L', 'U')
|
||||||
for (raw_text, toks, (ids, tags, heads, labels, biluo)) in gold_tuples:
|
for (raw_text, toks, (ids, tags, heads, labels, biluo)) in gold_tuples:
|
||||||
for i, ner_tag in enumerate(biluo):
|
for i, ner_tag in enumerate(biluo):
|
||||||
|
@ -99,7 +99,7 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
label = 0
|
label = 0
|
||||||
elif '-' in name:
|
elif '-' in name:
|
||||||
move_str, label_str = name.split('-', 1)
|
move_str, label_str = name.split('-', 1)
|
||||||
label = self.label_ids[label_str]
|
label = self.strings[label_str]
|
||||||
else:
|
else:
|
||||||
move_str = name
|
move_str = name
|
||||||
label = 0
|
label = 0
|
||||||
|
|
|
@ -21,7 +21,7 @@ cdef class TransitionSystem:
|
||||||
self.strings = string_table
|
self.strings = string_table
|
||||||
for action, label_strs in sorted(labels_by_action.items()):
|
for action, label_strs in sorted(labels_by_action.items()):
|
||||||
for label_str in sorted(label_strs):
|
for label_str in sorted(label_strs):
|
||||||
label_id = self.strings[unicode(label_str)]
|
label_id = self.strings[unicode(label_str)] if label_str else 0
|
||||||
moves[i] = self.init_transition(i, int(action), label_id)
|
moves[i] = self.init_transition(i, int(action), label_id)
|
||||||
i += 1
|
i += 1
|
||||||
self.c = moves
|
self.c = moves
|
||||||
|
|
Loading…
Reference in New Issue
Block a user