Add Span.kb_id and Span.id strings to Doc/DocBin serialization

* Switch tests that use int-only `kb_id/id` to strings.
This commit is contained in:
Adriane Boyd 2023-03-17 09:53:49 +01:00
parent 4c5a3a2a7b
commit befff3ebf8
4 changed files with 10 additions and 6 deletions

View File

@ -313,7 +313,7 @@ def test_json_to_doc_spans(doc):
"""Test that Doc.from_json() includes correct.spans.""" """Test that Doc.from_json() includes correct.spans."""
doc.spans["test"] = [ doc.spans["test"] = [
Span(doc, 0, 2, label="test"), Span(doc, 0, 2, label="test"),
Span(doc, 0, 1, label="test", kb_id=7), Span(doc, 0, 1, label="test", kb_id="7"),
] ]
json_doc = doc.to_json() json_doc = doc.to_json()
new_doc = Doc(doc.vocab).from_json(json_doc, validate=True) new_doc = Doc(doc.vocab).from_json(json_doc, validate=True)

View File

@ -599,22 +599,22 @@ def test_overfitting_IO(use_upper):
# test that kb_id is preserved # test that kb_id is preserved
test_text = "I like London and London." test_text = "I like London and London."
doc = nlp.make_doc(test_text) doc = nlp.make_doc(test_text)
doc.ents = [Span(doc, 2, 3, label="LOC", kb_id=1234)] doc.ents = [Span(doc, 2, 3, label="LOC", kb_id="1234")]
ents = doc.ents ents = doc.ents
assert len(ents) == 1 assert len(ents) == 1
assert ents[0].text == "London" assert ents[0].text == "London"
assert ents[0].label_ == "LOC" assert ents[0].label_ == "LOC"
assert ents[0].kb_id == 1234 assert ents[0].kb_id_ == "1234"
doc = nlp.get_pipe("ner")(doc) doc = nlp.get_pipe("ner")(doc)
ents = doc.ents ents = doc.ents
assert len(ents) == 2 assert len(ents) == 2
assert ents[0].text == "London" assert ents[0].text == "London"
assert ents[0].label_ == "LOC" assert ents[0].label_ == "LOC"
assert ents[0].kb_id == 1234 assert ents[0].kb_id_ == "1234"
# ent added by ner has kb_id == 0 # ent added by ner has kb_id == ""
assert ents[1].text == "London" assert ents[1].text == "London"
assert ents[1].label_ == "LOC" assert ents[1].label_ == "LOC"
assert ents[1].kb_id == 0 assert ents[1].kb_id_ == ""
def test_beam_ner_scores(): def test_beam_ner_scores():

View File

@ -124,6 +124,8 @@ class DocBin:
for key, group in doc.spans.items(): for key, group in doc.spans.items():
for span in group: for span in group:
self.strings.add(span.label_) self.strings.add(span.label_)
self.strings.add(span.kb_id_)
self.strings.add(span.id_)
def get_docs(self, vocab: Vocab) -> Iterator[Doc]: def get_docs(self, vocab: Vocab) -> Iterator[Doc]:
"""Recover Doc objects from the annotations, using the given vocab. """Recover Doc objects from the annotations, using the given vocab.

View File

@ -1350,6 +1350,8 @@ cdef class Doc:
for group in self.spans.values(): for group in self.spans.values():
for span in group: for span in group:
strings.add(span.label_) strings.add(span.label_)
strings.add(span.kb_id_)
strings.add(span.id_)
# Msgpack doesn't distinguish between lists and tuples, which is # Msgpack doesn't distinguish between lists and tuples, which is
# vexing for user data. As a best guess, we *know* that within # vexing for user data. As a best guess, we *know* that within
# keys, we must have tuples. In values we just have to hope # keys, we must have tuples. In values we just have to hope