mirror of
				https://github.com/explosion/spaCy.git
				synced 2025-11-04 09:57:26 +03:00 
			
		
		
		
	* Correct Span.label, Span.kb_id types. Fix SpanGroup.__iter__(). * Extend test. * Rename test. Fix typo. * Add comment. * Fix types for Span.label, Span.kb_id, Span.char_span(). * Update spacy/tests/doc/test_span_group.py Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> * Update docs. * Fix typo. * Update spacy/tokens/span_group.pyx Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com> Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
		
			
				
	
	
		
			256 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			256 lines
		
	
	
		
			7.5 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
from typing import List
 | 
						|
 | 
						|
import pytest
 | 
						|
from random import Random
 | 
						|
from spacy.matcher import Matcher
 | 
						|
from spacy.tokens import Span, SpanGroup, Doc
 | 
						|
from spacy.util import filter_spans
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def doc(en_tokenizer):
 | 
						|
    doc = en_tokenizer("0 1 2 3 4 5 6")
 | 
						|
    matcher = Matcher(en_tokenizer.vocab, validate=True)
 | 
						|
 | 
						|
    # fmt: off
 | 
						|
    matcher.add("4", [[{}, {}, {}, {}]])
 | 
						|
    matcher.add("2", [[{}, {}, ]])
 | 
						|
    matcher.add("1", [[{}, ]])
 | 
						|
    # fmt: on
 | 
						|
    matches = matcher(doc)
 | 
						|
    spans = []
 | 
						|
    for match in matches:
 | 
						|
        spans.append(
 | 
						|
            Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
 | 
						|
        )
 | 
						|
    Random(42).shuffle(spans)
 | 
						|
    doc.spans["SPANS"] = SpanGroup(
 | 
						|
        doc, name="SPANS", attrs={"key": "value"}, spans=spans
 | 
						|
    )
 | 
						|
    return doc
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def other_doc(en_tokenizer):
 | 
						|
    doc = en_tokenizer("0 1 2 3 4 5 6")
 | 
						|
    matcher = Matcher(en_tokenizer.vocab, validate=True)
 | 
						|
 | 
						|
    # fmt: off
 | 
						|
    matcher.add("4", [[{}, {}, {}, {}]])
 | 
						|
    matcher.add("2", [[{}, {}, ]])
 | 
						|
    matcher.add("1", [[{}, ]])
 | 
						|
    # fmt: on
 | 
						|
 | 
						|
    matches = matcher(doc)
 | 
						|
    spans = []
 | 
						|
    for match in matches:
 | 
						|
        spans.append(
 | 
						|
            Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
 | 
						|
        )
 | 
						|
    Random(42).shuffle(spans)
 | 
						|
    doc.spans["SPANS"] = SpanGroup(
 | 
						|
        doc, name="SPANS", attrs={"key": "value"}, spans=spans
 | 
						|
    )
 | 
						|
    return doc
 | 
						|
 | 
						|
 | 
						|
@pytest.fixture
 | 
						|
def span_group(en_tokenizer):
 | 
						|
    doc = en_tokenizer("0 1 2 3 4 5 6")
 | 
						|
    matcher = Matcher(en_tokenizer.vocab, validate=True)
 | 
						|
 | 
						|
    # fmt: off
 | 
						|
    matcher.add("4", [[{}, {}, {}, {}]])
 | 
						|
    matcher.add("2", [[{}, {}, ]])
 | 
						|
    matcher.add("1", [[{}, ]])
 | 
						|
    # fmt: on
 | 
						|
 | 
						|
    matches = matcher(doc)
 | 
						|
    spans = []
 | 
						|
    for match in matches:
 | 
						|
        spans.append(
 | 
						|
            Span(doc, match[1], match[2], en_tokenizer.vocab.strings[match[0]])
 | 
						|
        )
 | 
						|
    Random(42).shuffle(spans)
 | 
						|
    doc.spans["SPANS"] = SpanGroup(
 | 
						|
        doc, name="SPANS", attrs={"key": "value"}, spans=spans
 | 
						|
    )
 | 
						|
 | 
						|
 | 
						|
