mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-18 21:44:12 +03:00
c4112a1da3
* Require that all SpanGroup spans are from the current doc The restriction on only adding spans from the current doc were already implemented for all operations except for `SpanGroup.__init__`. Initialize copied spans for `SpanGroup.copy` with `Doc.char_span` in order to validate the character offsets and to make it possible to copy spans between documents with differing tokenization. Currently there is no validation that the document texts are identical, but the span char offsets must be valid spans in the target doc, which prevents you from ending up with completely invalid spans. * Undo change in test_beam_overfitting_IO
280 lines
8.5 KiB
Python
280 lines
8.5 KiB
Python
from typing import List
|
|
|
|
import pytest
|
|
from random import Random
|
|
from spacy.matcher import Matcher
|
|
from spacy.tokens import Span, SpanGroup, Doc
|
|
from spacy.util import filter_spans
|
|
|
|
|
|
@pytest.fixture
|
|
def doc(en_tokenizer):
|
|
doc = en_tokenizer("0 1 2 3 4 5 6")
|
|
matcher = Matcher(en_tokenizer.vocab, validate=True)
|
|
|
|
# fmt: off
|
|
matcher.add("4", [[{}, {}, {}, {}]])
|
|
matcher.add("2", [[{}, {}, ]])
|
|
matcher.add("1", [[{}, ]])
|
|
# fmt: on
|
|
matches = matcher(doc)
|
|
spans = []
|
|
for match in matches:
|
|
spans.append(
|
|
Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
|
|
)
|
|
Random(42).shuffle(spans)
|
|
doc.spans["SPANS"] = SpanGroup(
|
|
doc, name="SPANS", attrs={"key": "value"}, spans=spans
|
|
)
|
|
return doc
|
|
|
|
|
|
@pytest.fixture
|
|
def other_doc(en_tokenizer):
|
|
doc = en_tokenizer("0 1 2 3 4 5 6")
|
|
matcher = Matcher(en_tokenizer.vocab, validate=True)
|
|
|
|
# fmt: off
|
|
matcher.add("4", [[{}, {}, {}, {}]])
|
|
matcher.add("2", [[{}, {}, ]])
|
|
matcher.add("1", [[{}, ]])
|
|
# fmt: on
|
|
|
|
matches = matcher(doc)
|
|
spans = []
|
|
for match in matches:
|
|
spans.append(
|
|
Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
|
|
)
|
|
Random(42).shuffle(spans)
|
|
doc.spans["SPANS"] = SpanGroup(
|
|
doc, name="SPANS", attrs={"key": "value"}, spans=spans
|
|
)
|
|
return doc
|
|
|
|
|
|
@pytest.fixture
|
|
def span_group(en_tokenizer):
|
|
doc = en_tokenizer("0 1 2 3 4 5 6")
|
|
matcher = Matcher(en_tokenizer.vocab, validate=True)
|
|
|
|
# fmt: off
|
|
matcher.add("4", [[{}, {}, {}, {}]])
|
|
matcher.add("2", [[{}, {}, ]])
|
|
matcher.add("1", [[{}, ]])
|
|
# fmt: on
|
|
|
|
matches = matcher(doc)
|
|
spans = []
|
|
for match in matches:
|
|
spans.append(
|
|
Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
|
|
)
|
|
Random(42).shuffle(spans)
|
|
doc.spans["SPANS"] = SpanGroup(
|
|
doc, name="SPANS", attrs={"key": "value"}, spans=spans
|
|
)
|
|
|
|
|
|
def test_span_group_copy(doc):
|
|
span_group = doc.spans["SPANS"]
|
|
clone = span_group.copy()
|
|
assert clone != span_group
|
|
assert clone.name == span_group.name
|
|
assert clone.attrs == span_group.attrs
|
|
assert len(clone) == len(span_group)
|
|
assert list(span_group) == list(clone)
|
|
clone.name = "new_name"
|
|
clone.attrs["key"] = "new_value"
|
|
clone.append(Span(doc, 0, 6, "LABEL"))
|
|
assert clone.name != span_group.name
|
|
assert clone.attrs != span_group.attrs
|
|
assert span_group.attrs["key"] == "value"
|
|
assert list(span_group) != list(clone)
|
|
|
|
# can't copy if the character offsets don't align to tokens
|
|
doc2 = Doc(doc.vocab, words=[t.text + "x" for t in doc])
|
|
with pytest.raises(ValueError):
|
|
span_group.copy(doc=doc2)
|
|
|
|
# can copy with valid character offsets despite different tokenization
|
|
doc3 = doc.copy()
|
|
with doc3.retokenize() as retokenizer:
|
|
retokenizer.merge(doc3[0:2])
|
|
retokenizer.merge(doc3[3:6])
|
|
span_group = SpanGroup(doc, spans=[doc[0:6], doc[3:6]])
|
|
for span1, span2 in zip(span_group, span_group.copy(doc=doc3)):
|
|
assert span1.start_char == span2.start_char
|
|
assert span1.end_char == span2.end_char
|
|
|
|
|
|
def test_span_group_set_item(doc, other_doc):
|
|
span_group = doc.spans["SPANS"]
|
|
|
|
index = 5
|
|
span = span_group[index]
|
|
span.label_ = "NEW LABEL"
|
|
span.kb_id = doc.vocab.strings["KB_ID"]
|
|
|
|
assert span_group[index].label != span.label
|
|
assert span_group[index].kb_id != span.kb_id
|
|
|
|
span_group[index] = span
|
|
assert span_group[index].start == span.start
|
|
assert span_group[index].end == span.end
|
|
assert span_group[index].label == span.label
|
|
assert span_group[index].kb_id == span.kb_id
|
|
assert span_group[index] == span
|
|
|
|
with pytest.raises(IndexError):
|
|
span_group[-100] = span
|
|
with pytest.raises(IndexError):
|
|
span_group[100] = span
|
|
|
|
span = Span(other_doc, 0, 2)
|
|
with pytest.raises(ValueError):
|
|
span_group[index] = span
|
|
|
|
|
|
def test_span_group_has_overlap(doc):
|
|
span_group = doc.spans["SPANS"]
|
|
assert span_group.has_overlap
|
|
|
|
|
|
def test_span_group_concat(doc, other_doc):
|
|
span_group_1 = doc.spans["SPANS"]
|
|
spans = [doc[0:5], doc[0:6]]
|
|
span_group_2 = SpanGroup(
|
|
doc,
|
|
name="MORE_SPANS",
|
|
attrs={"key": "new_value", "new_key": "new_value"},
|
|
spans=spans,
|
|
)
|
|
span_group_3 = span_group_1._concat(span_group_2)
|
|
assert span_group_3.name == span_group_1.name
|
|
assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
|
|
span_list_expected = list(span_group_1) + list(span_group_2)
|
|
assert list(span_group_3) == list(span_list_expected)
|
|
|
|
# Inplace
|
|
span_list_expected = list(span_group_1) + list(span_group_2)
|
|
span_group_3 = span_group_1._concat(span_group_2, inplace=True)
|
|
assert span_group_3 == span_group_1
|
|
assert span_group_3.name == span_group_1.name
|
|
assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
|
|
assert list(span_group_3) == list(span_list_expected)
|
|
|
|
span_group_2 = other_doc.spans["SPANS"]
|
|
with pytest.raises(ValueError):
|
|
span_group_1._concat(span_group_2)
|
|
|
|
|
|
def test_span_doc_delitem(doc):
|
|
span_group = doc.spans["SPANS"]
|
|
length = len(span_group)
|
|
index = 5
|
|
span = span_group[index]
|
|
next_span = span_group[index + 1]
|
|
del span_group[index]
|
|
assert len(span_group) == length - 1
|
|
assert span_group[index] != span
|
|
assert span_group[index] == next_span
|
|
|
|
with pytest.raises(IndexError):
|
|
del span_group[-100]
|
|
with pytest.raises(IndexError):
|
|
del span_group[100]
|
|
|
|
|
|
def test_span_group_add(doc):
|
|
span_group_1 = doc.spans["SPANS"]
|
|
spans = [doc[0:5], doc[0:6]]
|
|
span_group_2 = SpanGroup(
|
|
doc,
|
|
name="MORE_SPANS",
|
|
attrs={"key": "new_value", "new_key": "new_value"},
|
|
spans=spans,
|
|
)
|
|
|
|
span_group_3_expected = span_group_1._concat(span_group_2)
|
|
|
|
span_group_3 = span_group_1 + span_group_2
|
|
assert len(span_group_3) == len(span_group_3_expected)
|
|
assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
|
|
assert list(span_group_3) == list(span_group_3_expected)
|
|
|
|
|
|
def test_span_group_iadd(doc):
|
|
span_group_1 = doc.spans["SPANS"].copy()
|
|
spans = [doc[0:5], doc[0:6]]
|
|
span_group_2 = SpanGroup(
|
|
doc,
|
|
name="MORE_SPANS",
|
|
attrs={"key": "new_value", "new_key": "new_value"},
|
|
spans=spans,
|
|
)
|
|
|
|
span_group_1_expected = span_group_1._concat(span_group_2)
|
|
|
|
span_group_1 += span_group_2
|
|
assert len(span_group_1) == len(span_group_1_expected)
|
|
assert span_group_1.attrs == {"key": "value", "new_key": "new_value"}
|
|
assert list(span_group_1) == list(span_group_1_expected)
|
|
|
|
span_group_1 = doc.spans["SPANS"].copy()
|
|
span_group_1 += spans
|
|
assert len(span_group_1) == len(span_group_1_expected)
|
|
assert span_group_1.attrs == {
|
|
"key": "value",
|
|
}
|
|
assert list(span_group_1) == list(span_group_1_expected)
|
|
|
|
|
|
def test_span_group_extend(doc):
|
|
span_group_1 = doc.spans["SPANS"].copy()
|
|
spans = [doc[0:5], doc[0:6]]
|
|
span_group_2 = SpanGroup(
|
|
doc,
|
|
name="MORE_SPANS",
|
|
attrs={"key": "new_value", "new_key": "new_value"},
|
|
spans=spans,
|
|
)
|
|
|
|
span_group_1_expected = span_group_1._concat(span_group_2)
|
|
|
|
span_group_1.extend(span_group_2)
|
|
assert len(span_group_1) == len(span_group_1_expected)
|
|
assert span_group_1.attrs == {"key": "value", "new_key": "new_value"}
|
|
assert list(span_group_1) == list(span_group_1_expected)
|
|
|
|
span_group_1 = doc.spans["SPANS"]
|
|
span_group_1.extend(spans)
|
|
assert len(span_group_1) == len(span_group_1_expected)
|
|
assert span_group_1.attrs == {"key": "value"}
|
|
assert list(span_group_1) == list(span_group_1_expected)
|
|
|
|
|
|
def test_span_group_dealloc(span_group):
|
|
with pytest.raises(AttributeError):
|
|
print(span_group.doc)
|
|
|
|
|
|
@pytest.mark.issue(11975)
|
|
def test_span_group_typing(doc: Doc):
|
|
"""Tests whether typing of `SpanGroup` as `Iterable[Span]`-like object is accepted by mypy."""
|
|
span_group: SpanGroup = doc.spans["SPANS"]
|
|
spans: List[Span] = list(span_group)
|
|
for i, span in enumerate(span_group):
|
|
assert span == span_group[i] == spans[i]
|
|
filter_spans(span_group)
|
|
|
|
|
|
def test_span_group_init_doc(en_tokenizer):
|
|
"""Test that all spans must come from the specified doc."""
|
|
doc1 = en_tokenizer("a b c")
|
|
doc2 = en_tokenizer("a b c")
|
|
span_group = SpanGroup(doc1, spans=[doc1[0:1], doc1[1:2]])
|
|
with pytest.raises(ValueError):
|
|
span_group = SpanGroup(doc1, spans=[doc1[0:1], doc2[1:2]])
|