mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-13 05:07:03 +03:00
Add equality definition for vectors (#11806)
* Add equality definition for vectors This re-uses the check from sourcing components. * Use the equality check * Format Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
caa9efad59
commit
c0c54e44bc
|
@ -706,13 +706,7 @@ class Language:
|
||||||
# Check source type
|
# Check source type
|
||||||
if not isinstance(source, Language):
|
if not isinstance(source, Language):
|
||||||
raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
|
raise ValueError(Errors.E945.format(name=source_name, source=type(source)))
|
||||||
# Check vectors, with faster checks first
|
if self.vocab.vectors != source.vocab.vectors:
|
||||||
if (
|
|
||||||
self.vocab.vectors.shape != source.vocab.vectors.shape
|
|
||||||
or self.vocab.vectors.key2row != source.vocab.vectors.key2row
|
|
||||||
or self.vocab.vectors.to_bytes(exclude=["strings"])
|
|
||||||
!= source.vocab.vectors.to_bytes(exclude=["strings"])
|
|
||||||
):
|
|
||||||
warnings.warn(Warnings.W113.format(name=source_name))
|
warnings.warn(Warnings.W113.format(name=source_name))
|
||||||
if source_name not in source.component_names:
|
if source_name not in source.component_names:
|
||||||
raise KeyError(
|
raise KeyError(
|
||||||
|
|
|
@ -626,3 +626,23 @@ def test_floret_vectors(floret_vectors_vec_str, floret_vectors_hashvec_str):
|
||||||
OPS.to_numpy(vocab_r[word].vector),
|
OPS.to_numpy(vocab_r[word].vector),
|
||||||
decimal=6,
|
decimal=6,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_equality():
|
||||||
|
vectors1 = Vectors(shape=(10, 10))
|
||||||
|
vectors2 = Vectors(shape=(10, 8))
|
||||||
|
|
||||||
|
assert vectors1 != vectors2
|
||||||
|
|
||||||
|
vectors2 = Vectors(shape=(10, 10))
|
||||||
|
assert vectors1 == vectors2
|
||||||
|
|
||||||
|
vectors1.add("hello", row=2)
|
||||||
|
assert vectors1 != vectors2
|
||||||
|
|
||||||
|
vectors2.add("hello", row=2)
|
||||||
|
assert vectors1 == vectors2
|
||||||
|
|
||||||
|
vectors1.resize((5, 9))
|
||||||
|
vectors2.resize((5, 9))
|
||||||
|
assert vectors1 == vectors2
|
||||||
|
|
|
@ -243,6 +243,15 @@ cdef class Vectors:
|
||||||
else:
|
else:
|
||||||
return key in self.key2row
|
return key in self.key2row
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
# Check for equality, with faster checks first
|
||||||
|
return (
|
||||||
|
self.shape == other.shape
|
||||||
|
and self.key2row == other.key2row
|
||||||
|
and self.to_bytes(exclude=["strings"])
|
||||||
|
== other.to_bytes(exclude=["strings"])
|
||||||
|
)
|
||||||
|
|
||||||
def resize(self, shape, inplace=False):
|
def resize(self, shape, inplace=False):
|
||||||
"""Resize the underlying vectors array. If inplace=True, the memory
|
"""Resize the underlying vectors array. If inplace=True, the memory
|
||||||
is reallocated. This may cause other references to the data to become
|
is reallocated. This may cause other references to the data to become
|
||||||
|
|
Loading…
Reference in New Issue
Block a user