Support custom attributes for tokens and spans in json conversion (#11125)

* Add token and span custom attributes to to_json()

* Change logic for to_json

* Add functionality to from_json

* Small adjustments

* Move token/span attributes to new dict key

* Fix test

* Fix the same test but much better

* Add backwards compatibility tests and adjust logic

* Add test to check if attributes not set in underscore are not saved in the json

* Add tests for json compatibility

* Adjust test names

* Fix tests and clean up code

* Fix assert json tests

* small adjustment

* adjust naming and code readability

* Adjust naming, added more tests and changed logic

* Fix typo

* Adjust errors, naming, and small test optimization

* Fix byte tests

* Fix bytes tests

* Change naming and json structure

* update schema

* Update spacy/schemas.py

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Update spacy/tokens/doc.pyx

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Update spacy/tokens/doc.pyx

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Update spacy/schemas.py

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Update schema for underscore attributes

* Adjust underscore schema

* adjust schema tests

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Edward 2022-08-23 10:05:02 +02:00 committed by GitHub
parent 7e75327893
commit 5afa98aabf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 243 additions and 24 deletions

View File

@ -389,7 +389,7 @@ class Errors(metaclass=ErrorsWithCodes):
"consider using doc.spans instead.") "consider using doc.spans instead.")
E106 = ("Can't find `doc._.{attr}` attribute specified in the underscore " E106 = ("Can't find `doc._.{attr}` attribute specified in the underscore "
"settings: {opts}") "settings: {opts}")
E107 = ("Value of `doc._.{attr}` is not JSON-serializable: {value}") E107 = ("Value of custom attribute `{attr}` is not JSON-serializable: {value}")
E109 = ("Component '{name}' could not be run. Did you forget to " E109 = ("Component '{name}' could not be run. Did you forget to "
"call `initialize()`?") "call `initialize()`?")
E110 = ("Invalid displaCy render wrapper. Expected callable, got: {obj}") E110 = ("Invalid displaCy render wrapper. Expected callable, got: {obj}")

View File

@ -514,6 +514,14 @@ class DocJSONSchema(BaseModel):
tokens: List[Dict[StrictStr, Union[StrictStr, StrictInt]]] = Field( tokens: List[Dict[StrictStr, Union[StrictStr, StrictInt]]] = Field(
..., title="Token information - ID, start, annotations" ..., title="Token information - ID, start, annotations"
) )
_: Optional[Dict[StrictStr, Any]] = Field( underscore_doc: Optional[Dict[StrictStr, Any]] = Field(
None, title="Any custom data stored in the document's _ attribute" None,
title="Any custom data stored in the document's _ attribute",
alias="_",
)
underscore_token: Optional[Dict[StrictStr, Dict[StrictStr, Any]]] = Field(
None, title="Any custom data stored in the token's _ attribute"
)
underscore_span: Optional[Dict[StrictStr, Dict[StrictStr, Any]]] = Field(
None, title="Any custom data stored in the span's _ attribute"
) )

View File

