Filter cut states depending on whether its actions have cost

This commit is contained in:
Daniël de Kok 2023-02-24 15:55:00 +01:00
parent 0f3b23420b
commit c3defdf66e
5 changed files with 56 additions and 24 deletions

View File

@ -2,6 +2,8 @@
from cymem.cymem cimport Pool, Address from cymem.cymem cimport Pool, Address
from libc.stdint cimport int32_t from libc.stdint cimport int32_t
from libcpp.vector cimport vector from libcpp.vector cimport vector
import numpy
cimport numpy as np
from collections import defaultdict, Counter from collections import defaultdict, Counter
@ -16,6 +18,7 @@ from .stateclass cimport StateClass
from ._state cimport StateC, ArcC from ._state cimport StateC, ArcC
from ...errors import Errors from ...errors import Errors
from .search cimport Beam from .search cimport Beam
from .transition_system import OracleSequence
cdef weight_t MIN_SCORE = -90000 cdef weight_t MIN_SCORE = -90000
cdef attr_t SUBTOK_LABEL = hash_string('subtok') cdef attr_t SUBTOK_LABEL = hash_string('subtok')
@ -834,19 +837,22 @@ cdef class ArcEager(TransitionSystem):
cdef Pool mem = Pool() cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0 assert self.n_moves > 0
costs = <float*>mem.alloc(self.n_moves, sizeof(float)) cdef np.ndarray costs
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int)) is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
history = [] history = []
cost_matrix = []
debug_log = [] debug_log = []
failed = False failed = False
while not state.is_final(): while not state.is_final():
costs = numpy.zeros((self.n_moves,), dtype="f")
try: try:
self.set_costs(is_valid, costs, state.c, gold) self.set_costs(is_valid, <float*>costs.data, state.c, gold)
except ValueError: except ValueError:
failed = True failed = True
break break
min_cost = min(costs[i] for i in range(self.n_moves)) cost_matrix.append(costs)
min_cost = costs.min()
for i in range(self.n_moves): for i in range(self.n_moves):
if is_valid[i] and costs[i] <= min_cost: if is_valid[i] and costs[i] <= min_cost:
action = self.c[i] action = self.c[i]
@ -901,4 +907,4 @@ cdef class ArcEager(TransitionSystem):
print("Stack", [example.x[i] for i in state.stack]) print("Stack", [example.x[i] for i in state.stack])
print("Buffer", [example.x[i] for i in state.queue]) print("Buffer", [example.x[i] for i in state.queue])
raise ValueError(Errors.E024) raise ValueError(Errors.E024)
return history return OracleSequence(history, numpy.array(cost_matrix))

View File

@ -1,8 +1,11 @@
# cython: infer_types=True # cython: infer_types=True
from __future__ import print_function from __future__ import print_function
from typing import List, Optional
from cymem.cymem cimport Pool from cymem.cymem cimport Pool
from libc.stdlib cimport calloc, free from libc.stdlib cimport calloc, free
from libcpp.vector cimport vector from libcpp.vector cimport vector
import numpy
cimport numpy as np
from collections import Counter from collections import Counter
import srsly import srsly
@ -25,6 +28,22 @@ class OracleError(Exception):
pass pass
class OracleSequence:
actions: List[int]
cost_matrix: numpy.ndarray
def __init__(self, actions: List[int], cost_matrix: numpy.ndarray):
self.actions = actions
self.cost_matrix = cost_matrix
__slots = ["actions", "cost_matrix"]
def has_cost(self, begin: int=0, end: Optional[int]=None) -> bool:
if end is None:
end = self.cost_matrix.shape[0]
return numpy.count_nonzero(self.cost_matrix[begin:end])
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef StateC* st = new StateC(<const TokenC*>tokens, length) cdef StateC* st = new StateC(<const TokenC*>tokens, length)
return <void*>st return <void*>st
@ -87,10 +106,10 @@ cdef class TransitionSystem:
def get_oracle_sequence(self, Example example, _debug=False): def get_oracle_sequence(self, Example example, _debug=False):
if not self.has_gold(example): if not self.has_gold(example):
return [] return OracleSequence([], numpy.zeros(0, self.n_moves))
states, golds, _ = self.init_gold_batch([example]) states, golds, _ = self.init_gold_batch([example])
if not states: if not states:
return [] return OracleSequence([], numpy.zeros(0, self.n_moves))
state = states[0] state = states[0]
gold = golds[0] gold = golds[0]
if _debug: if _debug:
@ -100,17 +119,20 @@ cdef class TransitionSystem:
def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None): def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
if state.is_final(): if state.is_final():
return [] return OracleSequence([], numpy.zeros(0, self.n_moves))
cdef Pool mem = Pool() cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0 assert self.n_moves > 0
costs = <float*>mem.alloc(self.n_moves, sizeof(float)) cdef np.ndarray costs
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int)) is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
history = [] history = []
cost_matrix = []
debug_log = [] debug_log = []
while not state.is_final(): while not state.is_final():
self.set_costs(is_valid, costs, state.c, gold) costs = numpy.zeros((self.n_moves,), dtype="f")
self.set_costs(is_valid, <float*>costs.data, state.c, gold)
cost_matrix.append(costs)
for i in range(self.n_moves): for i in range(self.n_moves):
if is_valid[i] and costs[i] <= 0: if is_valid[i] and costs[i] <= 0:
action = self.c[i] action = self.c[i]
@ -147,7 +169,7 @@ cdef class TransitionSystem:
))) )))
print("\n".join(debug_log)) print("\n".join(debug_log))
raise ValueError(Errors.E024) raise ValueError(Errors.E024)
return history return OracleSequence(history, numpy.array(cost_matrix))
def apply_transition(self, StateClass state, name): def apply_transition(self, StateClass state, name):
if not self.is_valid(state, name): if not self.is_valid(state, name):

