Add basic tuplify init

This commit is contained in:
Paul O'Leary McCann 2021-05-18 19:53:59 +09:00
parent 051715506e
commit 883c137b26

View File

@ -62,6 +62,7 @@ def tuplify(layer1: Model, layer2: Model, *layers) -> Model:
return Model(
"tuple(" + ", ".join(names) + ")",
tuplify_forward,
init=tuplify_init,
layers=layers,
)
@ -83,6 +84,17 @@ def tuplify_forward(model, X, is_train):
return tuple(Ys), backprop_tuplify
#TODO make more robust, see chain
def tuplify_init(model, X, Y) -> Model:
if X is None and Y is None:
for layer in model.layers:
layer.initialize()
return model
for layer in model.layers:
layer.initialize(X=X)
return model
@dataclass
class SpanEmbeddings: