Support custom token/lexeme attribute for vectors

This commit is contained in:
Adriane Boyd 2023-05-11 21:59:03 +02:00
parent 3252f6b13f
commit a6177187c2
10 changed files with 130 additions and 26 deletions

View File

@ -24,6 +24,7 @@ def init_vectors_cli(
name: Optional[str] = Opt(None, "--name", "-n", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"), name: Optional[str] = Opt(None, "--name", "-n", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"), verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
jsonl_loc: Optional[Path] = Opt(None, "--lexemes-jsonl", "-j", help="Location of JSONL-formatted attributes file", hidden=True), jsonl_loc: Optional[Path] = Opt(None, "--lexemes-jsonl", "-j", help="Location of JSONL-formatted attributes file", hidden=True),
attr: str = Opt("ORTH", "--attr", "-a", help="Optional token attribute to use for vectors, e.g. LOWER or NORM"),
# fmt: on # fmt: on
): ):
"""Convert word vectors for use with spaCy. Will export an nlp object that """Convert word vectors for use with spaCy. Will export an nlp object that
@ -42,6 +43,7 @@ def init_vectors_cli(
prune=prune, prune=prune,
name=name, name=name,
mode=mode, mode=mode,
attr=attr,
) )
msg.good(f"Successfully converted {len(nlp.vocab.vectors)} vectors") msg.good(f"Successfully converted {len(nlp.vocab.vectors)} vectors")
nlp.to_disk(output_dir) nlp.to_disk(output_dir)

View File

@ -215,6 +215,9 @@ class Warnings(metaclass=ErrorsWithCodes):
W123 = ("Argument `enable` with value {enable} does not contain all values specified in the config option " W123 = ("Argument `enable` with value {enable} does not contain all values specified in the config option "
"`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.") "`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.")
W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.") W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.")
W125 = ("The StaticVectors key_attr is no longer used. To set a custom "
"key attribute for vectors, configure it through Vectors(attr=) or "
"'spacy init vectors --attr'")
class Errors(metaclass=ErrorsWithCodes): class Errors(metaclass=ErrorsWithCodes):

View File

