Fix Embed and HistoryFeatures

This commit is contained in:
Matthew Honnibal 2017-10-04 19:55:34 -05:00
parent 246612cb53
commit 92066b04d6

View File

@ -231,6 +231,8 @@ class Embed(Model):
def __init__(self, nO, nV=None, **kwargs):
Model.__init__(self, **kwargs)
if 'name' in kwargs:
self.name = kwargs['name']
self.column = kwargs.get('column', 0)
self.nO = nO
self.nV = nV
@ -238,12 +240,12 @@ class Embed(Model):
def predict(self, ids):
if ids.ndim == 2:
ids = ids[:, self.column]
return self._embed(ids)
return self.ops.xp.ascontiguousarray(self.vectors[ids])
def begin_update(self, ids, drop=0.):
if ids.ndim == 2:
ids = ids[:, self.column]
vectors = self.vectors[ids]
vectors = self.ops.xp.ascontiguousarray(self.vectors[ids])
def backprop_embed(d_vectors, sgd=None):
n_vectors = d_vectors.shape[0]
self.ops.scatter_add(self.d_vectors, ids, d_vectors)
@ -255,7 +257,8 @@ class Embed(Model):
def HistoryFeatures(nr_class, hist_size=8, nr_dim=8):
'''Wrap a model, adding features representing action history.'''
embed_tables = [Embed(nr_dim, nr_class, column=i) for i in range(hist_size)]
embed_tables = [Embed(nr_dim, nr_class, column=i, name='embed%d')
for i in range(hist_size)]
embed = concatenate(*embed_tables)
ops = embed.ops
def add_history_fwd(vectors_hists, drop=0.):