From 44b01d2a875428a81d2fd1c7c370886da1f30279 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Sun, 31 Oct 2021 01:28:20 +0200 Subject: [PATCH] Hackishly fix resizing. 3 failures --- spacy/ml/tb_framework.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py index 207f4bd5d..fa796f21e 100644 --- a/spacy/ml/tb_framework.py +++ b/spacy/ml/tb_framework.py @@ -71,7 +71,12 @@ def resize_output(model: Model, new_nO: int) -> Model: model.attrs["unseen_classes"].add(i) model.set_param("upper_W", new_W) model.set_param("upper_b", new_b) - model.set_dim("nO", new_nO, force=True) + # TODO: Avoid this private intrusion + model._dims["nO"] = new_nO + if model.has_grad("upper_W"): + model.set_grad("upper_W", model.get_param("upper_W") * 0) + if model.has_grad("upper_b"): + model.set_grad("upper_b", model.get_param("upper_b") * 0) return model