@ -23,6 +23,8 @@ def StaticVectors(
linear projection to control the dimensionality. If a dropout rate is linear projection to control the dimensionality. If a dropout rate is
specified, the dropout is applied per dimension over the whole batch. specified, the dropout is applied per dimension over the whole batch.
""" """
if key_attr != "ORTH":
warnings.warn(Warnings.W125, DeprecationWarning)
return Model( return Model(
"static_vectors", "static_vectors",
forward, forward,
@ -39,9 +41,9 @@ def forward(
token_count = sum(len(doc) for doc in docs) token_count = sum(len(doc) for doc in docs)
if not token_count: if not token_count:
return _handle_empty(model.ops, model.get_dim("nO")) return _handle_empty(model.ops, model.get_dim("nO"))
key_attr: int = model.attrs["key_attr"]
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
vocab: Vocab = docs[0].vocab vocab: Vocab = docs[0].vocab
key_attr: int = vocab.vectors.attr
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
W = cast(Floats2d, model.ops.as_contig(model.get_param("W"))) W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
if vocab.vectors.mode == Mode.default: if vocab.vectors.mode == Mode.default:
V = model.ops.asarray(vocab.vectors.data) V = model.ops.asarray(vocab.vectors.data)

View File

@ -402,6 +402,7 @@ def test_vectors_serialize():
row_r = v_r.add("D", vector=OPS.asarray([10, 20, 30, 40], dtype="f")) row_r = v_r.add("D", vector=OPS.asarray([10, 20, 30, 40], dtype="f"))
assert row == row_r assert row == row_r
assert_equal(OPS.to_numpy(v.data), OPS.to_numpy(v_r.data)) assert_equal(OPS.to_numpy(v.data), OPS.to_numpy(v_r.data))
assert v.attr == v_r.attr
def test_vector_is_oov(): def test_vector_is_oov():
@ -646,3 +647,32 @@ def test_equality():
vectors1.resize((5, 9)) vectors1.resize((5, 9))
vectors2.resize((5, 9)) vectors2.resize((5, 9))
assert vectors1 == vectors2 assert vectors1 == vectors2
def test_vectors_attr():
data = numpy.asarray([[0, 0, 0], [1, 2, 3], [9, 8, 7]], dtype="f")
# default ORTH
nlp = English()
nlp.vocab.vectors = Vectors(data=data, keys=["A", "B", "C"])
assert nlp.vocab.strings["A"] in nlp.vocab.vectors.key2row
assert nlp.vocab.strings["a"] not in nlp.vocab.vectors.key2row
assert nlp.vocab["A"].has_vector is True
assert nlp.vocab["a"].has_vector is False
assert nlp("A")[0].has_vector is True
assert nlp("a")[0].has_vector is False
# custom LOWER
nlp = English()
nlp.vocab.vectors = Vectors(data=data, keys=["a", "b", "c"], attr="LOWER")
assert nlp.vocab.strings["A"] not in nlp.vocab.vectors.key2row
assert nlp.vocab.strings["a"] in nlp.vocab.vectors.key2row
assert nlp.vocab["A"].has_vector is True
assert nlp.vocab["a"].has_vector is True
assert nlp("A")[0].has_vector is True
assert nlp("a")[0].has_vector is True
# add a new vectors entry
assert nlp.vocab["D"].has_vector is False
assert nlp.vocab["d"].has_vector is False
nlp.vocab.set_vector("D", numpy.asarray([4, 5, 6]))
assert nlp.vocab["D"].has_vector is True
assert nlp.vocab["d"].has_vector is True

View File

@ -591,13 +591,26 @@ cdef class Doc:
""" """
if "similarity" in self.user_hooks: if "similarity" in self.user_hooks:
return self.user_hooks["similarity"](self, other) return self.user_hooks["similarity"](self, other)
if isinstance(other, (Lexeme, Token)) and self.length == 1: attr = self.doc.vocab.vectors.attr
if self.c[0].lex.orth == other.orth: cdef Token this_token
cdef Token other_token
cdef Lexeme other_lex
if len(self) == 1 and isinstance(other, Token):
this_token = self[0]
other_token = other
if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr):
return 1.0 return 1.0
elif isinstance(other, (Span, Doc)) and len(self) == len(other): elif len(self) == 1 and isinstance(other, Lexeme):
this_token = self[0]
other_lex = other
if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr):
return 1.0
elif isinstance(other, (Doc, Span)) and len(self) == len(other):
similar = True similar = True
for i in range(self.length): for i in range(len(self)):
if self[i].orth != other[i].orth: this_token = self[i]
other_token = other[i]
if Token.get_struct_attr(this_token.c, attr) != Token.get_struct_attr(other_token.c, attr):
similar = False similar = False
break break
if similar: if similar:

View File

@ -7,6 +7,7 @@ import warnings
import copy import copy
from .doc cimport token_by_start, token_by_end, get_token_attr, _get_lca_matrix from .doc cimport token_by_start, token_by_end, get_token_attr, _get_lca_matrix
from .token cimport Token
from ..structs cimport TokenC, LexemeC from ..structs cimport TokenC, LexemeC
from ..typedefs cimport flags_t, attr_t, hash_t from ..typedefs cimport flags_t, attr_t, hash_t
from ..attrs cimport attr_id_t from ..attrs cimport attr_id_t
@ -340,13 +341,26 @@ cdef class Span:
""" """
if "similarity" in self.doc.user_span_hooks: if "similarity" in self.doc.user_span_hooks:
return self.doc.user_span_hooks["similarity"](self, other) return self.doc.user_span_hooks["similarity"](self, other)
if len(self) == 1 and hasattr(other, "orth"): attr = self.doc.vocab.vectors.attr
if self[0].orth == other.orth: cdef Token this_token
cdef Token other_token
cdef Lexeme other_lex
if len(self) == 1 and isinstance(other, Token):
this_token = self[0]
other_token = other
if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr):
return 1.0
elif len(self) == 1 and isinstance(other, Lexeme):
this_token = self[0]
other_lex = other
if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr):
return 1.0 return 1.0
elif isinstance(other, (Doc, Span)) and len(self) == len(other): elif isinstance(other, (Doc, Span)) and len(self) == len(other):
similar = True similar = True
for i in range(len(self)): for i in range(len(self)):
if self[i].orth != getattr(other[i], "orth", None): this_token = self[i]
other_token = other[i]
if Token.get_struct_attr(this_token.c, attr) != Token.get_struct_attr(other_token.c, attr):
similar = False similar = False
break break
if similar: if similar:

View File