def test_span_group_copy(doc):
 | 
						|
    span_group = doc.spans["SPANS"]
 | 
						|
    clone = span_group.copy()
 | 
						|
    assert clone != span_group
 | 
						|
    assert clone.name == span_group.name
 | 
						|
    assert clone.attrs == span_group.attrs
 | 
						|
    assert len(clone) == len(span_group)
 | 
						|
    assert list(span_group) == list(clone)
 | 
						|
    clone.name = "new_name"
 | 
						|
    clone.attrs["key"] = "new_value"
 | 
						|
    clone.append(Span(doc, 0, 6, "LABEL"))
 | 
						|
    assert clone.name != span_group.name
 | 
						|
    assert clone.attrs != span_group.attrs
 | 
						|
    assert span_group.attrs["key"] == "value"
 | 
						|
    assert list(span_group) != list(clone)
 | 
						|
 | 
						|
 | 
						|
def test_span_group_set_item(doc, other_doc):
 | 
						|
    span_group = doc.spans["SPANS"]
 | 
						|
 | 
						|
    index = 5
 | 
						|
    span = span_group[index]
 | 
						|
    span.label_ = "NEW LABEL"
 | 
						|
    span.kb_id = doc.vocab.strings["KB_ID"]
 | 
						|
 | 
						|
    assert span_group[index].label != span.label
 | 
						|
    assert span_group[index].kb_id != span.kb_id
 | 
						|
 | 
						|
    span_group[index] = span
 | 
						|
    assert span_group[index].start == span.start
 | 
						|
    assert span_group[index].end == span.end
 | 
						|
    assert span_group[index].label == span.label
 | 
						|
    assert span_group[index].kb_id == span.kb_id
 | 
						|
    assert span_group[index] == span
 | 
						|
 | 
						|
    with pytest.raises(IndexError):
 | 
						|
        span_group[-100] = span
 | 
						|
    with pytest.raises(IndexError):
 | 
						|
        span_group[100] = span
 | 
						|
 | 
						|
    span = Span(other_doc, 0, 2)
 | 
						|
    with pytest.raises(ValueError):
 | 
						|
        span_group[index] = span
 | 
						|
 | 
						|
 | 
						|
def test_span_group_has_overlap(doc):
 | 
						|
    span_group = doc.spans["SPANS"]
 | 
						|
    assert span_group.has_overlap
 | 
						|
 | 
						|
 | 
						|
def test_span_group_concat(doc, other_doc):
 | 
						|
    span_group_1 = doc.spans["SPANS"]
 | 
						|
    spans = [doc[0:5], doc[0:6]]
 | 
						|
    span_group_2 = SpanGroup(
 | 
						|
        doc,
 | 
						|
        name="MORE_SPANS",
 | 
						|
        attrs={"key": "new_value", "new_key": "new_value"},
 | 
						|
        spans=spans,
 | 
						|
    )
 | 
						|
    span_group_3 = span_group_1._concat(span_group_2)
 | 
						|
    assert span_group_3.name == span_group_1.name
 | 
						|
    assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
 | 
						|
    span_list_expected = list(span_group_1) + list(span_group_2)
 | 
						|
    assert list(span_group_3) == list(span_list_expected)
 | 
						|
 | 
						|
    # Inplace
 | 
						|
    span_list_expected = list(span_group_1) + list(span_group_2)
 | 
						|
    span_group_3 = span_group_1._concat(span_group_2, inplace=True)
 | 
						|
    assert span_group_3 == span_group_1
 | 
						|
    assert span_group_3.name == span_group_1.name
 | 
						|
    assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
 | 
						|
    assert list(span_group_3) == list(span_list_expected)
 | 
						|
 | 
						|
    span_group_2 = other_doc.spans["SPANS"]
 | 
						|
    with pytest.raises(ValueError):
 | 
						|
        span_group_1._concat(span_group_2)
 | 
						|
 | 
						|
 | 
						|
def test_span_doc_delitem(doc):
 | 
						|
    span_group = doc.spans["SPANS"]
 | 
						|
    length = len(span_group)
 | 
						|
    index = 5
 | 
						|
    span = span_group[index]
 | 
						|
    next_span = span_group[index + 1]
 | 
						|
    del span_group[index]
 | 
						|
    assert len(span_group) == length - 1
 | 
						|
    assert span_group[index] != span
 | 
						|
    assert span_group[index] == next_span
 | 
						|
 | 
						|
    with pytest.raises(IndexError):
 | 
						|
        del span_group[-100]
 | 
						|
    with pytest.raises(IndexError):
 | 
						|
        del span_group[100]
 | 
						|
 | 
						|
 | 
						|
