mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-11 00:32:40 +03:00
Fix StaticVectors class
This commit is contained in:
parent
44d350dc94
commit
475d7c1c7c
|
@ -37,15 +37,14 @@ def forward(
|
||||||
if not len(docs):
|
if not len(docs):
|
||||||
return _handle_empty(model.ops, model.get_dim("nO"))
|
return _handle_empty(model.ops, model.get_dim("nO"))
|
||||||
key_attr = model.attrs["key_attr"]
|
key_attr = model.attrs["key_attr"]
|
||||||
W = cast(Floats2d, model.get_param("W"))
|
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
||||||
V = cast(Floats2d, docs[0].vocab.vectors.data)
|
V = cast(Floats2d, docs[0].vocab.vectors.data)
|
||||||
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
|
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
|
||||||
|
|
||||||
rows = model.ops.flatten(
|
rows = model.ops.flatten(
|
||||||
[doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs]
|
[doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs]
|
||||||
)
|
)
|
||||||
output = Ragged(
|
output = Ragged(
|
||||||
model.ops.gemm(V[rows], W, trans2=True),
|
model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True),
|
||||||
model.ops.asarray([len(doc) for doc in docs], dtype="i")
|
model.ops.asarray([len(doc) for doc in docs], dtype="i")
|
||||||
)
|
)
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
|
@ -54,7 +53,14 @@ def forward(
|
||||||
def backprop(d_output: Ragged) -> List[Doc]:
|
def backprop(d_output: Ragged) -> List[Doc]:
|
||||||
if mask is not None:
|
if mask is not None:
|
||||||
d_output.data *= mask
|
d_output.data *= mask
|
||||||
model.inc_grad("W", model.ops.gemm(d_output.data, V[rows], trans1=True))
|
model.inc_grad(
|
||||||
|
"W",
|
||||||
|
model.ops.gemm(
|
||||||
|
d_output.data,
|
||||||
|
model.ops.as_contig(V[rows]),
|
||||||
|
trans1=True
|
||||||
|
)
|
||||||
|
)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
return output, backprop
|
return output, backprop
|
||||||
|
|
Loading…
Reference in New Issue
Block a user