Support custom token/lexeme attribute for vectors (#12625)

* Support custom token/lexeme attribute for vectors

* Fix imports

* Back off to ORTH without Vectors.attr

* Fallback if vectors.attr doesn't exist

* Update docs
This commit is contained in:
Adriane Boyd 2023-06-28 09:43:14 +02:00 committed by GitHub
parent 337a360cc7
commit fb0da3e097
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 146 additions and 35 deletions

View File

@ -32,6 +32,7 @@ def init_vectors_cli(
name: Optional[str] = Opt(None, "--name", "-n", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"), name: Optional[str] = Opt(None, "--name", "-n", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"), verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
jsonl_loc: Optional[Path] = Opt(None, "--lexemes-jsonl", "-j", help="Location of JSONL-formatted attributes file", hidden=True), jsonl_loc: Optional[Path] = Opt(None, "--lexemes-jsonl", "-j", help="Location of JSONL-formatted attributes file", hidden=True),
attr: str = Opt("ORTH", "--attr", "-a", help="Optional token attribute to use for vectors, e.g. LOWER or NORM"),
# fmt: on # fmt: on
): ):
"""Convert word vectors for use with spaCy. Will export an nlp object that """Convert word vectors for use with spaCy. Will export an nlp object that
@ -50,6 +51,7 @@ def init_vectors_cli(
prune=prune, prune=prune,
name=name, name=name,
mode=mode, mode=mode,
attr=attr,
) )
msg.good(f"Successfully converted {len(nlp.vocab.vectors)} vectors") msg.good(f"Successfully converted {len(nlp.vocab.vectors)} vectors")
nlp.to_disk(output_dir) nlp.to_disk(output_dir)

View File

@ -216,6 +216,9 @@ class Warnings(metaclass=ErrorsWithCodes):
W123 = ("Argument `enable` with value {enable} does not contain all values specified in the config option " W123 = ("Argument `enable` with value {enable} does not contain all values specified in the config option "
"`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.") "`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.")
W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.") W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.")
W125 = ("The StaticVectors key_attr is no longer used. To set a custom "
"key attribute for vectors, configure it through Vectors(attr=) or "
"'spacy init vectors --attr'")
class Errors(metaclass=ErrorsWithCodes): class Errors(metaclass=ErrorsWithCodes):

View File

