Fix embed class in history features

This commit is contained in:
Matthew Honnibal 2017-10-03 13:26:55 +02:00
parent b50a359e11
commit b770f4e108

View File

@ -21,7 +21,6 @@ from thinc.neural._classes.affine import _set_dimensions_if_needed
from thinc.api import FeatureExtracter, with_getitem
from thinc.neural.pooling import Pooling, max_pool, mean_pool, sum_pool
from thinc.neural._classes.attention import ParametricAttention
from thinc.neural._classes.embed import Embed
from thinc.linear.linear import LinearModel
from thinc.api import uniqued, wrap, flatten_add_lengths, noop
@ -212,23 +211,60 @@ class PrecomputableMaxouts(Model):
return dXf
return Yfp, backward
# Thinc's Embed class is a bit broken atm, so drop this here.
from thinc import describe
from thinc.neural._classes.embed import _uniform_init
@describe.attributes(
nV=describe.Dimension("Number of vectors"),
nO=describe.Dimension("Size of output"),
vectors=describe.Weights("Embedding table",
lambda obj: (obj.nV, obj.nO),
_uniform_init(-0.1, 0.1)
),
d_vectors=describe.Gradient("vectors")
)
class Embed(Model):
name = 'embed'
def __init__(self, nO, nV=None, **kwargs):
Model.__init__(self, **kwargs)
self.column = kwargs.get('column', 0)
self.nO = nO
self.nV = nV
def predict(self, ids):
if ids.ndim == 2:
ids = ids[:, self.column]
return self._embed(ids)
def begin_update(self, ids, drop=0.):
if ids.ndim == 2:
ids = ids[:, self.column]
vectors = self.vectors[ids]
def backprop_embed(d_vectors, sgd=None):
n_vectors = d_vectors.shape[0]
self.ops.scatter_add(self.d_vectors, ids, d_vectors)
if sgd is not None:
sgd(self._mem.weights, self._mem.gradient, key=self.id)
return None
return vectors, backprop_embed
def HistoryFeatures(nr_class, hist_size=8, nr_dim=8):
'''Wrap a model, adding features representing action history.'''
embed = Embed(nr_dim, nr_dim, nr_class)
embed_tables = [Embed(nr_dim, nr_class, column=i) for i in range(hist_size)]
embed = concatenate(*embed_tables)
ops = embed.ops
def add_history_fwd(vectors_hists, drop=0.):
vectors, hist_ids = vectors_hists
flat_hists, bp_hists = embed.begin_update(hist_ids.flatten(), drop=drop)
hists = flat_hists.reshape((hist_ids.shape[0],
hist_ids.shape[1] * flat_hists.shape[1]))
outputs = ops.xp.hstack((vectors, hists))
hist_feats, bp_hists = embed.begin_update(hist_ids)
outputs = ops.xp.hstack((vectors, hist_feats))
def add_history_bwd(d_outputs, sgd=None):
d_vectors = d_outputs[:, :vectors.shape[1]]
d_hists = d_outputs[:, vectors.shape[1]:]
bp_hists(d_hists.reshape((d_hists.shape[0]*hist_size,
int(d_hists.shape[1]/hist_size))), sgd=sgd)
bp_hists(d_hists, sgd=sgd)
return embed.ops.xp.ascontiguousarray(d_vectors)
return outputs, add_history_bwd
return wrap(add_history_fwd, embed)