Fix StaticVectors class

This commit is contained in:
Matthew Honnibal 2020-07-28 15:52:55 +02:00
parent 44d350dc94
commit 475d7c1c7c

View File

@ -37,15 +37,14 @@ def forward(
if not len(docs): if not len(docs):
return _handle_empty(model.ops, model.get_dim("nO")) return _handle_empty(model.ops, model.get_dim("nO"))
key_attr = model.attrs["key_attr"] key_attr = model.attrs["key_attr"]
W = cast(Floats2d, model.get_param("W")) W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
V = cast(Floats2d, docs[0].vocab.vectors.data) V = cast(Floats2d, docs[0].vocab.vectors.data)
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate")) mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
rows = model.ops.flatten( rows = model.ops.flatten(
[doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs] [doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs]
) )
output = Ragged( output = Ragged(
model.ops.gemm(V[rows], W, trans2=True), model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True),
model.ops.asarray([len(doc) for doc in docs], dtype="i") model.ops.asarray([len(doc) for doc in docs], dtype="i")
) )
if mask is not None: if mask is not None:
@ -54,7 +53,14 @@ def forward(
def backprop(d_output: Ragged) -> List[Doc]: def backprop(d_output: Ragged) -> List[Doc]:
if mask is not None: if mask is not None:
d_output.data *= mask d_output.data *= mask
model.inc_grad("W", model.ops.gemm(d_output.data, V[rows], trans1=True)) model.inc_grad(
"W",
model.ops.gemm(
d_output.data,
model.ops.as_contig(V[rows]),
trans1=True
)
)
return [] return []
return output, backprop return output, backprop