From dea702b4b7a6786dc373e16a9a50ccd9070a4c5d Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 31 Oct 2021 01:28:20 +0200 Subject: [PATCH] Hackishly fix resizing. 3 failures --- spacy/ml/tb_framework.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py index cd543131a..9f852c628 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.py @@ -73,7 +73,12 @@ def resize_output(model: Model, new_nO: int) -> Model: model.attrs["unseen_classes"].add(i) model.set_param("upper_W", new_W) model.set_param("upper_b", new_b) - model.set_dim("nO", new_nO, force=True) + # TODO: Avoid this private intrusion + model._dims["nO"] = new_nO + if model.has_grad("upper_W"): + model.set_grad("upper_W", model.get_param("upper_W") * 0) + if model.has_grad("upper_b"): + model.set_grad("upper_b", model.get_param("upper_b") * 0) return model