@ -197,11 +197,17 @@ cdef class Token:
""" """
if "similarity" in self.doc.user_token_hooks: if "similarity" in self.doc.user_token_hooks:
return self.doc.user_token_hooks["similarity"](self, other) return self.doc.user_token_hooks["similarity"](self, other)
if hasattr(other, "__len__") and len(other) == 1 and hasattr(other, "__getitem__"): attr = self.doc.vocab.vectors.attr
if self.c.lex.orth == getattr(other[0], "orth", None): cdef Token this_token = self
cdef Token other_token
cdef Lexeme other_lex
if isinstance(other, Token):
other_token = other
if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr):
return 1.0 return 1.0
elif hasattr(other, "orth"): elif isinstance(other, Lexeme):
if self.c.lex.orth == other.orth: other_lex = other
if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr):
return 1.0 return 1.0
if self.vocab.vectors.n_keys == 0: if self.vocab.vectors.n_keys == 0:
warnings.warn(Warnings.W007.format(obj="Token")) warnings.warn(Warnings.W007.format(obj="Token"))
@ -398,7 +404,7 @@ cdef class Token:
return self.doc.user_token_hooks["has_vector"](self) return self.doc.user_token_hooks["has_vector"](self)
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0: if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
return True return True
return self.vocab.has_vector(self.c.lex.orth) return self.vocab.has_vector(Token.get_struct_attr(self.c, self.vocab.vectors.attr))
@property @property
def vector(self): def vector(self):
@ -414,7 +420,7 @@ cdef class Token:
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0: if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
return self.doc.tensor[self.i] return self.doc.tensor[self.i]
else: else:
return self.vocab.get_vector(self.c.lex.orth) return self.vocab.get_vector(Token.get_struct_attr(self.c, self.vocab.vectors.attr))
@property @property
def vector_norm(self): def vector_norm(self):

View File

@ -206,9 +206,14 @@ def convert_vectors(
prune: int, prune: int,
name: Optional[str] = None, name: Optional[str] = None,
mode: str = VectorsMode.default, mode: str = VectorsMode.default,
attr: str = "ORTH",
) -> None: ) -> None:
vectors_loc = ensure_path(vectors_loc) vectors_loc = ensure_path(vectors_loc)
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"): if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
if attr != "ORTH":
raise ValueError(
"ORTH is the only attribute supported for vectors in .npz format."
)
nlp.vocab.vectors = Vectors( nlp.vocab.vectors = Vectors(
strings=nlp.vocab.strings, data=numpy.load(vectors_loc.open("rb")) strings=nlp.vocab.strings, data=numpy.load(vectors_loc.open("rb"))
) )
@ -236,11 +241,15 @@ def convert_vectors(
nlp.vocab.vectors = Vectors( nlp.vocab.vectors = Vectors(
strings=nlp.vocab.strings, strings=nlp.vocab.strings,
data=vectors_data, data=vectors_data,
attr=attr,
**floret_settings, **floret_settings,
) )
else: else:
nlp.vocab.vectors = Vectors( nlp.vocab.vectors = Vectors(
strings=nlp.vocab.strings, data=vectors_data, keys=vector_keys strings=nlp.vocab.strings,
data=vectors_data,
keys=vector_keys,
attr=attr,
) )
nlp.vocab.deduplicate_vectors() nlp.vocab.deduplicate_vectors()
if name is None: if name is None:

View File

@ -14,8 +14,10 @@ from thinc.api import Ops, get_array_module, get_current_ops
from thinc.backends import get_array_ops from thinc.backends import get_array_ops
from thinc.types import Floats2d from thinc.types import Floats2d
from .attrs cimport attr_id_t, ORTH
from .strings cimport StringStore from .strings cimport StringStore
from .attrs import IDS
from .strings import get_string_id from .strings import get_string_id
from .errors import Errors, Warnings from .errors import Errors, Warnings
from . import util from . import util
@ -63,8 +65,9 @@ cdef class Vectors:
cdef readonly uint32_t hash_seed cdef readonly uint32_t hash_seed
cdef readonly unicode bow cdef readonly unicode bow
cdef readonly unicode eow cdef readonly unicode eow
cdef readonly attr_id_t attr
def __init__(self, *, strings=None, shape=None, data=None, keys=None, name=None, mode=Mode.default, minn=0, maxn=0, hash_count=1, hash_seed=0, bow="<", eow=">"): def __init__(self, *, strings=None, shape=None, data=None, keys=None, name=None, mode=Mode.default, minn=0, maxn=0, hash_count=1, hash_seed=0, bow="<", eow=">", attr="ORTH"):
"""Create a new vector store. """Create a new vector store.
strings (StringStore): The string store. strings (StringStore): The string store.
@ -79,6 +82,8 @@ cdef class Vectors:
hash_seed (int): The floret hash seed (default: 0). hash_seed (int): The floret hash seed (default: 0).
bow (str): The floret BOW string (default: "<"). bow (str): The floret BOW string (default: "<").
eow (str): The floret EOW string (default: ">"). eow (str): The floret EOW string (default: ">").
attr (Union[int, str]): The Token attribute for the vector keys
(default: "ORTH").
DOCS: https://spacy.io/api/vectors#init DOCS: https://spacy.io/api/vectors#init
""" """
@ -102,6 +107,14 @@ cdef class Vectors:
self.hash_seed = hash_seed self.hash_seed = hash_seed
self.bow = bow self.bow = bow
self.eow = eow self.eow = eow
if isinstance(attr, (int, long)):
self.attr = attr
else:
attr = attr.upper()
if attr == "TEXT":
attr = "ORTH"
self.attr = IDS.get(attr, ORTH)
if self.mode == Mode.default: if self.mode == Mode.default:
if data is None: if data is None:
if shape is None: if shape is None:
@ -545,6 +558,7 @@ cdef class Vectors:
"hash_seed": self.hash_seed, "hash_seed": self.hash_seed,
"bow": self.bow, "bow": self.bow,
"eow": self.eow, "eow": self.eow,
"attr": self.attr,
} }
def _set_cfg(self, cfg): def _set_cfg(self, cfg):
@ -555,6 +569,7 @@ cdef class Vectors:
self.hash_seed = cfg.get("hash_seed", 0) self.hash_seed = cfg.get("hash_seed", 0)
self.bow = cfg.get("bow", "<") self.bow = cfg.get("bow", "<")
self.eow = cfg.get("eow", ">") self.eow = cfg.get("eow", ">")
self.attr = cfg.get("attr", ORTH)
def to_disk(self, path, *, exclude=tuple()): def to_disk(self, path, *, exclude=tuple()):
"""Save the current state to a directory. """Save the current state to a directory.

