Add Span.kb_id/Span.id strings to Doc/DocBin serialization if set (#12493)

* Add Span.kb_id/Span.id strings to Doc/DocBin serialization if set

* Format
This commit is contained in:
Adriane Boyd 2023-04-03 15:11:12 +02:00 committed by GitHub
parent 4538ceb507
commit 4a1ec332de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 2 deletions

View File

@ -213,6 +213,13 @@ def test_serialize_doc_exclude(en_vocab):
def test_serialize_doc_span_groups(en_vocab):
doc = Doc(en_vocab, words=["hello", "world", "!"])
doc.spans["content"] = [doc[0:2]]
span = doc[0:2]
span.label_ = "test_serialize_doc_span_groups_label"
span.id_ = "test_serialize_doc_span_groups_id"
span.kb_id_ = "test_serialize_doc_span_groups_kb_id"
doc.spans["content"] = [span]
new_doc = Doc(en_vocab).from_bytes(doc.to_bytes())
assert len(new_doc.spans["content"]) == 1
assert new_doc.spans["content"][0].label_ == "test_serialize_doc_span_groups_label"
assert new_doc.spans["content"][0].id_ == "test_serialize_doc_span_groups_id"
assert new_doc.spans["content"][0].kb_id_ == "test_serialize_doc_span_groups_kb_id"

View File

@ -49,7 +49,11 @@ def test_serialize_doc_bin():
nlp = English()
for doc in nlp.pipe(texts):
doc.cats = cats
doc.spans["start"] = [doc[0:2]]
span = doc[0:2]
span.label_ = "UNUSUAL_SPAN_LABEL"
span.id_ = "UNUSUAL_SPAN_ID"
span.kb_id_ = "UNUSUAL_SPAN_KB_ID"
doc.spans["start"] = [span]
doc[0].norm_ = "UNUSUAL_TOKEN_NORM"
doc[0].ent_id_ = "UNUSUAL_TOKEN_ENT_ID"
doc_bin.add(doc)
@ -63,6 +67,9 @@ def test_serialize_doc_bin():
assert doc.text == texts[i]
assert doc.cats == cats
assert len(doc.spans) == 1
assert doc.spans["start"][0].label_ == "UNUSUAL_SPAN_LABEL"
assert doc.spans["start"][0].id_ == "UNUSUAL_SPAN_ID"
assert doc.spans["start"][0].kb_id_ == "UNUSUAL_SPAN_KB_ID"
assert doc[0].norm_ == "UNUSUAL_TOKEN_NORM"
assert doc[0].ent_id_ == "UNUSUAL_TOKEN_ENT_ID"

View File

@ -124,6 +124,10 @@ class DocBin:
for key, group in doc.spans.items():
for span in group:
self.strings.add(span.label_)
if span.kb_id in span.doc.vocab.strings:
self.strings.add(span.kb_id_)
if span.id in span.doc.vocab.strings:
self.strings.add(span.id_)
def get_docs(self, vocab: Vocab) -> Iterator[Doc]:
"""Recover Doc objects from the annotations, using the given vocab.

View File

@ -1346,6 +1346,10 @@ cdef class Doc:
for group in self.spans.values():
for span in group:
strings.add(span.label_)
if span.kb_id in span.doc.vocab.strings:
strings.add(span.kb_id_)
if span.id in span.doc.vocab.strings:
strings.add(span.id_)
# Msgpack doesn't distinguish between lists and tuples, which is
# vexing for user data. As a best guess, we *know* that within
# keys, we must have tuples. In values we just have to hope