mirror of
https://github.com/explosion/spaCy.git
synced 2025-02-11 17:10:36 +03:00
Hackishly fix resizing. 3 failures
This commit is contained in:
parent
931b3e112b
commit
44b01d2a87
|
@ -71,7 +71,12 @@ def resize_output(model: Model, new_nO: int) -> Model:
|
||||||
model.attrs["unseen_classes"].add(i)
|
model.attrs["unseen_classes"].add(i)
|
||||||
model.set_param("upper_W", new_W)
|
model.set_param("upper_W", new_W)
|
||||||
model.set_param("upper_b", new_b)
|
model.set_param("upper_b", new_b)
|
||||||
model.set_dim("nO", new_nO, force=True)
|
# TODO: Avoid this private intrusion
|
||||||
|
model._dims["nO"] = new_nO
|
||||||
|
if model.has_grad("upper_W"):
|
||||||
|
model.set_grad("upper_W", model.get_param("upper_W") * 0)
|
||||||
|
if model.has_grad("upper_b"):
|
||||||
|
model.set_grad("upper_b", model.get_param("upper_b") * 0)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue
Block a user