mirror of
https://github.com/explosion/spaCy.git
synced 2025-04-22 18:12:00 +03:00
Filter cut states depending on whether its actions have cost
This commit is contained in:
parent
0f3b23420b
commit
c3defdf66e
|
@ -2,6 +2,8 @@
|
|||
from cymem.cymem cimport Pool, Address
|
||||
from libc.stdint cimport int32_t
|
||||
from libcpp.vector cimport vector
|
||||
import numpy
|
||||
cimport numpy as np
|
||||
|
||||
from collections import defaultdict, Counter
|
||||
|
||||
|
@ -16,6 +18,7 @@ from .stateclass cimport StateClass
|
|||
from ._state cimport StateC, ArcC
|
||||
from ...errors import Errors
|
||||
from .search cimport Beam
|
||||
from .transition_system import OracleSequence
|
||||
|
||||
cdef weight_t MIN_SCORE = -90000
|
||||
cdef attr_t SUBTOK_LABEL = hash_string('subtok')
|
||||
|
@ -834,19 +837,22 @@ cdef class ArcEager(TransitionSystem):
|
|||
cdef Pool mem = Pool()
|
||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||
assert self.n_moves > 0
|
||||
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
|
||||
cdef np.ndarray costs
|
||||
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
|
||||
|
||||
history = []
|
||||
cost_matrix = []
|
||||
debug_log = []
|
||||
failed = False
|
||||
while not state.is_final():
|
||||
costs = numpy.zeros((self.n_moves,), dtype="f")
|
||||
try:
|
||||
self.set_costs(is_valid, costs, state.c, gold)
|
||||
self.set_costs(is_valid, <float*>costs.data, state.c, gold)
|
||||
except ValueError:
|
||||
failed = True
|
||||
break
|
||||
min_cost = min(costs[i] for i in range(self.n_moves))
|
||||
cost_matrix.append(costs)
|
||||
min_cost = costs.min()
|
||||
for i in range(self.n_moves):
|
||||
if is_valid[i] and costs[i] <= min_cost:
|
||||
action = self.c[i]
|
||||
|
@ -901,4 +907,4 @@ cdef class ArcEager(TransitionSystem):
|
|||
print("Stack", [example.x[i] for i in state.stack])
|
||||
print("Buffer", [example.x[i] for i in state.queue])
|
||||
raise ValueError(Errors.E024)
|
||||
return history
|
||||
return OracleSequence(history, numpy.array(cost_matrix))
|
||||
|
|
|
@ -1,8 +1,11 @@
|
|||
# cython: infer_types=True
|
||||
from __future__ import print_function
|
||||
from typing import List, Optional
|
||||
from cymem.cymem cimport Pool
|
||||
from libc.stdlib cimport calloc, free
|
||||
from libcpp.vector cimport vector
|
||||
import numpy
|
||||
cimport numpy as np
|
||||
|
||||
from collections import Counter
|
||||
import srsly
|
||||
|
@ -25,6 +28,22 @@ class OracleError(Exception):
|
|||
pass
|
||||
|
||||
|
||||
class OracleSequence:
|
||||
actions: List[int]
|
||||
cost_matrix: numpy.ndarray
|
||||
|
||||
def __init__(self, actions: List[int], cost_matrix: numpy.ndarray):
|
||||
self.actions = actions
|
||||
self.cost_matrix = cost_matrix
|
||||
|
||||
__slots = ["actions", "cost_matrix"]
|
||||
|
||||
def has_cost(self, begin: int=0, end: Optional[int]=None) -> bool:
|
||||
if end is None:
|
||||
end = self.cost_matrix.shape[0]
|
||||
return numpy.count_nonzero(self.cost_matrix[begin:end])
|
||||
|
||||
|
||||
cdef void* _init_state(Pool mem, int length, void* tokens) except NULL:
|
||||
cdef StateC* st = new StateC(<const TokenC*>tokens, length)
|
||||
return <void*>st
|
||||
|
@ -87,10 +106,10 @@ cdef class TransitionSystem:
|
|||
|
||||
def get_oracle_sequence(self, Example example, _debug=False):
|
||||
if not self.has_gold(example):
|
||||
return []
|
||||
return OracleSequence([], numpy.zeros(0, self.n_moves))
|
||||
states, golds, _ = self.init_gold_batch([example])
|
||||
if not states:
|
||||
return []
|
||||
return OracleSequence([], numpy.zeros(0, self.n_moves))
|
||||
state = states[0]
|
||||
gold = golds[0]
|
||||
if _debug:
|
||||
|
@ -100,17 +119,20 @@ cdef class TransitionSystem:
|
|||
|
||||
def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None):
|
||||
if state.is_final():
|
||||
return []
|
||||
return OracleSequence([], numpy.zeros(0, self.n_moves))
|
||||
cdef Pool mem = Pool()
|
||||
# n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc
|
||||
assert self.n_moves > 0
|
||||
costs = <float*>mem.alloc(self.n_moves, sizeof(float))
|
||||
cdef np.ndarray costs
|
||||
is_valid = <int*>mem.alloc(self.n_moves, sizeof(int))
|
||||
|
||||
history = []
|
||||
cost_matrix = []
|
||||
debug_log = []
|
||||
while not state.is_final():
|
||||
self.set_costs(is_valid, costs, state.c, gold)
|
||||
costs = numpy.zeros((self.n_moves,), dtype="f")
|
||||
self.set_costs(is_valid, <float*>costs.data, state.c, gold)
|
||||
cost_matrix.append(costs)
|
||||
for i in range(self.n_moves):
|
||||
if is_valid[i] and costs[i] <= 0:
|
||||
action = self.c[i]
|
||||
|
@ -147,7 +169,7 @@ cdef class TransitionSystem:
|
|||
)))
|
||||
print("\n".join(debug_log))
|
||||
raise ValueError(Errors.E024)
|
||||
return history
|
||||
return OracleSequence(history, numpy.array(cost_matrix))
|
||||
|
||||
def apply_transition(self, StateClass state, name):
|
||||
if not self.is_valid(state, name):
|
||||
|
|
|
@ -715,22 +715,26 @@ class Parser(TrainablePipe):
|
|||
states.append(state)
|
||||
golds.append(gold)
|
||||
else:
|
||||
oracle_actions = moves.get_oracle_sequence_from_state(
|
||||
oracle_seq = moves.get_oracle_sequence_from_state(
|
||||
state.copy(), gold)
|
||||
to_cut.append((eg, state, gold, oracle_actions))
|
||||
to_cut.append((eg, state, gold, oracle_seq))
|
||||
if not to_cut:
|
||||
return states, golds, 0
|
||||
cdef int clas
|
||||
for eg, state, gold, oracle_actions in to_cut:
|
||||
for i in range(0, len(oracle_actions), max_length):
|
||||
for eg, state, gold, oracle_seq in to_cut:
|
||||
for i in range(0, len(oracle_seq.actions), max_length):
|
||||
start_state = state.copy()
|
||||
for clas in oracle_actions[i:i+max_length]:
|
||||
for clas in oracle_seq.actions[i:i+max_length]:
|
||||
action = moves.c[clas]
|
||||
action.do(state.c, action.label)
|
||||
if state.is_final():
|
||||
break
|
||||
states.append(start_state)
|
||||
golds.append(gold)
|
||||
# If all actions along the history are zero-cost actions, there
|
||||
# is nothing to learn from this state in max_length stepss, so
|
||||
# we skip it.
|
||||
if oracle_seq.has_cost(i, i+max_length):
|
||||
states.append(start_state)
|
||||
golds.append(gold)
|
||||
if state.is_final():
|
||||
break
|
||||
return states, golds, max_length
|
||||
|
|
|
@ -168,7 +168,7 @@ def test_get_oracle_actions():
|
|||
example = Example.from_dict(
|
||||
doc, {"words": words, "tags": tags, "heads": heads, "deps": deps}
|
||||
)
|
||||
parser.moves.get_oracle_sequence(example)
|
||||
parser.moves.get_oracle_sequence(example).actions
|
||||
|
||||
|
||||
def test_oracle_dev_sentence(vocab, arc_eager):
|
||||
|
@ -254,7 +254,7 @@ def test_oracle_dev_sentence(vocab, arc_eager):
|
|||
arc_eager.add_action(3, dep) # Right
|
||||
doc = Doc(Vocab(), words=gold_words)
|
||||
example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps})
|
||||
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
|
||||
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False).actions
|
||||
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
|
||||
assert ae_oracle_actions == expected_transitions
|
||||
|
||||
|
@ -288,6 +288,6 @@ def test_oracle_bad_tokenization(vocab, arc_eager):
|
|||
reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"]
|
||||
)
|
||||
example = Example(predicted=predicted, reference=reference)
|
||||
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False)
|
||||
ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False).actions
|
||||
ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions]
|
||||
assert ae_oracle_actions
|
||||
|
|
|
@ -231,7 +231,7 @@ def test_issue4313():
|
|||
|
||||
def test_get_oracle_moves(tsys, doc, entity_annots):
|
||||
example = Example.from_dict(doc, {"entities": entity_annots})
|
||||
act_classes = tsys.get_oracle_sequence(example, _debug=False)
|
||||
act_classes = tsys.get_oracle_sequence(example, _debug=False).actions
|
||||
names = [tsys.get_class_name(act) for act in act_classes]
|
||||
assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"]
|
||||
|
||||
|
@ -250,7 +250,7 @@ def test_negative_samples_two_word_input(tsys, vocab, neg_key):
|
|||
Span(example.y, 0, 1, label="O"),
|
||||
Span(example.y, 0, 2, label="PERSON"),
|
||||
]
|
||||
act_classes = tsys.get_oracle_sequence(example)
|
||||
act_classes = tsys.get_oracle_sequence(example).actions
|
||||
names = [tsys.get_class_name(act) for act in act_classes]
|
||||
assert names
|
||||
assert names[0] != "O"
|
||||
|
@ -270,7 +270,7 @@ def test_negative_samples_three_word_input(tsys, vocab, neg_key):
|
|||
Span(example.y, 0, 1, label="O"),
|
||||
Span(example.y, 0, 2, label="PERSON"),
|
||||
]
|
||||
act_classes = tsys.get_oracle_sequence(example)
|
||||
act_classes = tsys.get_oracle_sequence(example).actions
|
||||
names = [tsys.get_class_name(act) for act in act_classes]
|
||||
assert names
|
||||
assert names[0] != "O"
|
||||
|
@ -289,7 +289,7 @@ def test_negative_samples_U_entity(tsys, vocab, neg_key):
|
|||
Span(example.y, 0, 1, label="O"),
|
||||
Span(example.y, 0, 1, label="PERSON"),
|
||||
]
|
||||
act_classes = tsys.get_oracle_sequence(example)
|
||||
act_classes = tsys.get_oracle_sequence(example).actions
|
||||
names = [tsys.get_class_name(act) for act in act_classes]
|
||||
assert names
|
||||
assert names[0] != "O"
|
||||
|
|
Loading…
Reference in New Issue
Block a user