import warnings from typing import Callable, List, Optional, Sequence, Tuple, cast from thinc.api import Model, Ops, registry from thinc.initializers import glorot_uniform_init from thinc.types import Floats1d, Floats2d, Ints1d, Ragged from thinc.util import partial from ..attrs import ORTH from ..errors import Errors, Warnings from ..tokens import Doc from ..vectors import Mode, Vectors from ..vocab import Vocab @registry.layers("spacy.StaticVectors.v2") def StaticVectors( nO: Optional[int] = None, nM: Optional[int] = None, *, dropout: Optional[float] = None, init_W: Callable = glorot_uniform_init, key_attr: str = "ORTH" ) -> Model[List[Doc], Ragged]: """Embed Doc objects with their vocab's vectors table, applying a learned linear projection to control the dimensionality. If a dropout rate is specified, the dropout is applied per dimension over the whole batch. """ if key_attr != "ORTH": warnings.warn(Warnings.W125, DeprecationWarning) return Model( "static_vectors", forward, init=partial(init, init_W), params={"W": None}, attrs={"key_attr": key_attr, "dropout_rate": dropout}, dims={"nO": nO, "nM": nM}, ) def forward( model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool ) -> Tuple[Ragged, Callable]: token_count = sum(len(doc) for doc in docs) if not token_count: return _handle_empty(model.ops, model.get_dim("nO")) vocab: Vocab = docs[0].vocab key_attr: int = getattr(vocab.vectors, "attr", ORTH) keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs]) W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default: V = model.ops.asarray(vocab.vectors.data) rows = vocab.vectors.find(keys=keys) V = model.ops.as_contig(V[rows]) elif isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.floret: V = vocab.vectors.get_batch(keys) V = model.ops.as_contig(V) elif hasattr(vocab.vectors, "get_batch"): V = vocab.vectors.get_batch(keys) V = model.ops.as_contig(V) else: raise RuntimeError(Errors.E896) try: vectors_data = model.ops.gemm(V, W, trans2=True) except ValueError: raise RuntimeError(Errors.E896) if isinstance(vocab.vectors, Vectors) and vocab.vectors.mode == Mode.default: # Convert negative indices to 0-vectors # TODO: more options for UNK tokens vectors_data[rows < 0] = 0 output = Ragged(vectors_data, model.ops.asarray1i([len(doc) for doc in docs])) mask = None if is_train: mask = _get_drop_mask(model.ops, W.shape[0], model.attrs.get("dropout_rate")) if mask is not None: output.data *= mask def backprop(d_output: Ragged) -> List[Doc]: if mask is not None: d_output.data *= mask model.inc_grad( "W", model.ops.gemm( cast(Floats2d, d_output.data), cast(Floats2d, model.ops.as_contig(V)), trans1=True, ), ) return [] return output, backprop def init( init_W: Callable, model: Model[List[Doc], Ragged], X: Optional[List[Doc]] = None, Y: Optional[Ragged] = None, ) -> Model[List[Doc], Ragged]: nM = model.get_dim("nM") if model.has_dim("nM") else None nO = model.get_dim("nO") if model.has_dim("nO") else None if X is not None and len(X): nM = X[0].vocab.vectors.shape[1] if Y is not None: nO = Y.data.shape[1] if nM is None: raise ValueError(Errors.E905) if nO is None: raise ValueError(Errors.E904) model.set_dim("nM", nM) model.set_dim("nO", nO) model.set_param("W", init_W(model.ops, (nO, nM))) return model def _handle_empty(ops: Ops, nO: int): return Ragged(ops.alloc2f(0, nO), ops.alloc1i(0)), lambda d_ragged: [] def _get_drop_mask(ops: Ops, nO: int, rate: Optional[float]) -> Optional[Floats1d]: if rate is not None: mask = ops.get_dropout_mask((nO,), rate) return mask # type: ignore return None