diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py index 55eaefec9..753c99cb9 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.py @@ -30,7 +30,7 @@ def TransitionModel( return Model( name="parser_model", - forward=_forward_reference, + forward=forward, init=init, layers=[tok2vec_projected], refs={"tok2vec": tok2vec_projected}, @@ -113,7 +113,7 @@ def init( Wu = ops.alloc2f(nO, nH) bu = ops.alloc1f(nO) Wu = zero_init(ops, Wu.shape) - #Wl = zero_init(ops, Wl.shape) + # Wl = zero_init(ops, Wl.shape) Wl = glorot_uniform_init(ops, Wl.shape) padl = uniform_init(ops, padl.shape) # type: ignore # TODO: Experiment with whether better to initialize upper_W @@ -143,12 +143,12 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo docs, moves = docs_moves states = moves.init_batch(docs) tokvecs, backprop_tok2vec = tok2vec(docs, is_train) + tokvecs = model.ops.xp.vstack((tokvecs, lower_pad)) feats, backprop_feats = _forward_precomputable_affine(model, tokvecs, is_train) all_ids = [] all_which = [] all_statevecs = [] all_scores = [] - all_tokfeats = [] next_states = [s for s in states if not s.is_final()] unseen_mask = _get_unseen_mask(model) ids = numpy.zeros((len(states), nF), dtype="i") @@ -157,11 +157,16 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo ids = ids[: len(next_states)] for i, state in enumerate(next_states): state.set_context_tokens(ids, i, nF) - preacts = feats[ids, arange].sum(axis=1) # type: ignore + # Sum the state features, add the bias and apply the activation (maxout) + # to create the state vectors. + preacts2f = feats[ids, arange].sum(axis=1) # type: ignore + preacts2f += lower_b + preacts = model.ops.reshape3f(preacts2f, preacts2f.shape[0], nH, nP) + assert preacts.shape[0] == len(next_states), preacts.shape statevecs, which = ops.maxout(preacts) # Multiply the state-vector by the scores weights and add the bias, # to get the logits. - scores = ops.gemm(statevecs, upper_W, trans2=True) + scores = model.ops.gemm(statevecs, upper_W, trans2=True) scores += upper_b scores[:, unseen_mask == 0] = model.ops.xp.nanmin(scores) # Transition the states, filtering out any that are finished. @@ -169,17 +174,15 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo all_scores.append(scores) if is_train: # Remember intermediate results for the backprop. - all_tokfeats.append(tokfeats) all_ids.append(ids.copy()) all_statevecs.append(statevecs) all_which.append(which) - nS = sum(len(s.history) for s in states) - def backprop_parser(d_states_d_scores): d_tokvecs = model.ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1]) ids = model.ops.xp.vstack(all_ids) which = ops.xp.vstack(all_which) + statevecs = model.ops.xp.vstack(all_statevecs) _, d_scores = d_states_d_scores if model.attrs.get("unseen_classes"): # If we have a negative gradient (i.e. the probability should @@ -189,26 +192,23 @@ def forward(model, docs_moves: Tuple[List[Doc], TransitionSystem], is_train: boo if (d_scores[:, clas] < 0).any(): model.attrs["unseen_classes"].remove(clas) d_scores *= unseen_mask - statevecs = ops.xp.vstack(all_statevecs) - tokfeats = ops.xp.vstack(all_tokfeats) - assert statevecs.shape == (nS, nH), statevecs.shape - assert d_scores.shape == (nS, nO), d_scores.shape # Calculate the gradients for the parameters of the upper layer. + # The weight gemm is (nS, nO) @ (nS, nH).T model.inc_grad("upper_b", d_scores.sum(axis=0)) model.inc_grad("upper_W", model.ops.gemm(d_scores, statevecs, trans1=True)) # Now calculate d_statevecs, by backproping through the upper linear layer. + # This gemm is (nS, nO) @ (nO, nH) d_statevecs = model.ops.gemm(d_scores, upper_W) # Backprop through the maxout activation - d_preacts = model.ops.backprop_maxout(d_statevecs, which, model.get_dim("nP")) - model.inc_grad("lower_b", d_preacts.sum(axis=0)) - model.inc_grad("lower_W", model.ops.gemm(d_preacts, tokfeats, trans1=True)) + d_preacts = model.ops.backprop_maxout(d_statevecs, which, nP) + d_preacts2f = model.ops.reshape2f(d_preacts, d_preacts.shape[0], nH * nP) + model.inc_grad("lower_b", d_preacts2f.sum(axis=0)) # We don't need to backprop the summation, because we pass back the IDs instead - d_state_features = backprop_feats((d_preacts, all_ids)) - ids1d = model.ops.xp.vstack(all_ids).flatten() - d_state_features = d_state_features.reshape((ids1d.size, -1)) - d_tokvecs = model.ops.alloc((tokvecs.shape[0] + 1, tokvecs.shape[1])) - model.ops.scatter_add(d_tokvecs, ids1d, d_state_features) - return (backprop_tok2vec(d_tokvecs), None) + d_state_features = backprop_feats((d_preacts2f, ids)) + d_tokvecs = model.ops.alloc2f(tokvecs.shape[0], tokvecs.shape[1]) + model.ops.scatter_add(d_tokvecs, ids, d_state_features) + model.inc_grad("lower_pad", d_tokvecs[-1]) + return (backprop_tok2vec(d_tokvecs[:-1]), None) return (states, all_scores), backprop_parser @@ -314,7 +314,6 @@ def _forward_reference( return (states, all_scores), backprop_parser - def _get_unseen_mask(model: Model) -> Floats1d: mask = model.ops.alloc1f(model.get_dim("nO")) mask.fill(1) @@ -324,17 +323,18 @@ def _get_unseen_mask(model: Model) -> Floats1d: def _forward_precomputable_affine(model, X: Floats2d, is_train: bool): - - W: Floats4d = model.get_param("lower_W") - pad: Floats4d = model.get_param("lower_pad") + W: Floats2d = model.get_param("lower_W") nF = model.get_dim("nF") nH = model.get_dim("nH") nP = model.get_dim("nP") nI = model.get_dim("nI") + # The weights start out (nH * nP, nF * nI). Transpose and reshape to (nF * nH *nP, nI) + W3f = model.ops.reshape3f(W, nH * nP, nF, nI) + W3f = W3f.transpose((1, 0, 2)) + W2f = model.ops.reshape2f(W3f, nF * nH * nP, nI) assert X.shape == (X.shape[0], nI), X.shape - Yf_ = model.ops.gemm(X, model.ops.reshape2f(W, nF * nH * nP, nI), trans2=True) - Yf = model.ops.reshape4f(Yf_, Yf_.shape[0], nF, nH, nP) - Yf = model.ops.xp.vstack((Yf, pad)) + Yf_ = model.ops.gemm(X, W2f, trans2=True) + Yf = model.ops.reshape3f(Yf_, Yf_.shape[0], nF, nH * nP) def backward(dY_ids: Tuple[Floats3d, Ints2d]): # This backprop is particularly tricky, because we get back a different @@ -351,54 +351,15 @@ def _forward_precomputable_affine(model, X: Floats2d, is_train: bool): # However, we avoid building that array for efficiency -- and just pass # in the indices. dY, ids = dY_ids - assert dY.ndim == 3 - assert dY.shape[1] == nH, dY.shape - assert dY.shape[2] == nP, dY.shape - # nB = dY.shape[0] - # model.inc_grad( - # "lower_pad", _backprop_precomputable_affine_padding(model, dY, ids) - # ) - # model.inc_grad("lower_b", dY.sum(axis=0)) # type: ignore - dY = model.ops.reshape2f(dY, dY.shape[0], nH * nP) - Wopfi = W.transpose((1, 2, 0, 3)) - Wopfi = Wopfi.reshape((nH * nP, nF * nI)) - dXf = model.ops.gemm(dY.reshape((dY.shape[0], nH * nP)), Wopfi) - ids1d = model.ops.xp.vstack(ids).flatten() - Xf = model.ops.reshape2f(X[ids1d], -1, nF * nI) - dWopfi = model.ops.gemm(dY, Xf, trans1=True) - dWopfi = dWopfi.reshape((nH, nP, nF, nI)) - # (o, p, f, i) --> (f, o, p, i) - dWopfi = dWopfi.transpose((2, 0, 1, 3)) - model.inc_grad("lower_W", dWopfi) + dXf = model.ops.gemm(dY, W) + Xf = X[ids].reshape((ids.shape[0], -1)) + dW = model.ops.gemm(dY, Xf, trans1=True) + model.inc_grad("lower_W", dW) return model.ops.reshape3f(dXf, dXf.shape[0], nF, nI) return Yf, backward -def _backprop_precomputable_affine_padding(model, dY, ids): - ids = model.ops.xp.vstack(ids) - nB = dY.shape[0] - nF = model.get_dim("nF") - nP = model.get_dim("nP") - nH = model.get_dim("nH") - # Backprop the "padding", used as a filler for missing values. - # Values that are missing are set to -1, and each state vector could - # have multiple missing values. The padding has different values for - # different missing features. The gradient of the padding vector is: - # - # for b in range(nB): - # for f in range(nF): - # if ids[b, f] < 0: - # d_pad[f] += dY[b] - # - # Which can be rewritten as: - # - # (ids < 0).T @ dY - mask = model.ops.asarray(ids < 0, dtype="f") - d_pad = model.ops.gemm(mask, dY.reshape(nB, nH * nP), trans1=True) - return d_pad.reshape((1, nF, nH, nP)) - - def _infer_nO(Y: Optional[Tuple[List[State], List[Floats2d]]]) -> Optional[int]: if Y is None: return None