Fix history features

This commit is contained in:
Matthew Honnibal 2017-10-03 13:27:10 +02:00
parent b770f4e108
commit 278a4c17c6

View File

@ -70,6 +70,8 @@ from ..attrs cimport ID, TAG, DEP, ORTH, NORM, PREFIX, SUFFIX, TAG
from . import _beam_utils from . import _beam_utils
USE_HISTORY = True USE_HISTORY = True
HIST_SIZE = 2
HIST_DIMS = 16
def get_templates(*args, **kwargs): def get_templates(*args, **kwargs):
return [] return []
@ -262,13 +264,11 @@ cdef class Parser:
with Model.use_device('cpu'): with Model.use_device('cpu'):
if depth == 0: if depth == 0:
hist_size = 8
nr_dim = 8
if USE_HISTORY: if USE_HISTORY:
upper = chain( upper = chain(
HistoryFeatures(nr_class=nr_class, hist_size=hist_size, HistoryFeatures(nr_class=nr_class, hist_size=HIST_SIZE,
nr_dim=nr_dim), nr_dim=HIST_DIMS),
zero_init(Affine(nr_class, nr_class+hist_size*nr_dim, zero_init(Affine(nr_class, nr_class+HIST_SIZE*HIST_DIMS,
drop_factor=0.0))) drop_factor=0.0)))
upper.is_noop = False upper.is_noop = False
else: else:
@ -736,15 +736,13 @@ cdef class Parser:
cdef StateClass state cdef StateClass state
cdef int[500] is_valid # TODO: Unhack cdef int[500] is_valid # TODO: Unhack
cdef float* c_scores = &scores[0, 0] cdef float* c_scores = &scores[0, 0]
hists = []
for state in states: for state in states:
self.moves.set_valid(is_valid, state.c) self.moves.set_valid(is_valid, state.c)
guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1]) guess = arg_max_if_valid(c_scores, is_valid, scores.shape[1])
action = self.moves.c[guess] action = self.moves.c[guess]
action.do(state.c, action.label) action.do(state.c, action.label)
c_scores += scores.shape[1] c_scores += scores.shape[1]
hists.append(guess) state.c.push_hist(guess)
return hists
def get_batch_loss(self, states, golds, float[:, ::1] scores): def get_batch_loss(self, states, golds, float[:, ::1] scores):
cdef StateClass state cdef StateClass state