mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Restore Doc attr getter values in Doc.to_json (#11700)
This commit is contained in:
parent
db56600536
commit
40e1000db0
|
@ -370,3 +370,12 @@ def test_json_to_doc_validation_error(doc):
|
||||||
doc_json.pop("tokens")
|
doc_json.pop("tokens")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
Doc(doc.vocab).from_json(doc_json, validate=True)
|
Doc(doc.vocab).from_json(doc_json, validate=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_to_json_underscore_doc_getters(doc):
|
||||||
|
def get_text_length(doc):
|
||||||
|
return len(doc.text)
|
||||||
|
|
||||||
|
Doc.set_extension("text_length", getter=get_text_length)
|
||||||
|
doc_json = doc.to_json(underscore=["text_length"])
|
||||||
|
assert doc_json["_"]["text_length"] == get_text_length(doc)
|
||||||
|
|
|
@ -1668,6 +1668,20 @@ cdef class Doc:
|
||||||
|
|
||||||
if underscore:
|
if underscore:
|
||||||
user_keys = set()
|
user_keys = set()
|
||||||
|
# Handle doc attributes with .get to include values from getters
|
||||||
|
# and not only values stored in user_data, for backwards
|
||||||
|
# compatibility
|
||||||
|
for attr in underscore:
|
||||||
|
if self.has_extension(attr):
|
||||||
|
if "_" not in data:
|
||||||
|
data["_"] = {}
|
||||||
|
value = self._.get(attr)
|
||||||
|
if not srsly.is_json_serializable(value):
|
||||||
|
raise ValueError(Errors.E107.format(attr=attr, value=repr(value)))
|
||||||
|
data["_"][attr] = value
|
||||||
|
user_keys.add(attr)
|
||||||
|
# Token and span attributes only include values stored in user_data
|
||||||
|
# and not values generated by getters
|
||||||
if self.user_data:
|
if self.user_data:
|
||||||
for data_key, value in self.user_data.copy().items():
|
for data_key, value in self.user_data.copy().items():
|
||||||
if type(data_key) == tuple and len(data_key) >= 4 and data_key[0] == "._.":
|
if type(data_key) == tuple and len(data_key) >= 4 and data_key[0] == "._.":
|
||||||
|
@ -1678,20 +1692,15 @@ cdef class Doc:
|
||||||
user_keys.add(attr)
|
user_keys.add(attr)
|
||||||
if not srsly.is_json_serializable(value):
|
if not srsly.is_json_serializable(value):
|
||||||
raise ValueError(Errors.E107.format(attr=attr, value=repr(value)))
|
raise ValueError(Errors.E107.format(attr=attr, value=repr(value)))
|
||||||
# Check if doc attribute
|
# Token attribute
|
||||||
if start is None:
|
if start is not None and end is None:
|
||||||
if "_" not in data:
|
|
||||||
data["_"] = {}
|
|
||||||
data["_"][attr] = value
|
|
||||||
# Check if token attribute
|
|
||||||
elif end is None:
|
|
||||||
if "underscore_token" not in data:
|
if "underscore_token" not in data:
|
||||||
data["underscore_token"] = {}
|
data["underscore_token"] = {}
|
||||||
if attr not in data["underscore_token"]:
|
if attr not in data["underscore_token"]:
|
||||||
data["underscore_token"][attr] = []
|
data["underscore_token"][attr] = []
|
||||||
data["underscore_token"][attr].append({"start": start, "value": value})
|
data["underscore_token"][attr].append({"start": start, "value": value})
|
||||||
# Else span attribute
|
# Span attribute
|
||||||
else:
|
elif start is not None and end is not None:
|
||||||
if "underscore_span" not in data:
|
if "underscore_span" not in data:
|
||||||
data["underscore_span"] = {}
|
data["underscore_span"] = {}
|
||||||
if attr not in data["underscore_span"]:
|
if attr not in data["underscore_span"]:
|
||||||
|
|
Loading…
Reference in New Issue
Block a user