Override SpanGroups.setdefault to provide default SpanGroup (#10772)

* Fix mistake in SpanGroup API docs

* Restrict SpanGroups.setdefault to SpanGroup only

* Refactor to support default span iterable
This commit is contained in:
Adriane Boyd 2022-05-12 10:06:25 +02:00 committed by GitHub
parent d524f6415f
commit b65d652881
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 21 additions and 2 deletions

View File

@ -11,7 +11,7 @@ from spacy.lang.en import English
from spacy.lang.xx import MultiLanguage
from spacy.language import Language
from spacy.lexeme import Lexeme
from spacy.tokens import Doc, Span, Token
from spacy.tokens import Doc, Span, SpanGroup, Token
from spacy.vocab import Vocab
from .test_underscore import clean_underscore # noqa: F401
@ -964,3 +964,13 @@ def test_doc_spans_copy(en_tokenizer):
assert weakref.ref(doc1) == doc1.spans.doc_ref
doc2 = doc1.copy()
assert weakref.ref(doc2) == doc2.spans.doc_ref
def test_doc_spans_setdefault(en_tokenizer):
doc = en_tokenizer("Some text about Colombia and the Czech Republic")
doc.spans.setdefault("key1")
assert len(doc.spans["key1"]) == 0
doc.spans.setdefault("key2", default=[doc[0:1]])
assert len(doc.spans["key2"]) == 1
doc.spans.setdefault("key3", default=SpanGroup(doc, spans=[doc[0:1], doc[1:2]]))
assert len(doc.spans["key3"]) == 2

View File

@ -43,6 +43,15 @@ class SpanGroups(UserDict):
doc = self._ensure_doc()
return SpanGroups(doc).from_bytes(self.to_bytes())
def setdefault(self, key, default=None):
if not isinstance(default, SpanGroup):
if default is None:
spans = []
else:
spans = default
default = self._make_span_group(key, spans)
return super().setdefault(key, default=default)
def to_bytes(self) -> bytes:
# We don't need to serialize this as a dict, because the groups
# know their names.

View File

@ -233,7 +233,7 @@ group.
> doc.spans["errors"] = []
> doc.spans["errors"].extend([doc[1:3], doc[0:1]])
> assert len(doc.spans["errors"]) == 2
> span_group = SpanGroup([doc[1:4], doc[0:3])
> span_group = SpanGroup(doc, spans=[doc[1:4], doc[0:3]])
> doc.spans["errors"].extend(span_group)
> ```