From b65d652881644f9a62a38d7979aee683853c818a Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Thu, 12 May 2022 10:06:25 +0200 Subject: [PATCH] Override SpanGroups.setdefault to provide default SpanGroup (#10772) * Fix mistake in SpanGroup API docs * Restrict SpanGroups.setdefault to SpanGroup only * Refactor to support default span iterable --- spacy/tests/doc/test_doc_api.py | 12 +++++++++++- spacy/tokens/_dict_proxies.py | 9 +++++++++ website/docs/api/spangroup.md | 2 +- 3 files changed, 21 insertions(+), 2 deletions(-) diff --git a/spacy/tests/doc/test_doc_api.py b/spacy/tests/doc/test_doc_api.py index 19b554572..dd4942989 100644 --- a/spacy/tests/doc/test_doc_api.py +++ b/spacy/tests/doc/test_doc_api.py @@ -11,7 +11,7 @@ from spacy.lang.en import English from spacy.lang.xx import MultiLanguage from spacy.language import Language from spacy.lexeme import Lexeme -from spacy.tokens import Doc, Span, Token +from spacy.tokens import Doc, Span, SpanGroup, Token from spacy.vocab import Vocab from .test_underscore import clean_underscore # noqa: F401 @@ -964,3 +964,13 @@ def test_doc_spans_copy(en_tokenizer): assert weakref.ref(doc1) == doc1.spans.doc_ref doc2 = doc1.copy() assert weakref.ref(doc2) == doc2.spans.doc_ref + + +def test_doc_spans_setdefault(en_tokenizer): + doc = en_tokenizer("Some text about Colombia and the Czech Republic") + doc.spans.setdefault("key1") + assert len(doc.spans["key1"]) == 0 + doc.spans.setdefault("key2", default=[doc[0:1]]) + assert len(doc.spans["key2"]) == 1 + doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]])) + assert len(doc.spans["key3"]) == 2 diff --git a/spacy/tokens/_dict_proxies.py b/spacy/tokens/_dict_proxies.py index 8643243fa..d9506769b 100644 --- a/spacy/tokens/_dict_proxies.py +++ b/spacy/tokens/_dict_proxies.py @@ -43,6 +43,15 @@ class SpanGroups(UserDict): doc = self._ensure_doc() return SpanGroups(doc).from_bytes(self.to_bytes()) + def setdefault(self, key, default=None): + if not isinstance(default, SpanGroup): + if default is None: + spans = [] + else: + spans = default + default = self._make_span_group(key, spans) + return super().setdefault(key, default=default) + def to_bytes(self) -> bytes: # We don't need to serialize this as a dict, because the groups # know their names. diff --git a/website/docs/api/spangroup.md b/website/docs/api/spangroup.md index 1e2d18a82..8dbdefc01 100644 --- a/website/docs/api/spangroup.md +++ b/website/docs/api/spangroup.md @@ -233,7 +233,7 @@ group. > doc.spans["errors"] = [] > doc.spans["errors"].extend([doc[1:3], doc[0:1]]) > assert len(doc.spans["errors"]) == 2 -> span_group = SpanGroup([doc[1:4], doc[0:3]) +> span_group = SpanGroup(doc, spans=[doc[1:4], doc[0:3]]) > doc.spans["errors"].extend(span_group) > ```