diff --git a/spacy/cli/init_pipeline.py b/spacy/cli/init_pipeline.py index e0d048c69..13202cb60 100644 --- a/spacy/cli/init_pipeline.py +++ b/spacy/cli/init_pipeline.py @@ -32,6 +32,7 @@ def init_vectors_cli( name: Optional[str] = Opt(None, "--name", "-n", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"), verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"), jsonl_loc: Optional[Path] = Opt(None, "--lexemes-jsonl", "-j", help="Location of JSONL-formatted attributes file", hidden=True), + attr: str = Opt("ORTH", "--attr", "-a", help="Optional token attribute to use for vectors, e.g. LOWER or NORM"), # fmt: on ): """Convert word vectors for use with spaCy. Will export an nlp object that @@ -50,6 +51,7 @@ def init_vectors_cli( prune=prune, name=name, mode=mode, + attr=attr, ) msg.good(f"Successfully converted {len(nlp.vocab.vectors)} vectors") nlp.to_disk(output_dir) diff --git a/spacy/errors.py b/spacy/errors.py index a95f0c8a2..db1a886aa 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -216,6 +216,9 @@ class Warnings(metaclass=ErrorsWithCodes): W123 = ("Argument `enable` with value {enable} does not contain all values specified in the config option " "`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.") W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.") + W125 = ("The StaticVectors key_attr is no longer used. To set a custom " + "key attribute for vectors, configure it through Vectors(attr=) or " + "'spacy init vectors --attr'") class Errors(metaclass=ErrorsWithCodes): diff --git a/spacy/ml/staticvectors.py b/spacy/ml/staticvectors.py index 6fcb13ad0..b75240c5d 100644 --- a/spacy/ml/staticvectors.py +++ b/spacy/ml/staticvectors.py @@ -1,3 +1,4 @@ +import warnings from typing import Callable, List, Optional, Sequence, Tuple, cast from thinc.api import Model, Ops, registry @@ -5,7 +6,8 @@ from thinc.initializers import glorot_uniform_init from thinc.types import Floats1d, Floats2d, Ints1d, Ragged from thinc.util import partial -from ..errors import Errors +from ..attrs import ORTH +from ..errors import Errors, Warnings from ..tokens import Doc from ..vectors import Mode from ..vocab import Vocab @@ -24,6 +26,8 @@ def StaticVectors( linear projection to control the dimensionality. If a dropout rate is specified, the dropout is applied per dimension over the whole batch. """ + if key_attr != "ORTH": + warnings.warn(Warnings.W125, DeprecationWarning) return Model( "static_vectors", forward, @@ -40,9 +44,9 @@ def forward( token_count = sum(len(doc) for doc in docs) if not token_count: return _handle_empty(model.ops, model.get_dim("nO")) - key_attr: int = model.attrs["key_attr"] - keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs]) vocab: Vocab = docs[0].vocab + key_attr: int = getattr(vocab.vectors, "attr", ORTH) + keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs]) W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) if vocab.vectors.mode == Mode.default: V = model.ops.asarray(vocab.vectors.data) diff --git a/spacy/tests/vocab_vectors/test_vectors.py b/spacy/tests/vocab_vectors/test_vectors.py index 70835816d..717291314 100644 --- a/spacy/tests/vocab_vectors/test_vectors.py +++ b/spacy/tests/vocab_vectors/test_vectors.py @@ -402,6 +402,7 @@ def test_vectors_serialize(): row_r = v_r.add("D", vector=OPS.asarray([10, 20, 30, 40], dtype="f")) assert row == row_r assert_equal(OPS.to_numpy(v.data), OPS.to_numpy(v_r.data)) + assert v.attr == v_r.attr def test_vector_is_oov(): @@ -646,3 +647,32 @@ def test_equality(): vectors1.resize((5, 9)) vectors2.resize((5, 9)) assert vectors1 == vectors2 + + +def test_vectors_attr(): + data = numpy.asarray([[0, 0, 0], [1, 2, 3], [9, 8, 7]], dtype="f") + # default ORTH + nlp = English() + nlp.vocab.vectors = Vectors(data=data, keys=["A", "B", "C"]) + assert nlp.vocab.strings["A"] in nlp.vocab.vectors.key2row + assert nlp.vocab.strings["a"] not in nlp.vocab.vectors.key2row + assert nlp.vocab["A"].has_vector is True + assert nlp.vocab["a"].has_vector is False + assert nlp("A")[0].has_vector is True + assert nlp("a")[0].has_vector is False + + # custom LOWER + nlp = English() + nlp.vocab.vectors = Vectors(data=data, keys=["a", "b", "c"], attr="LOWER") + assert nlp.vocab.strings["A"] not in nlp.vocab.vectors.key2row + assert nlp.vocab.strings["a"] in nlp.vocab.vectors.key2row + assert nlp.vocab["A"].has_vector is True + assert nlp.vocab["a"].has_vector is True + assert nlp("A")[0].has_vector is True + assert nlp("a")[0].has_vector is True + # add a new vectors entry + assert nlp.vocab["D"].has_vector is False + assert nlp.vocab["d"].has_vector is False + nlp.vocab.set_vector("D", numpy.asarray([4, 5, 6])) + assert nlp.vocab["D"].has_vector is True + assert nlp.vocab["d"].has_vector is True diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index 206253949..146b276e2 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -35,6 +35,7 @@ from ..attrs cimport ( LENGTH, MORPH, NORM, + ORTH, POS, SENT_START, SPACY, @@ -613,13 +614,26 @@ cdef class Doc: """ if "similarity" in self.user_hooks: return self.user_hooks["similarity"](self, other) - if isinstance(other, (Lexeme, Token)) and self.length == 1: - if self.c[0].lex.orth == other.orth: + attr = getattr(self.vocab.vectors, "attr", ORTH) + cdef Token this_token + cdef Token other_token + cdef Lexeme other_lex + if len(self) == 1 and isinstance(other, Token): + this_token = self[0] + other_token = other + if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr): return 1.0 - elif isinstance(other, (Span, Doc)) and len(self) == len(other): + elif len(self) == 1 and isinstance(other, Lexeme): + this_token = self[0] + other_lex = other + if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr): + return 1.0 + elif isinstance(other, (Doc, Span)) and len(self) == len(other): similar = True - for i in range(self.length): - if self[i].orth != other[i].orth: + for i in range(len(self)): + this_token = self[i] + other_token = other[i] + if Token.get_struct_attr(this_token.c, attr) != Token.get_struct_attr(other_token.c, attr): similar = False break if similar: diff --git a/spacy/tokens/span.pyx b/spacy/tokens/span.pyx index 73192b760..59ee21687 100644 --- a/spacy/tokens/span.pyx +++ b/spacy/tokens/span.pyx @@ -8,13 +8,14 @@ import numpy from thinc.api import get_array_module from ..attrs cimport * -from ..attrs cimport attr_id_t +from ..attrs cimport ORTH, attr_id_t from ..lexeme cimport Lexeme from ..parts_of_speech cimport univ_pos_t from ..structs cimport LexemeC, TokenC from ..symbols cimport dep from ..typedefs cimport attr_t, flags_t, hash_t from .doc cimport _get_lca_matrix, get_token_attr, token_by_end, token_by_start +from .token cimport Token from ..errors import Errors, Warnings from ..util import normalize_slice @@ -341,13 +342,26 @@ cdef class Span: """ if "similarity" in self.doc.user_span_hooks: return self.doc.user_span_hooks["similarity"](self, other) - if len(self) == 1 and hasattr(other, "orth"): - if self[0].orth == other.orth: + attr = getattr(self.doc.vocab.vectors, "attr", ORTH) + cdef Token this_token + cdef Token other_token + cdef Lexeme other_lex + if len(self) == 1 and isinstance(other, Token): + this_token = self[0] + other_token = other + if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr): + return 1.0 + elif len(self) == 1 and isinstance(other, Lexeme): + this_token = self[0] + other_lex = other + if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr): return 1.0 elif isinstance(other, (Doc, Span)) and len(self) == len(other): similar = True for i in range(len(self)): - if self[i].orth != getattr(other[i], "orth", None): + this_token = self[i] + other_token = other[i] + if Token.get_struct_attr(this_token.c, attr) != Token.get_struct_attr(other_token.c, attr): similar = False break if similar: diff --git a/spacy/tokens/token.pyx b/spacy/tokens/token.pyx index 8c384f417..6018c3112 100644 --- a/spacy/tokens/token.pyx +++ b/spacy/tokens/token.pyx @@ -28,6 +28,7 @@ from ..attrs cimport ( LIKE_EMAIL, LIKE_NUM, LIKE_URL, + ORTH, ) from ..lexeme cimport Lexeme from ..symbols cimport conj @@ -214,11 +215,17 @@ cdef class Token: """ if "similarity" in self.doc.user_token_hooks: return self.doc.user_token_hooks["similarity"](self, other) - if hasattr(other, "__len__") and len(other) == 1 and hasattr(other, "__getitem__"): - if self.c.lex.orth == getattr(other[0], "orth", None): + attr = getattr(self.doc.vocab.vectors, "attr", ORTH) + cdef Token this_token = self + cdef Token other_token + cdef Lexeme other_lex + if isinstance(other, Token): + other_token = other + if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr): return 1.0 - elif hasattr(other, "orth"): - if self.c.lex.orth == other.orth: + elif isinstance(other, Lexeme): + other_lex = other + if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr): return 1.0 if self.vocab.vectors.n_keys == 0: warnings.warn(Warnings.W007.format(obj="Token")) @@ -415,7 +422,7 @@ cdef class Token: return self.doc.user_token_hooks["has_vector"](self) if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0: return True - return self.vocab.has_vector(self.c.lex.orth) + return self.vocab.has_vector(Token.get_struct_attr(self.c, self.vocab.vectors.attr)) @property def vector(self): @@ -431,7 +438,7 @@ cdef class Token: if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0: return self.doc.tensor[self.i] else: - return self.vocab.get_vector(self.c.lex.orth) + return self.vocab.get_vector(Token.get_struct_attr(self.c, self.vocab.vectors.attr)) @property def vector_norm(self): diff --git a/spacy/training/initialize.py b/spacy/training/initialize.py index 3a46b6632..82d4ebf24 100644 --- a/spacy/training/initialize.py +++ b/spacy/training/initialize.py @@ -216,9 +216,14 @@ def convert_vectors( prune: int, name: Optional[str] = None, mode: str = VectorsMode.default, + attr: str = "ORTH", ) -> None: vectors_loc = ensure_path(vectors_loc) if vectors_loc and vectors_loc.parts[-1].endswith(".npz"): + if attr != "ORTH": + raise ValueError( + "ORTH is the only attribute supported for vectors in .npz format." + ) nlp.vocab.vectors = Vectors( strings=nlp.vocab.strings, data=numpy.load(vectors_loc.open("rb")) ) @@ -246,11 +251,15 @@ def convert_vectors( nlp.vocab.vectors = Vectors( strings=nlp.vocab.strings, data=vectors_data, + attr=attr, **floret_settings, ) else: nlp.vocab.vectors = Vectors( - strings=nlp.vocab.strings, data=vectors_data, keys=vector_keys + strings=nlp.vocab.strings, + data=vectors_data, + keys=vector_keys, + attr=attr, ) nlp.vocab.deduplicate_vectors() if name is None: diff --git a/spacy/vectors.pyx b/spacy/vectors.pyx index bc654252a..bf79481b8 100644 --- a/spacy/vectors.pyx +++ b/spacy/vectors.pyx @@ -15,9 +15,11 @@ from thinc.api import Ops, get_array_module, get_current_ops from thinc.backends import get_array_ops from thinc.types import Floats2d +from .attrs cimport ORTH, attr_id_t from .strings cimport StringStore from . import util +from .attrs import IDS from .errors import Errors, Warnings from .strings import get_string_id @@ -64,8 +66,9 @@ cdef class Vectors: cdef readonly uint32_t hash_seed cdef readonly unicode bow cdef readonly unicode eow + cdef readonly attr_id_t attr - def __init__(self, *, strings=None, shape=None, data=None, keys=None, name=None, mode=Mode.default, minn=0, maxn=0, hash_count=1, hash_seed=0, bow="<", eow=">"): + def __init__(self, *, strings=None, shape=None, data=None, keys=None, name=None, mode=Mode.default, minn=0, maxn=0, hash_count=1, hash_seed=0, bow="<", eow=">", attr="ORTH"): """Create a new vector store. strings (StringStore): The string store. @@ -80,6 +83,8 @@ cdef class Vectors: hash_seed (int): The floret hash seed (default: 0). bow (str): The floret BOW string (default: "<"). eow (str): The floret EOW string (default: ">"). + attr (Union[int, str]): The token attribute for the vector keys + (default: "ORTH"). DOCS: https://spacy.io/api/vectors#init """ @@ -103,6 +108,14 @@ cdef class Vectors: self.hash_seed = hash_seed self.bow = bow self.eow = eow + if isinstance(attr, (int, long)): + self.attr = attr + else: + attr = attr.upper() + if attr == "TEXT": + attr = "ORTH" + self.attr = IDS.get(attr, ORTH) + if self.mode == Mode.default: if data is None: if shape is None: @@ -546,6 +559,7 @@ cdef class Vectors: "hash_seed": self.hash_seed, "bow": self.bow, "eow": self.eow, + "attr": self.attr, } def _set_cfg(self, cfg): @@ -556,6 +570,7 @@ cdef class Vectors: self.hash_seed = cfg.get("hash_seed", 0) self.bow = cfg.get("bow", "<") self.eow = cfg.get("eow", ">") + self.attr = cfg.get("attr", ORTH) def to_disk(self, path, *, exclude=tuple()): """Save the current state to a directory. diff --git a/spacy/vocab.pyx b/spacy/vocab.pyx index d47122d08..520228b51 100644 --- a/spacy/vocab.pyx +++ b/spacy/vocab.pyx @@ -365,8 +365,13 @@ cdef class Vocab: self[orth] # Make prob negative so it sorts by rank ascending # (key2row contains the rank) - priority = [(-lex.prob, self.vectors.key2row[lex.orth], lex.orth) - for lex in self if lex.orth in self.vectors.key2row] + priority = [] + cdef Lexeme lex + cdef attr_t value + for lex in self: + value = Lexeme.get_struct_attr(lex.c, self.vectors.attr) + if value in self.vectors.key2row: + priority.append((-lex.prob, self.vectors.key2row[value], value)) priority.sort() indices = xp.asarray([i for (prob, i, key) in priority], dtype="uint64") keys = xp.asarray([key for (prob, i, key) in priority], dtype="uint64") @@ -399,8 +404,10 @@ cdef class Vocab: """ if isinstance(orth, str): orth = self.strings.add(orth) - if self.has_vector(orth): - return self.vectors[orth] + cdef Lexeme lex = self[orth] + key = Lexeme.get_struct_attr(lex.c, self.vectors.attr) + if self.has_vector(key): + return self.vectors[key] xp = get_array_module(self.vectors.data) vectors = xp.zeros((self.vectors_length,), dtype="f") return vectors @@ -416,15 +423,16 @@ cdef class Vocab: """ if isinstance(orth, str): orth = self.strings.add(orth) - if self.vectors.is_full and orth not in self.vectors: + cdef Lexeme lex = self[orth] + key = Lexeme.get_struct_attr(lex.c, self.vectors.attr) + if self.vectors.is_full and key not in self.vectors: new_rows = max(100, int(self.vectors.shape[0]*1.3)) if self.vectors.shape[1] == 0: width = vector.size else: width = self.vectors.shape[1] self.vectors.resize((new_rows, width)) - lex = self[orth] # Add word to vocab if necessary - row = self.vectors.add(orth, vector=vector) + row = self.vectors.add(key, vector=vector) if row >= 0: lex.rank = row @@ -439,7 +447,9 @@ cdef class Vocab: """ if isinstance(orth, str): orth = self.strings.add(orth) - return orth in self.vectors + cdef Lexeme lex = self[orth] + key = Lexeme.get_struct_attr(lex.c, self.vectors.attr) + return key in self.vectors property lookups: def __get__(self): diff --git a/website/docs/api/architectures.mdx b/website/docs/api/architectures.mdx index 268c04a07..bab24f13b 100644 --- a/website/docs/api/architectures.mdx +++ b/website/docs/api/architectures.mdx @@ -303,7 +303,7 @@ mapped to a zero vector. See the documentation on | `nM` | The width of the static vectors. ~~Optional[int]~~ | | `dropout` | Optional dropout rate. If set, it's applied per dimension over the whole batch. Defaults to `None`. ~~Optional[float]~~ | | `init_W` | The [initialization function](https://thinc.ai/docs/api-initializers). Defaults to [`glorot_uniform_init`](https://thinc.ai/docs/api-initializers#glorot_uniform_init). ~~Callable[[Ops, Tuple[int, ...]]], FloatsXd]~~ | -| `key_attr` | Defaults to `"ORTH"`. ~~str~~ | +| `key_attr` | This setting is ignored in spaCy v3.6+. To set a custom key attribute for vectors, configure it through [`Vectors`](/api/vectors) or [`spacy init vectors`](/api/cli#init-vectors). Defaults to `"ORTH"`. ~~str~~ | | **CREATES** | The model using the architecture. ~~Model[List[Doc], Ragged]~~ | ### spacy.FeatureExtractor.v1 {id="FeatureExtractor"} diff --git a/website/docs/api/cli.mdx b/website/docs/api/cli.mdx index 5b4bca1ce..6a87f78b8 100644 --- a/website/docs/api/cli.mdx +++ b/website/docs/api/cli.mdx @@ -211,7 +211,8 @@ $ python -m spacy init vectors [lang] [vectors_loc] [output_dir] [--prune] [--tr | `output_dir` | Pipeline output directory. Will be created if it doesn't exist. ~~Path (positional)~~ | | `--truncate`, `-t` | Number of vectors to truncate to when reading in vectors file. Defaults to `0` for no truncation. ~~int (option)~~ | | `--prune`, `-p` | Number of vectors to prune the vocabulary to. Defaults to `-1` for no pruning. ~~int (option)~~ | -| `--mode`, `-m` | Vectors mode: `default` or [`floret`](https://github.com/explosion/floret). Defaults to `default`. ~~Optional[str] \(option)~~ | +| `--mode`, `-m` | Vectors mode: `default` or [`floret`](https://github.com/explosion/floret). Defaults to `default`. ~~str \(option)~~ | +| `--attr`, `-a` | Token attribute to use for vectors, e.g. `LOWER` or `NORM`) Defaults to `ORTH`. ~~str \(option)~~ | | `--name`, `-n` | Name to assign to the word vectors in the `meta.json`, e.g. `en_core_web_md.vectors`. ~~Optional[str] \(option)~~ | | `--verbose`, `-V` | Print additional information and explanations. ~~bool (flag)~~ | | `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ | diff --git a/website/docs/api/vectors.mdx b/website/docs/api/vectors.mdx index d6033c096..fa4cd0c7a 100644 --- a/website/docs/api/vectors.mdx +++ b/website/docs/api/vectors.mdx @@ -60,6 +60,7 @@ modified later. | `hash_seed` 3.2 | The floret hash seed (default: `0`). ~~int~~ | | `bow` 3.2 | The floret BOW string (default: `"<"`). ~~str~~ | | `eow` 3.2 | The floret EOW string (default: `">"`). ~~str~~ | +| `attr` 3.6 | The token attribute for the vector keys (default: `"ORTH"`). ~~Union[int, str]~~ | ## Vectors.\_\_getitem\_\_ {id="getitem",tag="method"} @@ -453,8 +454,9 @@ Load state from a binary string. ## Attributes {id="attributes"} -| Name | Description | -| --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | -| `data` | Stored vectors data. `numpy` is used for CPU vectors, `cupy` for GPU vectors. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ | -| `key2row` | Dictionary mapping word hashes to rows in the `Vectors.data` table. ~~Dict[int, int]~~ | -| `keys` | Array keeping the keys in order, such that `keys[vectors.key2row[key]] == key`. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ | +| Name | Description | +| ----------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- | +| `data` | Stored vectors data. `numpy` is used for CPU vectors, `cupy` for GPU vectors. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ | +| `key2row` | Dictionary mapping word hashes to rows in the `Vectors.data` table. ~~Dict[int, int]~~ | +| `keys` | Array keeping the keys in order, such that `keys[vectors.key2row[key]] == key`. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ | +| `attr` 3.6 | The token attribute for the vector keys. ~~int~~ |