From 160dbc58eae17ed8ecd25fb498519f664a3241ac Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Mon, 1 Nov 2021 00:23:15 +0100 Subject: [PATCH] Improve indexing on reference implementation --- spacy/ml/tb_framework.py | 22 +++++----------------- 1 file changed, 5 insertions(+), 17 deletions(-) diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py index fba35fbfd..55eaefec9 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.py @@ -233,6 +233,7 @@ def _forward_reference( docs, moves = docs_moves states = moves.init_batch(docs) tokvecs, backprop_tok2vec = tok2vec(docs, is_train) + tokvecs = model.ops.xp.vstack((tokvecs, lower_pad)) all_ids = [] all_which = [] all_statevecs = [] @@ -247,13 +248,7 @@ def _forward_reference( state.set_context_tokens(ids, i, nF) # Sum the state features, add the bias and apply the activation (maxout) # to create the state vectors. - tokfeats3f = model.ops.alloc3f(ids.shape[0], nF, nI) - for i in range(ids.shape[0]): - for j in range(nF): - if ids[i, j] == -1: - tokfeats3f[i, j] = lower_pad - else: - tokfeats3f[i, j] = tokvecs[ids[i, j]] + tokfeats3f = tokvecs[ids] tokfeats = model.ops.reshape2f(tokfeats3f, tokfeats3f.shape[0], -1) preacts2f = model.ops.gemm(tokfeats, lower_W, trans2=True) preacts2f += lower_b @@ -312,16 +307,9 @@ def _forward_reference( d_tokfeats = model.ops.gemm(d_preacts2f, lower_W) # Get the gradients of the tokvecs and the padding d_tokfeats3f = model.ops.reshape3f(d_tokfeats, nS, nF, nI) - d_lower_pad = model.ops.alloc1f(nI) - assert ids.shape[0] == nS - for i in range(ids.shape[0]): - for j in range(ids.shape[1]): - if ids[i, j] == -1: - d_lower_pad += d_tokfeats3f[i, j] - else: - d_tokvecs[ids[i, j]] += d_tokfeats3f[i, j] - model.inc_grad("lower_pad", d_lower_pad) - return (backprop_tok2vec(d_tokvecs), None) + model.ops.scatter_add(d_tokvecs, ids, d_tokfeats3f) + model.inc_grad("lower_pad", d_tokvecs[-1]) + return (backprop_tok2vec(d_tokvecs[:-1]), None) return (states, all_scores), backprop_parser