Deal with generators in tuplify

This commit is contained in:
Paul O'Leary McCann 2021-05-18 19:55:52 +09:00
parent a7d9c8156d
commit 0620820857

View File

@ -70,6 +70,10 @@ def tuplify(layer1: Model, layer2: Model, *layers) -> Model:
def tuplify_forward(model, X, is_train):
Ys = []
backprops = []
# If the input is a generator we need to unroll it.
# The type check is necessary because arrays etc. are also OK.
if isinstance(X, Generator):
X = list(X)
for layer in model.layers:
Y, backprop = layer(X, is_train)
Ys.append(Y)