View File

@ -715,22 +715,26 @@ class Parser(TrainablePipe):
states.append(state) states.append(state)
golds.append(gold) golds.append(gold)
else: else:
oracle_actions = moves.get_oracle_sequence_from_state( oracle_seq = moves.get_oracle_sequence_from_state(
state.copy(), gold) state.copy(), gold)
to_cut.append((eg, state, gold, oracle_actions)) to_cut.append((eg, state, gold, oracle_seq))
if not to_cut: if not to_cut:
return states, golds, 0 return states, golds, 0
cdef int clas cdef int clas
for eg, state, gold, oracle_actions in to_cut: for eg, state, gold, oracle_seq in to_cut:
for i in range(0, len(oracle_actions), max_length): for i in range(0, len(oracle_seq.actions), max_length):
start_state = state.copy() start_state = state.copy()
for clas in oracle_actions[i:i+max_length]: for clas in oracle_seq.actions[i:i+max_length]:
action = moves.c[clas] action = moves.c[clas]
action.do(state.c, action.label) action.do(state.c, action.label)
if state.is_final(): if state.is_final():
break break
states.append(start_state) # If all actions along the history are zero-cost actions, there
golds.append(gold) # is nothing to learn from this state in max_length stepss, so
# we skip it.
if oracle_seq.has_cost(i, i+max_length):
states.append(start_state)
golds.append(gold)
if state.is_final(): if state.is_final():
break break
return states, golds, max_length return states, golds, max_length

View File

@ -168,7 +168,7 @@ def test_get_oracle_actions():
example = Example.from_dict( example = Example.from_dict(
doc, {"words": words, "tags": tags, "heads": heads, "deps": deps} doc, {"words": words, "tags": tags, "heads": heads, "deps": deps}
) )
parser.moves.get_oracle_sequence(example) parser.moves.get_oracle_sequence(example).actions
def test_oracle_dev_sentence(vocab, arc_eager): def test_oracle_dev_sentence(vocab, arc_eager):
@ -254,7 +254,7 @@ def test_oracle_dev_sentence(vocab, arc_eager):
arc_eager.add_action(3, dep) # Right arc_eager.add_action(3, dep) # Right
doc = Doc(Vocab(), words=gold_words) doc = Doc(Vocab(), words=gold_words)
example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps}) example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps})
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False) ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False).actions
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions] ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
assert ae_oracle_actions == expected_transitions assert ae_oracle_actions == expected_transitions
@ -288,6 +288,6 @@ def test_oracle_bad_tokenization(vocab, arc_eager):
reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"] reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"]
) )
example = Example(predicted=predicted, reference=reference) example = Example(predicted=predicted, reference=reference)
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False) ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False).actions
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions] ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
assert ae_oracle_actions assert ae_oracle_actions

View File

@ -231,7 +231,7 @@ def test_issue4313():
def test_get_oracle_moves(tsys, doc, entity_annots): def test_get_oracle_moves(tsys, doc, entity_annots):
example = Example.from_dict(doc, {"entities": entity_annots}) example = Example.from_dict(doc, {"entities": entity_annots})
act_classes = tsys.get_oracle_sequence(example, _debug=False) act_classes = tsys.get_oracle_sequence(example, _debug=False).actions
names = [tsys.get_class_name(act) for act in act_classes] names = [tsys.get_class_name(act) for act in act_classes]
assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"] assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
@ -250,7 +250,7 @@ def test_negative_samples_two_word_input(tsys, vocab, neg_key):
Span(example.y, 0, 1, label="O"), Span(example.y, 0, 1, label="O"),
Span(example.y, 0, 2, label="PERSON"), Span(example.y, 0, 2, label="PERSON"),
] ]
act_classes = tsys.get_oracle_sequence(example) act_classes = tsys.get_oracle_sequence(example).actions
names = [tsys.get_class_name(act) for act in act_classes] names = [tsys.get_class_name(act) for act in act_classes]
assert names assert names
assert names[0] != "O" assert names[0] != "O"
@ -270,7 +270,7 @@ def test_negative_samples_three_word_input(tsys, vocab, neg_key):
Span(example.y, 0, 1, label="O"), Span(example.y, 0, 1, label="O"),
Span(example.y, 0, 2, label="PERSON"), Span(example.y, 0, 2, label="PERSON"),
] ]
act_classes = tsys.get_oracle_sequence(example) act_classes = tsys.get_oracle_sequence(example).actions
names = [tsys.get_class_name(act) for act in act_classes] names = [tsys.get_class_name(act) for act in act_classes]
assert names assert names
assert names[0] != "O" assert names[0] != "O"
@ -289,7 +289,7 @@ def test_negative_samples_U_entity(tsys, vocab, neg_key):
Span(example.y, 0, 1, label="O"), Span(example.y, 0, 1, label="O"),
Span(example.y, 0, 1, label="PERSON"), Span(example.y, 0, 1, label="PERSON"),
] ]
act_classes = tsys.get_oracle_sequence(example) act_classes = tsys.get_oracle_sequence(example).actions
names = [tsys.get_class_name(act) for act in act_classes] names = [tsys.get_class_name(act) for act in act_classes]
assert names assert names
assert names[0] != "O" assert names[0] != "O"