@ -1,3 +1,4 @@
import warnings
from typing import Callable, List, Optional, Sequence, Tuple, cast from typing import Callable, List, Optional, Sequence, Tuple, cast
from thinc.api import Model, Ops, registry from thinc.api import Model, Ops, registry
@ -5,7 +6,8 @@ from thinc.initializers import glorot_uniform_init
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
from thinc.util import partial from thinc.util import partial
from ..errors import Errors from ..attrs import ORTH
from ..errors import Errors, Warnings
from ..tokens import Doc from ..tokens import Doc
from ..vectors import Mode from ..vectors import Mode
from ..vocab import Vocab from ..vocab import Vocab
@ -24,6 +26,8 @@ def StaticVectors(
linear projection to control the dimensionality. If a dropout rate is linear projection to control the dimensionality. If a dropout rate is
specified, the dropout is applied per dimension over the whole batch. specified, the dropout is applied per dimension over the whole batch.
""" """
if key_attr != "ORTH":
warnings.warn(Warnings.W125, DeprecationWarning)
return Model( return Model(
"static_vectors", "static_vectors",
forward, forward,
@ -40,9 +44,9 @@ def forward(
token_count = sum(len(doc) for doc in docs) token_count = sum(len(doc) for doc in docs)
if not token_count: if not token_count:
return _handle_empty(model.ops, model.get_dim("nO")) return _handle_empty(model.ops, model.get_dim("nO"))
key_attr: int = model.attrs["key_attr"]
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
vocab: Vocab = docs[0].vocab vocab: Vocab = docs[0].vocab
key_attr: int = getattr(vocab.vectors, "attr", ORTH)
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
if vocab.vectors.mode == Mode.default: if vocab.vectors.mode == Mode.default:
V = model.ops.asarray(vocab.vectors.data) V = model.ops.asarray(vocab.vectors.data)

View File

@ -402,6 +402,7 @@ def test_vectors_serialize():
row_r = v_r.add("D", vector=OPS.asarray([10, 20, 30, 40], dtype="f")) row_r = v_r.add("D", vector=OPS.asarray([10, 20, 30, 40], dtype="f"))
assert row == row_r assert row == row_r
assert_equal(OPS.to_numpy(v.data), OPS.to_numpy(v_r.data)) assert_equal(OPS.to_numpy(v.data), OPS.to_numpy(v_r.data))
assert v.attr == v_r.attr
def test_vector_is_oov(): def test_vector_is_oov():
@ -646,3 +647,32 @@ def test_equality():
vectors1.resize((5, 9)) vectors1.resize((5, 9))
vectors2.resize((5, 9)) vectors2.resize((5, 9))
assert vectors1 == vectors2 assert vectors1 == vectors2
def test_vectors_attr():
data = numpy.asarray([[0, 0, 0], [1, 2, 3], [9, 8, 7]], dtype="f")
# default ORTH
nlp = English()
nlp.vocab.vectors = Vectors(data=data, keys=["A", "B", "C"])
assert nlp.vocab.strings["A"] in nlp.vocab.vectors.key2row
assert nlp.vocab.strings["a"] not in nlp.vocab.vectors.key2row
assert nlp.vocab["A"].has_vector is True
assert nlp.vocab["a"].has_vector is False
assert nlp("A")[0].has_vector is True
assert nlp("a")[0].has_vector is False
# custom LOWER
nlp = English()
nlp.vocab.vectors = Vectors(data=data, keys=["a", "b", "c"], attr="LOWER")
assert nlp.vocab.strings["A"] not in nlp.vocab.vectors.key2row
assert nlp.vocab.strings["a"] in nlp.vocab.vectors.key2row
assert nlp.vocab["A"].has_vector is True
assert nlp.vocab["a"].has_vector is True
assert nlp("A")[0].has_vector is True
assert nlp("a")[0].has_vector is True
# add a new vectors entry
assert nlp.vocab["D"].has_vector is False
assert nlp.vocab["d"].has_vector is False
nlp.vocab.set_vector("D", numpy.asarray([4, 5, 6]))
assert nlp.vocab["D"].has_vector is True
assert nlp.vocab["d"].has_vector is True

View File

@ -35,6 +35,7 @@ from ..attrs cimport (
LENGTH, LENGTH,
MORPH, MORPH,
NORM, NORM,
ORTH,
POS, POS,
SENT_START, SENT_START,
SPACY, SPACY,
@ -613,13 +614,26 @@ cdef class Doc:
""" """
if "similarity" in self.user_hooks: if "similarity" in self.user_hooks:
return self.user_hooks["similarity"](self, other) return self.user_hooks["similarity"](self, other)
if isinstance(other, (Lexeme, Token)) and self.length == 1: attr = getattr(self.vocab.vectors, "attr", ORTH)
if self.c[0].lex.orth == other.orth: cdef Token this_token
cdef Token other_token
cdef Lexeme other_lex
if len(self) == 1 and isinstance(other, Token):
this_token = self[0]
other_token = other
if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr):
return 1.0 return 1.0
elif isinstance(other, (Span, Doc)) and len(self) == len(other): elif len(self) == 1 and isinstance(other, Lexeme):
this_token = self[0]
other_lex = other
if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr):
return 1.0
elif isinstance(other, (Doc, Span)) and len(self) == len(other):
similar = True similar = True
for i in range(self.length): for i in range(len(self)):
if self[i].orth != other[i].orth: this_token = self[i]
other_token = other[i]
if Token.get_struct_attr(this_token.c, attr) != Token.get_struct_attr(other_token.c, attr):
similar = False similar = False
break break
if similar: if similar:

View File

@ -8,13 +8,14 @@ import numpy
from thinc.api import get_array_module from thinc.api import get_array_module
from ..attrs cimport * from ..attrs cimport *
from ..attrs cimport attr_id_t from ..attrs cimport ORTH, attr_id_t
from ..lexeme cimport Lexeme from ..lexeme cimport Lexeme
from ..parts_of_speech cimport univ_pos_t from ..parts_of_speech cimport univ_pos_t
from ..structs cimport LexemeC, TokenC from ..structs cimport LexemeC, TokenC
from ..symbols cimport dep from ..symbols cimport dep
from ..typedefs cimport attr_t, flags_t, hash_t from ..typedefs cimport attr_t, flags_t, hash_t
from .doc cimport _get_lca_matrix, get_token_attr, token_by_end, token_by_start from .doc cimport _get_lca_matrix, get_token_attr, token_by_end, token_by_start
from .token cimport Token
from ..errors import Errors, Warnings from ..errors import Errors, Warnings
from ..util import normalize_slice from ..util import normalize_slice
@ -341,13 +342,26 @@ cdef class Span:
""" """
if "similarity" in self.doc.user_span_hooks: if "similarity" in self.doc.user_span_hooks:
return self.doc.user_span_hooks["similarity"](self, other) return self.doc.user_span_hooks["similarity"](self, other)
if len(self) == 1 and hasattr(other, "orth"): attr = getattr(self.doc.vocab.vectors, "attr", ORTH)
if self[0].orth == other.orth: cdef Token this_token
cdef Token other_token
cdef Lexeme other_lex
if len(self) == 1 and isinstance(other, Token):
this_token = self[0]
other_token = other
if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr):
return 1.0
elif len(self) == 1 and isinstance(other, Lexeme):
this_token = self[0]
other_lex = other
if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr):
return 1.0 return 1.0
elif isinstance(other, (Doc, Span)) and len(self) == len(other): elif isinstance(other, (Doc, Span)) and len(self) == len(other):
similar = True similar = True
for i in range(len(self)): for i in range(len(self)):
if self[i].orth != getattr(other[i], "orth", None): this_token = self[i]
other_token = other[i]
if Token.get_struct_attr(this_token.c, attr) != Token.get_struct_attr(other_token.c, attr):
similar = False similar = False
break break
if similar: if similar:

View File

@ -28,6 +28,7 @@ from ..attrs cimport (
LIKE_EMAIL, LIKE_EMAIL,
LIKE_NUM, LIKE_NUM,
LIKE_URL, LIKE_URL,
ORTH,
) )
from ..lexeme cimport Lexeme from ..lexeme cimport Lexeme
from ..symbols cimport conj from ..symbols cimport conj
@ -214,11 +215,17 @@ cdef class Token:
""" """
if "similarity" in self.doc.user_token_hooks: if "similarity" in self.doc.user_token_hooks:
return self.doc.user_token_hooks["similarity"](self, other) return self.doc.user_token_hooks["similarity"](self, other)
if hasattr(other, "__len__") and len(other) == 1 and hasattr(other, "__getitem__"): attr = getattr(self.doc.vocab.vectors, "attr", ORTH)
if self.c.lex.orth == getattr(other[0], "orth", None): cdef Token this_token = self
cdef Token other_token
cdef Lexeme other_lex
if isinstance(other, Token):
other_token = other
if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr):
return 1.0 return 1.0
elif hasattr(other, "orth"): elif isinstance(other, Lexeme):
if self.c.lex.orth == other.orth: other_lex = other
if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr):
return 1.0 return 1.0
if self.vocab.vectors.n_keys == 0: if self.vocab.vectors.n_keys == 0:
warnings.warn(Warnings.W007.format(obj="Token")) warnings.warn(Warnings.W007.format(obj="Token"))
@ -415,7 +422,7 @@ cdef class Token:
return self.doc.user_token_hooks["has_vector"](self) return self.doc.user_token_hooks["has_vector"](self)
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0: if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
return True return True
return self.vocab.has_vector(self.c.lex.orth) return self.vocab.has_vector(Token.get_struct_attr(self.c, self.vocab.vectors.attr))
@property @property
def vector(self): def vector(self):
@ -431,7 +438,7 @@ cdef class Token:
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0: if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
return self.doc.tensor[self.i] return self.doc.tensor[self.i]
else: else:
return self.vocab.get_vector(self.c.lex.orth) return self.vocab.get_vector(Token.get_struct_attr(self.c, self.vocab.vectors.attr))
@property @property
def vector_norm(self): def vector_norm(self):

