diff --git a/spacy/errors.py b/spacy/errors.py index a1420c8fc..608305a06 100644 --- a/spacy/errors.py +++ b/spacy/errors.py @@ -389,7 +389,7 @@ class Errors(metaclass=ErrorsWithCodes): "consider using doc.spans instead.") E106 = ("Can't find `doc._.{attr}` attribute specified in the underscore " "settings: {opts}") - E107 = ("Value of `doc._.{attr}` is not JSON-serializable: {value}") + E107 = ("Value of custom attribute `{attr}` is not JSON-serializable: {value}") E109 = ("Component '{name}' could not be run. Did you forget to " "call `initialize()`?") E110 = ("Invalid displaCy render wrapper. Expected callable, got: {obj}") diff --git a/spacy/schemas.py b/spacy/schemas.py index 9f91451a9..048082134 100644 --- a/spacy/schemas.py +++ b/spacy/schemas.py @@ -514,6 +514,14 @@ class DocJSONSchema(BaseModel): tokens: List[Dict[StrictStr, Union[StrictStr, StrictInt]]] = Field( ..., title="Token information - ID, start, annotations" ) - _: Optional[Dict[StrictStr, Any]] = Field( - None, title="Any custom data stored in the document's _ attribute" + underscore_doc: Optional[Dict[StrictStr, Any]] = Field( + None, + title="Any custom data stored in the document's _ attribute", + alias="_", + ) + underscore_token: Optional[Dict[StrictStr, Dict[StrictStr, Any]]] = Field( + None, title="Any custom data stored in the token's _ attribute" + ) + underscore_span: Optional[Dict[StrictStr, Dict[StrictStr, Any]]] = Field( + None, title="Any custom data stored in the span's _ attribute" ) diff --git a/spacy/tests/doc/test_json_doc_conversion.py b/spacy/tests/doc/test_json_doc_conversion.py index 85e4def29..0d7c061c9 100644 --- a/spacy/tests/doc/test_json_doc_conversion.py +++ b/spacy/tests/doc/test_json_doc_conversion.py @@ -1,12 +1,15 @@ import pytest import spacy from spacy import schemas -from spacy.tokens import Doc, Span +from spacy.tokens import Doc, Span, Token +import srsly +from .test_underscore import clean_underscore # noqa: F401 @pytest.fixture() def doc(en_vocab): words = ["c", "d", "e"] + spaces = [True, True, True] pos = ["VERB", "NOUN", "NOUN"] tags = ["VBP", "NN", "NN"] heads = [0, 0, 1] @@ -17,6 +20,7 @@ def doc(en_vocab): return Doc( en_vocab, words=words, + spaces=spaces, pos=pos, tags=tags, heads=heads, @@ -45,6 +49,47 @@ def doc_without_deps(en_vocab): ) +@pytest.fixture() +def doc_json(): + return { + "text": "c d e ", + "ents": [{"start": 2, "end": 3, "label": "ORG"}], + "sents": [{"start": 0, "end": 5}], + "tokens": [ + { + "id": 0, + "start": 0, + "end": 1, + "tag": "VBP", + "pos": "VERB", + "morph": "Feat1=A", + "dep": "ROOT", + "head": 0, + }, + { + "id": 1, + "start": 2, + "end": 3, + "tag": "NN", + "pos": "NOUN", + "morph": "Feat1=B", + "dep": "dobj", + "head": 0, + }, + { + "id": 2, + "start": 4, + "end": 5, + "tag": "NN", + "pos": "NOUN", + "morph": "Feat1=A|Feat2=D", + "dep": "dobj", + "head": 1, + }, + ], + } + + def test_doc_to_json(doc): json_doc = doc.to_json() assert json_doc["text"] == "c d e " @@ -56,7 +101,8 @@ def test_doc_to_json(doc): assert json_doc["ents"][0]["start"] == 2 # character offset! assert json_doc["ents"][0]["end"] == 3 # character offset! assert json_doc["ents"][0]["label"] == "ORG" - assert not schemas.validate(schemas.DocJSONSchema, json_doc) + assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0 + assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc def test_doc_to_json_underscore(doc): @@ -64,11 +110,96 @@ def test_doc_to_json_underscore(doc): Doc.set_extension("json_test2", default=False) doc._.json_test1 = "hello world" doc._.json_test2 = [1, 2, 3] + json_doc = doc.to_json(underscore=["json_test1", "json_test2"]) assert "_" in json_doc assert json_doc["_"]["json_test1"] == "hello world" assert json_doc["_"]["json_test2"] == [1, 2, 3] - assert not schemas.validate(schemas.DocJSONSchema, json_doc) + assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0 + assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc + + +def test_doc_to_json_with_token_span_attributes(doc): + Doc.set_extension("json_test1", default=False) + Doc.set_extension("json_test2", default=False) + Token.set_extension("token_test", default=False) + Span.set_extension("span_test", default=False) + + doc._.json_test1 = "hello world" + doc._.json_test2 = [1, 2, 3] + doc[0:1]._.span_test = "span_attribute" + doc[0]._.token_test = 117 + doc.spans["span_group"] = [doc[0:1]] + json_doc = doc.to_json( + underscore=["json_test1", "json_test2", "token_test", "span_test"] + ) + + assert "_" in json_doc + assert json_doc["_"]["json_test1"] == "hello world" + assert json_doc["_"]["json_test2"] == [1, 2, 3] + assert "underscore_token" in json_doc + assert "underscore_span" in json_doc + assert json_doc["underscore_token"]["token_test"]["value"] == 117 + assert json_doc["underscore_span"]["span_test"]["value"] == "span_attribute" + assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0 + assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc + + +def test_doc_to_json_with_custom_user_data(doc): + Doc.set_extension("json_test", default=False) + Token.set_extension("token_test", default=False) + Span.set_extension("span_test", default=False) + + doc._.json_test = "hello world" + doc[0:1]._.span_test = "span_attribute" + doc[0]._.token_test = 117 + json_doc = doc.to_json(underscore=["json_test", "token_test", "span_test"]) + doc.user_data["user_data_test"] = 10 + doc.user_data[("user_data_test2", True)] = 10 + + assert "_" in json_doc + assert json_doc["_"]["json_test"] == "hello world" + assert "underscore_token" in json_doc + assert "underscore_span" in json_doc + assert json_doc["underscore_token"]["token_test"]["value"] == 117 + assert json_doc["underscore_span"]["span_test"]["value"] == "span_attribute" + assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0 + assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc + + +def test_doc_to_json_with_token_span_same_identifier(doc): + Doc.set_extension("my_ext", default=False) + Token.set_extension("my_ext", default=False) + Span.set_extension("my_ext", default=False) + + doc._.my_ext = "hello world" + doc[0:1]._.my_ext = "span_attribute" + doc[0]._.my_ext = 117 + json_doc = doc.to_json(underscore=["my_ext"]) + + assert "_" in json_doc + assert json_doc["_"]["my_ext"] == "hello world" + assert "underscore_token" in json_doc + assert "underscore_span" in json_doc + assert json_doc["underscore_token"]["my_ext"]["value"] == 117 + assert json_doc["underscore_span"]["my_ext"]["value"] == "span_attribute" + assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0 + assert srsly.json_loads(srsly.json_dumps(json_doc)) == json_doc + + +def test_doc_to_json_with_token_attributes_missing(doc): + Token.set_extension("token_test", default=False) + Span.set_extension("span_test", default=False) + + doc[0:1]._.span_test = "span_attribute" + doc[0]._.token_test = 117 + json_doc = doc.to_json(underscore=["span_test"]) + + assert "underscore_token" in json_doc + assert "underscore_span" in json_doc + assert json_doc["underscore_span"]["span_test"]["value"] == "span_attribute" + assert "token_test" not in json_doc["underscore_token"] + assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0 def test_doc_to_json_underscore_error_attr(doc): @@ -94,11 +225,29 @@ def test_doc_to_json_span(doc): assert len(json_doc["spans"]) == 1 assert len(json_doc["spans"]["test"]) == 2 assert json_doc["spans"]["test"][0]["start"] == 0 - assert not schemas.validate(schemas.DocJSONSchema, json_doc) + assert len(schemas.validate(schemas.DocJSONSchema, json_doc)) == 0 def test_json_to_doc(doc): - new_doc = Doc(doc.vocab).from_json(doc.to_json(), validate=True) + json_doc = doc.to_json() + json_doc = srsly.json_loads(srsly.json_dumps(json_doc)) + new_doc = Doc(doc.vocab).from_json(json_doc, validate=True) + assert new_doc.text == doc.text == "c d e " + assert len(new_doc) == len(doc) == 3 + assert new_doc[0].pos == doc[0].pos + assert new_doc[0].tag == doc[0].tag + assert new_doc[0].dep == doc[0].dep + assert new_doc[0].head.idx == doc[0].head.idx + assert new_doc[0].lemma == doc[0].lemma + assert len(new_doc.ents) == 1 + assert new_doc.ents[0].start == 1 + assert new_doc.ents[0].end == 2 + assert new_doc.ents[0].label_ == "ORG" + assert doc.to_bytes() == new_doc.to_bytes() + + +def test_json_to_doc_compat(doc, doc_json): + new_doc = Doc(doc.vocab).from_json(doc_json, validate=True) new_tokens = [token for token in new_doc] assert new_doc.text == doc.text == "c d e " assert len(new_tokens) == len([token for token in doc]) == 3 @@ -114,11 +263,8 @@ def test_json_to_doc(doc): def test_json_to_doc_underscore(doc): - if not Doc.has_extension("json_test1"): - Doc.set_extension("json_test1", default=False) - if not Doc.has_extension("json_test2"): - Doc.set_extension("json_test2", default=False) - + Doc.set_extension("json_test1", default=False) + Doc.set_extension("json_test2", default=False) doc._.json_test1 = "hello world" doc._.json_test2 = [1, 2, 3] json_doc = doc.to_json(underscore=["json_test1", "json_test2"]) @@ -126,6 +272,34 @@ def test_json_to_doc_underscore(doc): assert all([new_doc.has_extension(f"json_test{i}") for i in range(1, 3)]) assert new_doc._.json_test1 == "hello world" assert new_doc._.json_test2 == [1, 2, 3] + assert doc.to_bytes() == new_doc.to_bytes() + + +def test_json_to_doc_with_token_span_attributes(doc): + Doc.set_extension("json_test1", default=False) + Doc.set_extension("json_test2", default=False) + Token.set_extension("token_test", default=False) + Span.set_extension("span_test", default=False) + doc._.json_test1 = "hello world" + doc._.json_test2 = [1, 2, 3] + doc[0:1]._.span_test = "span_attribute" + doc[0]._.token_test = 117 + + json_doc = doc.to_json( + underscore=["json_test1", "json_test2", "token_test", "span_test"] + ) + json_doc = srsly.json_loads(srsly.json_dumps(json_doc)) + new_doc = Doc(doc.vocab).from_json(json_doc, validate=True) + + assert all([new_doc.has_extension(f"json_test{i}") for i in range(1, 3)]) + assert new_doc._.json_test1 == "hello world" + assert new_doc._.json_test2 == [1, 2, 3] + assert new_doc[0]._.token_test == 117 + assert new_doc[0:1]._.span_test == "span_attribute" + assert new_doc.user_data == doc.user_data + assert new_doc.to_bytes(exclude=["user_data"]) == doc.to_bytes( + exclude=["user_data"] + ) def test_json_to_doc_spans(doc): diff --git a/spacy/tokens/doc.pyx b/spacy/tokens/doc.pyx index d9a104ac8..7ba9a3341 100644 --- a/spacy/tokens/doc.pyx +++ b/spacy/tokens/doc.pyx @@ -1602,13 +1602,30 @@ cdef class Doc: ents.append(char_span) self.ents = ents - # Add custom attributes. Note that only Doc extensions are currently considered, Token and Span extensions are - # not yet supported. + # Add custom attributes for the whole Doc object. for attr in doc_json.get("_", {}): if not Doc.has_extension(attr): Doc.set_extension(attr) self._.set(attr, doc_json["_"][attr]) + if doc_json.get("underscore_token", {}): + for token_attr in doc_json["underscore_token"]: + token_start = doc_json["underscore_token"][token_attr]["token_start"] + value = doc_json["underscore_token"][token_attr]["value"] + + if not Token.has_extension(token_attr): + Token.set_extension(token_attr) + self[token_start]._.set(token_attr, value) + + if doc_json.get("underscore_span", {}): + for span_attr in doc_json["underscore_span"]: + token_start = doc_json["underscore_span"][span_attr]["token_start"] + token_end = doc_json["underscore_span"][span_attr]["token_end"] + value = doc_json["underscore_span"][span_attr]["value"] + + if not Span.has_extension(span_attr): + Span.set_extension(span_attr) + self[token_start:token_end]._.set(span_attr, value) return self def to_json(self, underscore=None): @@ -1650,20 +1667,40 @@ cdef class Doc: for span_group in self.spans: data["spans"][span_group] = [] for span in self.spans[span_group]: - span_data = { - "start": span.start_char, "end": span.end_char, "label": span.label_, "kb_id": span.kb_id_ - } + span_data = {"start": span.start_char, "end": span.end_char, "label": span.label_, "kb_id": span.kb_id_} data["spans"][span_group].append(span_data) if underscore: - data["_"] = {} + user_keys = set() + if self.user_data: + data["_"] = {} + data["underscore_token"] = {} + data["underscore_span"] = {} + for data_key in self.user_data: + if type(data_key) == tuple and len(data_key) >= 4 and data_key[0] == "._.": + attr = data_key[1] + start = data_key[2] + end = data_key[3] + if attr in underscore: + user_keys.add(attr) + value = self.user_data[data_key] + if not srsly.is_json_serializable(value): + raise ValueError(Errors.E107.format(attr=attr, value=repr(value))) + # Check if doc attribute + if start is None: + data["_"][attr] = value + # Check if token attribute + elif end is None: + if attr not in data["underscore_token"]: + data["underscore_token"][attr] = {"token_start": start, "value": value} + # Else span attribute + else: + if attr not in data["underscore_span"]: + data["underscore_span"][attr] = {"token_start": start, "token_end": end, "value": value} + for attr in underscore: - if not self.has_extension(attr): + if attr not in user_keys: raise ValueError(Errors.E106.format(attr=attr, opts=underscore)) - value = self._.get(attr) - if not srsly.is_json_serializable(value): - raise ValueError(Errors.E107.format(attr=attr, value=repr(value))) - data["_"][attr] = value return data def to_utf8_array(self, int nr_char=-1):