mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-30 20:06:30 +03:00
8729307e67
* register extract_ngrams layer * fix import * bump spacy-legacy to 3.0.6 * revert bump (wrong PR)
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
from thinc.api import Model
|
|
|
|
from ..util import registry
|
|
from ..attrs import LOWER
|
|
|
|
|
|
@registry.layers("spacy.extract_ngrams.v1")
|
|
def extract_ngrams(ngram_size: int, attr: int = LOWER) -> Model:
|
|
model = Model("extract_ngrams", forward)
|
|
model.attrs["ngram_size"] = ngram_size
|
|
model.attrs["attr"] = attr
|
|
return model
|
|
|
|
|
|
def forward(model: Model, docs, is_train: bool):
|
|
batch_keys = []
|
|
batch_vals = []
|
|
for doc in docs:
|
|
unigrams = model.ops.asarray(doc.to_array([model.attrs["attr"]]))
|
|
ngrams = [unigrams]
|
|
for n in range(2, model.attrs["ngram_size"] + 1):
|
|
ngrams.append(model.ops.ngrams(n, unigrams))
|
|
keys = model.ops.xp.concatenate(ngrams)
|
|
keys, vals = model.ops.xp.unique(keys, return_counts=True)
|
|
batch_keys.append(keys)
|
|
batch_vals.append(vals)
|
|
lengths = model.ops.asarray([arr.shape[0] for arr in batch_keys], dtype="int32")
|
|
batch_keys = model.ops.xp.concatenate(batch_keys)
|
|
batch_vals = model.ops.asarray(model.ops.xp.concatenate(batch_vals), dtype="f")
|
|
|
|
def backprop(dY):
|
|
return []
|
|
|
|
return (batch_keys, batch_vals, lengths), backprop
|