View File

@ -216,9 +216,14 @@ def convert_vectors(
prune: int, prune: int,
name: Optional[str] = None, name: Optional[str] = None,
mode: str = VectorsMode.default, mode: str = VectorsMode.default,
attr: str = "ORTH",
) -> None: ) -> None:
vectors_loc = ensure_path(vectors_loc) vectors_loc = ensure_path(vectors_loc)
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"): if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
if attr != "ORTH":
raise ValueError(
"ORTH is the only attribute supported for vectors in .npz format."
)
nlp.vocab.vectors = Vectors( nlp.vocab.vectors = Vectors(
strings=nlp.vocab.strings, data=numpy.load(vectors_loc.open("rb")) strings=nlp.vocab.strings, data=numpy.load(vectors_loc.open("rb"))
) )
@ -246,11 +251,15 @@ def convert_vectors(
nlp.vocab.vectors = Vectors( nlp.vocab.vectors = Vectors(
strings=nlp.vocab.strings, strings=nlp.vocab.strings,
data=vectors_data, data=vectors_data,
attr=attr,
**floret_settings, **floret_settings,
) )
else: else:
nlp.vocab.vectors = Vectors( nlp.vocab.vectors = Vectors(
strings=nlp.vocab.strings, data=vectors_data, keys=vector_keys strings=nlp.vocab.strings,
data=vectors_data,
keys=vector_keys,
attr=attr,
) )
nlp.vocab.deduplicate_vectors() nlp.vocab.deduplicate_vectors()
if name is None: if name is None:

View File

@ -15,9 +15,11 @@ from thinc.api import Ops, get_array_module, get_current_ops
from thinc.backends import get_array_ops from thinc.backends import get_array_ops
from thinc.types import Floats2d from thinc.types import Floats2d
from .attrs cimport ORTH, attr_id_t
from .strings cimport StringStore from .strings cimport StringStore
from . import util from . import util
from .attrs import IDS
from .errors import Errors, Warnings from .errors import Errors, Warnings
from .strings import get_string_id from .strings import get_string_id
@ -64,8 +66,9 @@ cdef class Vectors:
cdef readonly uint32_t hash_seed cdef readonly uint32_t hash_seed
cdef readonly unicode bow cdef readonly unicode bow
cdef readonly unicode eow cdef readonly unicode eow
cdef readonly attr_id_t attr
def __init__(self, *, strings=None, shape=None, data=None, keys=None, name=None, mode=Mode.default, minn=0, maxn=0, hash_count=1, hash_seed=0, bow="<", eow=">"): def __init__(self, *, strings=None, shape=None, data=None, keys=None, name=None, mode=Mode.default, minn=0, maxn=0, hash_count=1, hash_seed=0, bow="<", eow=">", attr="ORTH"):
"""Create a new vector store. """Create a new vector store.
strings (StringStore): The string store. strings (StringStore): The string store.
@ -80,6 +83,8 @@ cdef class Vectors:
hash_seed (int): The floret hash seed (default: 0). hash_seed (int): The floret hash seed (default: 0).
bow (str): The floret BOW string (default: "<"). bow (str): The floret BOW string (default: "<").
eow (str): The floret EOW string (default: ">"). eow (str): The floret EOW string (default: ">").
attr (Union[int, str]): The token attribute for the vector keys
(default: "ORTH").
DOCS: https://spacy.io/api/vectors#init DOCS: https://spacy.io/api/vectors#init
""" """
@ -103,6 +108,14 @@ cdef class Vectors:
self.hash_seed = hash_seed self.hash_seed = hash_seed
self.bow = bow self.bow = bow
self.eow = eow self.eow = eow
if isinstance(attr, (int, long)):
self.attr = attr
else:
attr = attr.upper()
if attr == "TEXT":
attr = "ORTH"
self.attr = IDS.get(attr, ORTH)
if self.mode == Mode.default: if self.mode == Mode.default:
if data is None: if data is None:
if shape is None: if shape is None:
@ -546,6 +559,7 @@ cdef class Vectors:
"hash_seed": self.hash_seed, "hash_seed": self.hash_seed,
"bow": self.bow, "bow": self.bow,
"eow": self.eow, "eow": self.eow,
"attr": self.attr,
} }
def _set_cfg(self, cfg): def _set_cfg(self, cfg):
@ -556,6 +570,7 @@ cdef class Vectors:
self.hash_seed = cfg.get("hash_seed", 0) self.hash_seed = cfg.get("hash_seed", 0)
self.bow = cfg.get("bow", "<") self.bow = cfg.get("bow", "<")
self.eow = cfg.get("eow", ">") self.eow = cfg.get("eow", ">")
self.attr = cfg.get("attr", ORTH)
def to_disk(self, path, *, exclude=tuple()): def to_disk(self, path, *, exclude=tuple()):
"""Save the current state to a directory. """Save the current state to a directory.

