mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Add support for floret vectors (#8909)
* Add support for fasttext-bloom hash-only vectors Overview: * Extend `Vectors` to have two modes: `default` and `ngram` * `default` is the default mode and equivalent to the current `Vectors` * `ngram` supports the hash-only ngram tables from `fasttext-bloom` * Extend `spacy.StaticVectors.v2` to handle both modes with no changes for `default` vectors * Extend `spacy init vectors` to support ngram tables The `ngram` mode **only** supports vector tables produced by this fork of fastText, which adds an option to represent all vectors using only the ngram buckets table and which uses the exact same ngram generation algorithm and hash function (`MurmurHash3_x64_128`). `fasttext-bloom` produces an additional `.hashvec` table, which can be loaded by `spacy init vectors --fasttext-bloom-vectors`. https://github.com/adrianeboyd/fastText/tree/feature/bloom Implementation details: * `Vectors` now includes the `StringStore` as `Vectors.strings` so that the API can stay consistent for both `default` (which can look up from `str` or `int`) and `ngram` (which requires `str` to calculate the ngrams). * In ngram mode `Vectors` uses a default `Vectors` object as a cache since the ngram vectors lookups are relatively expensive. * The default cache size is the same size as the provided ngram vector table. * Once the cache is full, no more entries are added. The user is responsible for managing the cache in cases where the initial documents are not representative of the texts. * The cache can be resized by setting `Vectors.ngram_cache_size` or cleared with `vectors._ngram_cache.clear()`. * The API ends up a bit split between methods for `default` and for `ngram`, so functions that only make sense for `default` or `ngram` include warnings with custom messages suggesting alternatives where possible. * `Vocab.vectors` becomes a property so that the string stores can be synced when assigning vectors to a vocab. * `Vectors` serializes its own config settings as `vectors.cfg`. * The `Vectors` serialization methods have added support for `exclude` so that the `Vocab` can exclude the `Vectors` strings while serializing. Removed: * The `minn` and `maxn` options and related code from `Vocab.get_vector`, which does not work in a meaningful way for default vector tables. * The unused `GlobalRegistry` in `Vectors`. * Refactor to use reduce_mean Refactor to use reduce_mean and remove the ngram vectors cache. * Rename to floret * Rename to floret in error messages * Use --vectors-mode in CLI, vector init * Fix vectors mode in init * Remove unused var * Minor API and docstrings adjustments * Rename `--vectors-mode` to `--mode` in `init vectors` CLI * Rename `Vectors.get_floret_vectors` to `Vectors.get_batch` and support both modes. * Minor updates to Vectors docstrings. * Update API docs for Vectors and init vectors CLI * Update types for StaticVectors
This commit is contained in:
parent
0c97ed2746
commit
c053f158c5
|
@ -20,6 +20,7 @@ def init_vectors_cli(
|
|||
output_dir: Path = Arg(..., help="Pipeline output directory"),
|
||||
prune: int = Opt(-1, "--prune", "-p", help="Optional number of vectors to prune to"),
|
||||
truncate: int = Opt(0, "--truncate", "-t", help="Optional number of vectors to truncate to when reading in vectors file"),
|
||||
mode: str = Opt("default", "--mode", "-m", help="Vectors mode: default or floret"),
|
||||
name: Optional[str] = Opt(None, "--name", "-n", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
|
||||
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||
jsonl_loc: Optional[Path] = Opt(None, "--lexemes-jsonl", "-j", help="Location of JSONL-formatted attributes file", hidden=True),
|
||||
|
@ -34,7 +35,14 @@ def init_vectors_cli(
|
|||
nlp = util.get_lang_class(lang)()
|
||||
if jsonl_loc is not None:
|
||||
update_lexemes(nlp, jsonl_loc)
|
||||
convert_vectors(nlp, vectors_loc, truncate=truncate, prune=prune, name=name)
|
||||
convert_vectors(
|
||||
nlp,
|
||||
vectors_loc,
|
||||
truncate=truncate,
|
||||
prune=prune,
|
||||
name=name,
|
||||
mode=mode,
|
||||
)
|
||||
msg.good(f"Successfully converted {len(nlp.vocab.vectors)} vectors")
|
||||
nlp.to_disk(output_dir)
|
||||
msg.good(
|
||||
|
|
|
@ -27,6 +27,9 @@ def setup_default_warnings():
|
|||
# warn once about lemmatizer without required POS
|
||||
filter_warning("once", error_msg=Warnings.W108)
|
||||
|
||||
# floret vector table cannot be modified
|
||||
filter_warning("once", error_msg="[W114]")
|
||||
|
||||
|
||||
def filter_warning(action: str, error_msg: str):
|
||||
"""Customize how spaCy should handle a certain warning.
|
||||
|
@ -192,6 +195,8 @@ class Warnings:
|
|||
"vectors are not identical to current pipeline vectors.")
|
||||
W114 = ("Using multiprocessing with GPU models is not recommended and may "
|
||||
"lead to errors.")
|
||||
W115 = ("Skipping {method}: the floret vector table cannot be modified. "
|
||||
"Vectors are calculated from character ngrams.")
|
||||
|
||||
|
||||
@add_codes
|
||||
|
@ -518,9 +523,19 @@ class Errors:
|
|||
E199 = ("Unable to merge 0-length span at `doc[{start}:{end}]`.")
|
||||
E200 = ("Can't yet set {attr} from Span. Vote for this feature on the "
|
||||
"issue tracker: http://github.com/explosion/spaCy/issues")
|
||||
E202 = ("Unsupported alignment mode '{mode}'. Supported modes: {modes}.")
|
||||
E202 = ("Unsupported {name} mode '{mode}'. Supported modes: {modes}.")
|
||||
|
||||
# New errors added in v3.x
|
||||
E858 = ("The {mode} vector table does not support this operation. "
|
||||
"{alternative}")
|
||||
E859 = ("The floret vector table cannot be modified.")
|
||||
E860 = ("Can't truncate fasttext-bloom vectors.")
|
||||
E861 = ("No 'keys' should be provided when initializing floret vectors "
|
||||
"with 'minn' and 'maxn'.")
|
||||
E862 = ("'hash_count' must be between 1-4 for floret vectors.")
|
||||
E863 = ("'maxn' must be greater than or equal to 'minn'.")
|
||||
E864 = ("The complete vector table 'data' is required to initialize floret "
|
||||
"vectors.")
|
||||
E865 = ("A SpanGroup is not functional after the corresponding Doc has "
|
||||
"been garbage collected. To keep using the spans, make sure that "
|
||||
"the corresponding Doc object is still available in the scope of "
|
||||
|
|
|
@ -228,6 +228,7 @@ class Language:
|
|||
"vectors": len(self.vocab.vectors),
|
||||
"keys": self.vocab.vectors.n_keys,
|
||||
"name": self.vocab.vectors.name,
|
||||
"mode": self.vocab.vectors.mode,
|
||||
}
|
||||
self._meta["labels"] = dict(self.pipe_labels)
|
||||
# TODO: Adding this back to prevent breaking people's code etc., but
|
||||
|
|
|
@ -1,11 +1,13 @@
|
|||
from typing import List, Tuple, Callable, Optional, cast
|
||||
from typing import List, Tuple, Callable, Optional, Sequence, cast
|
||||
from thinc.initializers import glorot_uniform_init
|
||||
from thinc.util import partial
|
||||
from thinc.types import Ragged, Floats2d, Floats1d
|
||||
from thinc.types import Ragged, Floats2d, Floats1d, Ints1d
|
||||
from thinc.api import Model, Ops, registry
|
||||
|
||||
from ..tokens import Doc
|
||||
from ..errors import Errors
|
||||
from ..vectors import Mode
|
||||
from ..vocab import Vocab
|
||||
|
||||
|
||||
@registry.layers("spacy.StaticVectors.v2")
|
||||
|
@ -34,19 +36,31 @@ def StaticVectors(
|
|||
def forward(
|
||||
model: Model[List[Doc], Ragged], docs: List[Doc], is_train: bool
|
||||
) -> Tuple[Ragged, Callable]:
|
||||
if not sum(len(doc) for doc in docs):
|
||||
token_count = sum(len(doc) for doc in docs)
|
||||
if not token_count:
|
||||
return _handle_empty(model.ops, model.get_dim("nO"))
|
||||
key_attr = model.attrs["key_attr"]
|
||||
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
||||
V = cast(Floats2d, model.ops.asarray(docs[0].vocab.vectors.data))
|
||||
rows = model.ops.flatten(
|
||||
[doc.vocab.vectors.find(keys=doc.to_array(key_attr)) for doc in docs]
|
||||
key_attr: int = model.attrs["key_attr"]
|
||||
keys: Ints1d = model.ops.flatten(
|
||||
cast(Sequence, [doc.to_array(key_attr) for doc in docs])
|
||||
)
|
||||
vocab: Vocab = docs[0].vocab
|
||||
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
||||
if vocab.vectors.mode == Mode.default:
|
||||
V = cast(Floats2d, model.ops.asarray(vocab.vectors.data))
|
||||
rows = vocab.vectors.find(keys=keys)
|
||||
V = model.ops.as_contig(V[rows])
|
||||
elif vocab.vectors.mode == Mode.floret:
|
||||
V = cast(Floats2d, vocab.vectors.get_batch(keys))
|
||||
V = model.ops.as_contig(V)
|
||||
else:
|
||||
raise RuntimeError(Errors.E896)
|
||||
try:
|
||||
vectors_data = model.ops.gemm(model.ops.as_contig(V[rows]), W, trans2=True)
|
||||
vectors_data = model.ops.gemm(V, W, trans2=True)
|
||||
except ValueError:
|
||||
raise RuntimeError(Errors.E896)
|
||||
# Convert negative indices to 0-vectors (TODO: more options for UNK tokens)
|
||||
if vocab.vectors.mode == Mode.default:
|
||||
# Convert negative indices to 0-vectors
|
||||
# TODO: more options for UNK tokens
|
||||
vectors_data[rows < 0] = 0
|
||||
output = Ragged(
|
||||
vectors_data, model.ops.asarray([len(doc) for doc in docs], dtype="i") # type: ignore
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
import pytest
|
||||
import pickle
|
||||
from thinc.api import get_current_ops
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.strings import StringStore
|
||||
from spacy.vectors import Vectors
|
||||
|
||||
from ..util import make_tempdir
|
||||
|
||||
|
@ -129,7 +131,11 @@ def test_serialize_stringstore_roundtrip_disk(strings1, strings2):
|
|||
@pytest.mark.parametrize("strings,lex_attr", test_strings_attrs)
|
||||
def test_pickle_vocab(strings, lex_attr):
|
||||
vocab = Vocab(strings=strings)
|
||||
ops = get_current_ops()
|
||||
vectors = Vectors(data=ops.xp.zeros((10, 10)), mode="floret", hash_count=1)
|
||||
vocab.vectors = vectors
|
||||
vocab[strings[0]].norm_ = lex_attr
|
||||
vocab_pickled = pickle.dumps(vocab)
|
||||
vocab_unpickled = pickle.loads(vocab_pickled)
|
||||
assert vocab.to_bytes() == vocab_unpickled.to_bytes()
|
||||
assert vocab_unpickled.vectors.mode == "floret"
|
||||
|
|
|
@ -1,12 +1,14 @@
|
|||
import pytest
|
||||
import numpy
|
||||
from numpy.testing import assert_allclose, assert_equal
|
||||
from numpy.testing import assert_allclose, assert_equal, assert_almost_equal
|
||||
from thinc.api import get_current_ops
|
||||
from spacy.lang.en import English
|
||||
from spacy.vocab import Vocab
|
||||
from spacy.vectors import Vectors
|
||||
from spacy.tokenizer import Tokenizer
|
||||
from spacy.strings import hash_string # type: ignore
|
||||
from spacy.tokens import Doc
|
||||
from spacy.training.initialize import convert_vectors
|
||||
|
||||
from ..util import add_vecs_to_vocab, get_cosine, make_tempdir
|
||||
|
||||
|
@ -29,22 +31,6 @@ def vectors():
|
|||
]
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ngrams_vectors():
|
||||
return [
|
||||
("apple", OPS.asarray([1, 2, 3])),
|
||||
("app", OPS.asarray([-0.1, -0.2, -0.3])),
|
||||
("ppl", OPS.asarray([-0.2, -0.3, -0.4])),
|
||||
("pl", OPS.asarray([0.7, 0.8, 0.9])),
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def ngrams_vocab(en_vocab, ngrams_vectors):
|
||||
add_vecs_to_vocab(en_vocab, ngrams_vectors)
|
||||
return en_vocab
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def data():
|
||||
return numpy.asarray([[0.0, 1.0, 2.0], [3.0, -2.0, 4.0]], dtype="f")
|
||||
|
@ -125,6 +111,7 @@ def test_init_vectors_with_data(strings, data):
|
|||
def test_init_vectors_with_shape(strings):
|
||||
v = Vectors(shape=(len(strings), 3))
|
||||
assert v.shape == (len(strings), 3)
|
||||
assert v.is_full is False
|
||||
|
||||
|
||||
def test_get_vector(strings, data):
|
||||
|
@ -180,30 +167,6 @@ def test_vectors_token_vector(tokenizer_v, vectors, text):
|
|||
assert all([a == b for a, b in zip(vectors[1][1], doc[2].vector)])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text", ["apple"])
|
||||
def test_vectors__ngrams_word(ngrams_vocab, ngrams_vectors, text):
|
||||
assert list(ngrams_vocab.get_vector(text)) == list(ngrams_vectors[0][1])
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text", ["applpie"])
|
||||
def test_vectors__ngrams_subword(ngrams_vocab, ngrams_vectors, text):
|
||||
truth = list(ngrams_vocab.get_vector(text, 1, 6))
|
||||
test = list(
|
||||
[
|
||||
(
|
||||
ngrams_vectors[1][1][i]
|
||||
+ ngrams_vectors[2][1][i]
|
||||
+ ngrams_vectors[3][1][i]
|
||||
)
|
||||
/ 3
|
||||
for i in range(len(ngrams_vectors[1][1]))
|
||||
]
|
||||
)
|
||||
eps = [abs(truth[i] - test[i]) for i in range(len(truth))]
|
||||
for i in eps:
|
||||
assert i < 1e-6
|
||||
|
||||
|
||||
@pytest.mark.parametrize("text", ["apple", "orange"])
|
||||
def test_vectors_lexeme_vector(vocab, text):
|
||||
lex = vocab[text]
|
||||
|
@ -379,3 +342,178 @@ def test_vector_is_oov():
|
|||
assert vocab["cat"].is_oov is False
|
||||
assert vocab["dog"].is_oov is False
|
||||
assert vocab["hamster"].is_oov is True
|
||||
|
||||
|
||||
def test_init_vectors_unset():
|
||||
v = Vectors(shape=(10, 10))
|
||||
assert v.is_full is False
|
||||
assert v.data.shape == (10, 10)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
v = Vectors(shape=(10, 10), mode="floret")
|
||||
|
||||
v = Vectors(data=OPS.xp.zeros((10, 10)), mode="floret", hash_count=1)
|
||||
assert v.is_full is True
|
||||
|
||||
|
||||
def test_vectors_clear():
|
||||
data = OPS.asarray([[4, 2, 2, 2], [4, 2, 2, 2], [1, 1, 1, 1]], dtype="f")
|
||||
v = Vectors(data=data, keys=["A", "B", "C"])
|
||||
assert v.is_full is True
|
||||
assert hash_string("A") in v
|
||||
v.clear()
|
||||
# no keys
|
||||
assert v.key2row == {}
|
||||
assert list(v) == []
|
||||
assert v.is_full is False
|
||||
assert "A" not in v
|
||||
with pytest.raises(KeyError):
|
||||
v["A"]
|
||||
|
||||
|
||||
def test_vectors_get_batch():
|
||||
data = OPS.asarray([[4, 2, 2, 2], [4, 2, 2, 2], [1, 1, 1, 1]], dtype="f")
|
||||
v = Vectors(data=data, keys=["A", "B", "C"])
|
||||
# check with mixed int/str keys
|
||||
words = ["C", "B", "A", v.strings["B"]]
|
||||
rows = v.find(keys=words)
|
||||
vecs = OPS.as_contig(v.data[rows])
|
||||
assert_equal(OPS.to_numpy(vecs), OPS.to_numpy(v.get_batch(words)))
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def floret_vectors_hashvec_str():
|
||||
"""The full hashvec table from floret with the settings:
|
||||
bucket 10, dim 10, minn 2, maxn 3, hash count 2, hash seed 2166136261,
|
||||
bow <, eow >"""
|
||||
return """10 10 2 3 2 2166136261 < >
|
||||
0 -2.2611 3.9302 2.6676 -11.233 0.093715 -10.52 -9.6463 -0.11853 2.101 -0.10145
|
||||
1 -3.12 -1.7981 10.7 -6.171 4.4527 10.967 9.073 6.2056 -6.1199 -2.0402
|
||||
2 9.5689 5.6721 -8.4832 -1.2249 2.1871 -3.0264 -2.391 -5.3308 -3.2847 -4.0382
|
||||
3 3.6268 4.2759 -1.7007 1.5002 5.5266 1.8716 -12.063 0.26314 2.7645 2.4929
|
||||
4 -11.683 -7.7068 2.1102 2.214 7.2202 0.69799 3.2173 -5.382 -2.0838 5.0314
|
||||
5 -4.3024 8.0241 2.0714 -1.0174 -0.28369 1.7622 7.8797 -1.7795 6.7541 5.6703
|
||||
6 8.3574 -5.225 8.6529 8.5605 -8.9465 3.767 -5.4636 -1.4635 -0.98947 -0.58025
|
||||
7 -10.01 3.3894 -4.4487 1.1669 -11.904 6.5158 4.3681 0.79913 -6.9131 -8.687
|
||||
8 -5.4576 7.1019 -8.8259 1.7189 4.955 -8.9157 -3.8905 -0.60086 -2.1233 5.892
|
||||
9 8.0678 -4.4142 3.6236 4.5889 -2.7611 2.4455 0.67096 -4.2822 2.0875 4.6274
|
||||
"""
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def floret_vectors_vec_str():
|
||||
"""The top 10 rows from floret with the settings above, to verify
|
||||
that the spacy floret vectors are equivalent to the fasttext static
|
||||
vectors."""
|
||||
return """10 10
|
||||
, -5.7814 2.6918 0.57029 -3.6985 -2.7079 1.4406 1.0084 1.7463 -3.8625 -3.0565
|
||||
. 3.8016 -1.759 0.59118 3.3044 -0.72975 0.45221 -2.1412 -3.8933 -2.1238 -0.47409
|
||||
der 0.08224 2.6601 -1.173 1.1549 -0.42821 -0.097268 -2.5589 -1.609 -0.16968 0.84687
|
||||
die -2.8781 0.082576 1.9286 -0.33279 0.79488 3.36 3.5609 -0.64328 -2.4152 0.17266
|
||||
und 2.1558 1.8606 -1.382 0.45424 -0.65889 1.2706 0.5929 -2.0592 -2.6949 -1.6015
|
||||
" -1.1242 1.4588 -1.6263 1.0382 -2.7609 -0.99794 -0.83478 -1.5711 -1.2137 1.0239
|
||||
in -0.87635 2.0958 4.0018 -2.2473 -1.2429 2.3474 1.8846 0.46521 -0.506 -0.26653
|
||||
von -0.10589 1.196 1.1143 -0.40907 -1.0848 -0.054756 -2.5016 -1.0381 -0.41598 0.36982
|
||||
( 0.59263 2.1856 0.67346 1.0769 1.0701 1.2151 1.718 -3.0441 2.7291 3.719
|
||||
) 0.13812 3.3267 1.657 0.34729 -3.5459 0.72372 0.63034 -1.6145 1.2733 0.37798
|
||||
"""
|
||||
|
||||
|
||||
def test_floret_vectors(floret_vectors_vec_str, floret_vectors_hashvec_str):
|
||||
nlp = English()
|
||||
nlp_plain = English()
|
||||
# load both vec and hashvec tables
|
||||
with make_tempdir() as tmpdir:
|
||||
p = tmpdir / "test.hashvec"
|
||||
with open(p, "w") as fileh:
|
||||
fileh.write(floret_vectors_hashvec_str)
|
||||
convert_vectors(nlp, p, truncate=0, prune=-1, mode="floret")
|
||||
p = tmpdir / "test.vec"
|
||||
with open(p, "w") as fileh:
|
||||
fileh.write(floret_vectors_vec_str)
|
||||
convert_vectors(nlp_plain, p, truncate=0, prune=-1)
|
||||
|
||||
word = "der"
|
||||
# ngrams: full padded word + padded 2-grams + padded 3-grams
|
||||
ngrams = nlp.vocab.vectors._get_ngrams(word)
|
||||
assert ngrams == ["<der>", "<d", "de", "er", "r>", "<de", "der", "er>"]
|
||||
# rows: 2 rows per ngram
|
||||
rows = OPS.xp.asarray(
|
||||
[
|
||||
h % nlp.vocab.vectors.data.shape[0]
|
||||
for ngram in ngrams
|
||||
for h in nlp.vocab.vectors._get_ngram_hashes(ngram)
|
||||
],
|
||||
dtype="uint32",
|
||||
)
|
||||
assert_equal(
|
||||
OPS.to_numpy(rows),
|
||||
numpy.asarray([5, 6, 7, 5, 8, 2, 8, 9, 3, 3, 4, 6, 7, 3, 0, 2]),
|
||||
)
|
||||
assert len(rows) == len(ngrams) * nlp.vocab.vectors.hash_count
|
||||
# all vectors are equivalent for plain static table vs. hash ngrams
|
||||
for word in nlp_plain.vocab.vectors:
|
||||
word = nlp_plain.vocab.strings.as_string(word)
|
||||
assert_almost_equal(
|
||||
nlp.vocab[word].vector, nlp_plain.vocab[word].vector, decimal=3
|
||||
)
|
||||
|
||||
# every word has a vector
|
||||
assert nlp.vocab[word * 5].has_vector
|
||||
|
||||
# check that single and batched vector lookups are identical
|
||||
words = [s for s in nlp_plain.vocab.vectors]
|
||||
single_vecs = OPS.to_numpy(OPS.asarray([nlp.vocab[word].vector for word in words]))
|
||||
batch_vecs = OPS.to_numpy(nlp.vocab.vectors.get_batch(words))
|
||||
assert_equal(single_vecs, batch_vecs)
|
||||
|
||||
# an empty key returns 0s
|
||||
assert_equal(
|
||||
OPS.to_numpy(nlp.vocab[""].vector),
|
||||
numpy.zeros((nlp.vocab.vectors.data.shape[0],)),
|
||||
)
|
||||
# an empty batch returns 0s
|
||||
assert_equal(
|
||||
OPS.to_numpy(nlp.vocab.vectors.get_batch([""])),
|
||||
numpy.zeros((1, nlp.vocab.vectors.data.shape[0])),
|
||||
)
|
||||
# an empty key within a batch returns 0s
|
||||
assert_equal(
|
||||
OPS.to_numpy(nlp.vocab.vectors.get_batch(["a", "", "b"])[1]),
|
||||
numpy.zeros((nlp.vocab.vectors.data.shape[0],)),
|
||||
)
|
||||
|
||||
# the loaded ngram vector table cannot be modified
|
||||
# except for clear: warning, then return without modifications
|
||||
vector = list(range(nlp.vocab.vectors.shape[1]))
|
||||
orig_bytes = nlp.vocab.vectors.to_bytes(exclude=["strings"])
|
||||
with pytest.warns(UserWarning):
|
||||
nlp.vocab.set_vector("the", vector)
|
||||
assert orig_bytes == nlp.vocab.vectors.to_bytes(exclude=["strings"])
|
||||
with pytest.warns(UserWarning):
|
||||
nlp.vocab[word].vector = vector
|
||||
assert orig_bytes == nlp.vocab.vectors.to_bytes(exclude=["strings"])
|
||||
with pytest.warns(UserWarning):
|
||||
nlp.vocab.vectors.add("the", row=6)
|
||||
assert orig_bytes == nlp.vocab.vectors.to_bytes(exclude=["strings"])
|
||||
with pytest.warns(UserWarning):
|
||||
nlp.vocab.vectors.resize(shape=(100, 10))
|
||||
assert orig_bytes == nlp.vocab.vectors.to_bytes(exclude=["strings"])
|
||||
with pytest.raises(ValueError):
|
||||
nlp.vocab.vectors.clear()
|
||||
|
||||
# data and settings are serialized correctly
|
||||
with make_tempdir() as d:
|
||||
nlp.vocab.to_disk(d)
|
||||
vocab_r = Vocab()
|
||||
vocab_r.from_disk(d)
|
||||
assert nlp.vocab.vectors.to_bytes() == vocab_r.vectors.to_bytes()
|
||||
assert_equal(
|
||||
OPS.to_numpy(nlp.vocab.vectors.data), OPS.to_numpy(vocab_r.vectors.data)
|
||||
)
|
||||
assert_equal(nlp.vocab.vectors._get_cfg(), vocab_r.vectors._get_cfg())
|
||||
assert_almost_equal(
|
||||
OPS.to_numpy(nlp.vocab[word].vector),
|
||||
OPS.to_numpy(vocab_r[word].vector),
|
||||
decimal=6,
|
||||
)
|
||||
|
|
|
@ -138,8 +138,8 @@ class Doc:
|
|||
def count_by(
|
||||
self, attr_id: int, exclude: Optional[Any] = ..., counts: Optional[Any] = ...
|
||||
) -> Dict[Any, int]: ...
|
||||
def from_array(self, attrs: List[int], array: Ints2d) -> Doc: ...
|
||||
def to_array(self, py_attr_ids: List[int]) -> numpy.ndarray: ...
|
||||
def from_array(self, attrs: Union[int, str, List[Union[int, str]]], array: Ints2d) -> Doc: ...
|
||||
def to_array(self, py_attr_ids: Union[int, str, List[Union[int, str]]]) -> numpy.ndarray: ...
|
||||
@staticmethod
|
||||
def from_docs(
|
||||
docs: List[Doc],
|
||||
|
|
|
@ -534,7 +534,13 @@ cdef class Doc:
|
|||
kb_id = self.vocab.strings.add(kb_id)
|
||||
alignment_modes = ("strict", "contract", "expand")
|
||||
if alignment_mode not in alignment_modes:
|
||||
raise ValueError(Errors.E202.format(mode=alignment_mode, modes=", ".join(alignment_modes)))
|
||||
raise ValueError(
|
||||
Errors.E202.format(
|
||||
name="alignment",
|
||||
mode=alignment_mode,
|
||||
modes=", ".join(alignment_modes),
|
||||
)
|
||||
)
|
||||
cdef int start = token_by_char(self.c, self.length, start_idx)
|
||||
if start < 0 or (alignment_mode == "strict" and start_idx != self[start].idx):
|
||||
return None
|
||||
|
|
|
@ -13,7 +13,7 @@ import warnings
|
|||
|
||||
from .pretrain import get_tok2vec_ref
|
||||
from ..lookups import Lookups
|
||||
from ..vectors import Vectors
|
||||
from ..vectors import Vectors, Mode as VectorsMode
|
||||
from ..errors import Errors, Warnings
|
||||
from ..schemas import ConfigSchemaTraining
|
||||
from ..util import registry, load_model_from_config, resolve_dot_names, logger
|
||||
|
@ -160,7 +160,13 @@ def load_vectors_into_model(
|
|||
err = ConfigValidationError.from_error(e, title=title, desc=desc)
|
||||
raise err from None
|
||||
|
||||
if len(vectors_nlp.vocab.vectors.keys()) == 0:
|
||||
if (
|
||||
len(vectors_nlp.vocab.vectors.keys()) == 0
|
||||
and vectors_nlp.vocab.vectors.mode != VectorsMode.floret
|
||||
) or (
|
||||
vectors_nlp.vocab.vectors.data.shape[0] == 0
|
||||
and vectors_nlp.vocab.vectors.mode == VectorsMode.floret
|
||||
):
|
||||
logger.warning(Warnings.W112.format(name=name))
|
||||
|
||||
for lex in nlp.vocab:
|
||||
|
@ -197,39 +203,78 @@ def convert_vectors(
|
|||
truncate: int,
|
||||
prune: int,
|
||||
name: Optional[str] = None,
|
||||
mode: str = VectorsMode.default,
|
||||
) -> None:
|
||||
vectors_loc = ensure_path(vectors_loc)
|
||||
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
|
||||
nlp.vocab.vectors = Vectors(data=numpy.load(vectors_loc.open("rb")))
|
||||
nlp.vocab.vectors = Vectors(
|
||||
strings=nlp.vocab.strings, data=numpy.load(vectors_loc.open("rb"))
|
||||
)
|
||||
for lex in nlp.vocab:
|
||||
if lex.rank and lex.rank != OOV_RANK:
|
||||
nlp.vocab.vectors.add(lex.orth, row=lex.rank) # type: ignore[attr-defined]
|
||||
else:
|
||||
if vectors_loc:
|
||||
logger.info(f"Reading vectors from {vectors_loc}")
|
||||
vectors_data, vector_keys = read_vectors(vectors_loc, truncate)
|
||||
vectors_data, vector_keys, floret_settings = read_vectors(
|
||||
vectors_loc,
|
||||
truncate,
|
||||
mode=mode,
|
||||
)
|
||||
logger.info(f"Loaded vectors from {vectors_loc}")
|
||||
else:
|
||||
vectors_data, vector_keys = (None, None)
|
||||
if vector_keys is not None:
|
||||
if vector_keys is not None and mode != VectorsMode.floret:
|
||||
for word in vector_keys:
|
||||
if word not in nlp.vocab:
|
||||
nlp.vocab[word]
|
||||
if vectors_data is not None:
|
||||
nlp.vocab.vectors = Vectors(data=vectors_data, keys=vector_keys)
|
||||
if mode == VectorsMode.floret:
|
||||
nlp.vocab.vectors = Vectors(
|
||||
strings=nlp.vocab.strings,
|
||||
data=vectors_data,
|
||||
**floret_settings,
|
||||
)
|
||||
else:
|
||||
nlp.vocab.vectors = Vectors(
|
||||
strings=nlp.vocab.strings, data=vectors_data, keys=vector_keys
|
||||
)
|
||||
if name is None:
|
||||
# TODO: Is this correct? Does this matter?
|
||||
nlp.vocab.vectors.name = f"{nlp.meta['lang']}_{nlp.meta['name']}.vectors"
|
||||
else:
|
||||
nlp.vocab.vectors.name = name
|
||||
nlp.meta["vectors"]["name"] = nlp.vocab.vectors.name
|
||||
if prune >= 1:
|
||||
if prune >= 1 and mode != VectorsMode.floret:
|
||||
nlp.vocab.prune_vectors(prune)
|
||||
|
||||
|
||||
def read_vectors(vectors_loc: Path, truncate_vectors: int):
|
||||
def read_vectors(
|
||||
vectors_loc: Path, truncate_vectors: int, *, mode: str = VectorsMode.default
|
||||
):
|
||||
f = ensure_shape(vectors_loc)
|
||||
shape = tuple(int(size) for size in next(f).split())
|
||||
header_parts = next(f).split()
|
||||
shape = tuple(int(size) for size in header_parts[:2])
|
||||
floret_settings = {}
|
||||
if mode == VectorsMode.floret:
|
||||
if len(header_parts) != 8:
|
||||
raise ValueError(
|
||||
"Invalid header for floret vectors. "
|
||||
"Expected: bucket dim minn maxn hash_count hash_seed BOW EOW"
|
||||
)
|
||||
floret_settings = {
|
||||
"mode": "floret",
|
||||
"minn": int(header_parts[2]),
|
||||
"maxn": int(header_parts[3]),
|
||||
"hash_count": int(header_parts[4]),
|
||||
"hash_seed": int(header_parts[5]),
|
||||
"bow": header_parts[6],
|
||||
"eow": header_parts[7],
|
||||
}
|
||||
if truncate_vectors >= 1:
|
||||
raise ValueError(Errors.E860)
|
||||
else:
|
||||
assert len(header_parts) == 2
|
||||
if truncate_vectors >= 1:
|
||||
shape = (truncate_vectors, shape[1])
|
||||
vectors_data = numpy.zeros(shape=shape, dtype="f")
|
||||
|
@ -244,7 +289,7 @@ def read_vectors(vectors_loc: Path, truncate_vectors: int):
|
|||
vectors_keys.append(word)
|
||||
if i == truncate_vectors - 1:
|
||||
break
|
||||
return vectors_data, vectors_keys
|
||||
return vectors_data, vectors_keys, floret_settings
|
||||
|
||||
|
||||
def open_file(loc: Union[str, Path]) -> IO:
|
||||
|
@ -271,7 +316,7 @@ def ensure_shape(vectors_loc):
|
|||
lines = open_file(vectors_loc)
|
||||
first_line = next(lines)
|
||||
try:
|
||||
shape = tuple(int(size) for size in first_line.split())
|
||||
shape = tuple(int(size) for size in first_line.split()[:2])
|
||||
except ValueError:
|
||||
shape = None
|
||||
if shape is not None:
|
||||
|
|
|
@ -1,16 +1,23 @@
|
|||
cimport numpy as np
|
||||
from libc.stdint cimport uint32_t
|
||||
from cython.operator cimport dereference as deref
|
||||
from libcpp.set cimport set as cppset
|
||||
from murmurhash.mrmr cimport hash128_x64
|
||||
|
||||
import functools
|
||||
import numpy
|
||||
from typing import cast
|
||||
import warnings
|
||||
from enum import Enum
|
||||
import srsly
|
||||
from thinc.api import get_array_module, get_current_ops
|
||||
from thinc.backends import get_array_ops
|
||||
from thinc.types import Floats2d
|
||||
|
||||
from .strings cimport StringStore
|
||||
|
||||
from .strings import get_string_id
|
||||
from .errors import Errors
|
||||
from .errors import Errors, Warnings
|
||||
from . import util
|
||||
|
||||
|
||||
|
@ -18,18 +25,13 @@ def unpickle_vectors(bytes_data):
|
|||
return Vectors().from_bytes(bytes_data)
|
||||
|
||||
|
||||
class GlobalRegistry:
|
||||
"""Global store of vectors, to avoid repeatedly loading the data."""
|
||||
data = {}
|
||||
class Mode(str, Enum):
|
||||
default = "default"
|
||||
floret = "floret"
|
||||
|
||||
@classmethod
|
||||
def register(cls, name, data):
|
||||
cls.data[name] = data
|
||||
return functools.partial(cls.get, name)
|
||||
|
||||
@classmethod
|
||||
def get(cls, name):
|
||||
return cls.data[name]
|
||||
def values(cls):
|
||||
return list(cls.__members__.keys())
|
||||
|
||||
|
||||
cdef class Vectors:
|
||||
|
@ -37,45 +39,93 @@ cdef class Vectors:
|
|||
|
||||
Vectors data is kept in the vectors.data attribute, which should be an
|
||||
instance of numpy.ndarray (for CPU vectors) or cupy.ndarray
|
||||
(for GPU vectors). `vectors.key2row` is a dictionary mapping word hashes to
|
||||
rows in the vectors.data table.
|
||||
(for GPU vectors).
|
||||
|
||||
Multiple keys can be mapped to the same vector, and not all of the rows in
|
||||
the table need to be assigned - so len(list(vectors.keys())) may be
|
||||
greater or smaller than vectors.shape[0].
|
||||
In the default mode, `vectors.key2row` is a dictionary mapping word hashes
|
||||
to rows in the vectors.data table. Multiple keys can be mapped to the same
|
||||
vector, and not all of the rows in the table need to be assigned - so
|
||||
len(list(vectors.keys())) may be greater or smaller than vectors.shape[0].
|
||||
|
||||
In floret mode, the floret settings (minn, maxn, etc.) are used to
|
||||
calculate the vector from the rows corresponding to the key's ngrams.
|
||||
|
||||
DOCS: https://spacy.io/api/vectors
|
||||
"""
|
||||
cdef public object strings
|
||||
cdef public object name
|
||||
cdef readonly object mode
|
||||
cdef public object data
|
||||
cdef public object key2row
|
||||
cdef cppset[int] _unset
|
||||
cdef readonly uint32_t minn
|
||||
cdef readonly uint32_t maxn
|
||||
cdef readonly uint32_t hash_count
|
||||
cdef readonly uint32_t hash_seed
|
||||
cdef readonly unicode bow
|
||||
cdef readonly unicode eow
|
||||
|
||||
def __init__(self, *, shape=None, data=None, keys=None, name=None):
|
||||
def __init__(self, *, strings=None, shape=None, data=None, keys=None, name=None, mode=Mode.default, minn=0, maxn=0, hash_count=1, hash_seed=0, bow="<", eow=">"):
|
||||
"""Create a new vector store.
|
||||
|
||||
strings (StringStore): The string store.
|
||||
shape (tuple): Size of the table, as (# entries, # columns)
|
||||
data (numpy.ndarray or cupy.ndarray): The vector data.
|
||||
keys (iterable): A sequence of keys, aligned with the data.
|
||||
name (str): A name to identify the vectors table.
|
||||
mode (str): Vectors mode: "default" or "floret" (default: "default").
|
||||
minn (int): The floret char ngram minn (default: 0).
|
||||
maxn (int): The floret char ngram maxn (default: 0).
|
||||
hash_count (int): The floret hash count (1-4, default: 1).
|
||||
hash_seed (int): The floret hash seed (default: 0).
|
||||
bow (str): The floret BOW string (default: "<").
|
||||
eow (str): The floret EOW string (default: ">").
|
||||
|
||||
DOCS: https://spacy.io/api/vectors#init
|
||||
"""
|
||||
self.strings = strings
|
||||
if self.strings is None:
|
||||
self.strings = StringStore()
|
||||
self.name = name
|
||||
if mode not in Mode.values():
|
||||
raise ValueError(
|
||||
Errors.E202.format(
|
||||
name="vectors",
|
||||
mode=mode,
|
||||
modes=str(Mode.values())
|
||||
)
|
||||
)
|
||||
self.mode = Mode(mode).value
|
||||
self.key2row = {}
|
||||
self.minn = minn
|
||||
self.maxn = maxn
|
||||
self.hash_count = hash_count
|
||||
self.hash_seed = hash_seed
|
||||
self.bow = bow
|
||||
self.eow = eow
|
||||
if self.mode == Mode.default:
|
||||
if data is None:
|
||||
if shape is None:
|
||||
shape = (0,0)
|
||||
ops = get_current_ops()
|
||||
data = ops.xp.zeros(shape, dtype="f")
|
||||
self.data = data
|
||||
self.key2row = {}
|
||||
if self.data is not None:
|
||||
self._unset = cppset[int]({i for i in range(self.data.shape[0])})
|
||||
self._unset = cppset[int]({i for i in range(data.shape[0])})
|
||||
else:
|
||||
self._unset = cppset[int]()
|
||||
self.data = data
|
||||
if keys is not None:
|
||||
for i, key in enumerate(keys):
|
||||
self.add(key, row=i)
|
||||
elif self.mode == Mode.floret:
|
||||
if maxn < minn:
|
||||
raise ValueError(Errors.E863)
|
||||
if hash_count < 1 or hash_count >= 5:
|
||||
raise ValueError(Errors.E862)
|
||||
if data is None:
|
||||
raise ValueError(Errors.E864)
|
||||
if keys is not None:
|
||||
raise ValueError(Errors.E861)
|
||||
self.data = data
|
||||
self._unset = cppset[int]()
|
||||
|
||||
@property
|
||||
def shape(self):
|
||||
|
@ -106,6 +156,8 @@ cdef class Vectors:
|
|||
|
||||
DOCS: https://spacy.io/api/vectors#is_full
|
||||
"""
|
||||
if self.mode == Mode.floret:
|
||||
return True
|
||||
return self._unset.size() == 0
|
||||
|
||||
@property
|
||||
|
@ -113,7 +165,8 @@ cdef class Vectors:
|
|||
"""Get the number of keys in the table. Note that this is the number
|
||||
of all keys, not just unique vectors.
|
||||
|
||||
RETURNS (int): The number of keys in the table.
|
||||
RETURNS (int): The number of keys in the table for default vectors.
|
||||
For floret vectors, return -1.
|
||||
|
||||
DOCS: https://spacy.io/api/vectors#n_keys
|
||||
"""
|
||||
|
@ -125,25 +178,33 @@ cdef class Vectors:
|
|||
def __getitem__(self, key):
|
||||
"""Get a vector by key. If the key is not found, a KeyError is raised.
|
||||
|
||||
key (int): The key to get the vector for.
|
||||
key (str/int): The key to get the vector for.
|
||||
RETURNS (ndarray): The vector for the key.
|
||||
|
||||
DOCS: https://spacy.io/api/vectors#getitem
|
||||
"""
|
||||
i = self.key2row[key]
|
||||
if self.mode == Mode.default:
|
||||
i = self.key2row.get(get_string_id(key), None)
|
||||
if i is None:
|
||||
raise KeyError(Errors.E058.format(key=key))
|
||||
else:
|
||||
return self.data[i]
|
||||
elif self.mode == Mode.floret:
|
||||
return self.get_batch([key])[0]
|
||||
raise KeyError(Errors.E058.format(key=key))
|
||||
|
||||
def __setitem__(self, key, vector):
|
||||
"""Set a vector for the given key.
|
||||
|
||||
key (int): The key to set the vector for.
|
||||
key (str/int): The key to set the vector for.
|
||||
vector (ndarray): The vector to set.
|
||||
|
||||
DOCS: https://spacy.io/api/vectors#setitem
|
||||
"""
|
||||
if self.mode == Mode.floret:
|
||||
warnings.warn(Warnings.W115.format(method="Vectors.__setitem__"))
|
||||
return
|
||||
key = get_string_id(key)
|
||||
i = self.key2row[key]
|
||||
self.data[i] = vector
|
||||
if self._unset.count(i):
|
||||
|
@ -175,6 +236,9 @@ cdef class Vectors:
|
|||
|
||||
DOCS: https://spacy.io/api/vectors#contains
|
||||
"""
|
||||
if self.mode == Mode.floret:
|
||||
return True
|
||||
else:
|
||||
return key in self.key2row
|
||||
|
||||
def resize(self, shape, inplace=False):
|
||||
|
@ -192,6 +256,9 @@ cdef class Vectors:
|
|||
|
||||
DOCS: https://spacy.io/api/vectors#resize
|
||||
"""
|
||||
if self.mode == Mode.floret:
|
||||
warnings.warn(Warnings.W115.format(method="Vectors.resize"))
|
||||
return -1
|
||||
xp = get_array_module(self.data)
|
||||
if inplace:
|
||||
if shape[1] != self.data.shape[1]:
|
||||
|
@ -244,16 +311,23 @@ cdef class Vectors:
|
|||
def find(self, *, key=None, keys=None, row=None, rows=None):
|
||||
"""Look up one or more keys by row, or vice versa.
|
||||
|
||||
key (str / int): Find the row that the given key points to.
|
||||
key (Union[int, str]): Find the row that the given key points to.
|
||||
Returns int, -1 if missing.
|
||||
keys (iterable): Find rows that the keys point to.
|
||||
keys (Iterable[Union[int, str]]): Find rows that the keys point to.
|
||||
Returns ndarray.
|
||||
row (int): Find the first key that points to the row.
|
||||
Returns int.
|
||||
rows (iterable): Find the keys that point to the rows.
|
||||
rows (Iterable[int]): Find the keys that point to the rows.
|
||||
Returns ndarray.
|
||||
RETURNS: The requested key, keys, row or rows.
|
||||
"""
|
||||
if self.mode == Mode.floret:
|
||||
raise ValueError(
|
||||
Errors.E858.format(
|
||||
mode=self.mode,
|
||||
alternative="Use Vectors[key] instead.",
|
||||
)
|
||||
)
|
||||
if sum(arg is None for arg in (key, keys, row, rows)) != 3:
|
||||
bad_kwargs = {"key": key, "keys": keys, "row": row, "rows": rows}
|
||||
raise ValueError(Errors.E059.format(kwargs=bad_kwargs))
|
||||
|
@ -273,6 +347,67 @@ cdef class Vectors:
|
|||
results = [row2key[row] for row in rows]
|
||||
return xp.asarray(results, dtype="uint64")
|
||||
|
||||
def _get_ngram_hashes(self, unicode s):
|
||||
"""Calculate up to 4 32-bit hash values with MurmurHash3_x64_128 using
|
||||
the floret hash settings.
|
||||
key (str): The string key.
|
||||
RETURNS: A list of the integer hashes.
|
||||
"""
|
||||
cdef uint32_t[4] out
|
||||
chars = s.encode("utf8")
|
||||
cdef char* utf8_string = chars
|
||||
hash128_x64(utf8_string, len(chars), self.hash_seed, &out)
|
||||
rows = [out[i] for i in range(min(self.hash_count, 4))]
|
||||
return rows
|
||||
|
||||
def _get_ngrams(self, unicode key):
|
||||
"""Get all padded ngram strings using the ngram settings.
|
||||
key (str): The string key.
|
||||
RETURNS: A list of the ngram strings for the padded key.
|
||||
"""
|
||||
key = self.bow + key + self.eow
|
||||
ngrams = [key] + [
|
||||
key[start:start+ngram_size]
|
||||
for ngram_size in range(self.minn, self.maxn + 1)
|
||||
for start in range(0, len(key) - ngram_size + 1)
|
||||
]
|
||||
return ngrams
|
||||
|
||||
def get_batch(self, keys):
|
||||
"""Get the vectors for the provided keys efficiently as a batch.
|
||||
keys (Iterable[Union[int, str]]): The keys.
|
||||
RETURNS: The requested vectors from the vector table.
|
||||
"""
|
||||
ops = get_array_ops(self.data)
|
||||
if self.mode == Mode.default:
|
||||
rows = self.find(keys=keys)
|
||||
vecs = self.data[rows]
|
||||
elif self.mode == Mode.floret:
|
||||
keys = [self.strings.as_string(key) for key in keys]
|
||||
if sum(len(key) for key in keys) == 0:
|
||||
return ops.xp.zeros((len(keys), self.data.shape[1]))
|
||||
unique_keys = tuple(set(keys))
|
||||
row_index = {key: i for i, key in enumerate(unique_keys)}
|
||||
rows = [row_index[key] for key in keys]
|
||||
indices = []
|
||||
lengths = []
|
||||
for key in unique_keys:
|
||||
if key == "":
|
||||
ngram_rows = []
|
||||
else:
|
||||
ngram_rows = [
|
||||
h % self.data.shape[0]
|
||||
for ngram in self._get_ngrams(key)
|
||||
for h in self._get_ngram_hashes(ngram)
|
||||
]
|
||||
indices.extend(ngram_rows)
|
||||
lengths.append(len(ngram_rows))
|
||||
indices = ops.asarray(indices, dtype="int32")
|
||||
lengths = ops.asarray(lengths, dtype="int32")
|
||||
vecs = ops.reduce_mean(cast(Floats2d, self.data[indices]), lengths)
|
||||
vecs = vecs[rows]
|
||||
return ops.as_contig(vecs)
|
||||
|
||||
def add(self, key, *, vector=None, row=None):
|
||||
"""Add a key to the table. Keys can be mapped to an existing vector
|
||||
by setting `row`, or a new vector can be added.
|
||||
|
@ -284,6 +419,9 @@ cdef class Vectors:
|
|||
|
||||
DOCS: https://spacy.io/api/vectors#add
|
||||
"""
|
||||
if self.mode == Mode.floret:
|
||||
warnings.warn(Warnings.W115.format(method="Vectors.add"))
|
||||
return -1
|
||||
# use int for all keys and rows in key2row for more efficient access
|
||||
# and serialization
|
||||
key = int(get_string_id(key))
|
||||
|
@ -324,6 +462,11 @@ cdef class Vectors:
|
|||
RETURNS (tuple): The most similar entries as a `(keys, best_rows, scores)`
|
||||
tuple.
|
||||
"""
|
||||
if self.mode == Mode.floret:
|
||||
raise ValueError(Errors.E858.format(
|
||||
mode=self.mode,
|
||||
alternative="",
|
||||
))
|
||||
xp = get_array_module(self.data)
|
||||
filled = sorted(list({row for row in self.key2row.values()}))
|
||||
if len(filled) < n:
|
||||
|
@ -368,7 +511,32 @@ cdef class Vectors:
|
|||
for i in range(len(queries)) ], dtype="uint64")
|
||||
return (keys, best_rows, scores)
|
||||
|
||||
def to_disk(self, path, **kwargs):
|
||||
def _get_cfg(self):
|
||||
if self.mode == Mode.default:
|
||||
return {
|
||||
"mode": Mode(self.mode).value,
|
||||
}
|
||||
elif self.mode == Mode.floret:
|
||||
return {
|
||||
"mode": Mode(self.mode).value,
|
||||
"minn": self.minn,
|
||||
"maxn": self.maxn,
|
||||
"hash_count": self.hash_count,
|
||||
"hash_seed": self.hash_seed,
|
||||
"bow": self.bow,
|
||||
"eow": self.eow,
|
||||
}
|
||||
|
||||
def _set_cfg(self, cfg):
|
||||
self.mode = Mode(cfg.get("mode", Mode.default)).value
|
||||
self.minn = cfg.get("minn", 0)
|
||||
self.maxn = cfg.get("maxn", 0)
|
||||
self.hash_count = cfg.get("hash_count", 0)
|
||||
self.hash_seed = cfg.get("hash_seed", 0)
|
||||
self.bow = cfg.get("bow", "<")
|
||||
self.eow = cfg.get("eow", ">")
|
||||
|
||||
def to_disk(self, path, *, exclude=tuple()):
|
||||
"""Save the current state to a directory.
|
||||
|
||||
path (str / Path): A path to a directory, which will be created if
|
||||
|
@ -390,12 +558,14 @@ cdef class Vectors:
|
|||
save_array(self.data, _file)
|
||||
|
||||
serializers = {
|
||||
"strings": lambda p: self.strings.to_disk(p.with_suffix(".json")),
|
||||
"vectors": lambda p: save_vectors(p),
|
||||
"key2row": lambda p: srsly.write_msgpack(p, self.key2row)
|
||||
"key2row": lambda p: srsly.write_msgpack(p, self.key2row),
|
||||
"vectors.cfg": lambda p: srsly.write_json(p, self._get_cfg()),
|
||||
}
|
||||
return util.to_disk(path, serializers, [])
|
||||
return util.to_disk(path, serializers, exclude)
|
||||
|
||||
def from_disk(self, path, **kwargs):
|
||||
def from_disk(self, path, *, exclude=tuple()):
|
||||
"""Loads state from a directory. Modifies the object in place and
|
||||
returns it.
|
||||
|
||||
|
@ -422,17 +592,23 @@ cdef class Vectors:
|
|||
if path.exists():
|
||||
self.data = ops.xp.load(str(path))
|
||||
|
||||
def load_settings(path):
|
||||
if path.exists():
|
||||
self._set_cfg(srsly.read_json(path))
|
||||
|
||||
serializers = {
|
||||
"strings": lambda p: self.strings.from_disk(p.with_suffix(".json")),
|
||||
"vectors": load_vectors,
|
||||
"keys": load_keys,
|
||||
"key2row": load_key2row,
|
||||
"vectors.cfg": load_settings,
|
||||
}
|
||||
|
||||
util.from_disk(path, serializers, [])
|
||||
util.from_disk(path, serializers, exclude)
|
||||
self._sync_unset()
|
||||
return self
|
||||
|
||||
def to_bytes(self, **kwargs):
|
||||
def to_bytes(self, *, exclude=tuple()):
|
||||
"""Serialize the current state to a binary string.
|
||||
|
||||
exclude (list): String names of serialization fields to exclude.
|
||||
|
@ -447,12 +623,14 @@ cdef class Vectors:
|
|||
return srsly.msgpack_dumps(self.data)
|
||||
|
||||
serializers = {
|
||||
"strings": lambda: self.strings.to_bytes(),
|
||||
"key2row": lambda: srsly.msgpack_dumps(self.key2row),
|
||||
"vectors": serialize_weights
|
||||
"vectors": serialize_weights,
|
||||
"vectors.cfg": lambda: srsly.json_dumps(self._get_cfg()),
|
||||
}
|
||||
return util.to_bytes(serializers, [])
|
||||
return util.to_bytes(serializers, exclude)
|
||||
|
||||
def from_bytes(self, data, **kwargs):
|
||||
def from_bytes(self, data, *, exclude=tuple()):
|
||||
"""Load state from a binary string.
|
||||
|
||||
data (bytes): The data to load from.
|
||||
|
@ -469,13 +647,25 @@ cdef class Vectors:
|
|||
self.data = xp.asarray(srsly.msgpack_loads(b))
|
||||
|
||||
deserializers = {
|
||||
"strings": lambda b: self.strings.from_bytes(b),
|
||||
"key2row": lambda b: self.key2row.update(srsly.msgpack_loads(b)),
|
||||
"vectors": deserialize_weights
|
||||
"vectors": deserialize_weights,
|
||||
"vectors.cfg": lambda b: self._set_cfg(srsly.json_loads(b))
|
||||
}
|
||||
util.from_bytes(data, deserializers, [])
|
||||
util.from_bytes(data, deserializers, exclude)
|
||||
self._sync_unset()
|
||||
return self
|
||||
|
||||
def clear(self):
|
||||
"""Clear all entries in the vector table.
|
||||
|
||||
DOCS: https://spacy.io/api/vectors#clear
|
||||
"""
|
||||
if self.mode == Mode.floret:
|
||||
raise ValueError(Errors.E859)
|
||||
self.key2row = {}
|
||||
self._sync_unset()
|
||||
|
||||
def _sync_unset(self):
|
||||
filled = {row for row in self.key2row.values()}
|
||||
self._unset = cppset[int]({row for row in range(self.data.shape[0]) if row not in filled})
|
||||
|
|
|
@ -27,7 +27,7 @@ cdef class Vocab:
|
|||
cdef Pool mem
|
||||
cdef readonly StringStore strings
|
||||
cdef public Morphology morphology
|
||||
cdef public object vectors
|
||||
cdef public object _vectors
|
||||
cdef public object _lookups
|
||||
cdef public object writing_system
|
||||
cdef public object get_noun_chunks
|
||||
|
|
|
@ -14,7 +14,7 @@ from .attrs cimport LANG, ORTH
|
|||
from .compat import copy_reg
|
||||
from .errors import Errors
|
||||
from .attrs import intify_attrs, NORM, IS_STOP
|
||||
from .vectors import Vectors
|
||||
from .vectors import Vectors, Mode as VectorsMode
|
||||
from .util import registry
|
||||
from .lookups import Lookups
|
||||
from . import util
|
||||
|
@ -77,11 +77,21 @@ cdef class Vocab:
|
|||
_ = self[string]
|
||||
self.lex_attr_getters = lex_attr_getters
|
||||
self.morphology = Morphology(self.strings)
|
||||
self.vectors = Vectors(name=vectors_name)
|
||||
self.vectors = Vectors(strings=self.strings, name=vectors_name)
|
||||
self.lookups = lookups
|
||||
self.writing_system = writing_system
|
||||
self.get_noun_chunks = get_noun_chunks
|
||||
|
||||
property vectors:
|
||||
def __get__(self):
|
||||
return self._vectors
|
||||
|
||||
def __set__(self, vectors):
|
||||
for s in vectors.strings:
|
||||
self.strings.add(s)
|
||||
self._vectors = vectors
|
||||
self._vectors.strings = self.strings
|
||||
|
||||
@property
|
||||
def lang(self):
|
||||
langfunc = None
|
||||
|
@ -282,10 +292,10 @@ cdef class Vocab:
|
|||
if width is not None and shape is not None:
|
||||
raise ValueError(Errors.E065.format(width=width, shape=shape))
|
||||
elif shape is not None:
|
||||
self.vectors = Vectors(shape=shape)
|
||||
self.vectors = Vectors(strings=self.strings, shape=shape)
|
||||
else:
|
||||
width = width if width is not None else self.vectors.data.shape[1]
|
||||
self.vectors = Vectors(shape=(self.vectors.shape[0], width))
|
||||
self.vectors = Vectors(strings=self.strings, shape=(self.vectors.shape[0], width))
|
||||
|
||||
def prune_vectors(self, nr_row, batch_size=1024):
|
||||
"""Reduce the current vector table to `nr_row` unique entries. Words
|
||||
|
@ -314,6 +324,8 @@ cdef class Vocab:
|
|||
|
||||
DOCS: https://spacy.io/api/vocab#prune_vectors
|
||||
"""
|
||||
if self.vectors.mode != VectorsMode.default:
|
||||
raise ValueError(Errors.E866)
|
||||
ops = get_current_ops()
|
||||
xp = get_array_module(self.vectors.data)
|
||||
# Make sure all vectors are in the vocab
|
||||
|
@ -328,7 +340,7 @@ cdef class Vocab:
|
|||
keys = xp.asarray([key for (prob, i, key) in priority], dtype="uint64")
|
||||
keep = xp.ascontiguousarray(self.vectors.data[indices[:nr_row]])
|
||||
toss = xp.ascontiguousarray(self.vectors.data[indices[nr_row:]])
|
||||
self.vectors = Vectors(data=keep, keys=keys[:nr_row], name=self.vectors.name)
|
||||
self.vectors = Vectors(strings=self.strings, data=keep, keys=keys[:nr_row], name=self.vectors.name)
|
||||
syn_keys, syn_rows, scores = self.vectors.most_similar(toss, batch_size=batch_size)
|
||||
syn_keys = ops.to_numpy(syn_keys)
|
||||
remap = {}
|
||||
|
@ -340,19 +352,12 @@ cdef class Vocab:
|
|||
remap[word] = (synonym, score)
|
||||
return remap
|
||||
|
||||
def get_vector(self, orth, minn=None, maxn=None):
|
||||
def get_vector(self, orth):
|
||||
"""Retrieve a vector for a word in the vocabulary. Words can be looked
|
||||
up by string or int ID. If no vectors data is loaded, ValueError is
|
||||
raised.
|
||||
|
||||
If `minn` is defined, then the resulting vector uses Fasttext's
|
||||
subword features by average over ngrams of `orth`.
|
||||
|
||||
orth (int / str): The hash value of a word, or its unicode string.
|
||||
minn (int): Minimum n-gram length used for Fasttext's ngram computation.
|
||||
Defaults to the length of `orth`.
|
||||
maxn (int): Maximum n-gram length used for Fasttext's ngram computation.
|
||||
Defaults to the length of `orth`.
|
||||
orth (int / unicode): The hash value of a word, or its unicode string.
|
||||
RETURNS (numpy.ndarray or cupy.ndarray): A word vector. Size
|
||||
and shape determined by the `vocab.vectors` instance. Usually, a
|
||||
numpy ndarray of shape (300,) and dtype float32.
|
||||
|
@ -361,40 +366,10 @@ cdef class Vocab:
|
|||
"""
|
||||
if isinstance(orth, str):
|
||||
orth = self.strings.add(orth)
|
||||
word = self[orth].orth_
|
||||
if orth in self.vectors.key2row:
|
||||
if self.has_vector(orth):
|
||||
return self.vectors[orth]
|
||||
xp = get_array_module(self.vectors.data)
|
||||
vectors = xp.zeros((self.vectors_length,), dtype="f")
|
||||
if minn is None:
|
||||
return vectors
|
||||
# Fasttext's ngram computation taken from
|
||||
# https://github.com/facebookresearch/fastText
|
||||
# Assign default ngram limit to maxn which is the length of the word.
|
||||
if maxn is None:
|
||||
maxn = len(word)
|
||||
ngrams_size = 0;
|
||||
for i in range(len(word)):
|
||||
ngram = ""
|
||||
if (word[i] and 0xC0) == 0x80:
|
||||
continue
|
||||
n = 1
|
||||
j = i
|
||||
while (j < len(word) and n <= maxn):
|
||||
if n > maxn:
|
||||
break
|
||||
ngram += word[j]
|
||||
j = j + 1
|
||||
while (j < len(word) and (word[j] and 0xC0) == 0x80):
|
||||
ngram += word[j]
|
||||
j = j + 1
|
||||
if (n >= minn and not (n == 1 and (i == 0 or j == len(word)))):
|
||||
if self.strings[ngram] in self.vectors.key2row:
|
||||
vectors = xp.add(self.vectors[self.strings[ngram]], vectors)
|
||||
ngrams_size += 1
|
||||
n = n + 1
|
||||
if ngrams_size > 0:
|
||||
vectors = vectors * (1.0/ngrams_size)
|
||||
return vectors
|
||||
|
||||
def set_vector(self, orth, vector):
|
||||
|
@ -417,6 +392,7 @@ cdef class Vocab:
|
|||
self.vectors.resize((new_rows, width))
|
||||
lex = self[orth] # Add word to vocab if necessary
|
||||
row = self.vectors.add(orth, vector=vector)
|
||||
if row >= 0:
|
||||
lex.rank = row
|
||||
|
||||
def has_vector(self, orth):
|
||||
|
@ -461,7 +437,7 @@ cdef class Vocab:
|
|||
if "strings" not in exclude:
|
||||
self.strings.to_disk(path / "strings.json")
|
||||
if "vectors" not in "exclude":
|
||||
self.vectors.to_disk(path)
|
||||
self.vectors.to_disk(path, exclude=["strings"])
|
||||
if "lookups" not in "exclude":
|
||||
self.lookups.to_disk(path)
|
||||
|
||||
|
@ -504,7 +480,7 @@ cdef class Vocab:
|
|||
if self.vectors is None:
|
||||
return None
|
||||
else:
|
||||
return self.vectors.to_bytes()
|
||||
return self.vectors.to_bytes(exclude=["strings"])
|
||||
|
||||
getters = {
|
||||
"strings": lambda: self.strings.to_bytes(),
|
||||
|
@ -526,7 +502,7 @@ cdef class Vocab:
|
|||
if self.vectors is None:
|
||||
return None
|
||||
else:
|
||||
return self.vectors.from_bytes(b)
|
||||
return self.vectors.from_bytes(b, exclude=["strings"])
|
||||
|
||||
setters = {
|
||||
"strings": lambda b: self.strings.from_bytes(b),
|
||||
|
|
|
@ -208,6 +208,7 @@ $ python -m spacy init vectors [lang] [vectors_loc] [output_dir] [--prune] [--tr
|
|||
| `output_dir` | Pipeline output directory. Will be created if it doesn't exist. ~~Path (positional)~~ |
|
||||
| `--truncate`, `-t` | Number of vectors to truncate to when reading in vectors file. Defaults to `0` for no truncation. ~~int (option)~~ |
|
||||
| `--prune`, `-p` | Number of vectors to prune the vocabulary to. Defaults to `-1` for no pruning. ~~int (option)~~ |
|
||||
| `--mode`, `-m` | Vectors mode: `default` or [`floret`](https://github.com/explosion/floret). Defaults to `default`. ~~Optional[str] \(option)~~ |
|
||||
| `--name`, `-n` | Name to assign to the word vectors in the `meta.json`, e.g. `en_core_web_md.vectors`. ~~Optional[str] \(option)~~ |
|
||||
| `--verbose`, `-V` | Print additional information and explanations. ~~bool (flag)~~ |
|
||||
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
||||
|
|
|
@ -8,15 +8,30 @@ new: 2
|
|||
|
||||
Vectors data is kept in the `Vectors.data` attribute, which should be an
|
||||
instance of `numpy.ndarray` (for CPU vectors) or `cupy.ndarray` (for GPU
|
||||
vectors). Multiple keys can be mapped to the same vector, and not all of the
|
||||
rows in the table need to be assigned – so `vectors.n_keys` may be greater or
|
||||
smaller than `vectors.shape[0]`.
|
||||
vectors).
|
||||
|
||||
As of spaCy v3.2, `Vectors` supports two types of vector tables:
|
||||
|
||||
- `default`: A standard vector table (as in spaCy v3.1 and earlier) where each
|
||||
key is mapped to one row in the vector table. Multiple keys can be mapped to
|
||||
the same vector, and not all of the rows in the table need to be assigned – so
|
||||
`vectors.n_keys` may be greater or smaller than `vectors.shape[0]`.
|
||||
- `floret`: Only supports vectors trained with
|
||||
[floret](https://github.com/explosion/floret), an extended version of
|
||||
[fastText](https://fasttext.cc) that produces compact vector tables by
|
||||
combining fastText's subword ngrams with Bloom embeddings. The compact tables
|
||||
are similar to the [`HashEmbed`](https://thinc.ai/docs/api-layers#hashembed)
|
||||
embeddings already used in many spaCy components. Each word is represented as
|
||||
the sum of one or more rows as determined by the settings related to character
|
||||
ngrams and the hash table.
|
||||
|
||||
## Vectors.\_\_init\_\_ {#init tag="method"}
|
||||
|
||||
Create a new vector store. You can set the vector values and keys directly on
|
||||
initialization, or supply a `shape` keyword argument to create an empty table
|
||||
you can add vectors to later.
|
||||
Create a new vector store. With the default mode, you can set the vector values
|
||||
and keys directly on initialization, or supply a `shape` keyword argument to
|
||||
create an empty table you can add vectors to later. In floret mode, the complete
|
||||
vector data and settings must be provided on initialization and cannot be
|
||||
modified later.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -31,12 +46,20 @@ you can add vectors to later.
|
|||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| -------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| ----------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| _keyword-only_ | |
|
||||
| `strings` | The string store. A new string store is created if one is not provided. Defaults to `None`. ~~Optional[StringStore]~~ |
|
||||
| `shape` | Size of the table as `(n_entries, n_columns)`, the number of entries and number of columns. Not required if you're initializing the object with `data` and `keys`. ~~Tuple[int, int]~~ |
|
||||
| `data` | The vector data. ~~numpy.ndarray[ndim=1, dtype=float32]~~ |
|
||||
| `keys` | A sequence of keys aligned with the data. ~~Iterable[Union[str, int]]~~ |
|
||||
| `name` | A name to identify the vectors table. ~~str~~ |
|
||||
| `mode` <Tag variant="new">3.2</Tag> | Vectors mode: `"default"` or [`"floret"`](https://github.com/explosion/floret) (default: `"default"`). ~~str~~ |
|
||||
| `minn` <Tag variant="new">3.2</Tag> | The floret char ngram minn (default: `0`). ~~int~~ |
|
||||
| `maxn` <Tag variant="new">3.2</Tag> | The floret char ngram maxn (default: `0`). ~~int~~ |
|
||||
| `hash_count` <Tag variant="new">3.2</Tag> | The floret hash count. Supported values: 1--4 (default: `1`). ~~int~~ |
|
||||
| `hash_seed` <Tag variant="new">3.2</Tag> | The floret hash seed (default: `0`). ~~int~~ |
|
||||
| `bow` <Tag variant="new">3.2</Tag> | The floret BOW string (default: `"<"`). ~~str~~ |
|
||||
| `eow` <Tag variant="new">3.2</Tag> | The floret EOW string (default: `">"`). ~~str~~ |
|
||||
|
||||
## Vectors.\_\_getitem\_\_ {#getitem tag="method"}
|
||||
|
||||
|
@ -53,12 +76,12 @@ raised.
|
|||
|
||||
| Name | Description |
|
||||
| ----------- | ---------------------------------------------------------------- |
|
||||
| `key` | The key to get the vector for. ~~int~~ |
|
||||
| `key` | The key to get the vector for. ~~Union[int, str]~~ |
|
||||
| **RETURNS** | The vector for the key. ~~numpy.ndarray[ndim=1, dtype=float32]~~ |
|
||||
|
||||
## Vectors.\_\_setitem\_\_ {#setitem tag="method"}
|
||||
|
||||
Set a vector for the given key.
|
||||
Set a vector for the given key. Not supported for `floret` mode.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -75,7 +98,8 @@ Set a vector for the given key.
|
|||
|
||||
## Vectors.\_\_iter\_\_ {#iter tag="method"}
|
||||
|
||||
Iterate over the keys in the table.
|
||||
Iterate over the keys in the table. In `floret` mode, the keys table is not
|
||||
used.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -105,7 +129,8 @@ Return the number of vectors in the table.
|
|||
|
||||
## Vectors.\_\_contains\_\_ {#contains tag="method"}
|
||||
|
||||
Check whether a key has been mapped to a vector entry in the table.
|
||||
Check whether a key has been mapped to a vector entry in the table. In `floret`
|
||||
mode, returns `True` for all keys.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -123,11 +148,8 @@ Check whether a key has been mapped to a vector entry in the table.
|
|||
## Vectors.add {#add tag="method"}
|
||||
|
||||
Add a key to the table, optionally setting a vector value as well. Keys can be
|
||||
mapped to an existing vector by setting `row`, or a new vector can be added.
|
||||
When adding string keys, keep in mind that the `Vectors` class itself has no
|
||||
[`StringStore`](/api/stringstore), so you have to store the hash-to-string
|
||||
mapping separately. If you need to manage the strings, you should use the
|
||||
`Vectors` via the [`Vocab`](/api/vocab) class, e.g. `vocab.vectors`.
|
||||
mapped to an existing vector by setting `row`, or a new vector can be added. Not
|
||||
supported for `floret` mode.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -152,7 +174,8 @@ Resize the underlying vectors array. If `inplace=True`, the memory is
|
|||
reallocated. This may cause other references to the data to become invalid, so
|
||||
only use `inplace=True` if you're sure that's what you want. If the number of
|
||||
vectors is reduced, keys mapped to rows that have been deleted are removed.
|
||||
These removed items are returned as a list of `(key, row)` tuples.
|
||||
These removed items are returned as a list of `(key, row)` tuples. Not supported
|
||||
for `floret` mode.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -168,7 +191,8 @@ These removed items are returned as a list of `(key, row)` tuples.
|
|||
|
||||
## Vectors.keys {#keys tag="method"}
|
||||
|
||||
A sequence of the keys in the table.
|
||||
A sequence of the keys in the table. In `floret` mode, the keys table is not
|
||||
used.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -185,7 +209,7 @@ A sequence of the keys in the table.
|
|||
|
||||
Iterate over vectors that have been assigned to at least one key. Note that some
|
||||
vectors may be unassigned, so the number of vectors returned may be less than
|
||||
the length of the vectors table.
|
||||
the length of the vectors table. In `floret` mode, the keys table is not used.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -200,7 +224,8 @@ the length of the vectors table.
|
|||
|
||||
## Vectors.items {#items tag="method"}
|
||||
|
||||
Iterate over `(key, vector)` pairs, in order.
|
||||
Iterate over `(key, vector)` pairs, in order. In `floret` mode, the keys table
|
||||
is empty.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -215,7 +240,7 @@ Iterate over `(key, vector)` pairs, in order.
|
|||
|
||||
## Vectors.find {#find tag="method"}
|
||||
|
||||
Look up one or more keys by row, or vice versa.
|
||||
Look up one or more keys by row, or vice versa. Not supported for `floret` mode.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -273,7 +298,8 @@ The vector size, i.e. `rows * dims`.
|
|||
|
||||
Whether the vectors table is full and has no slots are available for new keys.
|
||||
If a table is full, it can be resized using
|
||||
[`Vectors.resize`](/api/vectors#resize).
|
||||
[`Vectors.resize`](/api/vectors#resize). In `floret` mode, the table is always
|
||||
full and cannot be resized.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -291,7 +317,7 @@ If a table is full, it can be resized using
|
|||
|
||||
Get the number of keys in the table. Note that this is the number of _all_ keys,
|
||||
not just unique vectors. If several keys are mapped to the same vectors, they
|
||||
will be counted individually.
|
||||
will be counted individually. In `floret` mode, the keys table is not used.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -311,7 +337,8 @@ For each of the given vectors, find the `n` most similar entries to it by
|
|||
cosine. Queries are by vector. Results are returned as a
|
||||
`(keys, best_rows, scores)` tuple. If `queries` is large, the calculations are
|
||||
performed in chunks to avoid consuming too much memory. You can set the
|
||||
`batch_size` to control the size/space trade-off during the calculations.
|
||||
`batch_size` to control the size/space trade-off during the calculations. Not
|
||||
supported for `floret` mode.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
|
@ -329,6 +356,21 @@ performed in chunks to avoid consuming too much memory. You can set the
|
|||
| `sort` | Whether to sort the entries returned by score. Defaults to `True`. ~~bool~~ |
|
||||
| **RETURNS** | tuple | The most similar entries as a `(keys, best_rows, scores)` tuple. ~~Tuple[numpy.ndarray, numpy.ndarray, numpy.ndarray]~~ |
|
||||
|
||||
## Vectors.get_batch {#get_batch tag="method" new="3.2"}
|
||||
|
||||
Get the vectors for the provided keys efficiently as a batch.
|
||||
|
||||
> #### Example
|
||||
>
|
||||
> ```python
|
||||
> words = ["cat", "dog"]
|
||||
> vectors = nlp.vocab.vectors.get_batch(words)
|
||||
> ```
|
||||
|
||||
| Name | Description |
|
||||
| ------ | --------------------------------------- |
|
||||
| `keys` | The keys. ~~Iterable[Union[int, str]]~~ |
|
||||
|
||||
## Vectors.to_disk {#to_disk tag="method"}
|
||||
|
||||
Save the current state to a directory.
|
||||
|
|
Loading…
Reference in New Issue
Block a user