mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-27 02:16:32 +03:00
Span richcmp fix (#9956)
* Corrected Span's __richcmp__ implementation to take end, label and kb_id in consideration * Updated test * Updated test * Removed formatting from a test for readability sake * Use same tuples for all comparisons Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
parent
e8a047a8d4
commit
47ea6704f1
|
@ -573,6 +573,55 @@ def test_span_with_vectors(doc):
|
||||||
doc.vocab.vectors = prev_vectors
|
doc.vocab.vectors = prev_vectors
|
||||||
|
|
||||||
|
|
||||||
|
# fmt: off
|
||||||
|
def test_span_comparison(doc):
|
||||||
|
|
||||||
|
# Identical start, end, only differ in label and kb_id
|
||||||
|
assert Span(doc, 0, 3) == Span(doc, 0, 3)
|
||||||
|
assert Span(doc, 0, 3, "LABEL") == Span(doc, 0, 3, "LABEL")
|
||||||
|
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") == Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||||
|
|
||||||
|
assert Span(doc, 0, 3) != Span(doc, 0, 3, "LABEL")
|
||||||
|
assert Span(doc, 0, 3) != Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||||
|
assert Span(doc, 0, 3, "LABEL") != Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||||
|
|
||||||
|
assert Span(doc, 0, 3) <= Span(doc, 0, 3) and Span(doc, 0, 3) >= Span(doc, 0, 3)
|
||||||
|
assert Span(doc, 0, 3, "LABEL") <= Span(doc, 0, 3, "LABEL") and Span(doc, 0, 3, "LABEL") >= Span(doc, 0, 3, "LABEL")
|
||||||
|
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||||
|
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||||
|
|
||||||
|
assert (Span(doc, 0, 3) < Span(doc, 0, 3, "", kb_id="KB_ID") < Span(doc, 0, 3, "LABEL") < Span(doc, 0, 3, "LABEL", kb_id="KB_ID"))
|
||||||
|
assert (Span(doc, 0, 3) <= Span(doc, 0, 3, "", kb_id="KB_ID") <= Span(doc, 0, 3, "LABEL") <= Span(doc, 0, 3, "LABEL", kb_id="KB_ID"))
|
||||||
|
|
||||||
|
assert (Span(doc, 0, 3, "LABEL", kb_id="KB_ID") > Span(doc, 0, 3, "LABEL") > Span(doc, 0, 3, "", kb_id="KB_ID") > Span(doc, 0, 3))
|
||||||
|
assert (Span(doc, 0, 3, "LABEL", kb_id="KB_ID") >= Span(doc, 0, 3, "LABEL") >= Span(doc, 0, 3, "", kb_id="KB_ID") >= Span(doc, 0, 3))
|
||||||
|
|
||||||
|
# Different end
|
||||||
|
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 0, 4, "LABEL", kb_id="KB_ID")
|
||||||
|
|
||||||
|
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 0, 4)
|
||||||
|
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 0, 4)
|
||||||
|
assert Span(doc, 0, 4) > Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||||
|
assert Span(doc, 0, 4) >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||||
|
|
||||||
|
# Different start
|
||||||
|
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") != Span(doc, 1, 3, "LABEL", kb_id="KB_ID")
|
||||||
|
|
||||||
|
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") < Span(doc, 1, 3)
|
||||||
|
assert Span(doc, 0, 3, "LABEL", kb_id="KB_ID") <= Span(doc, 1, 3)
|
||||||
|
assert Span(doc, 1, 3) > Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||||
|
assert Span(doc, 1, 3) >= Span(doc, 0, 3, "LABEL", kb_id="KB_ID")
|
||||||
|
|
||||||
|
# Different start & different end
|
||||||
|
assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") != Span(doc, 1, 3, "LABEL", kb_id="KB_ID")
|
||||||
|
|
||||||
|
assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") < Span(doc, 1, 3)
|
||||||
|
assert Span(doc, 0, 4, "LABEL", kb_id="KB_ID") <= Span(doc, 1, 3)
|
||||||
|
assert Span(doc, 1, 3) > Span(doc, 0, 4, "LABEL", kb_id="KB_ID")
|
||||||
|
assert Span(doc, 1, 3) >= Span(doc, 0, 4, "LABEL", kb_id="KB_ID")
|
||||||
|
# fmt: on
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"start,end,expected_sentences,expected_sentences_with_hook",
|
"start,end,expected_sentences,expected_sentences_with_hook",
|
||||||
[
|
[
|
||||||
|
|
|
@ -126,38 +126,26 @@ cdef class Span:
|
||||||
return False
|
return False
|
||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
self_tuple = (self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id, self.doc)
|
||||||
|
other_tuple = (other.c.start_char, other.c.end_char, other.c.label, other.c.kb_id, other.doc)
|
||||||
# <
|
# <
|
||||||
if op == 0:
|
if op == 0:
|
||||||
return self.c.start_char < other.c.start_char
|
return self_tuple < other_tuple
|
||||||
# <=
|
# <=
|
||||||
elif op == 1:
|
elif op == 1:
|
||||||
return self.c.start_char <= other.c.start_char
|
return self_tuple <= other_tuple
|
||||||
# ==
|
# ==
|
||||||
elif op == 2:
|
elif op == 2:
|
||||||
# Do the cheap comparisons first
|
return self_tuple == other_tuple
|
||||||
return (
|
|
||||||
(self.c.start_char == other.c.start_char) and \
|
|
||||||
(self.c.end_char == other.c.end_char) and \
|
|
||||||
(self.c.label == other.c.label) and \
|
|
||||||
(self.c.kb_id == other.c.kb_id) and \
|
|
||||||
(self.doc == other.doc)
|
|
||||||
)
|
|
||||||
# !=
|
# !=
|
||||||
elif op == 3:
|
elif op == 3:
|
||||||
# Do the cheap comparisons first
|
return self_tuple != other_tuple
|
||||||
return not (
|
|
||||||
(self.c.start_char == other.c.start_char) and \
|
|
||||||
(self.c.end_char == other.c.end_char) and \
|
|
||||||
(self.c.label == other.c.label) and \
|
|
||||||
(self.c.kb_id == other.c.kb_id) and \
|
|
||||||
(self.doc == other.doc)
|
|
||||||
)
|
|
||||||
# >
|
# >
|
||||||
elif op == 4:
|
elif op == 4:
|
||||||
return self.c.start_char > other.c.start_char
|
return self_tuple > other_tuple
|
||||||
# >=
|
# >=
|
||||||
elif op == 5:
|
elif op == 5:
|
||||||
return self.c.start_char >= other.c.start_char
|
return self_tuple >= other_tuple
|
||||||
|
|
||||||
def __hash__(self):
|
def __hash__(self):
|
||||||
return hash((self.doc, self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id))
|
return hash((self.doc, self.c.start_char, self.c.end_char, self.c.label, self.c.kb_id))
|
||||||
|
|
Loading…
Reference in New Issue
Block a user