mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 17:10:36 +03:00
Add missing costs to NER oracle
This commit is contained in:
parent
f73fa77bb9
commit
ad50c8baca
|
@ -49,6 +49,10 @@ cdef class BiluoGold:
|
||||||
self.mem = Pool()
|
self.mem = Pool()
|
||||||
self.c = create_gold_state(self.mem, moves, stcls, example)
|
self.c = create_gold_state(self.mem, moves, stcls, example)
|
||||||
|
|
||||||
|
def update(self, StateClass stcls):
|
||||||
|
update_gold_state(&self.c, stcls)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
cdef GoldNERStateC create_gold_state(
|
cdef GoldNERStateC create_gold_state(
|
||||||
Pool mem,
|
Pool mem,
|
||||||
|
@ -58,30 +62,15 @@ cdef GoldNERStateC create_gold_state(
|
||||||
) except *:
|
) except *:
|
||||||
cdef GoldNERStateC gs
|
cdef GoldNERStateC gs
|
||||||
gs.ner = <Transition*>mem.alloc(example.x.length, sizeof(Transition))
|
gs.ner = <Transition*>mem.alloc(example.x.length, sizeof(Transition))
|
||||||
ner_tags = get_aligned_ner(example)
|
ner_tags = example.get_aligned_ner()
|
||||||
for i, ner_tag in enumerate(ner_tags):
|
for i, ner_tag in enumerate(ner_tags):
|
||||||
gs.ner[i] = moves.lookup_transition(ner_tag)
|
gs.ner[i] = moves.lookup_transition(ner_tag)
|
||||||
return gs
|
return gs
|
||||||
|
|
||||||
|
|
||||||
def get_aligned_ner(Example example):
|
cdef void update_gold_state(GoldNERStateC* gs, StateClass stcls) except *:
|
||||||
cand_to_gold = example.alignment.cand_to_gold
|
# We don't need to update each time, unlike the parser.
|
||||||
i2j_multi = example.alignment.i2j_multi
|
|
||||||
y_tags = biluo_tags_from_offsets(
|
|
||||||
example.y,
|
|
||||||
[(e.start_char, e.end_char, e.label_) for e in example.y.ents]
|
|
||||||
)
|
|
||||||
x_tags = [None] * example.x.length
|
|
||||||
for i in range(example.x.length):
|
|
||||||
if example.x[i].is_space:
|
|
||||||
pass
|
pass
|
||||||
elif cand_to_gold[i] is not None:
|
|
||||||
x_tags[i] = y_tags[cand_to_gold[i]]
|
|
||||||
elif i in i2j_multi:
|
|
||||||
# Assign O/- for many-to-one O/- NER tags
|
|
||||||
if y_tags[i2j_multi[i]] in ("O", "-"):
|
|
||||||
x_tags[i] = y_tags[i2j_multi[i]]
|
|
||||||
return y_tags
|
|
||||||
|
|
||||||
|
|
||||||
cdef do_func_t[N_MOVES] do_funcs
|
cdef do_func_t[N_MOVES] do_funcs
|
||||||
|
@ -120,11 +109,12 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
for action in (BEGIN, IN, LAST, UNIT):
|
for action in (BEGIN, IN, LAST, UNIT):
|
||||||
actions[action][entity_type] = 1
|
actions[action][entity_type] = 1
|
||||||
moves = ('M', 'B', 'I', 'L', 'U')
|
moves = ('M', 'B', 'I', 'L', 'U')
|
||||||
for example in kwargs.get('gold_parses', []):
|
for example in kwargs.get('examples', []):
|
||||||
for ner_tag in example.get_aligned("ENT_TYPE", as_string=True):
|
for token in example.y:
|
||||||
if ner_tag != 'O' and ner_tag != '-':
|
ent_type = token.ent_type_
|
||||||
|
if ent_type:
|
||||||
for action in (BEGIN, IN, LAST, UNIT):
|
for action in (BEGIN, IN, LAST, UNIT):
|
||||||
actions[action][ner_tag] += 1
|
actions[action][ent_type] += 1
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -247,6 +237,37 @@ cdef class BiluoPushDown(TransitionSystem):
|
||||||
self.add_action(UNIT, st._sent[i].ent_type)
|
self.add_action(UNIT, st._sent[i].ent_type)
|
||||||
self.add_action(LAST, st._sent[i].ent_type)
|
self.add_action(LAST, st._sent[i].ent_type)
|
||||||
|
|
||||||
|
def get_cost(self, StateClass stcls, gold, int i):
|
||||||
|
if not isinstance(gold, BiluoGold):
|
||||||
|
raise TypeError("Expected BiluoGold")
|
||||||
|
cdef BiluoGold gold_ = gold
|
||||||
|
gold_state = gold_.c
|
||||||
|
n_gold = 0
|
||||||
|
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||||
|
cost = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
||||||
|
else:
|
||||||
|
cost = 9000
|
||||||
|
return cost
|
||||||
|
|
||||||
|
cdef int set_costs(self, int* is_valid, weight_t* costs,
|
||||||
|
StateClass stcls, gold) except -1:
|
||||||
|
if not isinstance(gold, BiluoGold):
|
||||||
|
raise TypeError("Expected BiluoGold")
|
||||||
|
cdef BiluoGold gold_ = gold
|
||||||
|
gold_.update(stcls)
|
||||||
|
gold_state = gold_.c
|
||||||
|
n_gold = 0
|
||||||
|
for i in range(self.n_moves):
|
||||||
|
if self.c[i].is_valid(stcls.c, self.c[i].label):
|
||||||
|
is_valid[i] = True
|
||||||
|
costs[i] = self.c[i].get_cost(stcls, &gold_state, self.c[i].label)
|
||||||
|
n_gold += costs[i] <= 0
|
||||||
|
else:
|
||||||
|
is_valid[i] = False
|
||||||
|
costs[i] = 9000
|
||||||
|
if n_gold < 1:
|
||||||
|
raise ValueError
|
||||||
|
|
||||||
|
|
||||||
cdef class Missing:
|
cdef class Missing:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
Loading…
Reference in New Issue
Block a user