Fix parser resizing for cupy (#6758)

This commit is contained in:
Adriane Boyd 2021-01-18 20:43:15 +01:00 committed by GitHub
parent c2a18e4fa3
commit 26c34ab8b0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -185,8 +185,7 @@ def _resize_lower(model, new_nO):
nI = smaller.maybe_get_dim("nI") nI = smaller.maybe_get_dim("nI")
nF = smaller.maybe_get_dim("nF") nF = smaller.maybe_get_dim("nF")
nP = smaller.maybe_get_dim("nP") nP = smaller.maybe_get_dim("nP")
with use_ops("numpy"): larger = _define_lower(nO=new_nO, nI=nI, nF=nF, nP=nP)
larger = _define_lower(nO=new_nO, nI=nI, nF=nF, nP=nP)
# it could be that the model is not initialized yet, then skip this bit # it could be that the model is not initialized yet, then skip this bit
if smaller.has_param("W"): if smaller.has_param("W"):
larger_W = larger.ops.alloc4f(nF, new_nO, nP, nI) larger_W = larger.ops.alloc4f(nF, new_nO, nP, nI)