Support custom attributes for tokens and spans in json conversion (#11125)

* Add token and span custom attributes to to_json()

* Change logic for to_json

* Add functionality to from_json

* Small adjustments

* Move token/span attributes to new dict key

* Fix test

* Fix the same test but much better

* Add backwards compatibility tests and adjust logic

* Add test to check if attributes not set in underscore are not saved in the json

* Add tests for json compatibility

* Adjust test names

* Fix tests and clean up code

* Fix assert json tests

* small adjustment

* adjust naming and code readability

* Adjust naming, added more tests and changed logic

* Fix typo

* Adjust errors, naming, and small test optimization

* Fix byte tests

* Fix bytes tests

* Change naming and json structure

* update schema

* Update spacy/schemas.py

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Update spacy/tokens/doc.pyx

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Update spacy/tokens/doc.pyx

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Update spacy/schemas.py

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>

* Update schema for underscore attributes

* Adjust underscore schema

* adjust schema tests

Co-authored-by: Adriane Boyd <adrianeboyd@gmail.com>
This commit is contained in:
Edward 2022-08-23 10:05:02 +02:00 committed by GitHub
parent 7e75327893
commit 5afa98aabf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 243 additions and 24 deletions

View File

@ -389,7 +389,7 @@ class Errors(metaclass=ErrorsWithCodes):
"consider using doc.spans instead.")
E106 = ("Can't find `doc._.{attr}` attribute specified in the underscore "
"settings: {opts}")
E107 = ("Value of `doc._.{attr}` is not JSON-serializable: {value}")
E107 = ("Value of custom attribute `{attr}` is not JSON-serializable: {value}")
E109 = ("Component '{name}' could not be run. Did you forget to "
"call `initialize()`?")
E110 = ("Invalid displaCy render wrapper. Expected callable, got: {obj}")

View File

