mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-13 18:56:36 +03:00
* Fix NER oracle
This commit is contained in:
parent
c04e6ebca6
commit
0114e7600d
|
@ -1,6 +1,7 @@
|
||||||
from .transition_system cimport TransitionSystem
|
from .transition_system cimport TransitionSystem
|
||||||
from .transition_system cimport Transition
|
from .transition_system cimport Transition
|
||||||
from ._state cimport State
|
from ._state cimport State
|
||||||
|
from ..gold cimport GoldParseC
|
||||||
|
|
||||||
|
|
||||||
cdef class BiluoPushDown(TransitionSystem):
|
cdef class BiluoPushDown(TransitionSystem):
|
||||||
|
|
|
@ -186,8 +186,13 @@ cdef class Begin:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
if not Begin.is_valid(s, label):
|
||||||
|
return 9000
|
||||||
cdef int g_act = gold.ner[s.i].move
|
cdef int g_act = gold.ner[s.i].move
|
||||||
cdef int g_tag = gold.ner[s.i].label
|
cdef int g_tag = gold.ner[s.i].label
|
||||||
|
|
||||||
|
if g_act == MISSING:
|
||||||
|
return 0
|
||||||
if g_act == BEGIN:
|
if g_act == BEGIN:
|
||||||
# B, Gold B --> Label match
|
# B, Gold B --> Label match
|
||||||
return label != g_tag
|
return label != g_tag
|
||||||
|
@ -211,12 +216,17 @@ cdef class In:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
if not In.is_valid(s, label):
|
||||||
|
return 9000
|
||||||
|
move = IN
|
||||||
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
|
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
|
||||||
cdef int g_act = gold.ner[s.i].move
|
cdef int g_act = gold.ner[s.i].move
|
||||||
cdef int g_tag = gold.ner[s.i].label
|
cdef int g_tag = gold.ner[s.i].label
|
||||||
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
|
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
|
||||||
|
|
||||||
if g_act == BEGIN:
|
if g_act == MISSING:
|
||||||
|
return 0
|
||||||
|
elif g_act == BEGIN:
|
||||||
# I, Gold B --> True (P of bad open entity sunk, R of this entity sunk)
|
# I, Gold B --> True (P of bad open entity sunk, R of this entity sunk)
|
||||||
return 0
|
return 0
|
||||||
elif g_act == IN:
|
elif g_act == IN:
|
||||||
|
@ -231,6 +241,8 @@ cdef class In:
|
||||||
elif g_act == UNIT:
|
elif g_act == UNIT:
|
||||||
# I, Gold U --> True iff next tag == O
|
# I, Gold U --> True iff next tag == O
|
||||||
return next_act != OUT
|
return next_act != OUT
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@ -248,10 +260,16 @@ cdef class Last:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
if not Last.is_valid(s, label):
|
||||||
|
return 9000
|
||||||
|
move = LAST
|
||||||
|
|
||||||
cdef int g_act = gold.ner[s.i].move
|
cdef int g_act = gold.ner[s.i].move
|
||||||
cdef int g_tag = gold.ner[s.i].label
|
cdef int g_tag = gold.ner[s.i].label
|
||||||
|
|
||||||
if g_act == BEGIN:
|
if g_act == MISSING:
|
||||||
|
return 0
|
||||||
|
elif g_act == BEGIN:
|
||||||
# L, Gold B --> True
|
# L, Gold B --> True
|
||||||
return 0
|
return 0
|
||||||
elif g_act == IN:
|
elif g_act == IN:
|
||||||
|
@ -266,6 +284,8 @@ cdef class Last:
|
||||||
elif g_act == UNIT:
|
elif g_act == UNIT:
|
||||||
# L, Gold U --> True
|
# L, Gold U --> True
|
||||||
return 0
|
return 0
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
|
||||||
cdef class Unit:
|
cdef class Unit:
|
||||||
|
@ -286,10 +306,14 @@ cdef class Unit:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
if not Unit.is_valid(s, label):
|
||||||
|
return 9000
|
||||||
cdef int g_act = gold.ner[s.i].move
|
cdef int g_act = gold.ner[s.i].move
|
||||||
cdef int g_tag = gold.ner[s.i].label
|
cdef int g_tag = gold.ner[s.i].label
|
||||||
|
|
||||||
if g_act == UNIT:
|
if g_act == MISSING:
|
||||||
|
return 0
|
||||||
|
elif g_act == UNIT:
|
||||||
# U, Gold U --> True iff tag match
|
# U, Gold U --> True iff tag match
|
||||||
return label != g_tag
|
return label != g_tag
|
||||||
else:
|
else:
|
||||||
|
@ -312,10 +336,16 @@ cdef class Out:
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
|
||||||
|
if not Out.is_valid(s, label):
|
||||||
|
return 9000
|
||||||
|
|
||||||
cdef int g_act = gold.ner[s.i].move
|
cdef int g_act = gold.ner[s.i].move
|
||||||
cdef int g_tag = gold.ner[s.i].label
|
cdef int g_tag = gold.ner[s.i].label
|
||||||
|
|
||||||
if g_act == BEGIN:
|
|
||||||
|
if g_act == MISSING:
|
||||||
|
return 0
|
||||||
|
elif g_act == BEGIN:
|
||||||
# O, Gold B --> False
|
# O, Gold B --> False
|
||||||
return 1
|
return 1
|
||||||
elif g_act == IN:
|
elif g_act == IN:
|
||||||
|
@ -330,6 +360,93 @@ cdef class Out:
|
||||||
elif g_act == UNIT:
|
elif g_act == UNIT:
|
||||||
# O, Gold U --> False
|
# O, Gold U --> False
|
||||||
return 1
|
return 1
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
# TODO: Move this logic into the cost functions
|
||||||
|
cdef int _get_cost(int move, int label, const State* s, const GoldParseC* gold) except -1:
|
||||||
|
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
|
||||||
|
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
|
||||||
|
cdef bint is_gold = _is_gold(move, label, gold.ner[s.i].move,
|
||||||
|
gold.ner[s.i].label, next_act, is_sunk)
|
||||||
|
return not is_gold
|
||||||
|
|
||||||
|
|
||||||
|
cdef bint _is_gold(int act, int tag, int g_act, int g_tag,
|
||||||
|
int next_act, bint is_sunk):
|
||||||
|
if g_act == MISSING:
|
||||||
|
return True
|
||||||
|
if act == BEGIN:
|
||||||
|
if g_act == BEGIN:
|
||||||
|
# B, Gold B --> Label match
|
||||||
|
return tag == g_tag
|
||||||
|
else:
|
||||||
|
# B, Gold I --> False (P)
|
||||||
|
# B, Gold L --> False (P)
|
||||||
|
# B, Gold O --> False (P)
|
||||||
|
# B, Gold U --> False (P)
|
||||||
|
return False
|
||||||
|
elif act == IN:
|
||||||
|
if g_act == BEGIN:
|
||||||
|
# I, Gold B --> True (P of bad open entity sunk, R of this entity sunk)
|
||||||
|
return True
|
||||||
|
elif g_act == IN:
|
||||||
|
# I, Gold I --> True (label forced by prev, if mismatch, P and R both sunk)
|
||||||
|
return True
|
||||||
|
elif g_act == LAST:
|
||||||
|
# I, Gold L --> True iff this entity sunk and next tag == O
|
||||||
|
return is_sunk and (next_act == OUT or next_act == MISSING)
|
||||||
|
elif g_act == OUT:
|
||||||
|
# I, Gold O --> True iff next tag == O
|
||||||
|
return next_act == OUT or next_act == MISSING
|
||||||
|
elif g_act == UNIT:
|
||||||
|
# I, Gold U --> True iff next tag == O
|
||||||
|
return next_act == OUT
|
||||||
|
elif act == LAST:
|
||||||
|
if g_act == BEGIN:
|
||||||
|
# L, Gold B --> True
|
||||||
|
return True
|
||||||
|
elif g_act == IN:
|
||||||
|
# L, Gold I --> True iff this entity sunk
|
||||||
|
return is_sunk
|
||||||
|
elif g_act == LAST:
|
||||||
|
# L, Gold L --> True
|
||||||
|
return True
|
||||||
|
elif g_act == OUT:
|
||||||
|
# L, Gold O --> True
|
||||||
|
return True
|
||||||
|
elif g_act == UNIT:
|
||||||
|
# L, Gold U --> True
|
||||||
|
return True
|
||||||
|
elif act == OUT:
|
||||||
|
if g_act == BEGIN:
|
||||||
|
# O, Gold B --> False
|
||||||
|
return False
|
||||||
|
elif g_act == IN:
|
||||||
|
# O, Gold I --> True
|
||||||
|
return True
|
||||||
|
elif g_act == LAST:
|
||||||
|
# O, Gold L --> True
|
||||||
|
return True
|
||||||
|
elif g_act == OUT:
|
||||||
|
# O, Gold O --> True
|
||||||
|
return True
|
||||||
|
elif g_act == UNIT:
|
||||||
|
# O, Gold U --> False
|
||||||
|
return False
|
||||||
|
elif act == UNIT:
|
||||||
|
if g_act == UNIT:
|
||||||
|
# U, Gold U --> True iff tag match
|
||||||
|
return tag == g_tag
|
||||||
|
else:
|
||||||
|
# U, Gold B --> False
|
||||||
|
# U, Gold I --> False
|
||||||
|
# U, Gold L --> False
|
||||||
|
# U, Gold O --> False
|
||||||
|
return False
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class OracleError(Exception):
|
class OracleError(Exception):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user