View File

@ -364,8 +364,13 @@ cdef class Vocab:
self[orth] self[orth]
# Make prob negative so it sorts by rank ascending # Make prob negative so it sorts by rank ascending
# (key2row contains the rank) # (key2row contains the rank)
priority = [(-lex.prob, self.vectors.key2row[lex.orth], lex.orth) priority = []
for lex in self if lex.orth in self.vectors.key2row] cdef Lexeme lex
cdef attr_t value
for lex in self:
value = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
if value in self.vectors.key2row:
priority.append((-lex.prob, self.vectors.key2row[value], value))
priority.sort() priority.sort()
indices = xp.asarray([i for (prob, i, key) in priority], dtype="uint64") indices = xp.asarray([i for (prob, i, key) in priority], dtype="uint64")
keys = xp.asarray([key for (prob, i, key) in priority], dtype="uint64") keys = xp.asarray([key for (prob, i, key) in priority], dtype="uint64")
@ -398,8 +403,10 @@ cdef class Vocab:
""" """
if isinstance(orth, str): if isinstance(orth, str):
orth = self.strings.add(orth) orth = self.strings.add(orth)
if self.has_vector(orth): cdef Lexeme lex = self[orth]
return self.vectors[orth] key = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
if self.has_vector(key):
return self.vectors[key]
xp = get_array_module(self.vectors.data) xp = get_array_module(self.vectors.data)
vectors = xp.zeros((self.vectors_length,), dtype="f") vectors = xp.zeros((self.vectors_length,), dtype="f")
return vectors return vectors
@ -415,15 +422,16 @@ cdef class Vocab:
""" """
if isinstance(orth, str): if isinstance(orth, str):
orth = self.strings.add(orth) orth = self.strings.add(orth)
if self.vectors.is_full and orth not in self.vectors: cdef Lexeme lex = self[orth]
key = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
if self.vectors.is_full and key not in self.vectors:
new_rows = max(100, int(self.vectors.shape[0]*1.3)) new_rows = max(100, int(self.vectors.shape[0]*1.3))
if self.vectors.shape[1] == 0: if self.vectors.shape[1] == 0:
width = vector.size width = vector.size
else: else:
width = self.vectors.shape[1] width = self.vectors.shape[1]
self.vectors.resize((new_rows, width)) self.vectors.resize((new_rows, width))
lex = self[orth] # Add word to vocab if necessary row = self.vectors.add(key, vector=vector)
row = self.vectors.add(orth, vector=vector)
if row >= 0: if row >= 0:
lex.rank = row lex.rank = row
@ -438,7 +446,9 @@ cdef class Vocab:
""" """
if isinstance(orth, str): if isinstance(orth, str):
orth = self.strings.add(orth) orth = self.strings.add(orth)
return orth in self.vectors cdef Lexeme lex = self[orth]
key = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
return key in self.vectors
property lookups: property lookups:
def __get__(self): def __get__(self):