diff --git a/spacy/syntax/_beam_utils.pyx b/spacy/syntax/_beam_utils.pyx index 03702e54e..52d49db82 100644 --- a/spacy/syntax/_beam_utils.pyx +++ b/spacy/syntax/_beam_utils.pyx @@ -9,7 +9,6 @@ import numpy from ..typedefs cimport hash_t, class_t from .transition_system cimport TransitionSystem, Transition -from ..gold cimport GoldParse from .stateclass cimport StateC, StateClass from ..errors import Errors @@ -126,12 +125,12 @@ cdef class ParserBeam(object): beam.scores[i][j] = 0 beam.costs[i][j] = 0 - def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False): + def _set_costs(self, Beam beam, NewExample example, int follow_gold=False): for i in range(beam.size): state = StateClass.borrow(beam.at(i)) if not state.is_final(): self.moves.set_costs(beam.is_valid[i], beam.costs[i], - state, gold) + state, example) if follow_gold: min_cost = 0 for j in range(beam.nr_class): diff --git a/spacy/syntax/nn_parser.pyx b/spacy/syntax/nn_parser.pyx index 7bd9562e2..4e3721cda 100644 --- a/spacy/syntax/nn_parser.pyx +++ b/spacy/syntax/nn_parser.pyx @@ -20,7 +20,6 @@ import numpy import warnings from ..tokens.doc cimport Doc -from ..gold cimport GoldParse from ..typedefs cimport weight_t, class_t, hash_t from ._parser_model cimport alloc_activations, free_activations from ._parser_model cimport predict_states, arg_max_if_valid @@ -567,9 +566,9 @@ cdef class Parser: max_moves = max(max_moves, len(oracle_actions)) return states, golds, max_moves - def get_batch_loss(self, states, golds, float[:, ::1] scores, losses): + def get_batch_loss(self, states, examples, float[:, ::1] scores, losses): cdef StateClass state - cdef GoldParse gold + cdef NewExample example cdef Pool mem = Pool() cdef int i @@ -582,10 +581,10 @@ cdef class Parser: dtype='f', order='C') c_d_scores = d_scores.data unseen_classes = self.model.attrs["unseen_classes"] - for i, (state, gold) in enumerate(zip(states, golds)): + for i, (state, eg) in enumerate(zip(states, examples)): memset(is_valid, 0, self.moves.n_moves * sizeof(int)) memset(costs, 0, self.moves.n_moves * sizeof(float)) - self.moves.set_costs(is_valid, costs, state, gold) + self.moves.set_costs(is_valid, costs, state, eg) for j in range(self.moves.n_moves): if costs[j] <= 0.0 and j in unseen_classes: unseen_classes.remove(j)