From c3defdf66ece6ead6305ad55572b1cfb03f2b566 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Danie=CC=88l=20de=20Kok?= Date: Fri, 24 Feb 2023 15:55:00 +0100 Subject: [PATCH] Filter cut states depending on whether its actions have cost --- .../pipeline/_parser_internals/arc_eager.pyx | 14 +++++--- .../_parser_internals/transition_system.pyx | 34 +++++++++++++++---- spacy/pipeline/transition_parser.pyx | 18 ++++++---- spacy/tests/parser/test_arc_eager_oracle.py | 6 ++-- spacy/tests/parser/test_ner.py | 8 ++--- 5 files changed, 56 insertions(+), 24 deletions(-) diff --git a/spacy/pipeline/_parser_internals/arc_eager.pyx b/spacy/pipeline/_parser_internals/arc_eager.pyx index 9c358475a..68015bb17 100644 --- a/spacy/pipeline/_parser_internals/arc_eager.pyx +++ b/spacy/pipeline/_parser_internals/arc_eager.pyx @@ -2,6 +2,8 @@ from cymem.cymem cimport Pool, Address from libc.stdint cimport int32_t from libcpp.vector cimport vector +import numpy +cimport numpy as np from collections import defaultdict, Counter @@ -16,6 +18,7 @@ from .stateclass cimport StateClass from ._state cimport StateC, ArcC from ...errors import Errors from .search cimport Beam +from .transition_system import OracleSequence cdef weight_t MIN_SCORE = -90000 cdef attr_t SUBTOK_LABEL = hash_string('subtok') @@ -834,19 +837,22 @@ cdef class ArcEager(TransitionSystem): cdef Pool mem = Pool() # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc assert self.n_moves > 0 - costs = mem.alloc(self.n_moves, sizeof(float)) + cdef np.ndarray costs is_valid = mem.alloc(self.n_moves, sizeof(int)) history = [] + cost_matrix = [] debug_log = [] failed = False while not state.is_final(): + costs = numpy.zeros((self.n_moves,), dtype="f") try: - self.set_costs(is_valid, costs, state.c, gold) + self.set_costs(is_valid, costs.data, state.c, gold) except ValueError: failed = True break - min_cost = min(costs[i] for i in range(self.n_moves)) + cost_matrix.append(costs) + min_cost = costs.min() for i in range(self.n_moves): if is_valid[i] and costs[i] <= min_cost: action = self.c[i] @@ -901,4 +907,4 @@ cdef class ArcEager(TransitionSystem): print("Stack", [example.x[i] for i in state.stack]) print("Buffer", [example.x[i] for i in state.queue]) raise ValueError(Errors.E024) - return history + return OracleSequence(history, numpy.array(cost_matrix)) diff --git a/spacy/pipeline/_parser_internals/transition_system.pyx b/spacy/pipeline/_parser_internals/transition_system.pyx index 89f9e8ae8..c1850a542 100644 --- a/spacy/pipeline/_parser_internals/transition_system.pyx +++ b/spacy/pipeline/_parser_internals/transition_system.pyx @@ -1,8 +1,11 @@ # cython: infer_types=True from __future__ import print_function +from typing import List, Optional from cymem.cymem cimport Pool from libc.stdlib cimport calloc, free from libcpp.vector cimport vector +import numpy +cimport numpy as np from collections import Counter import srsly @@ -25,6 +28,22 @@ class OracleError(Exception): pass +class OracleSequence: + actions: List[int] + cost_matrix: numpy.ndarray + + def __init__(self, actions: List[int], cost_matrix: numpy.ndarray): + self.actions = actions + self.cost_matrix = cost_matrix + + __slots = ["actions", "cost_matrix"] + + def has_cost(self, begin: int=0, end: Optional[int]=None) -> bool: + if end is None: + end = self.cost_matrix.shape[0] + return numpy.count_nonzero(self.cost_matrix[begin:end]) + + cdef void* _init_state(Pool mem, int length, void* tokens) except NULL: cdef StateC* st = new StateC(tokens, length) return st @@ -87,10 +106,10 @@ cdef class TransitionSystem: def get_oracle_sequence(self, Example example, _debug=False): if not self.has_gold(example): - return [] + return OracleSequence([], numpy.zeros(0, self.n_moves)) states, golds, _ = self.init_gold_batch([example]) if not states: - return [] + return OracleSequence([], numpy.zeros(0, self.n_moves)) state = states[0] gold = golds[0] if _debug: @@ -100,17 +119,20 @@ cdef class TransitionSystem: def get_oracle_sequence_from_state(self, StateClass state, gold, _debug=None): if state.is_final(): - return [] + return OracleSequence([], numpy.zeros(0, self.n_moves)) cdef Pool mem = Pool() # n_moves should not be zero at this point, but make sure to avoid zero-length mem alloc assert self.n_moves > 0 - costs = mem.alloc(self.n_moves, sizeof(float)) + cdef np.ndarray costs is_valid = mem.alloc(self.n_moves, sizeof(int)) history = [] + cost_matrix = [] debug_log = [] while not state.is_final(): - self.set_costs(is_valid, costs, state.c, gold) + costs = numpy.zeros((self.n_moves,), dtype="f") + self.set_costs(is_valid, costs.data, state.c, gold) + cost_matrix.append(costs) for i in range(self.n_moves): if is_valid[i] and costs[i] <= 0: action = self.c[i] @@ -147,7 +169,7 @@ cdef class TransitionSystem: ))) print("\n".join(debug_log)) raise ValueError(Errors.E024) - return history + return OracleSequence(history, numpy.array(cost_matrix)) def apply_transition(self, StateClass state, name): if not self.is_valid(state, name): diff --git a/spacy/pipeline/transition_parser.pyx b/spacy/pipeline/transition_parser.pyx index 7b49402a2..06a5c2183 100644 --- a/spacy/pipeline/transition_parser.pyx +++ b/spacy/pipeline/transition_parser.pyx @@ -715,22 +715,26 @@ class Parser(TrainablePipe): states.append(state) golds.append(gold) else: - oracle_actions = moves.get_oracle_sequence_from_state( + oracle_seq = moves.get_oracle_sequence_from_state( state.copy(), gold) - to_cut.append((eg, state, gold, oracle_actions)) + to_cut.append((eg, state, gold, oracle_seq)) if not to_cut: return states, golds, 0 cdef int clas - for eg, state, gold, oracle_actions in to_cut: - for i in range(0, len(oracle_actions), max_length): + for eg, state, gold, oracle_seq in to_cut: + for i in range(0, len(oracle_seq.actions), max_length): start_state = state.copy() - for clas in oracle_actions[i:i+max_length]: + for clas in oracle_seq.actions[i:i+max_length]: action = moves.c[clas] action.do(state.c, action.label) if state.is_final(): break - states.append(start_state) - golds.append(gold) + # If all actions along the history are zero-cost actions, there + # is nothing to learn from this state in max_length stepss, so + # we skip it. + if oracle_seq.has_cost(i, i+max_length): + states.append(start_state) + golds.append(gold) if state.is_final(): break return states, golds, max_length diff --git a/spacy/tests/parser/test_arc_eager_oracle.py b/spacy/tests/parser/test_arc_eager_oracle.py index bb226f9c5..1dd01bc52 100644 --- a/spacy/tests/parser/test_arc_eager_oracle.py +++ b/spacy/tests/parser/test_arc_eager_oracle.py @@ -168,7 +168,7 @@ def test_get_oracle_actions(): example = Example.from_dict( doc, {"words": words, "tags": tags, "heads": heads, "deps": deps} ) - parser.moves.get_oracle_sequence(example) + parser.moves.get_oracle_sequence(example).actions def test_oracle_dev_sentence(vocab, arc_eager): @@ -254,7 +254,7 @@ def test_oracle_dev_sentence(vocab, arc_eager): arc_eager.add_action(3, dep) # Right doc = Doc(Vocab(), words=gold_words) example = Example.from_dict(doc, {"heads": gold_heads, "deps": gold_deps}) - ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False) + ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False).actions ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions] assert ae_oracle_actions == expected_transitions @@ -288,6 +288,6 @@ def test_oracle_bad_tokenization(vocab, arc_eager): reference.vocab, words=["[", "catalase", "]", ":", "that", "is", "bad"] ) example = Example(predicted=predicted, reference=reference) - ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False) + ae_oracle_actions = arc_eager.get_oracle_sequence(example, _debug=False).actions ae_oracle_actions = [arc_eager.get_class_name(i) for i in ae_oracle_actions] assert ae_oracle_actions diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index 8d46b57d5..ce1bdfce0 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -231,7 +231,7 @@ def test_issue4313(): def test_get_oracle_moves(tsys, doc, entity_annots): example = Example.from_dict(doc, {"entities": entity_annots}) - act_classes = tsys.get_oracle_sequence(example, _debug=False) + act_classes = tsys.get_oracle_sequence(example, _debug=False).actions names = [tsys.get_class_name(act) for act in act_classes] assert names == ["U-PERSON", "O", "O", "B-GPE", "L-GPE", "O"] @@ -250,7 +250,7 @@ def test_negative_samples_two_word_input(tsys, vocab, neg_key): Span(example.y, 0, 1, label="O"), Span(example.y, 0, 2, label="PERSON"), ] - act_classes = tsys.get_oracle_sequence(example) + act_classes = tsys.get_oracle_sequence(example).actions names = [tsys.get_class_name(act) for act in act_classes] assert names assert names[0] != "O" @@ -270,7 +270,7 @@ def test_negative_samples_three_word_input(tsys, vocab, neg_key): Span(example.y, 0, 1, label="O"), Span(example.y, 0, 2, label="PERSON"), ] - act_classes = tsys.get_oracle_sequence(example) + act_classes = tsys.get_oracle_sequence(example).actions names = [tsys.get_class_name(act) for act in act_classes] assert names assert names[0] != "O" @@ -289,7 +289,7 @@ def test_negative_samples_U_entity(tsys, vocab, neg_key): Span(example.y, 0, 1, label="O"), Span(example.y, 0, 1, label="PERSON"), ] - act_classes = tsys.get_oracle_sequence(example) + act_classes = tsys.get_oracle_sequence(example).actions names = [tsys.get_class_name(act) for act in act_classes] assert names assert names[0] != "O"