Hackishly fix resizing. 3 failures

This commit is contained in:
Matthew Honnibal 2021-10-31 01:28:20 +02:00
parent 931b3e112b
commit 44b01d2a87

View File

@ -71,7 +71,12 @@ def resize_output(model: Model, new_nO: int) -> Model:
model.attrs["unseen_classes"].add(i) model.attrs["unseen_classes"].add(i)
model.set_param("upper_W", new_W) model.set_param("upper_W", new_W)
model.set_param("upper_b", new_b) model.set_param("upper_b", new_b)
model.set_dim("nO", new_nO, force=True) # TODO: Avoid this private intrusion
model._dims["nO"] = new_nO
if model.has_grad("upper_W"):
model.set_grad("upper_W", model.get_param("upper_W") * 0)
if model.has_grad("upper_b"):
model.set_grad("upper_b", model.get_param("upper_b") * 0)
return model return model