More GoldParse excise

This commit is contained in:
Matthew Honnibal 2020-06-14 17:26:54 +02:00
parent 60d4e5a9e0
commit 9296d71a54
2 changed files with 6 additions and 8 deletions

View File

@ -9,7 +9,6 @@ import numpy
from ..typedefs cimport hash_t, class_t from ..typedefs cimport hash_t, class_t
from .transition_system cimport TransitionSystem, Transition from .transition_system cimport TransitionSystem, Transition
from ..gold cimport GoldParse
from .stateclass cimport StateC, StateClass from .stateclass cimport StateC, StateClass
from ..errors import Errors from ..errors import Errors
@ -126,12 +125,12 @@ cdef class ParserBeam(object):
beam.scores[i][j] = 0 beam.scores[i][j] = 0
beam.costs[i][j] = 0 beam.costs[i][j] = 0
def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False): def _set_costs(self, Beam beam, NewExample example, int follow_gold=False):
for i in range(beam.size): for i in range(beam.size):
state = StateClass.borrow(<StateC*>beam.at(i)) state = StateClass.borrow(<StateC*>beam.at(i))
if not state.is_final(): if not state.is_final():
self.moves.set_costs(beam.is_valid[i], beam.costs[i], self.moves.set_costs(beam.is_valid[i], beam.costs[i],
state, gold) state, example)
if follow_gold: if follow_gold:
min_cost = 0 min_cost = 0
for j in range(beam.nr_class): for j in range(beam.nr_class):

View File

@ -20,7 +20,6 @@ import numpy
import warnings import warnings
from ..tokens.doc cimport Doc from ..tokens.doc cimport Doc
from ..gold cimport GoldParse
from ..typedefs cimport weight_t, class_t, hash_t from ..typedefs cimport weight_t, class_t, hash_t
from ._parser_model cimport alloc_activations, free_activations from ._parser_model cimport alloc_activations, free_activations
from ._parser_model cimport predict_states, arg_max_if_valid from ._parser_model cimport predict_states, arg_max_if_valid
@ -567,9 +566,9 @@ cdef class Parser:
max_moves = max(max_moves, len(oracle_actions)) max_moves = max(max_moves, len(oracle_actions))
return states, golds, max_moves return states, golds, max_moves
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses): def get_batch_loss(self, states, examples, float[:, ::1] scores, losses):
cdef StateClass state cdef StateClass state
cdef GoldParse gold cdef NewExample example
cdef Pool mem = Pool() cdef Pool mem = Pool()
cdef int i cdef int i
@ -582,10 +581,10 @@ cdef class Parser:
dtype='f', order='C') dtype='f', order='C')
c_d_scores = <float*>d_scores.data c_d_scores = <float*>d_scores.data
unseen_classes = self.model.attrs["unseen_classes"] unseen_classes = self.model.attrs["unseen_classes"]
for i, (state, gold) in enumerate(zip(states, golds)): for i, (state, eg) in enumerate(zip(states, examples)):
memset(is_valid, 0, self.moves.n_moves * sizeof(int)) memset(is_valid, 0, self.moves.n_moves * sizeof(int))
memset(costs, 0, self.moves.n_moves * sizeof(float)) memset(costs, 0, self.moves.n_moves * sizeof(float))
self.moves.set_costs(is_valid, costs, state, gold) self.moves.set_costs(is_valid, costs, state, eg)
for j in range(self.moves.n_moves): for j in range(self.moves.n_moves):
if costs[j] <= 0.0 and j in unseen_classes: if costs[j] <= 0.0 and j in unseen_classes:
unseen_classes.remove(j) unseen_classes.remove(j)