mirror of
https://github.com/explosion/spaCy.git
synced 2025-06-23 06:23:06 +03:00
Support custom token/lexeme attribute for vectors (#12625)
* Support custom token/lexeme attribute for vectors * Fix imports * Back off to ORTH without Vectors.attr * Fallback if vectors.attr doesn't exist * Update docs
This commit is contained in:
parent
337a360cc7
commit
fb0da3e097
|
@ -32,6 +32,7 @@ def init_vectors_cli(
|
||||||
name: Optional[str] = Opt(None, "--name", "-n", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
|
name: Optional[str] = Opt(None, "--name", "-n", help="Optional name for the word vectors, e.g. en_core_web_lg.vectors"),
|
||||||
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
verbose: bool = Opt(False, "--verbose", "-V", "-VV", help="Display more information for debugging purposes"),
|
||||||
jsonl_loc: Optional[Path] = Opt(None, "--lexemes-jsonl", "-j", help="Location of JSONL-formatted attributes file", hidden=True),
|
jsonl_loc: Optional[Path] = Opt(None, "--lexemes-jsonl", "-j", help="Location of JSONL-formatted attributes file", hidden=True),
|
||||||
|
attr: str = Opt("ORTH", "--attr", "-a", help="Optional token attribute to use for vectors, e.g. LOWER or NORM"),
|
||||||
# fmt: on
|
# fmt: on
|
||||||
):
|
):
|
||||||
"""Convert word vectors for use with spaCy. Will export an nlp object that
|
"""Convert word vectors for use with spaCy. Will export an nlp object that
|
||||||
|
@ -50,6 +51,7 @@ def init_vectors_cli(
|
||||||
prune=prune,
|
prune=prune,
|
||||||
name=name,
|
name=name,
|
||||||
mode=mode,
|
mode=mode,
|
||||||
|
attr=attr,
|
||||||
)
|
)
|
||||||
msg.good(f"Successfully converted {len(nlp.vocab.vectors)} vectors")
|
msg.good(f"Successfully converted {len(nlp.vocab.vectors)} vectors")
|
||||||
nlp.to_disk(output_dir)
|
nlp.to_disk(output_dir)
|
||||||
|
|
|
@ -216,6 +216,9 @@ class Warnings(metaclass=ErrorsWithCodes):
|
||||||
W123 = ("Argument `enable` with value {enable} does not contain all values specified in the config option "
|
W123 = ("Argument `enable` with value {enable} does not contain all values specified in the config option "
|
||||||
"`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.")
|
"`enabled` ({enabled}). Be aware that this might affect other components in your pipeline.")
|
||||||
W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.")
|
W124 = ("{host}:{port} is already in use, using the nearest available port {serve_port} as an alternative.")
|
||||||
|
W125 = ("The StaticVectors key_attr is no longer used. To set a custom "
|
||||||
|
"key attribute for vectors, configure it through Vectors(attr=) or "
|
||||||
|
"'spacy init vectors --attr'")
|
||||||
|
|
||||||
|
|
||||||
class Errors(metaclass=ErrorsWithCodes):
|
class Errors(metaclass=ErrorsWithCodes):
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
import warnings
|
||||||
from typing import Callable, List, Optional, Sequence, Tuple, cast
|
from typing import Callable, List, Optional, Sequence, Tuple, cast
|
||||||
|
|
||||||
from thinc.api import Model, Ops, registry
|
from thinc.api import Model, Ops, registry
|
||||||
|
@ -5,7 +6,8 @@ from thinc.initializers import glorot_uniform_init
|
||||||
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
|
from thinc.types import Floats1d, Floats2d, Ints1d, Ragged
|
||||||
from thinc.util import partial
|
from thinc.util import partial
|
||||||
|
|
||||||
from ..errors import Errors
|
from ..attrs import ORTH
|
||||||
|
from ..errors import Errors, Warnings
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..vectors import Mode
|
from ..vectors import Mode
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
|
@ -24,6 +26,8 @@ def StaticVectors(
|
||||||
linear projection to control the dimensionality. If a dropout rate is
|
linear projection to control the dimensionality. If a dropout rate is
|
||||||
specified, the dropout is applied per dimension over the whole batch.
|
specified, the dropout is applied per dimension over the whole batch.
|
||||||
"""
|
"""
|
||||||
|
if key_attr != "ORTH":
|
||||||
|
warnings.warn(Warnings.W125, DeprecationWarning)
|
||||||
return Model(
|
return Model(
|
||||||
"static_vectors",
|
"static_vectors",
|
||||||
forward,
|
forward,
|
||||||
|
@ -40,9 +44,9 @@ def forward(
|
||||||
token_count = sum(len(doc) for doc in docs)
|
token_count = sum(len(doc) for doc in docs)
|
||||||
if not token_count:
|
if not token_count:
|
||||||
return _handle_empty(model.ops, model.get_dim("nO"))
|
return _handle_empty(model.ops, model.get_dim("nO"))
|
||||||
key_attr: int = model.attrs["key_attr"]
|
|
||||||
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
|
|
||||||
vocab: Vocab = docs[0].vocab
|
vocab: Vocab = docs[0].vocab
|
||||||
|
key_attr: int = getattr(vocab.vectors, "attr", ORTH)
|
||||||
|
keys = model.ops.flatten([cast(Ints1d, doc.to_array(key_attr)) for doc in docs])
|
||||||
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
W = cast(Floats2d, model.ops.as_contig(model.get_param("W")))
|
||||||
if vocab.vectors.mode == Mode.default:
|
if vocab.vectors.mode == Mode.default:
|
||||||
V = model.ops.asarray(vocab.vectors.data)
|
V = model.ops.asarray(vocab.vectors.data)
|
||||||
|
|
|
@ -402,6 +402,7 @@ def test_vectors_serialize():
|
||||||
row_r = v_r.add("D", vector=OPS.asarray([10, 20, 30, 40], dtype="f"))
|
row_r = v_r.add("D", vector=OPS.asarray([10, 20, 30, 40], dtype="f"))
|
||||||
assert row == row_r
|
assert row == row_r
|
||||||
assert_equal(OPS.to_numpy(v.data), OPS.to_numpy(v_r.data))
|
assert_equal(OPS.to_numpy(v.data), OPS.to_numpy(v_r.data))
|
||||||
|
assert v.attr == v_r.attr
|
||||||
|
|
||||||
|
|
||||||
def test_vector_is_oov():
|
def test_vector_is_oov():
|
||||||
|
@ -646,3 +647,32 @@ def test_equality():
|
||||||
vectors1.resize((5, 9))
|
vectors1.resize((5, 9))
|
||||||
vectors2.resize((5, 9))
|
vectors2.resize((5, 9))
|
||||||
assert vectors1 == vectors2
|
assert vectors1 == vectors2
|
||||||
|
|
||||||
|
|
||||||
|
def test_vectors_attr():
|
||||||
|
data = numpy.asarray([[0, 0, 0], [1, 2, 3], [9, 8, 7]], dtype="f")
|
||||||
|
# default ORTH
|
||||||
|
nlp = English()
|
||||||
|
nlp.vocab.vectors = Vectors(data=data, keys=["A", "B", "C"])
|
||||||
|
assert nlp.vocab.strings["A"] in nlp.vocab.vectors.key2row
|
||||||
|
assert nlp.vocab.strings["a"] not in nlp.vocab.vectors.key2row
|
||||||
|
assert nlp.vocab["A"].has_vector is True
|
||||||
|
assert nlp.vocab["a"].has_vector is False
|
||||||
|
assert nlp("A")[0].has_vector is True
|
||||||
|
assert nlp("a")[0].has_vector is False
|
||||||
|
|
||||||
|
# custom LOWER
|
||||||
|
nlp = English()
|
||||||
|
nlp.vocab.vectors = Vectors(data=data, keys=["a", "b", "c"], attr="LOWER")
|
||||||
|
assert nlp.vocab.strings["A"] not in nlp.vocab.vectors.key2row
|
||||||
|
assert nlp.vocab.strings["a"] in nlp.vocab.vectors.key2row
|
||||||
|
assert nlp.vocab["A"].has_vector is True
|
||||||
|
assert nlp.vocab["a"].has_vector is True
|
||||||
|
assert nlp("A")[0].has_vector is True
|
||||||
|
assert nlp("a")[0].has_vector is True
|
||||||
|
# add a new vectors entry
|
||||||
|
assert nlp.vocab["D"].has_vector is False
|
||||||
|
assert nlp.vocab["d"].has_vector is False
|
||||||
|
nlp.vocab.set_vector("D", numpy.asarray([4, 5, 6]))
|
||||||
|
assert nlp.vocab["D"].has_vector is True
|
||||||
|
assert nlp.vocab["d"].has_vector is True
|
||||||
|
|
|
@ -35,6 +35,7 @@ from ..attrs cimport (
|
||||||
LENGTH,
|
LENGTH,
|
||||||
MORPH,
|
MORPH,
|
||||||
NORM,
|
NORM,
|
||||||
|
ORTH,
|
||||||
POS,
|
POS,
|
||||||
SENT_START,
|
SENT_START,
|
||||||
SPACY,
|
SPACY,
|
||||||
|
@ -613,13 +614,26 @@ cdef class Doc:
|
||||||
"""
|
"""
|
||||||
if "similarity" in self.user_hooks:
|
if "similarity" in self.user_hooks:
|
||||||
return self.user_hooks["similarity"](self, other)
|
return self.user_hooks["similarity"](self, other)
|
||||||
if isinstance(other, (Lexeme, Token)) and self.length == 1:
|
attr = getattr(self.vocab.vectors, "attr", ORTH)
|
||||||
if self.c[0].lex.orth == other.orth:
|
cdef Token this_token
|
||||||
|
cdef Token other_token
|
||||||
|
cdef Lexeme other_lex
|
||||||
|
if len(self) == 1 and isinstance(other, Token):
|
||||||
|
this_token = self[0]
|
||||||
|
other_token = other
|
||||||
|
if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr):
|
||||||
return 1.0
|
return 1.0
|
||||||
elif isinstance(other, (Span, Doc)) and len(self) == len(other):
|
elif len(self) == 1 and isinstance(other, Lexeme):
|
||||||
|
this_token = self[0]
|
||||||
|
other_lex = other
|
||||||
|
if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr):
|
||||||
|
return 1.0
|
||||||
|
elif isinstance(other, (Doc, Span)) and len(self) == len(other):
|
||||||
similar = True
|
similar = True
|
||||||
for i in range(self.length):
|
for i in range(len(self)):
|
||||||
if self[i].orth != other[i].orth:
|
this_token = self[i]
|
||||||
|
other_token = other[i]
|
||||||
|
if Token.get_struct_attr(this_token.c, attr) != Token.get_struct_attr(other_token.c, attr):
|
||||||
similar = False
|
similar = False
|
||||||
break
|
break
|
||||||
if similar:
|
if similar:
|
||||||
|
|
|
@ -8,13 +8,14 @@ import numpy
|
||||||
from thinc.api import get_array_module
|
from thinc.api import get_array_module
|
||||||
|
|
||||||
from ..attrs cimport *
|
from ..attrs cimport *
|
||||||
from ..attrs cimport attr_id_t
|
from ..attrs cimport ORTH, attr_id_t
|
||||||
from ..lexeme cimport Lexeme
|
from ..lexeme cimport Lexeme
|
||||||
from ..parts_of_speech cimport univ_pos_t
|
from ..parts_of_speech cimport univ_pos_t
|
||||||
from ..structs cimport LexemeC, TokenC
|
from ..structs cimport LexemeC, TokenC
|
||||||
from ..symbols cimport dep
|
from ..symbols cimport dep
|
||||||
from ..typedefs cimport attr_t, flags_t, hash_t
|
from ..typedefs cimport attr_t, flags_t, hash_t
|
||||||
from .doc cimport _get_lca_matrix, get_token_attr, token_by_end, token_by_start
|
from .doc cimport _get_lca_matrix, get_token_attr, token_by_end, token_by_start
|
||||||
|
from .token cimport Token
|
||||||
|
|
||||||
from ..errors import Errors, Warnings
|
from ..errors import Errors, Warnings
|
||||||
from ..util import normalize_slice
|
from ..util import normalize_slice
|
||||||
|
@ -341,13 +342,26 @@ cdef class Span:
|
||||||
"""
|
"""
|
||||||
if "similarity" in self.doc.user_span_hooks:
|
if "similarity" in self.doc.user_span_hooks:
|
||||||
return self.doc.user_span_hooks["similarity"](self, other)
|
return self.doc.user_span_hooks["similarity"](self, other)
|
||||||
if len(self) == 1 and hasattr(other, "orth"):
|
attr = getattr(self.doc.vocab.vectors, "attr", ORTH)
|
||||||
if self[0].orth == other.orth:
|
cdef Token this_token
|
||||||
|
cdef Token other_token
|
||||||
|
cdef Lexeme other_lex
|
||||||
|
if len(self) == 1 and isinstance(other, Token):
|
||||||
|
this_token = self[0]
|
||||||
|
other_token = other
|
||||||
|
if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr):
|
||||||
|
return 1.0
|
||||||
|
elif len(self) == 1 and isinstance(other, Lexeme):
|
||||||
|
this_token = self[0]
|
||||||
|
other_lex = other
|
||||||
|
if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr):
|
||||||
return 1.0
|
return 1.0
|
||||||
elif isinstance(other, (Doc, Span)) and len(self) == len(other):
|
elif isinstance(other, (Doc, Span)) and len(self) == len(other):
|
||||||
similar = True
|
similar = True
|
||||||
for i in range(len(self)):
|
for i in range(len(self)):
|
||||||
if self[i].orth != getattr(other[i], "orth", None):
|
this_token = self[i]
|
||||||
|
other_token = other[i]
|
||||||
|
if Token.get_struct_attr(this_token.c, attr) != Token.get_struct_attr(other_token.c, attr):
|
||||||
similar = False
|
similar = False
|
||||||
break
|
break
|
||||||
if similar:
|
if similar:
|
||||||
|
|
|
@ -28,6 +28,7 @@ from ..attrs cimport (
|
||||||
LIKE_EMAIL,
|
LIKE_EMAIL,
|
||||||
LIKE_NUM,
|
LIKE_NUM,
|
||||||
LIKE_URL,
|
LIKE_URL,
|
||||||
|
ORTH,
|
||||||
)
|
)
|
||||||
from ..lexeme cimport Lexeme
|
from ..lexeme cimport Lexeme
|
||||||
from ..symbols cimport conj
|
from ..symbols cimport conj
|
||||||
|
@ -214,11 +215,17 @@ cdef class Token:
|
||||||
"""
|
"""
|
||||||
if "similarity" in self.doc.user_token_hooks:
|
if "similarity" in self.doc.user_token_hooks:
|
||||||
return self.doc.user_token_hooks["similarity"](self, other)
|
return self.doc.user_token_hooks["similarity"](self, other)
|
||||||
if hasattr(other, "__len__") and len(other) == 1 and hasattr(other, "__getitem__"):
|
attr = getattr(self.doc.vocab.vectors, "attr", ORTH)
|
||||||
if self.c.lex.orth == getattr(other[0], "orth", None):
|
cdef Token this_token = self
|
||||||
|
cdef Token other_token
|
||||||
|
cdef Lexeme other_lex
|
||||||
|
if isinstance(other, Token):
|
||||||
|
other_token = other
|
||||||
|
if Token.get_struct_attr(this_token.c, attr) == Token.get_struct_attr(other_token.c, attr):
|
||||||
return 1.0
|
return 1.0
|
||||||
elif hasattr(other, "orth"):
|
elif isinstance(other, Lexeme):
|
||||||
if self.c.lex.orth == other.orth:
|
other_lex = other
|
||||||
|
if Token.get_struct_attr(this_token.c, attr) == Lexeme.get_struct_attr(other_lex.c, attr):
|
||||||
return 1.0
|
return 1.0
|
||||||
if self.vocab.vectors.n_keys == 0:
|
if self.vocab.vectors.n_keys == 0:
|
||||||
warnings.warn(Warnings.W007.format(obj="Token"))
|
warnings.warn(Warnings.W007.format(obj="Token"))
|
||||||
|
@ -415,7 +422,7 @@ cdef class Token:
|
||||||
return self.doc.user_token_hooks["has_vector"](self)
|
return self.doc.user_token_hooks["has_vector"](self)
|
||||||
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
|
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
|
||||||
return True
|
return True
|
||||||
return self.vocab.has_vector(self.c.lex.orth)
|
return self.vocab.has_vector(Token.get_struct_attr(self.c, self.vocab.vectors.attr))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vector(self):
|
def vector(self):
|
||||||
|
@ -431,7 +438,7 @@ cdef class Token:
|
||||||
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
|
if self.vocab.vectors.size == 0 and self.doc.tensor.size != 0:
|
||||||
return self.doc.tensor[self.i]
|
return self.doc.tensor[self.i]
|
||||||
else:
|
else:
|
||||||
return self.vocab.get_vector(self.c.lex.orth)
|
return self.vocab.get_vector(Token.get_struct_attr(self.c, self.vocab.vectors.attr))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def vector_norm(self):
|
def vector_norm(self):
|
||||||
|
|
|
@ -216,9 +216,14 @@ def convert_vectors(
|
||||||
prune: int,
|
prune: int,
|
||||||
name: Optional[str] = None,
|
name: Optional[str] = None,
|
||||||
mode: str = VectorsMode.default,
|
mode: str = VectorsMode.default,
|
||||||
|
attr: str = "ORTH",
|
||||||
) -> None:
|
) -> None:
|
||||||
vectors_loc = ensure_path(vectors_loc)
|
vectors_loc = ensure_path(vectors_loc)
|
||||||
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
|
if vectors_loc and vectors_loc.parts[-1].endswith(".npz"):
|
||||||
|
if attr != "ORTH":
|
||||||
|
raise ValueError(
|
||||||
|
"ORTH is the only attribute supported for vectors in .npz format."
|
||||||
|
)
|
||||||
nlp.vocab.vectors = Vectors(
|
nlp.vocab.vectors = Vectors(
|
||||||
strings=nlp.vocab.strings, data=numpy.load(vectors_loc.open("rb"))
|
strings=nlp.vocab.strings, data=numpy.load(vectors_loc.open("rb"))
|
||||||
)
|
)
|
||||||
|
@ -246,11 +251,15 @@ def convert_vectors(
|
||||||
nlp.vocab.vectors = Vectors(
|
nlp.vocab.vectors = Vectors(
|
||||||
strings=nlp.vocab.strings,
|
strings=nlp.vocab.strings,
|
||||||
data=vectors_data,
|
data=vectors_data,
|
||||||
|
attr=attr,
|
||||||
**floret_settings,
|
**floret_settings,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
nlp.vocab.vectors = Vectors(
|
nlp.vocab.vectors = Vectors(
|
||||||
strings=nlp.vocab.strings, data=vectors_data, keys=vector_keys
|
strings=nlp.vocab.strings,
|
||||||
|
data=vectors_data,
|
||||||
|
keys=vector_keys,
|
||||||
|
attr=attr,
|
||||||
)
|
)
|
||||||
nlp.vocab.deduplicate_vectors()
|
nlp.vocab.deduplicate_vectors()
|
||||||
if name is None:
|
if name is None:
|
||||||
|
|
|
@ -15,9 +15,11 @@ from thinc.api import Ops, get_array_module, get_current_ops
|
||||||
from thinc.backends import get_array_ops
|
from thinc.backends import get_array_ops
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
|
|
||||||
|
from .attrs cimport ORTH, attr_id_t
|
||||||
from .strings cimport StringStore
|
from .strings cimport StringStore
|
||||||
|
|
||||||
from . import util
|
from . import util
|
||||||
|
from .attrs import IDS
|
||||||
from .errors import Errors, Warnings
|
from .errors import Errors, Warnings
|
||||||
from .strings import get_string_id
|
from .strings import get_string_id
|
||||||
|
|
||||||
|
@ -64,8 +66,9 @@ cdef class Vectors:
|
||||||
cdef readonly uint32_t hash_seed
|
cdef readonly uint32_t hash_seed
|
||||||
cdef readonly unicode bow
|
cdef readonly unicode bow
|
||||||
cdef readonly unicode eow
|
cdef readonly unicode eow
|
||||||
|
cdef readonly attr_id_t attr
|
||||||
|
|
||||||
def __init__(self, *, strings=None, shape=None, data=None, keys=None, name=None, mode=Mode.default, minn=0, maxn=0, hash_count=1, hash_seed=0, bow="<", eow=">"):
|
def __init__(self, *, strings=None, shape=None, data=None, keys=None, name=None, mode=Mode.default, minn=0, maxn=0, hash_count=1, hash_seed=0, bow="<", eow=">", attr="ORTH"):
|
||||||
"""Create a new vector store.
|
"""Create a new vector store.
|
||||||
|
|
||||||
strings (StringStore): The string store.
|
strings (StringStore): The string store.
|
||||||
|
@ -80,6 +83,8 @@ cdef class Vectors:
|
||||||
hash_seed (int): The floret hash seed (default: 0).
|
hash_seed (int): The floret hash seed (default: 0).
|
||||||
bow (str): The floret BOW string (default: "<").
|
bow (str): The floret BOW string (default: "<").
|
||||||
eow (str): The floret EOW string (default: ">").
|
eow (str): The floret EOW string (default: ">").
|
||||||
|
attr (Union[int, str]): The token attribute for the vector keys
|
||||||
|
(default: "ORTH").
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/vectors#init
|
DOCS: https://spacy.io/api/vectors#init
|
||||||
"""
|
"""
|
||||||
|
@ -103,6 +108,14 @@ cdef class Vectors:
|
||||||
self.hash_seed = hash_seed
|
self.hash_seed = hash_seed
|
||||||
self.bow = bow
|
self.bow = bow
|
||||||
self.eow = eow
|
self.eow = eow
|
||||||
|
if isinstance(attr, (int, long)):
|
||||||
|
self.attr = attr
|
||||||
|
else:
|
||||||
|
attr = attr.upper()
|
||||||
|
if attr == "TEXT":
|
||||||
|
attr = "ORTH"
|
||||||
|
self.attr = IDS.get(attr, ORTH)
|
||||||
|
|
||||||
if self.mode == Mode.default:
|
if self.mode == Mode.default:
|
||||||
if data is None:
|
if data is None:
|
||||||
if shape is None:
|
if shape is None:
|
||||||
|
@ -546,6 +559,7 @@ cdef class Vectors:
|
||||||
"hash_seed": self.hash_seed,
|
"hash_seed": self.hash_seed,
|
||||||
"bow": self.bow,
|
"bow": self.bow,
|
||||||
"eow": self.eow,
|
"eow": self.eow,
|
||||||
|
"attr": self.attr,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _set_cfg(self, cfg):
|
def _set_cfg(self, cfg):
|
||||||
|
@ -556,6 +570,7 @@ cdef class Vectors:
|
||||||
self.hash_seed = cfg.get("hash_seed", 0)
|
self.hash_seed = cfg.get("hash_seed", 0)
|
||||||
self.bow = cfg.get("bow", "<")
|
self.bow = cfg.get("bow", "<")
|
||||||
self.eow = cfg.get("eow", ">")
|
self.eow = cfg.get("eow", ">")
|
||||||
|
self.attr = cfg.get("attr", ORTH)
|
||||||
|
|
||||||
def to_disk(self, path, *, exclude=tuple()):
|
def to_disk(self, path, *, exclude=tuple()):
|
||||||
"""Save the current state to a directory.
|
"""Save the current state to a directory.
|
||||||
|
|
|
@ -365,8 +365,13 @@ cdef class Vocab:
|
||||||
self[orth]
|
self[orth]
|
||||||
# Make prob negative so it sorts by rank ascending
|
# Make prob negative so it sorts by rank ascending
|
||||||
# (key2row contains the rank)
|
# (key2row contains the rank)
|
||||||
priority = [(-lex.prob, self.vectors.key2row[lex.orth], lex.orth)
|
priority = []
|
||||||
for lex in self if lex.orth in self.vectors.key2row]
|
cdef Lexeme lex
|
||||||
|
cdef attr_t value
|
||||||
|
for lex in self:
|
||||||
|
value = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
|
||||||
|
if value in self.vectors.key2row:
|
||||||
|
priority.append((-lex.prob, self.vectors.key2row[value], value))
|
||||||
priority.sort()
|
priority.sort()
|
||||||
indices = xp.asarray([i for (prob, i, key) in priority], dtype="uint64")
|
indices = xp.asarray([i for (prob, i, key) in priority], dtype="uint64")
|
||||||
keys = xp.asarray([key for (prob, i, key) in priority], dtype="uint64")
|
keys = xp.asarray([key for (prob, i, key) in priority], dtype="uint64")
|
||||||
|
@ -399,8 +404,10 @@ cdef class Vocab:
|
||||||
"""
|
"""
|
||||||
if isinstance(orth, str):
|
if isinstance(orth, str):
|
||||||
orth = self.strings.add(orth)
|
orth = self.strings.add(orth)
|
||||||
if self.has_vector(orth):
|
cdef Lexeme lex = self[orth]
|
||||||
return self.vectors[orth]
|
key = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
|
||||||
|
if self.has_vector(key):
|
||||||
|
return self.vectors[key]
|
||||||
xp = get_array_module(self.vectors.data)
|
xp = get_array_module(self.vectors.data)
|
||||||
vectors = xp.zeros((self.vectors_length,), dtype="f")
|
vectors = xp.zeros((self.vectors_length,), dtype="f")
|
||||||
return vectors
|
return vectors
|
||||||
|
@ -416,15 +423,16 @@ cdef class Vocab:
|
||||||
"""
|
"""
|
||||||
if isinstance(orth, str):
|
if isinstance(orth, str):
|
||||||
orth = self.strings.add(orth)
|
orth = self.strings.add(orth)
|
||||||
if self.vectors.is_full and orth not in self.vectors:
|
cdef Lexeme lex = self[orth]
|
||||||
|
key = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
|
||||||
|
if self.vectors.is_full and key not in self.vectors:
|
||||||
new_rows = max(100, int(self.vectors.shape[0]*1.3))
|
new_rows = max(100, int(self.vectors.shape[0]*1.3))
|
||||||
if self.vectors.shape[1] == 0:
|
if self.vectors.shape[1] == 0:
|
||||||
width = vector.size
|
width = vector.size
|
||||||
else:
|
else:
|
||||||
width = self.vectors.shape[1]
|
width = self.vectors.shape[1]
|
||||||
self.vectors.resize((new_rows, width))
|
self.vectors.resize((new_rows, width))
|
||||||
lex = self[orth] # Add word to vocab if necessary
|
row = self.vectors.add(key, vector=vector)
|
||||||
row = self.vectors.add(orth, vector=vector)
|
|
||||||
if row >= 0:
|
if row >= 0:
|
||||||
lex.rank = row
|
lex.rank = row
|
||||||
|
|
||||||
|
@ -439,7 +447,9 @@ cdef class Vocab:
|
||||||
"""
|
"""
|
||||||
if isinstance(orth, str):
|
if isinstance(orth, str):
|
||||||
orth = self.strings.add(orth)
|
orth = self.strings.add(orth)
|
||||||
return orth in self.vectors
|
cdef Lexeme lex = self[orth]
|
||||||
|
key = Lexeme.get_struct_attr(lex.c, self.vectors.attr)
|
||||||
|
return key in self.vectors
|
||||||
|
|
||||||
property lookups:
|
property lookups:
|
||||||
def __get__(self):
|
def __get__(self):
|
||||||
|
|
|
@ -303,7 +303,7 @@ mapped to a zero vector. See the documentation on
|
||||||
| `nM` | The width of the static vectors. ~~Optional[int]~~ |
|
| `nM` | The width of the static vectors. ~~Optional[int]~~ |
|
||||||
| `dropout` | Optional dropout rate. If set, it's applied per dimension over the whole batch. Defaults to `None`. ~~Optional[float]~~ |
|
| `dropout` | Optional dropout rate. If set, it's applied per dimension over the whole batch. Defaults to `None`. ~~Optional[float]~~ |
|
||||||
| `init_W` | The [initialization function](https://thinc.ai/docs/api-initializers). Defaults to [`glorot_uniform_init`](https://thinc.ai/docs/api-initializers#glorot_uniform_init). ~~Callable[[Ops, Tuple[int, ...]]], FloatsXd]~~ |
|
| `init_W` | The [initialization function](https://thinc.ai/docs/api-initializers). Defaults to [`glorot_uniform_init`](https://thinc.ai/docs/api-initializers#glorot_uniform_init). ~~Callable[[Ops, Tuple[int, ...]]], FloatsXd]~~ |
|
||||||
| `key_attr` | Defaults to `"ORTH"`. ~~str~~ |
|
| `key_attr` | This setting is ignored in spaCy v3.6+. To set a custom key attribute for vectors, configure it through [`Vectors`](/api/vectors) or [`spacy init vectors`](/api/cli#init-vectors). Defaults to `"ORTH"`. ~~str~~ |
|
||||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Ragged]~~ |
|
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Ragged]~~ |
|
||||||
|
|
||||||
### spacy.FeatureExtractor.v1 {id="FeatureExtractor"}
|
### spacy.FeatureExtractor.v1 {id="FeatureExtractor"}
|
||||||
|
|
|
@ -211,7 +211,8 @@ $ python -m spacy init vectors [lang] [vectors_loc] [output_dir] [--prune] [--tr
|
||||||
| `output_dir` | Pipeline output directory. Will be created if it doesn't exist. ~~Path (positional)~~ |
|
| `output_dir` | Pipeline output directory. Will be created if it doesn't exist. ~~Path (positional)~~ |
|
||||||
| `--truncate`, `-t` | Number of vectors to truncate to when reading in vectors file. Defaults to `0` for no truncation. ~~int (option)~~ |
|
| `--truncate`, `-t` | Number of vectors to truncate to when reading in vectors file. Defaults to `0` for no truncation. ~~int (option)~~ |
|
||||||
| `--prune`, `-p` | Number of vectors to prune the vocabulary to. Defaults to `-1` for no pruning. ~~int (option)~~ |
|
| `--prune`, `-p` | Number of vectors to prune the vocabulary to. Defaults to `-1` for no pruning. ~~int (option)~~ |
|
||||||
| `--mode`, `-m` | Vectors mode: `default` or [`floret`](https://github.com/explosion/floret). Defaults to `default`. ~~Optional[str] \(option)~~ |
|
| `--mode`, `-m` | Vectors mode: `default` or [`floret`](https://github.com/explosion/floret). Defaults to `default`. ~~str \(option)~~ |
|
||||||
|
| `--attr`, `-a` | Token attribute to use for vectors, e.g. `LOWER` or `NORM`) Defaults to `ORTH`. ~~str \(option)~~ |
|
||||||
| `--name`, `-n` | Name to assign to the word vectors in the `meta.json`, e.g. `en_core_web_md.vectors`. ~~Optional[str] \(option)~~ |
|
| `--name`, `-n` | Name to assign to the word vectors in the `meta.json`, e.g. `en_core_web_md.vectors`. ~~Optional[str] \(option)~~ |
|
||||||
| `--verbose`, `-V` | Print additional information and explanations. ~~bool (flag)~~ |
|
| `--verbose`, `-V` | Print additional information and explanations. ~~bool (flag)~~ |
|
||||||
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
| `--help`, `-h` | Show help message and available arguments. ~~bool (flag)~~ |
|
||||||
|
|
|
@ -60,6 +60,7 @@ modified later.
|
||||||
| `hash_seed` <Tag variant="new">3.2</Tag> | The floret hash seed (default: `0`). ~~int~~ |
|
| `hash_seed` <Tag variant="new">3.2</Tag> | The floret hash seed (default: `0`). ~~int~~ |
|
||||||
| `bow` <Tag variant="new">3.2</Tag> | The floret BOW string (default: `"<"`). ~~str~~ |
|
| `bow` <Tag variant="new">3.2</Tag> | The floret BOW string (default: `"<"`). ~~str~~ |
|
||||||
| `eow` <Tag variant="new">3.2</Tag> | The floret EOW string (default: `">"`). ~~str~~ |
|
| `eow` <Tag variant="new">3.2</Tag> | The floret EOW string (default: `">"`). ~~str~~ |
|
||||||
|
| `attr` <Tag variant="new">3.6</Tag> | The token attribute for the vector keys (default: `"ORTH"`). ~~Union[int, str]~~ |
|
||||||
|
|
||||||
## Vectors.\_\_getitem\_\_ {id="getitem",tag="method"}
|
## Vectors.\_\_getitem\_\_ {id="getitem",tag="method"}
|
||||||
|
|
||||||
|
@ -454,7 +455,8 @@ Load state from a binary string.
|
||||||
## Attributes {id="attributes"}
|
## Attributes {id="attributes"}
|
||||||
|
|
||||||
| Name | Description |
|
| Name | Description |
|
||||||
| --------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
| ----------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `data` | Stored vectors data. `numpy` is used for CPU vectors, `cupy` for GPU vectors. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ |
|
| `data` | Stored vectors data. `numpy` is used for CPU vectors, `cupy` for GPU vectors. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ |
|
||||||
| `key2row` | Dictionary mapping word hashes to rows in the `Vectors.data` table. ~~Dict[int, int]~~ |
|
| `key2row` | Dictionary mapping word hashes to rows in the `Vectors.data` table. ~~Dict[int, int]~~ |
|
||||||
| `keys` | Array keeping the keys in order, such that `keys[vectors.key2row[key]] == key`. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ |
|
| `keys` | Array keeping the keys in order, such that `keys[vectors.key2row[key]] == key`. ~~Union[numpy.ndarray[ndim=1, dtype=float32], cupy.ndarray[ndim=1, dtype=float32]]~~ |
|
||||||
|
| `attr` <Tag variant="new">3.6</Tag> | The token attribute for the vector keys. ~~int~~ |
|
||||||
|
|
Loading…
Reference in New Issue
Block a user