Filter cut states depending on whether its actions have cost

This commit is contained in:
Daniël de Kok 2023-02-24 15:55:00 +01:00
parent 0f3b23420b
commit c3defdf66e
5 changed files with 56 additions and 24 deletions

View File

@ -2,6 +2,8 @@
from cymem.cymem cimport Pool, Address
from libc.stdint cimport int32_t
from libcpp.vector cimport vector
import numpy
cimport numpy as np
from collections import defaultdict, Counter
@ -16,6 +18,7 @@ from .stateclass cimport StateClass
from ._state cimport StateC, ArcC
from ...errors import Errors
from .search cimport Beam
from .transition_system import OracleSequence
cdef weight_t MIN_SCORE = -90000
cdef attr_t SUBTOK_LABEL = hash_string('subtok')
@ -834,19 +837,22 @@ cdef class ArcEager(TransitionSystem):
cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
cdef np.ndarray costs
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
history = []
cost_matrix = []
debug_log = []
failed = False
while not state.is_final():
costs = numpy.zeros((self.n_moves,), dtype="f")
try:
self.set_costs(is_valid, costs, state.c, gold)
self.set_costs(is_valid, <float*>costs.data, state.c, gold)
except ValueError:
failed = True
break
min_cost = min(costs[i] for i in range(self.n_moves))
cost_matrix.append(costs)
min_cost = costs.min()
for i in range(self.n_moves):
if is_valid[i] and costs[i] <= min_cost:
action = self.c[i]
@ -901,4 +907,4 @@ cdef class ArcEager(TransitionSystem):
print("Stack", [example.x[i] for i in state.stack])
print("Buffer", [example.x[i] for i in state.queue])
raise ValueError(Errors.E024)
return history
return OracleSequence(history, numpy.array(cost_matrix))

View File

@ -1,8 +1,11 @@
# cython: infer_types=True
from __future__ import print_function
from typing import List, Optional
from cymem.cymem cimport Pool
from libc.stdlib cimport calloc, free
from libcpp.vector cimport vector
import numpy
cimport numpy as np
from collections import Counter
import srsly
@ -25,6 +28,22 @@ class OracleError(Exception):
pass
class OracleSequence:
actions: List[int]
cost_matrix: numpy.ndarray
def __init__(self, actions: List[int], cost_matrix: numpy.ndarray):
self.actions = actions
self.cost_matrix = cost_matrix
__slots = ["actions", "cost_matrix"]
def has_cost(self, begin: int=0, end: Optional[int]=None) -> bool:
if end is None:
end = self.cost_matrix.shape[0]
return numpy.count_nonzero(self.cost_matrix[begin:end])
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
cdef StateC* st = new StateC(<const TokenC*>tokens, length)
return <void*>st
@ -87,10 +106,10 @@ cdef class TransitionSystem:
def get_oracle_sequence(self, Example example, _debug=False):
if not self.has_gold(example):
return []
return OracleSequence([], numpy.zeros(0, self.n_moves))
states, golds, _ = self.init_gold_batch([example])
if not states:
return []
return OracleSequence([], numpy.zeros(0, self.n_moves))
state = states[0]
gold = golds[0]
if _debug:
@ -100,17 +119,20 @@ cdef class TransitionSystem:
def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
if state.is_final():
return []
return OracleSequence([], numpy.zeros(0, self.n_moves))
cdef Pool mem = Pool()
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
assert self.n_moves > 0
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
cdef np.ndarray costs
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
history = []
cost_matrix = []
debug_log = []
while not state.is_final():
self.set_costs(is_valid, costs, state.c, gold)
costs = numpy.zeros((self.n_moves,), dtype="f")
self.set_costs(is_valid, <float*>costs.data, state.c, gold)
cost_matrix.append(costs)
for i in range(self.n_moves):
if is_valid[i] and costs[i] <= 0:
action = self.c[i]
@ -147,7 +169,7 @@ cdef class TransitionSystem:
)))
print("\n".join(debug_log))
raise ValueError(Errors.E024)
return history
return OracleSequence(history, numpy.array(cost_matrix))
def apply_transition(self, StateClass state, name):
if not self.is_valid(state, name):

