mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-06 14:40:34 +03:00
More GoldParse excise
This commit is contained in:
parent
60d4e5a9e0
commit
9296d71a54
|
@ -9,7 +9,6 @@ import numpy
|
||||||
|
|
||||||
from ..typedefs cimport hash_t, class_t
|
from ..typedefs cimport hash_t, class_t
|
||||||
from .transition_system cimport TransitionSystem, Transition
|
from .transition_system cimport TransitionSystem, Transition
|
||||||
from ..gold cimport GoldParse
|
|
||||||
from .stateclass cimport StateC, StateClass
|
from .stateclass cimport StateC, StateClass
|
||||||
|
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
|
@ -126,12 +125,12 @@ cdef class ParserBeam(object):
|
||||||
beam.scores[i][j] = 0
|
beam.scores[i][j] = 0
|
||||||
beam.costs[i][j] = 0
|
beam.costs[i][j] = 0
|
||||||
|
|
||||||
def _set_costs(self, Beam beam, GoldParse gold, int follow_gold=False):
|
def _set_costs(self, Beam beam, NewExample example, int follow_gold=False):
|
||||||
for i in range(beam.size):
|
for i in range(beam.size):
|
||||||
state = StateClass.borrow(<StateC*>beam.at(i))
|
state = StateClass.borrow(<StateC*>beam.at(i))
|
||||||
if not state.is_final():
|
if not state.is_final():
|
||||||
self.moves.set_costs(beam.is_valid[i], beam.costs[i],
|
self.moves.set_costs(beam.is_valid[i], beam.costs[i],
|
||||||
state, gold)
|
state, example)
|
||||||
if follow_gold:
|
if follow_gold:
|
||||||
min_cost = 0
|
min_cost = 0
|
||||||
for j in range(beam.nr_class):
|
for j in range(beam.nr_class):
|
||||||
|
|
|
@ -20,7 +20,6 @@ import numpy
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from ..tokens.doc cimport Doc
|
from ..tokens.doc cimport Doc
|
||||||
from ..gold cimport GoldParse
|
|
||||||
from ..typedefs cimport weight_t, class_t, hash_t
|
from ..typedefs cimport weight_t, class_t, hash_t
|
||||||
from ._parser_model cimport alloc_activations, free_activations
|
from ._parser_model cimport alloc_activations, free_activations
|
||||||
from ._parser_model cimport predict_states, arg_max_if_valid
|
from ._parser_model cimport predict_states, arg_max_if_valid
|
||||||
|
@ -567,9 +566,9 @@ cdef class Parser:
|
||||||
max_moves = max(max_moves, len(oracle_actions))
|
max_moves = max(max_moves, len(oracle_actions))
|
||||||
return states, golds, max_moves
|
return states, golds, max_moves
|
||||||
|
|
||||||
def get_batch_loss(self, states, golds, float[:, ::1] scores, losses):
|
def get_batch_loss(self, states, examples, float[:, ::1] scores, losses):
|
||||||
cdef StateClass state
|
cdef StateClass state
|
||||||
cdef GoldParse gold
|
cdef NewExample example
|
||||||
cdef Pool mem = Pool()
|
cdef Pool mem = Pool()
|
||||||
cdef int i
|
cdef int i
|
||||||
|
|
||||||
|
@ -582,10 +581,10 @@ cdef class Parser:
|
||||||
dtype='f', order='C')
|
dtype='f', order='C')
|
||||||
c_d_scores = <float*>d_scores.data
|
c_d_scores = <float*>d_scores.data
|
||||||
unseen_classes = self.model.attrs["unseen_classes"]
|
unseen_classes = self.model.attrs["unseen_classes"]
|
||||||
for i, (state, gold) in enumerate(zip(states, golds)):
|
for i, (state, eg) in enumerate(zip(states, examples)):
|
||||||
memset(is_valid, 0, self.moves.n_moves * sizeof(int))
|
memset(is_valid, 0, self.moves.n_moves * sizeof(int))
|
||||||
memset(costs, 0, self.moves.n_moves * sizeof(float))
|
memset(costs, 0, self.moves.n_moves * sizeof(float))
|
||||||
self.moves.set_costs(is_valid, costs, state, gold)
|
self.moves.set_costs(is_valid, costs, state, eg)
|
||||||
for j in range(self.moves.n_moves):
|
for j in range(self.moves.n_moves):
|
||||||
if costs[j] <= 0.0 and j in unseen_classes:
|
if costs[j] <= 0.0 and j in unseen_classes:
|
||||||
unseen_classes.remove(j)
|
unseen_classes.remove(j)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user