mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-26 09:14:32 +03:00
Sync Span __eq__ and __hash__ (#5005)
* Sync Span __eq__ and __hash__ Use the same tuple for `__eq__` and `__hash__`, including all attributes except `vector` and `vector_norm`. * Update entity comparison in tests Update `assert_docs_equal()` test util to compare `Span` properties for ents rather than `Span` objects.
This commit is contained in:
parent
0c47a53b5e
commit
3b22eb651b
|
@ -279,3 +279,12 @@ def test_filter_spans(doc):
|
|||
assert len(filtered[1]) == 5
|
||||
assert filtered[0].start == 1 and filtered[0].end == 4
|
||||
assert filtered[1].start == 5 and filtered[1].end == 10
|
||||
|
||||
|
||||
def test_span_eq_hash(doc, doc_not_parsed):
|
||||
assert doc[0:2] == doc[0:2]
|
||||
assert doc[0:2] != doc[1:3]
|
||||
assert doc[0:2] != doc_not_parsed[0:2]
|
||||
assert hash(doc[0:2]) == hash(doc[0:2])
|
||||
assert hash(doc[0:2]) != hash(doc[1:3])
|
||||
assert hash(doc[0:2]) != hash(doc_not_parsed[0:2])
|
||||
|
|
|
@ -95,7 +95,11 @@ def assert_docs_equal(doc1, doc2):
|
|||
|
||||
assert [t.ent_type for t in doc1] == [t.ent_type for t in doc2]
|
||||
assert [t.ent_iob for t in doc1] == [t.ent_iob for t in doc2]
|
||||
assert [ent for ent in doc1.ents] == [ent for ent in doc2.ents]
|
||||
for ent1, ent2 in zip(doc1.ents, doc2.ents):
|
||||
assert ent1.start == ent2.start
|
||||
assert ent1.end == ent2.end
|
||||
assert ent1.label == ent2.label
|
||||
assert ent1.kb_id == ent2.kb_id
|
||||
|
||||
|
||||
def assert_packed_msg_equal(b1, b2):
|
||||
|
|
|
@ -127,22 +127,27 @@ cdef class Span:
|
|||
return False
|
||||
else:
|
||||
return True
|
||||
# Eq
|
||||
# <
|
||||
if op == 0:
|
||||
return self.start_char < other.start_char
|
||||
# <=
|
||||
elif op == 1:
|
||||
return self.start_char <= other.start_char
|
||||
# ==
|
||||
elif op == 2:
|
||||
return self.start_char == other.start_char and self.end_char == other.end_char
|
||||
return (self.doc, self.start_char, self.end_char, self.label, self.kb_id) == (other.doc, other.start_char, other.end_char, other.label, other.kb_id)
|
||||
# !=
|
||||
elif op == 3:
|
||||
return self.start_char != other.start_char or self.end_char != other.end_char
|
||||
return (self.doc, self.start_char, self.end_char, self.label, self.kb_id) != (other.doc, other.start_char, other.end_char, other.label, other.kb_id)
|
||||
# >
|
||||
elif op == 4:
|
||||
return self.start_char > other.start_char
|
||||
# >=
|
||||
elif op == 5:
|
||||
return self.start_char >= other.start_char
|
||||
|
||||
def __hash__(self):
|
||||
return hash((self.doc, self.label, self.start_char, self.end_char))
|
||||
return hash((self.doc, self.start_char, self.end_char, self.label, self.kb_id))
|
||||
|
||||
def __len__(self):
|
||||
"""Get the number of tokens in the span.
|
||||
|
|
Loading…
Reference in New Issue
Block a user