spaCy/spacy/ml/staticvectors.py
Daniël de Kok e2b70df012
Configure isort to use the Black profile, recursively isort the spacy module (#12721)
* Use isort with Black profile

* isort all the things

* Fix import cycles as a result of import sorting

* Add DOCBIN_ALL_ATTRS type definition

* Add isort to requirements

* Remove isort from build dependencies check

* Typo
2023-06-14 17:48:41 +02:00

119 lines
3.8 KiB
Python

from typing import Callable, List, Optional, Sequence, Tuple, cast
from thinc.api import Model, Ops, registry
from thinc.initializers import glorot_uniform_init
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
from thinc.util import partial
from ..errors import Errors
from ..tokens import Doc
from ..vectors import Mode
from ..vocab import Vocab
@registry.layers("spacy.StaticVectors.v2")
def StaticVectors(
nO: Optional[int] = None,
nM: Optional[int] = None,
*,
dropout: Optional[float] = None,
init_W: Callable = glorot_uniform_init,
key_attr: str = "ORTH"
) -> Model[List[Doc], Ragged]:
"""Embed Doc objects with their vocab's vectors table, applying a learned
linear projection to control the dimensionality. If a dropout rate is
specified, the dropout is applied per dimension over the whole batch.
"""
return Model(
"static_vectors",
forward,
init=partial(init, init_W),
params={"W": None},
attrs={"key_attr": key_attr, "dropout_rate": dropout},
dims={"nO": nO, "nM": nM},
)
def forward(
model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
) -> Tuple[Ragged, Callable]:
token_count = sum(len(doc) for doc in docs)
if not token_count:
return _handle_empty(model.ops, model.get_dim("nO"))
key_attr: int = model.attrs["key_attr"]
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
vocab: Vocab = docs[0].vocab
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
if vocab.vectors.mode == Mode.default:
V = model.ops.asarray(vocab.vectors.data)
rows = vocab.vectors.find(keys=keys)
V = model.ops.as_contig(V[rows])
elif vocab.vectors.mode == Mode.floret:
V = vocab.vectors.get_batch(keys)
V = model.ops.as_contig(V)
else:
raise RuntimeError(Errors.E896)
try:
vectors_data = model.ops.gemm(V, W, trans2=True)
except ValueError:
raise RuntimeError(Errors.E896)
if vocab.vectors.mode == Mode.default:
# Convert negative indices to 0-vectors
# TODO: more options for UNK tokens
vectors_data[rows < 0] = 0
output = Ragged(vectors_data, model.ops.asarray1i([len(doc) for doc in docs]))
mask = None
if is_train:
mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate"))
if mask is not None:
output.data *= mask
def backprop(d_output: Ragged) -> List[Doc]:
if mask is not None:
d_output.data *= mask
model.inc_grad(
"W",
model.ops.gemm(
cast(Floats2d, d_output.data),
cast(Floats2d, model.ops.as_contig(V)),
trans1=True,
),
)
return []
return output, backprop
def init(
init_W: Callable,
model: Model[List[Doc], Ragged],
X: Optional[List[Doc]] = None,
Y: Optional[Ragged] = None,
) -> Model[List[Doc], Ragged]:
nM = model.get_dim("nM") if model.has_dim("nM") else None
nO = model.get_dim("nO") if model.has_dim("nO") else None
if X is not None and len(X):
nM = X[0].vocab.vectors.shape[1]
if Y is not None:
nO = Y.data.shape[1]
if nM is None:
raise ValueError(Errors.E905)
if nO is None:
raise ValueError(Errors.E904)
model.set_dim("nM", nM)
model.set_dim("nO", nO)
model.set_param("W", init_W(model.ops, (nO, nM)))
return model
def _handle_empty(ops: Ops, nO: int):
return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: []
def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]:
if rate is not None:
mask = ops.get_dropout_mask((nO,), rate)
return mask # type: ignore
return None