* Fix NER oracle

This commit is contained in:
Matthew Honnibal 2015-06-05 17:11:26 +02:00
parent c04e6ebca6
commit 0114e7600d
2 changed files with 124 additions and 6 deletions

View File

@ -1,6 +1,7 @@
from .transition_system cimport TransitionSystem from .transition_system cimport TransitionSystem
from .transition_system cimport Transition from .transition_system cimport Transition
from ._state cimport State from ._state cimport State
from ..gold cimport GoldParseC
cdef class BiluoPushDown(TransitionSystem): cdef class BiluoPushDown(TransitionSystem):

View File

@ -186,8 +186,13 @@ cdef class Begin:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Begin.is_valid(s, label):
return 9000
cdef int g_act = gold.ner[s.i].move cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label cdef int g_tag = gold.ner[s.i].label
if g_act == MISSING:
return 0
if g_act == BEGIN: if g_act == BEGIN:
# B, Gold B --> Label match # B, Gold B --> Label match
return label != g_tag return label != g_tag
@ -211,12 +216,17 @@ cdef class In:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not In.is_valid(s, label):
return 9000
move = IN
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
cdef int g_act = gold.ner[s.i].move cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label cdef int g_tag = gold.ner[s.i].label
cdef bint is_sunk = _entity_is_sunk(s, gold.ner) cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
if g_act == BEGIN: if g_act == MISSING:
return 0
elif g_act == BEGIN:
# I, Gold B --> True (P of bad open entity sunk, R of this entity sunk) # I, Gold B --> True (P of bad open entity sunk, R of this entity sunk)
return 0 return 0
elif g_act == IN: elif g_act == IN:
@ -231,6 +241,8 @@ cdef class In:
elif g_act == UNIT: elif g_act == UNIT:
# I, Gold U --> True iff next tag == O # I, Gold U --> True iff next tag == O
return next_act != OUT return next_act != OUT
else:
return 1
@ -248,10 +260,16 @@ cdef class Last:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Last.is_valid(s, label):
return 9000
move = LAST
cdef int g_act = gold.ner[s.i].move cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label cdef int g_tag = gold.ner[s.i].label
if g_act == BEGIN: if g_act == MISSING:
return 0
elif g_act == BEGIN:
# L, Gold B --> True # L, Gold B --> True
return 0 return 0
elif g_act == IN: elif g_act == IN:
@ -266,6 +284,8 @@ cdef class Last:
elif g_act == UNIT: elif g_act == UNIT:
# L, Gold U --> True # L, Gold U --> True
return 0 return 0
else:
return 1
cdef class Unit: cdef class Unit:
@ -286,10 +306,14 @@ cdef class Unit:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Unit.is_valid(s, label):
return 9000
cdef int g_act = gold.ner[s.i].move cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label cdef int g_tag = gold.ner[s.i].label
if g_act == UNIT: if g_act == MISSING:
return 0
elif g_act == UNIT:
# U, Gold U --> True iff tag match # U, Gold U --> True iff tag match
return label != g_tag return label != g_tag
else: else:
@ -312,10 +336,16 @@ cdef class Out:
@staticmethod @staticmethod
cdef int cost(const State* s, const GoldParseC* gold, int label) except -1: cdef int cost(const State* s, const GoldParseC* gold, int label) except -1:
if not Out.is_valid(s, label):
return 9000
cdef int g_act = gold.ner[s.i].move cdef int g_act = gold.ner[s.i].move
cdef int g_tag = gold.ner[s.i].label cdef int g_tag = gold.ner[s.i].label
if g_act == BEGIN:
if g_act == MISSING:
return 0
elif g_act == BEGIN:
# O, Gold B --> False # O, Gold B --> False
return 1 return 1
elif g_act == IN: elif g_act == IN:
@ -330,6 +360,93 @@ cdef class Out:
elif g_act == UNIT: elif g_act == UNIT:
# O, Gold U --> False # O, Gold U --> False
return 1 return 1
else:
return 1
"""
# TODO: Move this logic into the cost functions
cdef int _get_cost(int move, int label, const State* s, const GoldParseC* gold) except -1:
cdef bint is_sunk = _entity_is_sunk(s, gold.ner)
cdef int next_act = gold.ner[s.i+1].move if s.i < s.sent_len else OUT
cdef bint is_gold = _is_gold(move, label, gold.ner[s.i].move,
gold.ner[s.i].label, next_act, is_sunk)
return not is_gold
cdef bint _is_gold(int act, int tag, int g_act, int g_tag,
int next_act, bint is_sunk):
if g_act == MISSING:
return True
if act == BEGIN:
if g_act == BEGIN:
# B, Gold B --> Label match
return tag == g_tag
else:
# B, Gold I --> False (P)
# B, Gold L --> False (P)
# B, Gold O --> False (P)
# B, Gold U --> False (P)
return False
elif act == IN:
if g_act == BEGIN:
# I, Gold B --> True (P of bad open entity sunk, R of this entity sunk)
return True
elif g_act == IN:
# I, Gold I --> True (label forced by prev, if mismatch, P and R both sunk)
return True
elif g_act == LAST:
# I, Gold L --> True iff this entity sunk and next tag == O
return is_sunk and (next_act == OUT or next_act == MISSING)
elif g_act == OUT:
# I, Gold O --> True iff next tag == O
return next_act == OUT or next_act == MISSING
elif g_act == UNIT:
# I, Gold U --> True iff next tag == O
return next_act == OUT
elif act == LAST:
if g_act == BEGIN:
# L, Gold B --> True
return True
elif g_act == IN:
# L, Gold I --> True iff this entity sunk
return is_sunk
elif g_act == LAST:
# L, Gold L --> True
return True
elif g_act == OUT:
# L, Gold O --> True
return True
elif g_act == UNIT:
# L, Gold U --> True
return True
elif act == OUT:
if g_act == BEGIN:
# O, Gold B --> False
return False
elif g_act == IN:
# O, Gold I --> True
return True
elif g_act == LAST:
# O, Gold L --> True
return True
elif g_act == OUT:
# O, Gold O --> True
return True
elif g_act == UNIT:
# O, Gold U --> False
return False
elif act == UNIT:
if g_act == UNIT:
# U, Gold U --> True iff tag match
return tag == g_tag
else:
# U, Gold B --> False
# U, Gold I --> False
# U, Gold L --> False
# U, Gold O --> False
return False
"""
class OracleError(Exception): class OracleError(Exception):