spaCy/spacy/ml/extract_ngrams.py
Daniël de Kok e2b70df012
Configure isort to use the Black profile, recursively isort the spacy module (#12721)
* Use isort with Black profile

* isort all the things

* Fix import cycles as a result of import sorting

* Add DOCBIN_ALL_ATTRS type definition

* Add isort to requirements

* Remove isort from build dependencies check

* Typo
2023-06-14 17:48:41 +02:00

35 lines
1.2 KiB
Python

from thinc.api import Model
from ..attrs import LOWER
from ..util import registry
@registry.layers("spacy.extract_ngrams.v1")
def extract_ngrams(ngram_size: int, attr: int = LOWER) -> Model:
model: Model = Model("extract_ngrams", forward)
model.attrs["ngram_size"] = ngram_size
model.attrs["attr"] = attr
return model
def forward(model: Model, docs, is_train: bool):
batch_keys = []
batch_vals = []
for doc in docs:
unigrams = model.ops.asarray(doc.to_array([model.attrs["attr"]]))
ngrams = [unigrams]
for n in range(2, model.attrs["ngram_size"] + 1):
ngrams.append(model.ops.ngrams(n, unigrams)) # type: ignore[arg-type]
keys = model.ops.xp.concatenate(ngrams)
keys, vals = model.ops.xp.unique(keys, return_counts=True)
batch_keys.append(keys)
batch_vals.append(vals)
lengths = model.ops.asarray([arr.shape[0] for arr in batch_keys], dtype="int32")
batch_keys = model.ops.xp.concatenate(batch_keys)
batch_vals = model.ops.asarray(model.ops.xp.concatenate(batch_vals), dtype="f")
def backprop(dY):
return []
return (batch_keys, batch_vals, lengths), backprop