def test_span_group_add(doc):
 | 
						|
    span_group_1 = doc.spans["SPANS"]
 | 
						|
    spans = [doc[0:5], doc[0:6]]
 | 
						|
    span_group_2 = SpanGroup(
 | 
						|
        doc,
 | 
						|
        name="MORE_SPANS",
 | 
						|
        attrs={"key": "new_value", "new_key": "new_value"},
 | 
						|
        spans=spans,
 | 
						|
    )
 | 
						|
 | 
						|
    span_group_3_expected = span_group_1._concat(span_group_2)
 | 
						|
 | 
						|
    span_group_3 = span_group_1 + span_group_2
 | 
						|
    assert len(span_group_3) == len(span_group_3_expected)
 | 
						|
    assert span_group_3.attrs == {"key": "value", "new_key": "new_value"}
 | 
						|
    assert list(span_group_3) == list(span_group_3_expected)
 | 
						|
 | 
						|
 | 
						|
def test_span_group_iadd(doc):
 | 
						|
    span_group_1 = doc.spans["SPANS"].copy()
 | 
						|
    spans = [doc[0:5], doc[0:6]]
 | 
						|
    span_group_2 = SpanGroup(
 | 
						|
        doc,
 | 
						|
        name="MORE_SPANS",
 | 
						|
        attrs={"key": "new_value", "new_key": "new_value"},
 | 
						|
        spans=spans,
 | 
						|
    )
 | 
						|
 | 
						|
    span_group_1_expected = span_group_1._concat(span_group_2)
 | 
						|
 | 
						|
    span_group_1 += span_group_2
 | 
						|
    assert len(span_group_1) == len(span_group_1_expected)
 | 
						|
    assert span_group_1.attrs == {"key": "value", "new_key": "new_value"}
 | 
						|
    assert list(span_group_1) == list(span_group_1_expected)
 | 
						|
 | 
						|
    span_group_1 = doc.spans["SPANS"].copy()
 | 
						|
    span_group_1 += spans
 | 
						|
    assert len(span_group_1) == len(span_group_1_expected)
 | 
						|
    assert span_group_1.attrs == {
 | 
						|
        "key": "value",
 | 
						|
    }
 | 
						|
    assert list(span_group_1) == list(span_group_1_expected)
 | 
						|
 | 
						|
 | 
						|
def test_span_group_extend(doc):
 | 
						|
    span_group_1 = doc.spans["SPANS"].copy()
 | 
						|
    spans = [doc[0:5], doc[0:6]]
 | 
						|
    span_group_2 = SpanGroup(
 | 
						|
        doc,
 | 
						|
        name="MORE_SPANS",
 | 
						|
        attrs={"key": "new_value", "new_key": "new_value"},
 | 
						|
        spans=spans,
 | 
						|
    )
 | 
						|
 | 
						|
    span_group_1_expected = span_group_1._concat(span_group_2)
 | 
						|
 | 
						|
    span_group_1.extend(span_group_2)
 | 
						|
    assert len(span_group_1) == len(span_group_1_expected)
 | 
						|
    assert span_group_1.attrs == {"key": "value", "new_key": "new_value"}
 | 
						|
    assert list(span_group_1) == list(span_group_1_expected)
 | 
						|
 | 
						|
    span_group_1 = doc.spans["SPANS"]
 | 
						|
    span_group_1.extend(spans)
 | 
						|
    assert len(span_group_1) == len(span_group_1_expected)
 | 
						|
    assert span_group_1.attrs == {"key": "value"}
 | 
						|
    assert list(span_group_1) == list(span_group_1_expected)
 | 
						|
 | 
						|
 | 
						|
def test_span_group_dealloc(span_group):
 | 
						|
    with pytest.raises(AttributeError):
 | 
						|
        print(span_group.doc)
 | 
						|
 | 
						|
 | 
						|
@pytest.mark.issue(11975)
 | 
						|
def test_span_group_typing(doc: Doc):
 | 
						|
    """Tests whether typing of `SpanGroup` as `Iterable[Span]`-like object is accepted by mypy."""
 | 
						|
    span_group: SpanGroup = doc.spans["SPANS"]
 | 
						|
    spans: List[Span] = list(span_group)
 | 
						|
    for i, span in enumerate(span_group):
 | 
						|
        assert span == span_group[i] == spans[i]
 | 
						|
    filter_spans(span_group)
 |