View File

@ -365,8 +365,13 @@ cdef class Vocab:
self[orth] self[orth]
# Make prob negative so it sorts by rank ascending # Make prob negative so it sorts by rank ascending
# (key2row contains the rank) # (key2row contains the rank)
priority = [(-lex.prob, self.vectors.key2row[lex.orth], lex.orth) priority = []
for lex in self if lex.orth in self.vectors.key2row] cdef Lexeme lex
cdef attr_t value
for lex in self:
value = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
if value in self.vectors.key2row:
priority.append((-lex.prob, self.vectors.key2row[value], value))
priority.sort() priority.sort()
indices = xp.asarray([i for (prob, i, key) in priority], dtype="uint64") indices = xp.asarray([i for (prob, i, key) in priority], dtype="uint64")
keys = xp.asarray([key for (prob, i, key) in priority], dtype="uint64") keys = xp.asarray([key for (prob, i, key) in priority], dtype="uint64")
@ -399,8 +404,10 @@ cdef class Vocab:
""" """
if isinstance(orth, str): if isinstance(orth, str):
orth = self.strings.add(orth) orth = self.strings.add(orth)
if self.has_vector(orth): cdef Lexeme lex = self[orth]
return self.vectors[orth] key = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
if self.has_vector(key):
return self.vectors[key]
xp = get_array_module(self.vectors.data) xp = get_array_module(self.vectors.data)
vectors = xp.zeros((self.vectors_length,), dtype="f") vectors = xp.zeros((self.vectors_length,), dtype="f")
return vectors return vectors
@ -416,15 +423,16 @@ cdef class Vocab:
""" """
if isinstance(orth, str): if isinstance(orth, str):
orth = self.strings.add(orth) orth = self.strings.add(orth)
if self.vectors.is_full and orth not in self.vectors: cdef Lexeme lex = self[orth]
key = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
if self.vectors.is_full and key not in self.vectors:
new_rows = max(100, int(self.vectors.shape[0]*1.3)) new_rows = max(100, int(self.vectors.shape[0]*1.3))
if self.vectors.shape[1] == 0: if self.vectors.shape[1] == 0:
width = vector.size width = vector.size
else: else:
width = self.vectors.shape[1] width = self.vectors.shape[1]
self.vectors.resize((new_rows, width)) self.vectors.resize((new_rows, width))
lex = self[orth] # Add word to vocab if necessary row = self.vectors.add(key, vector=vector)
row = self.vectors.add(orth, vector=vector)
if row >= 0: if row >= 0:
lex.rank = row lex.rank = row
@ -439,7 +447,9 @@ cdef class Vocab:
""" """
if isinstance(orth, str): if isinstance(orth, str):
orth = self.strings.add(orth) orth = self.strings.add(orth)
return orth in self.vectors cdef Lexeme lex = self[orth]
key = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
return key in self.vectors
property lookups: property lookups:
def __get__(self): def __get__(self):

