From 5ae9e3480d88963b5c5d64f84550042d80822bc1 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Fri, 19 Jun 2020 00:11:59 +0200 Subject: [PATCH] Return ArcEagerGoldParse from ArcEager --- spacy/syntax/arc_eager.pyx | 18 +++++++++++++++++- spacy/syntax/nn_parser.pyx | 35 ++++++++++++++--------------------- 2 files changed, 31 insertions(+), 22 deletions(-) diff --git a/spacy/syntax/arc_eager.pyx b/spacy/syntax/arc_eager.pyx index dc0ce9f59..23e72916e 100644 --- a/spacy/syntax/arc_eager.pyx +++ b/spacy/syntax/arc_eager.pyx @@ -132,6 +132,15 @@ cdef GoldParseStateC create_gold_state(Pool mem, StateClass stcls, Example examp ) return gs + +cdef class ArcEagerGoldParse: + cdef GoldParseStateC c + def __init__(self, StateClass stcls, Example example): + self.mem = Pool() + self.c = create_gold_state(self.mem, stcls, example) + + + cdef int check_state_gold(char state_bits, char flag) nogil: cdef char one = 1 return state_bits & (one << flag) @@ -156,7 +165,6 @@ cdef int is_head_in_buffer(const GoldParseStateC* gold, int i) nogil: cdef int is_head_unknown(const GoldParseStateC* gold, int i) nogil: return check_state_gold(gold.state_bits[i], HEAD_UNKNOWN) - # Helper functions for the arc-eager oracle cdef weight_t push_cost(StateClass stcls, const void* _gold, int target) nogil: @@ -500,6 +508,14 @@ cdef class ArcEager(TransitionSystem): def preprocess_gold(self, example): raise NotImplementedError + def init_gold_batch(self, examples): + states = self.init_batch([eg.predicted for eg in examples]) + keeps = [i for i, s in enumerate(states) if not s.is_final()] + states = [states[i] for i in keeps] + examples = [examples[i] for i in keeps] + n_steps = sum([len(s.buffer_length()) * 4 for s in states]) + return states, examples, n_steps + cdef Transition lookup_transition(self, object name_or_id) except *: if isinstance(name_or_id, int): return self.c[name_or_id] diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index d139d8c35..22e0e7995 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -268,15 +268,10 @@ cdef class Parser: for multitask in self._multitasks: multitask.update(examples, drop=drop, sgd=sgd) set_dropout_rate(self.model, drop) - try: - states, golds, max_steps = self._init_gold_batch_no_cut(examples) - except AttributeError: - types = set([type(eg) for eg in examples]) - raise ValueError(Errors.E978.format(name="Parser", method="update", types=types)) - states_golds = [(s, g) for (s, g) in zip(states, golds) - if not s.is_final() and g is not None] # Prepare the stepwise model, and get the callback for finishing the batch - model, backprop_tok2vec = self.model.begin_update([eg.doc for eg in examples]) + model, backprop_tok2vec = self.model.begin_update( + [eg.predicted for eg in examples]) + states, golds, max_steps = self.moves.init_gold_batch(examples) all_states = list(states) for _ in range(max_steps): if not states_golds: @@ -287,12 +282,12 @@ cdef class Parser: backprop(d_scores) # Follow the predicted action self.transition_states(states, scores) - states_golds = [eg for eg in states_golds if not eg[0].is_final()] + states_golds = [(s, g) for (s, g) in zip(states, golds) if not s.is_final()] backprop_tok2vec(golds) if sgd is not None: self.model.finish_update(sgd) if set_annotations: - docs = [eg.doc for eg in examples] + docs = [eg.predicted for eg in examples] self.set_annotations(docs, all_states) return losses @@ -307,7 +302,7 @@ cdef class Parser: return None losses.setdefault(self.name, 0.) - docs = [eg.doc for eg in examples] + docs = [eg.predicted for eg in examples] states = self.moves.init_batch(docs) # This is pretty dirty, but the NER can resize itself in init_batch, # if labels are missing. We therefore have to check whether we need to @@ -356,11 +351,7 @@ cdef class Parser: queue.extend(node._layers) return gradients - def _init_gold_batch_no_cut(self, examples): - states = self.moves.init_batch([eg.predicted for eg in examples]) - return states, examples - - def get_batch_loss(self, states, examples, float[:, ::1] scores, losses): + def get_batch_loss(self, states, golds, float[:, ::1] scores, losses): cdef StateClass state cdef Example example cdef Pool mem = Pool() @@ -375,10 +366,10 @@ cdef class Parser: dtype='f', order='C') c_d_scores = d_scores.data unseen_classes = self.model.attrs["unseen_classes"] - for i, (state, eg) in enumerate(zip(states, examples)): + for i, (state, gold) in enumerate(zip(states, golds)): memset(is_valid, 0, self.moves.n_moves * sizeof(int)) memset(costs, 0, self.moves.n_moves * sizeof(float)) - self.moves.set_costs(is_valid, costs, state, eg) + self.moves.set_costs(is_valid, costs, state, gold) for j in range(self.moves.n_moves): if costs[j] <= 0.0 and j in unseen_classes: unseen_classes.remove(j) @@ -403,9 +394,11 @@ cdef class Parser: if not hasattr(get_examples, '__call__'): gold_tuples = get_examples get_examples = lambda: gold_tuples - actions = self.moves.get_actions(gold_parses=get_examples(), - min_freq=self.cfg['min_action_freq'], - learn_tokens=self.cfg["learn_tokens"]) + actions = self.moves.get_actions( + examples=get_examples(), + min_freq=self.cfg['min_action_freq'], + learn_tokens=self.cfg["learn_tokens"] + ) for action, labels in self.moves.labels.items(): actions.setdefault(action, {}) for label, freq in labels.items():