mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Fix history features
This commit is contained in:
parent
b770f4e108
commit
278a4c17c6
|
@ -70,6 +70,8 @@ from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
|
|||
from . import _beam_utils
|
||||
|
||||
USE_HISTORY = True
|
||||
HIST_SIZE = 2
|
||||
HIST_DIMS = 16
|
||||
|
||||
def get_templates(*args, **kwargs):
|
||||
return []
|
||||
|
@ -262,13 +264,11 @@ cdef class Parser:
|
|||
|
||||
with Model.use_device('cpu'):
|
||||
if depth == 0:
|
||||
hist_size = 8
|
||||
nr_dim = 8
|
||||
if USE_HISTORY:
|
||||
upper = chain(
|
||||
HistoryFeatures(nr_class=nr_class, hist_size=hist_size,
|
||||
nr_dim=nr_dim),
|
||||
zero_init(Affine(nr_class, nr_class+hist_size*nr_dim,
|
||||
HistoryFeatures(nr_class=nr_class, hist_size=HIST_SIZE,
|
||||
nr_dim=HIST_DIMS),
|
||||
zero_init(Affine(nr_class, nr_class+HIST_SIZE*HIST_DIMS,
|
||||
drop_factor=0.0)))
|
||||
upper.is_noop = False
|
||||
else:
|
||||
|
@ -736,15 +736,13 @@ cdef class Parser:
|
|||
cdef StateClass state
|
||||
cdef int[500] is_valid # TODO: Unhack
|
||||
cdef float* c_scores = &scores[0, 0]
|
||||
hists = []
|
||||
for state in states:
|
||||
self.moves.set_valid(is_valid, state.c)
|
||||
guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1])
|
||||
action = self.moves.c[guess]
|
||||
action.do(state.c, action.label)
|
||||
c_scores += scores.shape[1]
|
||||
hists.append(guess)
|
||||
return hists
|
||||
state.c.push_hist(guess)
|
||||
|
||||
def get_batch_loss(self, states, golds, float[:, ::1] scores):
|
||||
cdef StateClass state
|
||||
|
|
Loading…
Reference in New Issue
Block a user