diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py index fa796f21e..fb62828f3 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.py @@ -1,5 +1,6 @@ from typing import List, Tuple, Any, Optional from thinc.api import Ops, Model, normal_init, chain, list2array, Linear +from thinc.api import uniform_init from thinc.types import Floats1d, Floats2d, Floats3d, Ints2d, Floats4d import numpy from ..tokens.doc import Doc @@ -27,7 +28,7 @@ def TransitionModel( return Model( name="parser_model", - forward=forward, + forward=_forward_reference, init=init, layers=[tok2vec_projected], refs={"tok2vec": tok2vec_projected}, @@ -184,6 +185,137 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo d_statevecs = model.ops.gemm(d_scores, upper_W) # Backprop through the maxout activation d_preacts = model.ops.backprop_maxout(d_statevecs, which, model.get_dim("nP")) + d_preacts2f = model.ops.reshape2f(d_preacts, d_preacts.shape[0], -1) + model.inc_grad("lower_b", d_preacts2f.sum(axis=0)) + model.inc_grad("lower_W", model.ops.gemm(d_preacts2f, tokfeats, trans1=True)) + d_tokfeats = model.ops.gemm(d_preacts2f, lower_W) + d_tokfeats3f = model.ops.reshape3f(d_tokfeats, nS, nF, nI) + d_lower_pad = model.ops.alloc2f(nF, nI) + for i in range(ids.shape[0]): + for j in range(ids.shape[1]): + if ids[i, j] == -1: + d_lower_pad[j] += d_tokfeats3f[i, j] + else: + d_tokvecs[ids[i, j]] += d_tokfeats3f[i, j] + model.inc_grad("lower_pad", d_lower_pad) + # We don't need to backprop the summation, because we pass back the IDs instead + # d_state_features = backprop_feats((d_preacts, all_ids)) + # ids1d = model.ops.xp.vstack(all_ids).flatten() + # d_state_features = d_state_features.reshape((ids1d.size, -1)) + # d_tokvecs = model.ops.alloc((tokvecs.shape[0] + 1, tokvecs.shape[1])) + # model.ops.scatter_add(d_tokvecs, ids1d, d_state_features) + return (backprop_tok2vec(d_tokvecs), None) + + return (states, all_scores), backprop_parser + + + +def _forward_reference(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: bool): + """Slow reference implementation, without the precomputation""" + nF = model.get_dim("nF") + tok2vec = model.get_ref("tok2vec") + lower_pad = model.get_param("lower_pad") + lower_W = model.get_param("lower_W") + lower_b = model.get_param("lower_b") + upper_W = model.get_param("upper_W") + upper_b = model.get_param("upper_b") + nH = model.get_dim("nH") + nP = model.get_dim("nP") + nO = model.get_dim("nO") + nI = model.get_dim("nI") + + ops = model.ops + docs, moves = docs_moves + states = moves.init_batch(docs) + tokvecs, backprop_tok2vec = tok2vec(docs, is_train) + all_ids = [] + all_which = [] + all_statevecs = [] + all_scores = [] + all_tokfeats = [] + next_states = [s for s in states if not s.is_final()] + unseen_mask = _get_unseen_mask(model) + assert unseen_mask.all() # TODO unhack + ids = numpy.zeros((len(states), nF), dtype="i") + while next_states: + ids = ids[: len(next_states)] + for i, state in enumerate(next_states): + state.set_context_tokens(ids, i, nF) + # Sum the state features, add the bias and apply the activation (maxout) + # to create the state vectors. + tokfeats3f = model.ops.alloc3f(ids.shape[0], nF, nI) + for i in range(ids.shape[0]): + for j in range(nF): + if ids[i, j] == -1: + tokfeats3f[i, j] = lower_pad + else: + tokfeats3f[i, j] = tokvecs[ids[i, j]] + tokfeats = model.ops.reshape2f(tokfeats3f, tokfeats3f.shape[0], -1) + preacts2f = model.ops.gemm(tokfeats, lower_W, trans2=True) + preacts2f += lower_b + preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP) + statevecs, which = ops.maxout(preacts) + # Multiply the state-vector by the scores weights and add the bias, + # to get the logits. + scores = model.ops.gemm(statevecs, upper_W, trans2=True) + scores += upper_b + scores[:, unseen_mask == 0] = model.ops.xp.nanmin(scores) + # Transition the states, filtering out any that are finished. + next_states = moves.transition_states(next_states, scores) + all_scores.append(scores) + if is_train: + # Remember intermediate results for the backprop. + all_tokfeats.append(tokfeats) + all_ids.append(ids.copy()) + all_statevecs.append(statevecs) + all_which.append(which) + + nS = sum(len(s.history) for s in states) + + def backprop_parser(d_states_d_scores): + d_tokvecs = model.ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1]) + ids = model.ops.xp.vstack(all_ids) + which = ops.xp.vstack(all_which) + statevecs = model.ops.xp.vstack(all_statevecs) + tokfeats = model.ops.xp.vstack(all_tokfeats) + _, d_scores = d_states_d_scores + if model.attrs.get("unseen_classes"): + # If we have a negative gradient (i.e. the probability should + # increase) on any classes we filtered out as unseen, mark + # them as seen. + for clas in set(model.attrs["unseen_classes"]): + if (d_scores[:, clas] < 0).any(): + model.attrs["unseen_classes"].remove(clas) + d_scores *= unseen_mask + assert statevecs.shape == (nS, nH), statevecs.shape + assert d_scores.shape == (nS, nO), d_scores.shape + # Calculate the gradients for the parameters of the upper layer. + # The weight gemm is (nS, nO) @ (nS, nH).T + model.inc_grad("upper_b", d_scores.sum(axis=0)) + model.inc_grad("upper_W", model.ops.gemm(d_scores, statevecs, trans1=True)) + # Now calculate d_statevecs, by backproping through the upper linear layer. + # This gemm is (nS, nO) @ (nO, nH) + d_statevecs = model.ops.gemm(d_scores, upper_W) + # Backprop through the maxout activation + d_preacts = model.ops.backprop_maxout(d_statevecs, which, nP) + d_preacts2f = model.ops.reshape2f(d_preacts, d_preacts.shape[0], nH*nP) + # Now increment the gradients for the lower layer. + # The gemm here is (nS, nH*nP) @ (nS, nF*nI) + model.inc_grad("lower_b", d_preacts2f.sum(axis=0)) + model.inc_grad("lower_W", model.ops.gemm(d_preacts2f, tokfeats, trans1=True)) + # Caclulate d_tokfeats + # The gemm here is (nS, nH*nP) @ (nH*nP, nF*nI) + d_tokfeats = model.ops.gemm(d_preacts2f, lower_W) + # Get the gradients of the tokvecs and the padding + d_tokfeats3f = model.ops.reshape3f(d_tokfeats, nS, nF, nI) + d_lower_pad = model.ops.alloc1f(nI) + for i in range(ids.shape[0]): + for j in range(ids.shape[1]): + if ids[i, j] == -1: + d_lower_pad += d_tokfeats3f[i, j] + else: + d_tokvecs[ids[i, j]] += d_tokfeats3f[i, j] + model.inc_grad("lower_pad", d_lower_pad) # We don't need to backprop the summation, because we pass back the IDs instead d_state_features = backprop_feats((d_preacts, all_ids)) ids1d = model.ops.xp.vstack(all_ids).flatten()