mirror of
https://github.com/explosion/spaCy.git
synced 2025-01-12 02:06:31 +03:00
Add strings and ENT_KB_ID to Doc serialization (#5691)
* Add strings for all writeable Token attributes to `Doc.to/from_bytes()`. * Add ENT_KB_ID to default attributes.
This commit is contained in:
parent
971826a96d
commit
a77c4c3465
|
@ -106,10 +106,16 @@ def test_doc_api_getitem(en_tokenizer):
|
||||||
)
|
)
|
||||||
def test_doc_api_serialize(en_tokenizer, text):
|
def test_doc_api_serialize(en_tokenizer, text):
|
||||||
tokens = en_tokenizer(text)
|
tokens = en_tokenizer(text)
|
||||||
|
tokens[0].lemma_ = "lemma"
|
||||||
|
tokens[0].norm_ = "norm"
|
||||||
|
tokens[0].ent_kb_id_ = "ent_kb_id"
|
||||||
new_tokens = Doc(tokens.vocab).from_bytes(tokens.to_bytes())
|
new_tokens = Doc(tokens.vocab).from_bytes(tokens.to_bytes())
|
||||||
assert tokens.text == new_tokens.text
|
assert tokens.text == new_tokens.text
|
||||||
assert [t.text for t in tokens] == [t.text for t in new_tokens]
|
assert [t.text for t in tokens] == [t.text for t in new_tokens]
|
||||||
assert [t.orth for t in tokens] == [t.orth for t in new_tokens]
|
assert [t.orth for t in tokens] == [t.orth for t in new_tokens]
|
||||||
|
assert new_tokens[0].lemma_ == "lemma"
|
||||||
|
assert new_tokens[0].norm_ == "norm"
|
||||||
|
assert new_tokens[0].ent_kb_id_ == "ent_kb_id"
|
||||||
|
|
||||||
new_tokens = Doc(tokens.vocab).from_bytes(
|
new_tokens = Doc(tokens.vocab).from_bytes(
|
||||||
tokens.to_bytes(exclude=["tensor"]), exclude=["tensor"]
|
tokens.to_bytes(exclude=["tensor"]), exclude=["tensor"]
|
||||||
|
|
|
@ -892,7 +892,7 @@ cdef class Doc:
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/doc#to_bytes
|
DOCS: https://spacy.io/api/doc#to_bytes
|
||||||
"""
|
"""
|
||||||
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_ID, NORM] # TODO: ENT_KB_ID ?
|
array_head = [LENGTH, SPACY, LEMMA, ENT_IOB, ENT_TYPE, ENT_ID, NORM, ENT_KB_ID]
|
||||||
if self.is_tagged:
|
if self.is_tagged:
|
||||||
array_head.extend([TAG, POS])
|
array_head.extend([TAG, POS])
|
||||||
# If doc parsed add head and dep attribute
|
# If doc parsed add head and dep attribute
|
||||||
|
@ -901,6 +901,14 @@ cdef class Doc:
|
||||||
# Otherwise add sent_start
|
# Otherwise add sent_start
|
||||||
else:
|
else:
|
||||||
array_head.append(SENT_START)
|
array_head.append(SENT_START)
|
||||||
|
strings = set()
|
||||||
|
for token in self:
|
||||||
|
strings.add(token.tag_)
|
||||||
|
strings.add(token.lemma_)
|
||||||
|
strings.add(token.dep_)
|
||||||
|
strings.add(token.ent_type_)
|
||||||
|
strings.add(token.ent_kb_id_)
|
||||||
|
strings.add(token.norm_)
|
||||||
# Msgpack doesn't distinguish between lists and tuples, which is
|
# Msgpack doesn't distinguish between lists and tuples, which is
|
||||||
# vexing for user data. As a best guess, we *know* that within
|
# vexing for user data. As a best guess, we *know* that within
|
||||||
# keys, we must have tuples. In values we just have to hope
|
# keys, we must have tuples. In values we just have to hope
|
||||||
|
@ -912,6 +920,7 @@ cdef class Doc:
|
||||||
"sentiment": lambda: self.sentiment,
|
"sentiment": lambda: self.sentiment,
|
||||||
"tensor": lambda: self.tensor,
|
"tensor": lambda: self.tensor,
|
||||||
"cats": lambda: self.cats,
|
"cats": lambda: self.cats,
|
||||||
|
"strings": lambda: list(strings),
|
||||||
}
|
}
|
||||||
for key in kwargs:
|
for key in kwargs:
|
||||||
if key in serializers or key in ("user_data", "user_data_keys", "user_data_values"):
|
if key in serializers or key in ("user_data", "user_data_keys", "user_data_values"):
|
||||||
|
@ -942,6 +951,7 @@ cdef class Doc:
|
||||||
"sentiment": lambda b: None,
|
"sentiment": lambda b: None,
|
||||||
"tensor": lambda b: None,
|
"tensor": lambda b: None,
|
||||||
"cats": lambda b: None,
|
"cats": lambda b: None,
|
||||||
|
"strings": lambda b: None,
|
||||||
"user_data_keys": lambda b: None,
|
"user_data_keys": lambda b: None,
|
||||||
"user_data_values": lambda b: None,
|
"user_data_values": lambda b: None,
|
||||||
}
|
}
|
||||||
|
@ -965,6 +975,9 @@ cdef class Doc:
|
||||||
self.tensor = msg["tensor"]
|
self.tensor = msg["tensor"]
|
||||||
if "cats" not in exclude and "cats" in msg:
|
if "cats" not in exclude and "cats" in msg:
|
||||||
self.cats = msg["cats"]
|
self.cats = msg["cats"]
|
||||||
|
if "strings" not in exclude and "strings" in msg:
|
||||||
|
for s in msg["strings"]:
|
||||||
|
self.vocab.strings.add(s)
|
||||||
start = 0
|
start = 0
|
||||||
cdef const LexemeC* lex
|
cdef const LexemeC* lex
|
||||||
cdef unicode orth_
|
cdef unicode orth_
|
||||||
|
|
Loading…
Reference in New Issue
Block a user