Serialize AttributeRuler.patterns

Serialize `AttributeRuler.patterns` instead of the individual lists to
simplify the serialized and so that patterns are reloaded exactly as
they were originally provided (preserving `_attrs_unnormed`).
This commit is contained in:
Adriane Boyd 2020-08-28 20:42:26 +02:00
parent 89f692bc8a
commit 8674b17651
2 changed files with 6 additions and 42 deletions

View File

@ -230,10 +230,7 @@ class AttributeRuler(Pipe):
""" """
serialize = {} serialize = {}
serialize["vocab"] = self.vocab.to_bytes serialize["vocab"] = self.vocab.to_bytes
patterns = {k: self.matcher.get(k)[1] for k in range(len(self.attrs))} serialize["patterns"] = lambda: srsly.msgpack_dumps(self.patterns)
serialize["patterns"] = lambda: srsly.msgpack_dumps(patterns)
serialize["attrs"] = lambda: srsly.msgpack_dumps(self.attrs)
serialize["indices"] = lambda: srsly.msgpack_dumps(self.indices)
return util.to_bytes(serialize, exclude) return util.to_bytes(serialize, exclude)
def from_bytes(self, bytes_data: bytes, exclude: Iterable[str] = tuple()): def from_bytes(self, bytes_data: bytes, exclude: Iterable[str] = tuple()):
@ -245,31 +242,15 @@ class AttributeRuler(Pipe):
DOCS: https://spacy.io/api/attributeruler#from_bytes DOCS: https://spacy.io/api/attributeruler#from_bytes
""" """
data = {"patterns": b""}
def load_patterns(b): def load_patterns(b):
data["patterns"] = srsly.msgpack_loads(b) self.add_patterns(srsly.msgpack_loads(b))
def load_attrs(b):
self.attrs = srsly.msgpack_loads(b)
def load_indices(b):
self.indices = srsly.msgpack_loads(b)
deserialize = { deserialize = {
"vocab": lambda b: self.vocab.from_bytes(b), "vocab": lambda b: self.vocab.from_bytes(b),
"patterns": load_patterns, "patterns": load_patterns,
"attrs": load_attrs,
"indices": load_indices,
} }
util.from_bytes(bytes_data, deserialize, exclude) util.from_bytes(bytes_data, deserialize, exclude)
if data["patterns"]:
for key, pattern in data["patterns"].items():
self.matcher.add(key, pattern)
assert len(self.attrs) == len(data["patterns"])
assert len(self.indices) == len(data["patterns"])
return self return self
def to_disk(self, path: Union[Path, str], exclude: Iterable[str] = tuple()) -> None: def to_disk(self, path: Union[Path, str], exclude: Iterable[str] = tuple()) -> None:
@ -279,12 +260,9 @@ class AttributeRuler(Pipe):
exclude (Iterable[str]): String names of serialization fields to exclude. exclude (Iterable[str]): String names of serialization fields to exclude.
DOCS: https://spacy.io/api/attributeruler#to_disk DOCS: https://spacy.io/api/attributeruler#to_disk
""" """
patterns = {k: self.matcher.get(k)[1] for k in range(len(self.attrs))}
serialize = { serialize = {
"vocab": lambda p: self.vocab.to_disk(p), "vocab": lambda p: self.vocab.to_disk(p),
"patterns": lambda p: srsly.write_msgpack(p, patterns), "patterns": lambda p: srsly.write_msgpack(p, self.patterns),
"attrs": lambda p: srsly.write_msgpack(p, self.attrs),
"indices": lambda p: srsly.write_msgpack(p, self.indices),
} }
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
@ -297,31 +275,15 @@ class AttributeRuler(Pipe):
exclude (Iterable[str]): String names of serialization fields to exclude. exclude (Iterable[str]): String names of serialization fields to exclude.
DOCS: https://spacy.io/api/attributeruler#from_disk DOCS: https://spacy.io/api/attributeruler#from_disk
""" """
data = {"patterns": b""}
def load_patterns(p): def load_patterns(p):
data["patterns"] = srsly.read_msgpack(p) self.add_patterns(srsly.read_msgpack(p))
def load_attrs(p):
self.attrs = srsly.read_msgpack(p)
def load_indices(p):
self.indices = srsly.read_msgpack(p)
deserialize = { deserialize = {
"vocab": lambda p: self.vocab.from_disk(p), "vocab": lambda p: self.vocab.from_disk(p),
"patterns": load_patterns, "patterns": load_patterns,
"attrs": load_attrs,
"indices": load_indices,
} }
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)
if data["patterns"]:
for key, pattern in data["patterns"].items():
self.matcher.add(key, pattern)
assert len(self.attrs) == len(data["patterns"])
assert len(self.indices) == len(data["patterns"])
return self return self

View File

@ -215,6 +215,7 @@ def test_attributeruler_serialize(nlp, pattern_dicts):
assert a.to_bytes() == a_reloaded.to_bytes() assert a.to_bytes() == a_reloaded.to_bytes()
doc1 = a_reloaded(nlp.make_doc(text)) doc1 = a_reloaded(nlp.make_doc(text))
numpy.array_equal(doc.to_array(attrs), doc1.to_array(attrs)) numpy.array_equal(doc.to_array(attrs), doc1.to_array(attrs))
assert a.patterns == a_reloaded.patterns
# disk roundtrip # disk roundtrip
with make_tempdir() as tmp_dir: with make_tempdir() as tmp_dir:
@ -223,3 +224,4 @@ def test_attributeruler_serialize(nlp, pattern_dicts):
doc2 = nlp2(text) doc2 = nlp2(text)
assert nlp2.get_pipe("attribute_ruler").to_bytes() == a.to_bytes() assert nlp2.get_pipe("attribute_ruler").to_bytes() == a.to_bytes()
assert numpy.array_equal(doc.to_array(attrs), doc2.to_array(attrs)) assert numpy.array_equal(doc.to_array(attrs), doc2.to_array(attrs))
assert a.patterns == nlp2.get_pipe("attribute_ruler").patterns