@ -1,12 +1,15 @@
import pytest import pytest
import spacy import spacy
from spacy import schemas from spacy import schemas
from spacy.tokens import Doc, Span from spacy.tokens import Doc, Span, Token
import srsly
from .test_underscore import clean_underscore # noqa: F401
@pytest.fixture() @pytest.fixture()
def doc(en_vocab): def doc(en_vocab):
words = ["c", "d", "e"] words = ["c", "d", "e"]
spaces = [True, True, True]
pos = ["VERB", "NOUN", "NOUN"] pos = ["VERB", "NOUN", "NOUN"]
tags = ["VBP", "NN", "NN"] tags = ["VBP", "NN", "NN"]
heads = [0, 0, 1] heads = [0, 0, 1]
@ -17,6 +20,7 @@ def doc(en_vocab):
return Doc( return Doc(
en_vocab, en_vocab,
words=words, words=words,
spaces=spaces,
pos=pos, pos=pos,
tags=tags, tags=tags,
heads=heads, heads=heads,
@ -45,6 +49,47 @@ def doc_without_deps(en_vocab):
) )
@pytest.fixture()
def doc_json():
return {
"text": "c d e ",
"ents": [{"start": 2, "end": 3, "label": "ORG"}],
"sents": [{"start": 0, "end": 5}],
"tokens": [
{
"id": 0,
"start": 0,
"end": 1,
"tag": "VBP",
"pos": "VERB",
"morph": "Feat1=A",
"dep": "ROOT",
"head": 0,
},
{
"id": 1,
"start": 2,
"end": 3,
"tag": "NN",
"pos": "NOUN",
"morph": "Feat1=B",
"dep": "dobj",
"head": 0,
},
{
"id": 2,
"start": 4,
"end": 5,
"tag": "NN",
"pos": "NOUN",
"morph": "Feat1=A|Feat2=D",
"dep": "dobj",
"head": 1,
},
],
}
def test_doc_to_json(doc): def test_doc_to_json(doc):
json_doc = doc.to_json() json_doc = doc.to_json()
assert json_doc["text"] == "c d e " assert json_doc["text"] == "c d e "
@ -56,7 +101,8 @@ def test_doc_to_json(doc):
assert json_doc["ents"][0]["start"] == 2 # character offset! assert json_doc["ents"][0]["start"] == 2 # character offset!
assert json_doc["ents"][0]["end"] == 3 # character offset! assert json_doc["ents"][0]["end"] == 3 # character offset!
assert json_doc["ents"][0]["label"] == "ORG" assert json_doc["ents"][0]["label"] == "ORG"
assert not schemas.validate(schemas.DocJSONSchema, json_doc) assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
def test_doc_to_json_underscore(doc): def test_doc_to_json_underscore(doc):
@ -64,11 +110,96 @@ def test_doc_to_json_underscore(doc):
Doc.set_extension("json_test2", default=False) Doc.set_extension("json_test2", default=False)
doc._.json_test1 = "hello world" doc._.json_test1 = "hello world"
doc._.json_test2 = [1, 2, 3] doc._.json_test2 = [1, 2, 3]
json_doc = doc.to_json(underscore=["json_test1", "json_test2"]) json_doc = doc.to_json(underscore=["json_test1", "json_test2"])
assert "_" in json_doc assert "_" in json_doc
assert json_doc["_"]["json_test1"] == "hello world" assert json_doc["_"]["json_test1"] == "hello world"
assert json_doc["_"]["json_test2"] == [1, 2, 3] assert json_doc["_"]["json_test2"] == [1, 2, 3]
assert not schemas.validate(schemas.DocJSONSchema, json_doc) assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
def test_doc_to_json_with_token_span_attributes(doc):
Doc.set_extension("json_test1", default=False)
Doc.set_extension("json_test2", default=False)
Token.set_extension("token_test", default=False)
Span.set_extension("span_test", default=False)
doc._.json_test1 = "hello world"
doc._.json_test2 = [1, 2, 3]
doc[0:1]._.span_test = "span_attribute"
doc[0]._.token_test = 117
doc.spans["span_group"] = [doc[0:1]]
json_doc = doc.to_json(
underscore=["json_test1", "json_test2", "token_test", "span_test"]
)
assert "_" in json_doc
assert json_doc["_"]["json_test1"] == "hello world"
assert json_doc["_"]["json_test2"] == [1, 2, 3]
assert "underscore_token" in json_doc
assert "underscore_span" in json_doc
assert json_doc["underscore_token"]["token_test"]["value"] == 117
assert json_doc["underscore_span"]["span_test"]["value"] == "span_attribute"
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
def test_doc_to_json_with_custom_user_data(doc):
Doc.set_extension("json_test", default=False)
Token.set_extension("token_test", default=False)
Span.set_extension("span_test", default=False)
doc._.json_test = "hello world"
doc[0:1]._.span_test = "span_attribute"
doc[0]._.token_test = 117
json_doc = doc.to_json(underscore=["json_test", "token_test", "span_test"])
doc.user_data["user_data_test"] = 10
doc.user_data[("user_data_test2", True)] = 10
assert "_" in json_doc
assert json_doc["_"]["json_test"] == "hello world"
assert "underscore_token" in json_doc
assert "underscore_span" in json_doc
assert json_doc["underscore_token"]["token_test"]["value"] == 117
assert json_doc["underscore_span"]["span_test"]["value"] == "span_attribute"
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
def test_doc_to_json_with_token_span_same_identifier(doc):
Doc.set_extension("my_ext", default=False)
Token.set_extension("my_ext", default=False)
Span.set_extension("my_ext", default=False)
doc._.my_ext = "hello world"
doc[0:1]._.my_ext = "span_attribute"
doc[0]._.my_ext = 117
json_doc = doc.to_json(underscore=["my_ext"])
assert "_" in json_doc
assert json_doc["_"]["my_ext"] == "hello world"
assert "underscore_token" in json_doc
assert "underscore_span" in json_doc
assert json_doc["underscore_token"]["my_ext"]["value"] == 117
assert json_doc["underscore_span"]["my_ext"]["value"] == "span_attribute"
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
def test_doc_to_json_with_token_attributes_missing(doc):
Token.set_extension("token_test", default=False)
Span.set_extension("span_test", default=False)
doc[0:1]._.span_test = "span_attribute"
doc[0]._.token_test = 117
json_doc = doc.to_json(underscore=["span_test"])
assert "underscore_token" in json_doc
assert "underscore_span" in json_doc
assert json_doc["underscore_span"]["span_test"]["value"] == "span_attribute"
assert "token_test" not in json_doc["underscore_token"]
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
def test_doc_to_json_underscore_error_attr(doc): def test_doc_to_json_underscore_error_attr(doc):
@ -94,11 +225,29 @@ def test_doc_to_json_span(doc):
assert len(json_doc["spans"]) == 1 assert len(json_doc["spans"]) == 1
assert len(json_doc["spans"]["test"]) == 2 assert len(json_doc["spans"]["test"]) == 2
assert json_doc["spans"]["test"][0]["start"] == 0 assert json_doc["spans"]["test"][0]["start"] == 0
assert not schemas.validate(schemas.DocJSONSchema, json_doc) assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
def test_json_to_doc(doc): def test_json_to_doc(doc):
new_doc = Doc(doc.vocab).from_json(doc.to_json(), validate=True) json_doc = doc.to_json()
json_doc = srsly.json_loads(srsly.json_dumps(json_doc))
new_doc = Doc(doc.vocab).from_json(json_doc, validate=True)
assert new_doc.text == doc.text == "c d e "
assert len(new_doc) == len(doc) == 3
assert new_doc[0].pos == doc[0].pos
assert new_doc[0].tag == doc[0].tag
assert new_doc[0].dep == doc[0].dep
assert new_doc[0].head.idx == doc[0].head.idx
assert new_doc[0].lemma == doc[0].lemma
assert len(new_doc.ents) == 1
assert new_doc.ents[0].start == 1
assert new_doc.ents[0].end == 2
assert new_doc.ents[0].label_ == "ORG"
assert doc.to_bytes() == new_doc.to_bytes()
def test_json_to_doc_compat(doc, doc_json):
new_doc = Doc(doc.vocab).from_json(doc_json, validate=True)
new_tokens = [token for token in new_doc] new_tokens = [token for token in new_doc]
assert new_doc.text == doc.text == "c d e " assert new_doc.text == doc.text == "c d e "
assert len(new_tokens) == len([token for token in doc]) == 3 assert len(new_tokens) == len([token for token in doc]) == 3
@ -114,11 +263,8 @@ def test_json_to_doc(doc):
def test_json_to_doc_underscore(doc): def test_json_to_doc_underscore(doc):
if not Doc.has_extension("json_test1"): Doc.set_extension("json_test1", default=False)
Doc.set_extension("json_test1", default=False) Doc.set_extension("json_test2", default=False)
if not Doc.has_extension("json_test2"):
Doc.set_extension("json_test2", default=False)
doc._.json_test1 = "hello world" doc._.json_test1 = "hello world"
doc._.json_test2 = [1, 2, 3] doc._.json_test2 = [1, 2, 3]
json_doc = doc.to_json(underscore=["json_test1", "json_test2"]) json_doc = doc.to_json(underscore=["json_test1", "json_test2"])
@ -126,6 +272,34 @@ def test_json_to_doc_underscore(doc):
assert all([new_doc.has_extension(f"json_test{i}") for i in range(1, 3)]) assert all([new_doc.has_extension(f"json_test{i}") for i in range(1, 3)])
assert new_doc._.json_test1 == "hello world" assert new_doc._.json_test1 == "hello world"
assert new_doc._.json_test2 == [1, 2, 3] assert new_doc._.json_test2 == [1, 2, 3]
assert doc.to_bytes() == new_doc.to_bytes()
def test_json_to_doc_with_token_span_attributes(doc):
Doc.set_extension("json_test1", default=False)
Doc.set_extension("json_test2", default=False)
Token.set_extension("token_test", default=False)
Span.set_extension("span_test", default=False)
doc._.json_test1 = "hello world"
doc._.json_test2 = [1, 2, 3]
doc[0:1]._.span_test = "span_attribute"
doc[0]._.token_test = 117
json_doc = doc.to_json(
underscore=["json_test1", "json_test2", "token_test", "span_test"]
)
json_doc = srsly.json_loads(srsly.json_dumps(json_doc))
new_doc = Doc(doc.vocab).from_json(json_doc, validate=True)
assert all([new_doc.has_extension(f"json_test{i}") for i in range(1, 3)])
assert new_doc._.json_test1 == "hello world"
assert new_doc._.json_test2 == [1, 2, 3]
assert new_doc[0]._.token_test == 117
assert new_doc[0:1]._.span_test == "span_attribute"
assert new_doc.user_data == doc.user_data
assert new_doc.to_bytes(exclude=["user_data"]) == doc.to_bytes(
exclude=["user_data"]
)
def test_json_to_doc_spans(doc): def test_json_to_doc_spans(doc):

