spaCy/spacy/ml/_precomputable_affine.py
Matthew Honnibal 333b1a308b
Adapt parser and NER for transformers (#5449)
* Draft layer for BILUO actions

* Fixes to biluo layer

* WIP on BILUO layer

* Add tests for BILUO layer

* Format

* Fix transitions

* Update test

* Link in the simple_ner

* Update BILUO tagger

* Update __init__

* Import simple_ner

* Update test

* Import

* Add files

* Add config

* Fix label passing for BILUO and tagger

* Fix label handling for simple_ner component

* Update simple NER test

* Update config

* Hack train script

* Update BILUO layer

* Fix SimpleNER component

* Update train_from_config

* Add biluo_to_iob helper

* Add IOB layer

* Add IOBTagger model

* Update biluo layer

* Update SimpleNER tagger

* Update BILUO

* Read random seed in train-from-config

* Update use of normal_init

* Fix normalization of gradient in SimpleNER

* Update IOBTagger

* Remove print

* Tweak masking in BILUO

* Add dropout in SimpleNER

* Update thinc

* Tidy up simple_ner

* Fix biluo model

* Unhack train-from-config

* Update setup.cfg and requirements

* Add tb_framework.py for parser model

* Try to avoid memory leak in BILUO

* Move ParserModel into spacy.ml, avoid need for subclass.

* Use updated parser model

* Remove incorrect call to model.initializre in PrecomputableAffine

* Update parser model

* Avoid divide by zero in tagger

* Add extra dropout layer in tagger

* Refine minibatch_by_words function to avoid oom

* Fix parser model after refactor

* Try to avoid div-by-zero in SimpleNER

* Fix infinite loop in minibatch_by_words

* Use SequenceCategoricalCrossentropy in Tagger

* Fix parser model when hidden layer

* Remove extra dropout from tagger

* Add extra nan check in tagger

* Fix thinc version

* Update tests and imports

* Fix test

* Update test

* Update tests

* Fix tests

* Fix test

Co-authored-by: Ines Montani <ines@ines.io>
2020-05-18 22:23:33 +02:00

157 lines
5.3 KiB
Python

from thinc.api import Model, normal_init
def PrecomputableAffine(nO, nI, nF, nP):
model = Model(
"precomputable_affine",
forward,
init=init,
dims={"nO": nO, "nI": nI, "nF": nF, "nP": nP},
params={"W": None, "b": None, "pad": None},
)
return model
def forward(model, X, is_train):
nF = model.get_dim("nF")
nO = model.get_dim("nO")
nP = model.get_dim("nP")
nI = model.get_dim("nI")
W = model.get_param("W")
Yf = model.ops.gemm(X, W.reshape((nF * nO * nP, nI)), trans2=True)
Yf = Yf.reshape((Yf.shape[0], nF, nO, nP))
Yf = model.ops.xp.vstack((model.get_param("pad"), Yf))
def backward(dY_ids):
# This backprop is particularly tricky, because we get back a different
# thing from what we put out. We put out an array of shape:
# (nB, nF, nO, nP), and get back:
# (nB, nO, nP) and ids (nB, nF)
# The ids tell us the values of nF, so we would have:
#
# dYf = zeros((nB, nF, nO, nP))
# for b in range(nB):
# for f in range(nF):
# dYf[b, ids[b, f]] += dY[b]
#
# However, we avoid building that array for efficiency -- and just pass
# in the indices.
dY, ids = dY_ids
assert dY.ndim == 3
assert dY.shape[1] == nO, dY.shape
assert dY.shape[2] == nP, dY.shape
# nB = dY.shape[0]
model.inc_grad("pad", _backprop_precomputable_affine_padding(model, dY, ids))
Xf = X[ids]
Xf = Xf.reshape((Xf.shape[0], nF * nI))
model.inc_grad("b", dY.sum(axis=0))
dY = dY.reshape((dY.shape[0], nO * nP))
Wopfi = W.transpose((1, 2, 0, 3))
Wopfi = model.ops.xp.ascontiguousarray(Wopfi)
Wopfi = Wopfi.reshape((nO * nP, nF * nI))
dXf = model.ops.gemm(dY.reshape((dY.shape[0], nO * nP)), Wopfi)
# Reuse the buffer
dWopfi = Wopfi
dWopfi.fill(0.0)
model.ops.gemm(dY, Xf, out=dWopfi, trans1=True)
dWopfi = dWopfi.reshape((nO, nP, nF, nI))
# (o, p, f, i) --> (f, o, p, i)
model.inc_grad("W", dWopfi.transpose((2, 0, 1, 3)))
return dXf.reshape((dXf.shape[0], nF, nI))
return Yf, backward
def _backprop_precomputable_affine_padding(model, dY, ids):
nB = dY.shape[0]
nF = model.get_dim("nF")
nP = model.get_dim("nP")
nO = model.get_dim("nO")
# Backprop the "padding", used as a filler for missing values.
# Values that are missing are set to -1, and each state vector could
# have multiple missing values. The padding has different values for
# different missing features. The gradient of the padding vector is:
#
# for b in range(nB):
# for f in range(nF):
# if ids[b, f] < 0:
# d_pad[f] += dY[b]
#
# Which can be rewritten as:
#
# (ids < 0).T @ dY
mask = model.ops.asarray(ids < 0, dtype="f")
d_pad = model.ops.gemm(mask, dY.reshape(nB, nO*nP), trans1=True)
return d_pad.reshape((1, nF, nO, nP))
def init(model, X=None, Y=None):
"""This is like the 'layer sequential unit variance', but instead
of taking the actual inputs, we randomly generate whitened data.
Why's this all so complicated? We have a huge number of inputs,
and the maxout unit makes guessing the dynamics tricky. Instead
we set the maxout weights to values that empirically result in
whitened outputs given whitened inputs.
"""
if model.has_param("W") and model.get_param("W").any():
return
nF = model.get_dim("nF")
nO = model.get_dim("nO")
nP = model.get_dim("nP")
nI = model.get_dim("nI")
W = model.ops.alloc4f(nF, nO, nP, nI)
b = model.ops.alloc2f(nO, nP)
pad = model.ops.alloc4f(1, nF, nO, nP)
ops = model.ops
W = normal_init(ops, W.shape, mean=float(ops.xp.sqrt(1.0 / nF * nI)))
model.set_param("W", W)
model.set_param("b", b)
model.set_param("pad", pad)
ids = ops.alloc((5000, nF), dtype="f")
ids += ops.xp.random.uniform(0, 1000, ids.shape)
ids = ops.asarray(ids, dtype="i")
tokvecs = ops.alloc((5000, nI), dtype="f")
tokvecs += ops.xp.random.normal(loc=0.0, scale=1.0, size=tokvecs.size).reshape(
tokvecs.shape
)
def predict(ids, tokvecs):
# nS ids. nW tokvecs. Exclude the padding array.
hiddens = model.predict(tokvecs[:-1]) # (nW, f, o, p)
vectors = model.ops.alloc((ids.shape[0], nO * nP), dtype="f")
# need nS vectors
hiddens = hiddens.reshape((hiddens.shape[0] * nF, nO * nP))
model.ops.scatter_add(vectors, ids.flatten(), hiddens)
vectors = vectors.reshape((vectors.shape[0], nO, nP))
vectors += b
vectors = model.ops.asarray(vectors)
if nP >= 2:
return model.ops.maxout(vectors)[0]
else:
return vectors * (vectors >= 0)
tol_var = 0.01
tol_mean = 0.01
t_max = 10
W = model.get_param("W").copy()
b = model.get_param("b").copy()
for t_i in range(t_max):
acts1 = predict(ids, tokvecs)
var = model.ops.xp.var(acts1)
mean = model.ops.xp.mean(acts1)
if abs(var - 1.0) >= tol_var:
W /= model.ops.xp.sqrt(var)
model.set_param("W", W)
elif abs(mean) >= tol_mean:
b -= mean
model.set_param("b", b)
else:
break