View File

@ -303,7 +303,7 @@ mapped to a zero vector. See the documentation on
| `nM` | The width of the static vectors. ~~Optional[int]~~ | | `nM` | The width of the static vectors. ~~Optional[int]~~ |
| `dropout` | Optional dropout rate. If set, it's applied per dimension over the whole batch. Defaults to `None`. ~~Optional[float]~~ | | `dropout` | Optional dropout rate. If set, it's applied per dimension over the whole batch. Defaults to `None`. ~~Optional[float]~~ |
| `init_W` | The [initialization function](https://thinc.ai/docs/api-initializers). Defaults to [`glorot_uniform_init`](https://thinc.ai/docs/api-initializers#glorot_uniform_init). ~~Callable[[Ops, Tuple[int, ...]]], FloatsXd]~~ | | `init_W` | The [initialization function](https://thinc.ai/docs/api-initializers). Defaults to [`glorot_uniform_init`](https://thinc.ai/docs/api-initializers#glorot_uniform_init). ~~Callable[[Ops, Tuple[int, ...]]], FloatsXd]~~ |
| `key_attr` | Defaults to `"ORTH"`. ~~str~~ | | `key_attr` | This setting is ignored in spaCy v3.6+. To set a custom key attribute for vectors, configure it through [`Vectors`](/api/vectors) or [`spacy init vectors`](/api/cli#init-vectors). Defaults to `"ORTH"`. ~~str~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Ragged]~~ | | **CREATES** | The model using the architecture. ~~Model[List[Doc], Ragged]~~ |
### spacy.FeatureExtractor.v1 {id="FeatureExtractor"} ### spacy.FeatureExtractor.v1 {id="FeatureExtractor"}

View File

@ -211,7 +211,8 @@ $ python -m spacy init vectors [lang] [vectors_loc] [output_dir] [--prune] [--tr
| `output_dir` | Pipeline output directory. Will be created if it doesn't exist. ~~Path (positional)~~ | | `output_dir` | Pipeline output directory. Will be created if it doesn't exist. ~~Path (positional)~~ |
| `--truncate`, `-t` | Number of vectors to truncate to when reading in vectors file. Defaults to `0` for no truncation. ~~int (option)~~ | | `--truncate`, `-t` | Number of vectors to truncate to when reading in vectors file. Defaults to `0` for no truncation. ~~int (option)~~ |
| `--prune`, `-p` | Number of vectors to prune the vocabulary to. Defaults to `-1` for no pruning. ~~int (option)~~ | | `--prune`, `-p` | Number of vectors to prune the vocabulary to. Defaults to `-1` for no pruning. ~~int (option)~~ |
| `--mode`, `-m` | Vectors mode: `default` or [`floret`](https://github.com/explosion/floret). Defaults to `default`. ~~Optional[str] \(option)~~ | | `--mode`, `-m` | Vectors mode: `default` or [`floret`](https://github.com/explosion/floret). Defaults to `default`. ~~str \(option)~~ |
| `--attr`, `-a` | Token attribute to use for vectors, e.g. `LOWER` or `NORM`) Defaults to `ORTH`. ~~str \(option)~~ |
| `--name`, `-n` | Name to assign to the word vectors in the `meta.json`, e.g. `en_core_web_md.vectors`. ~~Optional[str] \(option)~~ | | `--name`, `-n` | Name to assign to the word vectors in the `meta.json`, e.g. `en_core_web_md.vectors`. ~~Optional[str] \(option)~~ |
| `--verbose`, `-V` | Print additional information and explanations. ~~bool (flag)~~ | | `--verbose`, `-V` | Print additional information and explanations. ~~bool (flag)~~ |
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ | | `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |

View File

@ -60,6 +60,7 @@ modified later.
| `hash_seed` <Tag variant="new">3.2</Tag> | The floret hash seed (default: `0`). ~~int~~ | | `hash_seed` <Tag variant="new">3.2</Tag> | The floret hash seed (default: `0`). ~~int~~ |
| `bow` <Tag variant="new">3.2</Tag> | The floret BOW string (default: `"<"`). ~~str~~ | | `bow` <Tag variant="new">3.2</Tag> | The floret BOW string (default: `"<"`). ~~str~~ |
| `eow` <Tag variant="new">3.2</Tag> | The floret EOW string (default: `">"`). ~~str~~ | | `eow` <Tag variant="new">3.2</Tag> | The floret EOW string (default: `">"`). ~~str~~ |
| `attr` <Tag variant="new">3.6</Tag> | The token attribute for the vector keys (default: `"ORTH"`). ~~Union[int, str]~~ |
## Vectors.\_\_getitem\_\_ {id="getitem",tag="method"} ## Vectors.\_\_getitem\_\_ {id="getitem",tag="method"}
@ -454,7 +455,8 @@ Load state from a binary string.
## Attributes {id="attributes"} ## Attributes {id="attributes"}
| Name | Description | | Name | Description |
| --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ----------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `data` | Stored vectors data. `numpy` is used for CPU vectors, `cupy` for GPU vectors. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ | | `data` | Stored vectors data. `numpy` is used for CPU vectors, `cupy` for GPU vectors. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ |
| `key2row` | Dictionary mapping word hashes to rows in the `Vectors.data` table. ~~Dict[int, int]~~ | | `key2row` | Dictionary mapping word hashes to rows in the `Vectors.data` table. ~~Dict[int, int]~~ |
| `keys` | Array keeping the keys in order, such that `keys[vectors.key2row[key]] == key`. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ | | `keys` | Array keeping the keys in order, such that `keys[vectors.key2row[key]] == key`. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ |
| `attr` <Tag variant="new">3.6</Tag> | The token attribute for the vector keys. ~~int~~ |