From fe3a4aa846dcfde51eb3343cd04560da2b5ba705 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Mon, 17 May 2021 10:06:11 +0200 Subject: [PATCH] Add ENT_ID and NORM to DocBin strings (#8054) Save strings for token attributes `ENT_ID` and `NORM` in `DocBin` strings. --- spacy/tests/serialize/test_serialize_doc.py | 6 +++++- spacy/tokens/_serialize.py | 2 ++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/spacy/tests/serialize/test_serialize_doc.py b/spacy/tests/serialize/test_serialize_doc.py index 837c128af..5ce2549aa 100644 --- a/spacy/tests/serialize/test_serialize_doc.py +++ b/spacy/tests/serialize/test_serialize_doc.py @@ -64,13 +64,15 @@ def test_serialize_doc_span_groups(en_vocab): def test_serialize_doc_bin(): - doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE"], store_user_data=True) + doc_bin = DocBin(attrs=["LEMMA", "ENT_IOB", "ENT_TYPE", "NORM", "ENT_ID"], store_user_data=True) texts = ["Some text", "Lots of texts...", "..."] cats = {"A": 0.5} nlp = English() for doc in nlp.pipe(texts): doc.cats = cats doc.spans["start"] = [doc[0:2]] + doc[0].norm_ = "UNUSUAL_TOKEN_NORM" + doc[0].ent_id_ = "UNUSUAL_TOKEN_ENT_ID" doc_bin.add(doc) bytes_data = doc_bin.to_bytes() @@ -82,6 +84,8 @@ def test_serialize_doc_bin(): assert doc.text == texts[i] assert doc.cats == cats assert len(doc.spans) == 1 + assert doc[0].norm_ == "UNUSUAL_TOKEN_NORM" + assert doc[0].ent_id_ == "UNUSUAL_TOKEN_ENT_ID" def test_serialize_doc_bin_unknown_spaces(en_vocab): diff --git a/spacy/tokens/_serialize.py b/spacy/tokens/_serialize.py index d5b4e4ff7..e3251c29e 100644 --- a/spacy/tokens/_serialize.py +++ b/spacy/tokens/_serialize.py @@ -103,10 +103,12 @@ class DocBin: self.strings.add(token.text) self.strings.add(token.tag_) self.strings.add(token.lemma_) + self.strings.add(token.norm_) self.strings.add(str(token.morph)) self.strings.add(token.dep_) self.strings.add(token.ent_type_) self.strings.add(token.ent_kb_id_) + self.strings.add(token.ent_id_) self.cats.append(doc.cats) self.user_data.append(srsly.msgpack_dumps(doc.user_data)) self.span_groups.append(doc.spans.to_bytes())