mirror of
https://github.com/explosion/spaCy.git
synced 2025-09-21 03:22:37 +03:00
Add basic tuplify init
This commit is contained in:
parent
051715506e
commit
883c137b26
|
@ -62,6 +62,7 @@ def tuplify(layer1: Model, layer2: Model, *layers) -> Model:
|
||||||
return Model(
|
return Model(
|
||||||
"tuple(" + ", ".join(names) + ")",
|
"tuple(" + ", ".join(names) + ")",
|
||||||
tuplify_forward,
|
tuplify_forward,
|
||||||
|
init=tuplify_init,
|
||||||
layers=layers,
|
layers=layers,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -83,6 +84,17 @@ def tuplify_forward(model, X, is_train):
|
||||||
|
|
||||||
return tuple(Ys), backprop_tuplify
|
return tuple(Ys), backprop_tuplify
|
||||||
|
|
||||||
|
#TODO make more robust, see chain
|
||||||
|
def tuplify_init(model, X, Y) -> Model:
|
||||||
|
if X is None and Y is None:
|
||||||
|
for layer in model.layers:
|
||||||
|
layer.initialize()
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
for layer in model.layers:
|
||||||
|
layer.initialize(X=X)
|
||||||
|
return model
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SpanEmbeddings:
|
class SpanEmbeddings:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user