mirror of
https://github.com/explosion/spaCy.git
synced 2025-03-13 16:05:50 +03:00
Improve indexing on reference implementation
This commit is contained in:
parent
0a7d9278cd
commit
4e894e1076
|
@ -230,6 +230,7 @@ def _forward_reference(model, docs_moves: Tuple[List[Doc], TransitionSystem], is
|
|||
docs, moves = docs_moves
|
||||
states = moves.init_batch(docs)
|
||||
tokvecs, backprop_tok2vec = tok2vec(docs, is_train)
|
||||
tokvecs = model.ops.xp.vstack((tokvecs, lower_pad))
|
||||
all_ids = []
|
||||
all_which = []
|
||||
all_statevecs = []
|
||||
|
@ -244,13 +245,7 @@ def _forward_reference(model, docs_moves: Tuple[List[Doc], TransitionSystem], is
|
|||
state.set_context_tokens(ids, i, nF)
|
||||
# Sum the state features, add the bias and apply the activation (maxout)
|
||||
# to create the state vectors.
|
||||
tokfeats3f = model.ops.alloc3f(ids.shape[0], nF, nI)
|
||||
for i in range(ids.shape[0]):
|
||||
for j in range(nF):
|
||||
if ids[i, j] == -1:
|
||||
tokfeats3f[i, j] = lower_pad
|
||||
else:
|
||||
tokfeats3f[i, j] = tokvecs[ids[i, j]]
|
||||
tokfeats3f = tokvecs[ids]
|
||||
tokfeats = model.ops.reshape2f(tokfeats3f, tokfeats3f.shape[0], -1)
|
||||
preacts2f = model.ops.gemm(tokfeats, lower_W, trans2=True)
|
||||
preacts2f += lower_b
|
||||
|
@ -309,16 +304,9 @@ def _forward_reference(model, docs_moves: Tuple[List[Doc], TransitionSystem], is
|
|||
d_tokfeats = model.ops.gemm(d_preacts2f, lower_W)
|
||||
# Get the gradients of the tokvecs and the padding
|
||||
d_tokfeats3f = model.ops.reshape3f(d_tokfeats, nS, nF, nI)
|
||||
d_lower_pad = model.ops.alloc1f(nI)
|
||||
assert ids.shape[0] == nS
|
||||
for i in range(ids.shape[0]):
|
||||
for j in range(ids.shape[1]):
|
||||
if ids[i, j] == -1:
|
||||
d_lower_pad += d_tokfeats3f[i, j]
|
||||
else:
|
||||
d_tokvecs[ids[i, j]] += d_tokfeats3f[i, j]
|
||||
model.inc_grad("lower_pad", d_lower_pad)
|
||||
return (backprop_tok2vec(d_tokvecs), None)
|
||||
model.ops.scatter_add(d_tokvecs, ids, d_tokfeats3f)
|
||||
model.inc_grad("lower_pad", d_tokvecs[-1])
|
||||
return (backprop_tok2vec(d_tokvecs[:-1]), None)
|
||||
|
||||
return (states, all_scores), backprop_parser
|
||||
|
||||
|
|
Loading…
Reference in New Issue
Block a user