@ -514,6 +514,14 @@ class DocJSONSchema(BaseModel):
tokens: List[Dict[StrictStr, Union[StrictStr, StrictInt]]] = Field(
..., title="Token information - ID, start, annotations"
)
_: Optional[Dict[StrictStr, Any]] = Field(
None, title="Any custom data stored in the document's _ attribute"
underscore_doc: Optional[Dict[StrictStr, Any]] = Field(
None,
title="Any custom data stored in the document's _ attribute",
alias="_",
)
underscore_token: Optional[Dict[StrictStr, Dict[StrictStr, Any]]] = Field(
None, title="Any custom data stored in the token's _ attribute"
)
underscore_span: Optional[Dict[StrictStr, Dict[StrictStr, Any]]] = Field(
None, title="Any custom data stored in the span's _ attribute"
)

View File

@ -1,12 +1,15 @@
import pytest
import spacy
from spacy import schemas
from spacy.tokens import Doc, Span
from spacy.tokens import Doc, Span, Token
import srsly
from .test_underscore import clean_underscore # noqa: F401
@pytest.fixture()
def doc(en_vocab):
words = ["c", "d", "e"]
spaces = [True, True, True]
pos = ["VERB", "NOUN", "NOUN"]
tags = ["VBP", "NN", "NN"]
heads = [0, 0, 1]
@ -17,6 +20,7 @@ def doc(en_vocab):
return Doc(
en_vocab,
words=words,
spaces=spaces,
pos=pos,
tags=tags,
heads=heads,
@ -45,6 +49,47 @@ def doc_without_deps(en_vocab):
)
@pytest.fixture()
def doc_json():
return {
"text": "c d e ",
"ents": [{"start": 2, "end": 3, "label": "ORG"}],
"sents": [{"start": 0, "end": 5}],
"tokens": [
{
"id": 0,
"start": 0,
"end": 1,
"tag": "VBP",
"pos": "VERB",
"morph": "Feat1=A",
"dep": "ROOT",
"head": 0,
},
{
"id": 1,
"start": 2,
"end": 3,
"tag": "NN",
"pos": "NOUN",
"morph": "Feat1=B",
"dep": "dobj",
"head": 0,
},
{
"id": 2,
"start": 4,
"end": 5,
"tag": "NN",
"pos": "NOUN",
"morph": "Feat1=A|Feat2=D",
"dep": "dobj",
"head": 1,
},
],
}
def test_doc_to_json(doc):
json_doc = doc.to_json()
assert json_doc["text"] == "c d e "
@ -56,7 +101,8 @@ def test_doc_to_json(doc):
assert json_doc["ents"][0]["start"] == 2 # character offset!
assert json_doc["ents"][0]["end"] == 3 # character offset!
assert json_doc["ents"][0]["label"] == "ORG"
assert not schemas.validate(schemas.DocJSONSchema, json_doc)
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
def test_doc_to_json_underscore(doc):
@ -64,11 +110,96 @@ def test_doc_to_json_underscore(doc):
Doc.set_extension("json_test2", default=False)
doc._.json_test1 = "hello world"
doc._.json_test2 = [1, 2, 3]
json_doc = doc.to_json(underscore=["json_test1", "json_test2"])
assert "_" in json_doc
assert json_doc["_"]["json_test1"] == "hello world"
assert json_doc["_"]["json_test2"] == [1, 2, 3]
assert not schemas.validate(schemas.DocJSONSchema, json_doc)
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
def test_doc_to_json_with_token_span_attributes(doc):
Doc.set_extension("json_test1", default=False)
Doc.set_extension("json_test2", default=False)
Token.set_extension("token_test", default=False)
Span.set_extension("span_test", default=False)
doc._.json_test1 = "hello world"
doc._.json_test2 = [1, 2, 3]
doc[0:1]._.span_test = "span_attribute"
doc[0]._.token_test = 117
doc.spans["span_group"] = [doc[0:1]]
json_doc = doc.to_json(
underscore=["json_test1", "json_test2", "token_test", "span_test"]
)
assert "_" in json_doc
assert json_doc["_"]["json_test1"] == "hello world"
assert json_doc["_"]["json_test2"] == [1, 2, 3]
assert "underscore_token" in json_doc
assert "underscore_span" in json_doc
assert json_doc["underscore_token"]["token_test"]["value"] == 117
assert json_doc["underscore_span"]["span_test"]["value"] == "span_attribute"
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
def test_doc_to_json_with_custom_user_data(doc):
Doc.set_extension("json_test", default=False)
Token.set_extension("token_test", default=False)
Span.set_extension("span_test", default=False)
doc._.json_test = "hello world"
doc[0:1]._.span_test = "span_attribute"
doc[0]._.token_test = 117
json_doc = doc.to_json(underscore=["json_test", "token_test", "span_test"])
doc.user_data["user_data_test"] = 10
doc.user_data[("user_data_test2", True)] = 10
assert "_" in json_doc
assert json_doc["_"]["json_test"] == "hello world"
assert "underscore_token" in json_doc
assert "underscore_span" in json_doc
assert json_doc["underscore_token"]["token_test"]["value"] == 117
assert json_doc["underscore_span"]["span_test"]["value"] == "span_attribute"
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
def test_doc_to_json_with_token_span_same_identifier(doc):
Doc.set_extension("my_ext", default=False)
Token.set_extension("my_ext", default=False)
Span.set_extension("my_ext", default=False)
doc._.my_ext = "hello world"
doc[0:1]._.my_ext = "span_attribute"
doc[0]._.my_ext = 117
json_doc = doc.to_json(underscore=["my_ext"])
assert "_" in json_doc
assert json_doc["_"]["my_ext"] == "hello world"
assert "underscore_token" in json_doc
assert "underscore_span" in json_doc
assert json_doc["underscore_token"]["my_ext"]["value"] == 117
assert json_doc["underscore_span"]["my_ext"]["value"] == "span_attribute"
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc
def test_doc_to_json_with_token_attributes_missing(doc):
Token.set_extension("token_test", default=False)
Span.set_extension("span_test", default=False)
doc[0:1]._.span_test = "span_attribute"
doc[0]._.token_test = 117
json_doc = doc.to_json(underscore=["span_test"])
assert "underscore_token" in json_doc
assert "underscore_span" in json_doc
assert json_doc["underscore_span"]["span_test"]["value"] == "span_attribute"
assert "token_test" not in json_doc["underscore_token"]
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
def test_doc_to_json_underscore_error_attr(doc):
@ -94,11 +225,29 @@ def test_doc_to_json_span(doc):
assert len(json_doc["spans"]) == 1
assert len(json_doc["spans"]["test"]) == 2
assert json_doc["spans"]["test"][0]["start"] == 0
assert not schemas.validate(schemas.DocJSONSchema, json_doc)
assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0
def test_json_to_doc(doc):
new_doc = Doc(doc.vocab).from_json(doc.to_json(), validate=True)
json_doc = doc.to_json()
json_doc = srsly.json_loads(srsly.json_dumps(json_doc))
new_doc = Doc(doc.vocab).from_json(json_doc, validate=True)
assert new_doc.text == doc.text == "c d e "
assert len(new_doc) == len(doc) == 3
assert new_doc[0].pos == doc[0].pos
assert new_doc[0].tag == doc[0].tag
assert new_doc[0].dep == doc[0].dep
assert new_doc[0].head.idx == doc[0].head.idx
assert new_doc[0].lemma == doc[0].lemma
assert len(new_doc.ents) == 1
assert new_doc.ents[0].start == 1
assert new_doc.ents[0].end == 2
assert new_doc.ents[0].label_ == "ORG"
assert doc.to_bytes() == new_doc.to_bytes()
def test_json_to_doc_compat(doc, doc_json):
new_doc = Doc(doc.vocab).from_json(doc_json, validate=True)
new_tokens = [token for token in new_doc]
assert new_doc.text == doc.text == "c d e "
assert len(new_tokens) == len([token for token in doc]) == 3
@ -114,11 +263,8 @@ def test_json_to_doc(doc):
def test_json_to_doc_underscore(doc):
if not Doc.has_extension("json_test1"):
Doc.set_extension("json_test1", default=False)
if not Doc.has_extension("json_test2"):
Doc.set_extension("json_test2", default=False)
doc._.json_test1 = "hello world"
doc._.json_test2 = [1, 2, 3]
json_doc = doc.to_json(underscore=["json_test1", "json_test2"])
@ -126,6 +272,34 @@ def test_json_to_doc_underscore(doc):
assert all([new_doc.has_extension(f"json_test{i}") for i in range(1, 3)])
assert new_doc._.json_test1 == "hello world"
assert new_doc._.json_test2 == [1, 2, 3]
assert doc.to_bytes() == new_doc.to_bytes()
def test_json_to_doc_with_token_span_attributes(doc):
Doc.set_extension("json_test1", default=False)
Doc.set_extension("json_test2", default=False)
Token.set_extension("token_test", default=False)
Span.set_extension("span_test", default=False)
doc._.json_test1 = "hello world"
doc._.json_test2 = [1, 2, 3]
doc[0:1]._.span_test = "span_attribute"
doc[0]._.token_test = 117
json_doc = doc.to_json(
underscore=["json_test1", "json_test2", "token_test", "span_test"]
)
json_doc = srsly.json_loads(srsly.json_dumps(json_doc))
new_doc = Doc(doc.vocab).from_json(json_doc, validate=True)
assert all([new_doc.has_extension(f"json_test{i}") for i in range(1, 3)])
assert new_doc._.json_test1 == "hello world"
assert new_doc._.json_test2 == [1, 2, 3]
assert new_doc[0]._.token_test == 117
assert new_doc[0:1]._.span_test == "span_attribute"
assert new_doc.user_data == doc.user_data
assert new_doc.to_bytes(exclude=["user_data"]) == doc.to_bytes(
exclude=["user_data"]
)
def test_json_to_doc_spans(doc):

View File

@ -1602,13 +1602,30 @@ cdef class Doc:
ents.append(char_span)
self.ents = ents
# Add custom attributes. Note that only Doc extensions are currently considered, Token and Span extensions are
# not yet supported.
# Add custom attributes for the whole Doc object.
for attr in doc_json.get("_", {}):
if not Doc.has_extension(attr):
Doc.set_extension(attr)
self._.set(attr, doc_json["_"][attr])
if doc_json.get("underscore_token", {}):
for token_attr in doc_json["underscore_token"]:
token_start = doc_json["underscore_token"][token_attr]["token_start"]
value = doc_json["underscore_token"][token_attr]["value"]
if not Token.has_extension(token_attr):
Token.set_extension(token_attr)
self[token_start]._.set(token_attr, value)
if doc_json.get("underscore_span", {}):
for span_attr in doc_json["underscore_span"]:
token_start = doc_json["underscore_span"][span_attr]["token_start"]
token_end = doc_json["underscore_span"][span_attr]["token_end"]
value = doc_json["underscore_span"][span_attr]["value"]
if not Span.has_extension(span_attr):
Span.set_extension(span_attr)
self[token_start:token_end]._.set(span_attr, value)
return self
def to_json(self, underscore=None):
@ -1650,20 +1667,40 @@ cdef class Doc:
for span_group in self.spans:
data["spans"][span_group] = []
for span in self.spans[span_group]:
span_data = {
"start": span.start_char, "end": span.end_char, "label": span.label_, "kb_id": span.kb_id_
}
span_data = {"start": span.start_char, "end": span.end_char, "label": span.label_, "kb_id": span.kb_id_}
data["spans"][span_group].append(span_data)
if underscore:
user_keys = set()
if self.user_data:
data["_"] = {}
for attr in underscore:
if not self.has_extension(attr):
raise ValueError(Errors.E106.format(attr=attr, opts=underscore))
value = self._.get(attr)
data["underscore_token"] = {}
data["underscore_span"] = {}
for data_key in self.user_data:
if type(data_key) == tuple and len(data_key) >= 4 and data_key[0] == "._.":
attr = data_key[1]
start = data_key[2]
end = data_key[3]
if attr in underscore:
user_keys.add(attr)
value = self.user_data[data_key]
if not srsly.is_json_serializable(value):
raise ValueError(Errors.E107.format(attr=attr, value=repr(value)))
# Check if doc attribute
if start is None:
data["_"][attr] = value
# Check if token attribute
elif end is None:
if attr not in data["underscore_token"]:
data["underscore_token"][attr] = {"token_start": start, "value": value}
# Else span attribute
else:
if attr not in data["underscore_span"]:
data["underscore_span"][attr] = {"token_start": start, "token_end": end, "value": value}
for attr in underscore:
if attr not in user_keys:
raise ValueError(Errors.E106.format(attr=attr, opts=underscore))
return data
def to_utf8_array(self, int nr_char=-1):