from typing import Callable, List, Optional, Sequence, Tuple, cast from thinc.api import Model, Ops, registry from thinc.initializers import glorot_uniform_init from thinc.types import Floats1d, Floats2d, Ints1d, Ragged from thinc.util import partial from ..errors import Errors from ..tokens import Doc from ..vectors import Mode from ..vocab import Vocab @registry.layers("spacy.StaticVectors.v2") def StaticVectors( nO: Optional[int] = None, nM: Optional[int] = None, *, dropout: Optional[float] = None, init_W: Callable = glorot_uniform_init, key_attr: str = "ORTH" ) -> Model[List[Doc], Ragged]: """Embed Doc objects with their vocab's vectors table, applying a learned linear projection to control the dimensionality. If a dropout rate is specified, the dropout is applied per dimension over the whole batch. """ return Model( "static_vectors", forward, init=partial(init, init_W), params={"W": None}, attrs={"key_attr": key_attr, "dropout_rate": dropout}, dims={"nO": nO, "nM": nM}, ) def forward( model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool ) -> Tuple[Ragged, Callable]: token_count = sum(len(doc) for doc in docs) if not token_count: return _handle_empty(model.ops, model.get_dim("nO")) key_attr: int = model.attrs["key_attr"] keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs]) vocab: Vocab = docs[0].vocab W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) if vocab.vectors.mode == Mode.default: V = model.ops.asarray(vocab.vectors.data) rows = vocab.vectors.find(keys=keys) V = model.ops.as_contig(V[rows]) elif vocab.vectors.mode == Mode.floret: V = vocab.vectors.get_batch(keys) V = model.ops.as_contig(V) else: raise RuntimeError(Errors.E896) try: vectors_data = model.ops.gemm(V, W, trans2=True) except ValueError: raise RuntimeError(Errors.E896) if vocab.vectors.mode == Mode.default: # Convert negative indices to 0-vectors # TODO: more options for UNK tokens vectors_data[rows < 0] = 0 output = Ragged(vectors_data, model.ops.asarray1i([len(doc) for doc in docs])) mask = None if is_train: mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate")) if mask is not None: output.data *= mask def backprop(d_output: Ragged) -> List[Doc]: if mask is not None: d_output.data *= mask model.inc_grad( "W", model.ops.gemm( cast(Floats2d, d_output.data), cast(Floats2d, model.ops.as_contig(V)), trans1=True, ), ) return [] return output, backprop def init( init_W: Callable, model: Model[List[Doc], Ragged], X: Optional[List[Doc]] = None, Y: Optional[Ragged] = None, ) -> Model[List[Doc], Ragged]: nM = model.get_dim("nM") if model.has_dim("nM") else None nO = model.get_dim("nO") if model.has_dim("nO") else None if X is not None and len(X): nM = X[0].vocab.vectors.shape[1] if Y is not None: nO = Y.data.shape[1] if nM is None: raise ValueError(Errors.E905) if nO is None: raise ValueError(Errors.E904) model.set_dim("nM", nM) model.set_dim("nO", nO) model.set_param("W", init_W(model.ops, (nO, nM))) return model def _handle_empty(ops: Ops, nO: int): return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: [] def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]: if rate is not None: mask = ops.get_dropout_mask((nO,), rate) return mask # type: ignore return None