Span richcmp fix (#9956)

* Corrected Span's __richcmp__ implementation to take end, label and kb_id in consideration

* Updated test

* Updated test

* Removed formatting from a test for readability sake

* Use same tuples for all comparisons

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Natalia Rodnova 2022-01-17 03:17:49 -07:00 committed by GitHub
parent e8a047a8d4
commit 47ea6704f1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 20 deletions

View File

@ -573,6 +573,55 @@ def test_span_with_vectors(doc):
doc.vocab.vectors = prev_vectors
# fmt: off
def test_span_comparison(doc):
# Identical start, end, only differ in label and kb_id
assert Span(doc, 0, 3) == Span(doc, 0, 3)
assert Span(doc, 0, 3, "LABEL") == Span(doc, 0, 3, "LABEL")
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") == Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
assert Span(doc, 0, 3) != Span(doc, 0, 3, "LABEL")
assert Span(doc, 0, 3) != Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
assert Span(doc, 0, 3, "LABEL") != Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
assert Span(doc, 0, 3) <= Span(doc, 0, 3) and Span(doc, 0, 3) >= Span(doc, 0, 3)
assert Span(doc, 0, 3, "LABEL") <= Span(doc, 0, 3, "LABEL") and Span(doc, 0, 3, "LABEL") >= Span(doc, 0, 3, "LABEL")
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
assert (Span(doc, 0, 3) < Span(doc, 0, 3, "", kb_id="KB_ID") < Span(doc, 0, 3, "LABEL") < Span(doc, 0, 3, "LABEL", kb_id="KB_ID"))
assert (Span(doc, 0, 3) <= Span(doc, 0, 3, "", kb_id="KB_ID") <= Span(doc, 0, 3, "LABEL") <= Span(doc, 0, 3, "LABEL", kb_id="KB_ID"))
assert (Span(doc, 0, 3, "LABEL", kb_id="KB_ID") > Span(doc, 0, 3, "LABEL") > Span(doc, 0, 3, "", kb_id="KB_ID") > Span(doc, 0, 3))
assert (Span(doc, 0, 3, "LABEL", kb_id="KB_ID") >= Span(doc, 0, 3, "LABEL") >= Span(doc, 0, 3, "", kb_id="KB_ID") >= Span(doc, 0, 3))
# Different end
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 0, 4, "LABEL", kb_id="KB_ID")
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 0, 4)
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 0, 4)
assert Span(doc, 0, 4) > Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
assert Span(doc, 0, 4) >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
# Different start
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") != Span(doc, 1, 3, "LABEL", kb_id="KB_ID")
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 1, 3)
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 1, 3)
assert Span(doc, 1, 3) > Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
assert Span(doc, 1, 3) >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
# Different start & different end
assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") != Span(doc, 1, 3, "LABEL", kb_id="KB_ID")
assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") < Span(doc, 1, 3)
assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") <= Span(doc, 1, 3)
assert Span(doc, 1, 3) > Span(doc, 0, 4, "LABEL", kb_id="KB_ID")
assert Span(doc, 1, 3) >= Span(doc, 0, 4, "LABEL", kb_id="KB_ID")
# fmt: on
@pytest.mark.parametrize(
"start,end,expected_sentences,expected_sentences_with_hook",
[

View File

@ -126,38 +126,26 @@ cdef class Span:
return False
else:
return True
self_tuple = (self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id, self.doc)
other_tuple = (other.c.start_char, other.c.end_char, other.c.label, other.c.kb_id, other.doc)
# <
if op == 0:
return self.c.start_char < other.c.start_char
return self_tuple < other_tuple
# <=
elif op == 1:
return self.c.start_char <= other.c.start_char
return self_tuple <= other_tuple
# ==
elif op == 2:
# Do the cheap comparisons first
return (
(self.c.start_char == other.c.start_char) and \
(self.c.end_char == other.c.end_char) and \
(self.c.label == other.c.label) and \
(self.c.kb_id == other.c.kb_id) and \
(self.doc == other.doc)
)
return self_tuple == other_tuple
# !=
elif op == 3:
# Do the cheap comparisons first
return not (
(self.c.start_char == other.c.start_char) and \
(self.c.end_char == other.c.end_char) and \
(self.c.label == other.c.label) and \
(self.c.kb_id == other.c.kb_id) and \
(self.doc == other.doc)
)
return self_tuple != other_tuple
# >
elif op == 4:
return self.c.start_char > other.c.start_char
return self_tuple > other_tuple
# >=
elif op == 5:
return self.c.start_char >= other.c.start_char
return self_tuple >= other_tuple
def __hash__(self):
return hash((self.doc, self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id))