mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-06 06:30:35 +03:00
More GoldParse excise
This commit is contained in:
parent
60d4e5a9e0
commit
9296d71a54
|
@ -9,7 +9,6 @@ import numpy
|
|||
|
||||
from ..typedefs cimport hash_t, class_t
|
||||
from .transition_system cimport TransitionSystem, Transition
|
||||
from ..gold cimport GoldParse
|
||||
from .stateclass cimport StateC, StateClass
|
||||
|
||||
from ..errors import Errors
|
||||
|
@ -126,12 +125,12 @@ cdef class ParserBeam(object):
|
|||
beam.scores[i][j] = 0
|
||||
beam.costs[i][j] = 0
|
||||
|
||||
def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
|
||||
def _set_costs(self, Beam beam, NewExample example, int follow_gold=False):
|
||||
for i in range(beam.size):
|
||||
state = StateClass.borrow(<StateC*>beam.at(i))
|
||||
if not state.is_final():
|
||||
self.moves.set_costs(beam.is_valid[i], beam.costs[i],
|
||||
state, gold)
|
||||
state, example)
|
||||
if follow_gold:
|
||||
min_cost = 0
|
||||
for j in range(beam.nr_class):
|
||||
|
|
|
@ -20,7 +20,6 @@ import numpy
|
|||
import warnings
|
||||
|
||||
from ..tokens.doc cimport Doc
|
||||
from ..gold cimport GoldParse
|
||||
from ..typedefs cimport weight_t, class_t, hash_t
|
||||
from ._parser_model cimport alloc_activations, free_activations
|
||||
from ._parser_model cimport predict_states, arg_max_if_valid
|
||||
|
@ -567,9 +566,9 @@ cdef class Parser:
|
|||
max_moves = max(max_moves, len(oracle_actions))
|
||||
return states, golds, max_moves
|
||||
|
||||
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
|
||||
def get_batch_loss(self, states, examples, float[:, ::1] scores, losses):
|
||||
cdef StateClass state
|
||||
cdef GoldParse gold
|
||||
cdef NewExample example
|
||||
cdef Pool mem = Pool()
|
||||
cdef int i
|
||||
|
||||
|
@ -582,10 +581,10 @@ cdef class Parser:
|
|||
dtype='f', order='C')
|
||||
c_d_scores = <float*>d_scores.data
|
||||
unseen_classes = self.model.attrs["unseen_classes"]
|
||||
for i, (state, gold) in enumerate(zip(states, golds)):
|
||||
for i, (state, eg) in enumerate(zip(states, examples)):
|
||||
memset(is_valid, 0, self.moves.n_moves * sizeof(int))
|
||||
memset(costs, 0, self.moves.n_moves * sizeof(float))
|
||||
self.moves.set_costs(is_valid, costs, state, gold)
|
||||
self.moves.set_costs(is_valid, costs, state, eg)
|
||||
for j in range(self.moves.n_moves):
|
||||
if costs[j] <= 0.0 and j in unseen_classes:
|
||||
unseen_classes.remove(j)
|
||||
|
|
Loading…
Reference in New Issue
Block a user