Try to fix listener

This commit is contained in:
Matthew Honnibal 2020-08-28 16:08:00 +02:00
parent ef9888c1f7
commit 1b8d2ed14f

View File

@ -291,8 +291,14 @@ class Tok2VecListener(Model):
def forward(model: Tok2VecListener, inputs, is_train: bool): def forward(model: Tok2VecListener, inputs, is_train: bool):
"""Supply the outputs from the upstream Tok2Vec component.""" """Supply the outputs from the upstream Tok2Vec component."""
if is_train: if is_train:
model.verify_inputs(inputs) model.verify_inputs(inputs)
return model._outputs, model._backprop return model._outputs, model._backprop
else: else:
return [doc.tensor for doc in inputs], lambda dX: [] if model._outputs is None:
outputs = [model.ops.alloc2f(len(doc), width) for doc in inputs]
else:
outputs = model._outputs
model._outputs = None
return outputs, lambda dX: []