Improve take_vecs implementation

This pulls out references to needed bits so that other parts (the larger
embeddings) can be freed before backprop.
This commit is contained in:
Paul O'Leary McCann 2021-07-05 21:08:42 +09:00
parent 13bef2ddb6
commit eb5820b593

View File

@ -356,9 +356,11 @@ def build_take_vecs() -> Model[SpanEmbeddings, Floats2d]:
def take_vecs_forward(model, inputs: SpanEmbeddings, is_train) -> Floats2d: def take_vecs_forward(model, inputs: SpanEmbeddings, is_train) -> Floats2d:
idxs = inputs.indices
lens = inputs.vectors.lengths
def backprop(dY: Floats2d) -> SpanEmbeddings: def backprop(dY: Floats2d) -> SpanEmbeddings:
vecs = Ragged(dY, inputs.vectors.lengths) vecs = Ragged(dY, lens)
return SpanEmbeddings(inputs.indices, vecs) return SpanEmbeddings(idxs, vecs)
return inputs.vectors.data, backprop return inputs.vectors.data, backprop