View File

@ -1602,13 +1602,30 @@ cdef class Doc:
ents.append(char_span) ents.append(char_span)
self.ents = ents self.ents = ents
# Add custom attributes. Note that only Doc extensions are currently considered, Token and Span extensions are # Add custom attributes for the whole Doc object.
# not yet supported.
for attr in doc_json.get("_", {}): for attr in doc_json.get("_", {}):
if not Doc.has_extension(attr): if not Doc.has_extension(attr):
Doc.set_extension(attr) Doc.set_extension(attr)
self._.set(attr, doc_json["_"][attr]) self._.set(attr, doc_json["_"][attr])
if doc_json.get("underscore_token", {}):
for token_attr in doc_json["underscore_token"]:
token_start = doc_json["underscore_token"][token_attr]["token_start"]
value = doc_json["underscore_token"][token_attr]["value"]
if not Token.has_extension(token_attr):
Token.set_extension(token_attr)
self[token_start]._.set(token_attr, value)
if doc_json.get("underscore_span", {}):
for span_attr in doc_json["underscore_span"]:
token_start = doc_json["underscore_span"][span_attr]["token_start"]
token_end = doc_json["underscore_span"][span_attr]["token_end"]
value = doc_json["underscore_span"][span_attr]["value"]
if not Span.has_extension(span_attr):
Span.set_extension(span_attr)
self[token_start:token_end]._.set(span_attr, value)
return self return self
def to_json(self, underscore=None): def to_json(self, underscore=None):
@ -1650,20 +1667,40 @@ cdef class Doc:
for span_group in self.spans: for span_group in self.spans:
data["spans"][span_group] = [] data["spans"][span_group] = []
for span in self.spans[span_group]: for span in self.spans[span_group]:
span_data = { span_data = {"start": span.start_char, "end": span.end_char, "label": span.label_, "kb_id": span.kb_id_}
"start": span.start_char, "end": span.end_char, "label": span.label_, "kb_id": span.kb_id_
}
data["spans"][span_group].append(span_data) data["spans"][span_group].append(span_data)
if underscore: if underscore:
data["_"] = {} user_keys = set()
if self.user_data:
data["_"] = {}
data["underscore_token"] = {}
data["underscore_span"] = {}
for data_key in self.user_data:
if type(data_key) == tuple and len(data_key) >= 4 and data_key[0] == "._.":
attr = data_key[1]
start = data_key[2]
end = data_key[3]
if attr in underscore:
user_keys.add(attr)
value = self.user_data[data_key]
if not srsly.is_json_serializable(value):
raise ValueError(Errors.E107.format(attr=attr, value=repr(value)))
# Check if doc attribute
if start is None:
data["_"][attr] = value
# Check if token attribute
elif end is None:
if attr not in data["underscore_token"]:
data["underscore_token"][attr] = {"token_start": start, "value": value}
# Else span attribute
else:
if attr not in data["underscore_span"]:
data["underscore_span"][attr] = {"token_start": start, "token_end": end, "value": value}
for attr in underscore: for attr in underscore:
if not self.has_extension(attr): if attr not in user_keys:
raise ValueError(Errors.E106.format(attr=attr, opts=underscore)) raise ValueError(Errors.E106.format(attr=attr, opts=underscore))
value = self._.get(attr)
if not srsly.is_json_serializable(value):
raise ValueError(Errors.E107.format(attr=attr, value=repr(value)))
data["_"][attr] = value
return data return data
def to_utf8_array(self, int nr_char=-1): def to_utf8_array(self, int nr_char=-1):