Sync Span __eq__ and __hash__ (#5005)

* Sync Span __eq__ and __hash__

Use the same tuple for `__eq__` and `__hash__`, including all attributes
except `vector` and `vector_norm`.

* Update entity comparison in tests

Update `assert_docs_equal()` test util to compare `Span` properties for
ents rather than `Span` objects.
This commit is contained in:
adrianeboyd 2020-02-16 17:20:36 +01:00 committed by GitHub
parent 0c47a53b5e
commit 3b22eb651b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 23 additions and 5 deletions

View File

@ -279,3 +279,12 @@ def test_filter_spans(doc):
assert len(filtered[1]) == 5 assert len(filtered[1]) == 5
assert filtered[0].start == 1 and filtered[0].end == 4 assert filtered[0].start == 1 and filtered[0].end == 4
assert filtered[1].start == 5 and filtered[1].end == 10 assert filtered[1].start == 5 and filtered[1].end == 10
def test_span_eq_hash(doc, doc_not_parsed):
assert doc[0:2] == doc[0:2]
assert doc[0:2] != doc[1:3]
assert doc[0:2] != doc_not_parsed[0:2]
assert hash(doc[0:2]) == hash(doc[0:2])
assert hash(doc[0:2]) != hash(doc[1:3])
assert hash(doc[0:2]) != hash(doc_not_parsed[0:2])

View File

@ -95,7 +95,11 @@ def assert_docs_equal(doc1, doc2):
assert [t.ent_type for t in doc1] == [t.ent_type for t in doc2] assert [t.ent_type for t in doc1] == [t.ent_type for t in doc2]
assert [t.ent_iob for t in doc1] == [t.ent_iob for t in doc2] assert [t.ent_iob for t in doc1] == [t.ent_iob for t in doc2]
assert [ent for ent in doc1.ents] == [ent for ent in doc2.ents] for ent1, ent2 in zip(doc1.ents, doc2.ents):
assert ent1.start == ent2.start
assert ent1.end == ent2.end
assert ent1.label == ent2.label
assert ent1.kb_id == ent2.kb_id
def assert_packed_msg_equal(b1, b2): def assert_packed_msg_equal(b1, b2):

View File

@ -127,22 +127,27 @@ cdef class Span:
return False return False
else: else:
return True return True
# Eq # <
if op == 0: if op == 0:
return self.start_char < other.start_char return self.start_char < other.start_char
# <=
elif op == 1: elif op == 1:
return self.start_char <= other.start_char return self.start_char <= other.start_char
# ==
elif op == 2: elif op == 2:
return self.start_char == other.start_char and self.end_char == other.end_char return (self.doc, self.start_char, self.end_char, self.label, self.kb_id) == (other.doc, other.start_char, other.end_char, other.label, other.kb_id)
# !=
elif op == 3: elif op == 3:
return self.start_char != other.start_char or self.end_char != other.end_char return (self.doc, self.start_char, self.end_char, self.label, self.kb_id) != (other.doc, other.start_char, other.end_char, other.label, other.kb_id)
# >
elif op == 4: elif op == 4:
return self.start_char > other.start_char return self.start_char > other.start_char
# >=
elif op == 5: elif op == 5:
return self.start_char >= other.start_char return self.start_char >= other.start_char
def __hash__(self): def __hash__(self):
return hash((self.doc, self.label, self.start_char, self.end_char)) return hash((self.doc, self.start_char, self.end_char, self.label, self.kb_id))
def __len__(self): def __len__(self):
"""Get the number of tokens in the span. """Get the number of tokens in the span.