View File

@ -715,22 +715,26 @@ class Parser(TrainablePipe):
states.append(state)
golds.append(gold)
else:
oracle_actions = moves.get_oracle_sequence_from_state(
oracle_seq = moves.get_oracle_sequence_from_state(
state.copy(), gold)
to_cut.append((eg, state, gold, oracle_actions))
to_cut.append((eg, state, gold, oracle_seq))
if not to_cut:
return states, golds, 0
cdef int clas
for eg, state, gold, oracle_actions in to_cut:
for i in range(0, len(oracle_actions), max_length):
for eg, state, gold, oracle_seq in to_cut:
for i in range(0, len(oracle_seq.actions), max_length):
start_state = state.copy()
for clas in oracle_actions[i:i+max_length]:
for clas in oracle_seq.actions[i:i+max_length]:
action = moves.c[clas]
action.do(state.c, action.label)
if state.is_final():
break
states.append(start_state)
golds.append(gold)
# If all actions along the history are zero-cost actions, there
# is nothing to learn from this state in max_length stepss, so
# we skip it.
if oracle_seq.has_cost(i, i+max_length):
states.append(start_state)
golds.append(gold)
if state.is_final():
break
return states, golds, max_length

View File

@ -168,7 +168,7 @@ def test_get_oracle_actions():
example = Example.from_dict(
doc, {"words": words, "tags": tags, "heads": heads, "deps": deps}
)
parser.moves.get_oracle_sequence(example)
parser.moves.get_oracle_sequence(example).actions
def test_oracle_dev_sentence(vocab, arc_eager):
@ -254,7 +254,7 @@ def test_oracle_dev_sentence(vocab, arc_eager):
arc_eager.add_action(3, dep) # Right
doc = Doc(Vocab(), words=gold_words)
example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps})
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False).actions
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
assert ae_oracle_actions == expected_transitions
@ -288,6 +288,6 @@ def test_oracle_bad_tokenization(vocab, arc_eager):
reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"]
)
example = Example(predicted=predicted, reference=reference)
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False).actions
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
assert ae_oracle_actions

View File

@ -231,7 +231,7 @@ def test_issue4313():
def test_get_oracle_moves(tsys, doc, entity_annots):
example = Example.from_dict(doc, {"entities": entity_annots})
act_classes = tsys.get_oracle_sequence(example, _debug=False)
act_classes = tsys.get_oracle_sequence(example, _debug=False).actions
names = [tsys.get_class_name(act) for act in act_classes]
assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
@ -250,7 +250,7 @@ def test_negative_samples_two_word_input(tsys, vocab, neg_key):
Span(example.y, 0, 1, label="O"),
Span(example.y, 0, 2, label="PERSON"),
]
act_classes = tsys.get_oracle_sequence(example)
act_classes = tsys.get_oracle_sequence(example).actions
names = [tsys.get_class_name(act) for act in act_classes]
assert names
assert names[0] != "O"
@ -270,7 +270,7 @@ def test_negative_samples_three_word_input(tsys, vocab, neg_key):
Span(example.y, 0, 1, label="O"),
Span(example.y, 0, 2, label="PERSON"),
]
act_classes = tsys.get_oracle_sequence(example)
act_classes = tsys.get_oracle_sequence(example).actions
names = [tsys.get_class_name(act) for act in act_classes]
assert names
assert names[0] != "O"
@ -289,7 +289,7 @@ def test_negative_samples_U_entity(tsys, vocab, neg_key):
Span(example.y, 0, 1, label="O"),
Span(example.y, 0, 1, label="PERSON"),
]
act_classes = tsys.get_oracle_sequence(example)
act_classes = tsys.get_oracle_sequence(example).actions
names = [tsys.get_class_name(act) for act in act_classes]
assert names
assert names[0] != "O"