Pass excludes when serializing vocab (#8824)

* Pass excludes when serializing vocab

Additional minor bug fix:

* Deserialize vocab in `EntityLinker.from_disk`

* Add test for excluding strings on load

* Fix formatting
This commit is contained in:
Adriane Boyd 2021-08-03 14:42:44 +02:00 committed by GitHub
parent 175847f92c
commit 941a591f3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 45 additions and 26 deletions

View File

@ -1909,7 +1909,7 @@ class Language:
if not hasattr(proc, "to_disk"): if not hasattr(proc, "to_disk"):
continue continue
serializers[name] = lambda p, proc=proc: proc.to_disk(p, exclude=["vocab"]) serializers[name] = lambda p, proc=proc: proc.to_disk(p, exclude=["vocab"])
serializers["vocab"] = lambda p: self.vocab.to_disk(p) serializers["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
util.to_disk(path, serializers, exclude) util.to_disk(path, serializers, exclude)
def from_disk( def from_disk(
@ -1940,7 +1940,7 @@ class Language:
def deserialize_vocab(path: Path) -> None: def deserialize_vocab(path: Path) -> None:
if path.exists(): if path.exists():
self.vocab.from_disk(path) self.vocab.from_disk(path, exclude=exclude)
path = util.ensure_path(path) path = util.ensure_path(path)
deserializers = {} deserializers = {}
@ -1978,7 +1978,7 @@ class Language:
DOCS: https://spacy.io/api/language#to_bytes DOCS: https://spacy.io/api/language#to_bytes
""" """
serializers = {} serializers = {}
serializers["vocab"] = lambda: self.vocab.to_bytes() serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta) serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
serializers["config.cfg"] = lambda: self.config.to_bytes() serializers["config.cfg"] = lambda: self.config.to_bytes()
@ -2014,7 +2014,7 @@ class Language:
b, interpolate=False b, interpolate=False
) )
deserializers["meta.json"] = deserialize_meta deserializers["meta.json"] = deserialize_meta
deserializers["vocab"] = self.vocab.from_bytes deserializers["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes( deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
b, exclude=["vocab"] b, exclude=["vocab"]
) )

View File

@ -276,7 +276,7 @@ class AttributeRuler(Pipe):
DOCS: https://spacy.io/api/attributeruler#to_bytes DOCS: https://spacy.io/api/attributeruler#to_bytes
""" """
serialize = {} serialize = {}
serialize["vocab"] = self.vocab.to_bytes serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serialize["patterns"] = lambda: srsly.msgpack_dumps(self.patterns) serialize["patterns"] = lambda: srsly.msgpack_dumps(self.patterns)
return util.to_bytes(serialize, exclude) return util.to_bytes(serialize, exclude)
@ -296,7 +296,7 @@ class AttributeRuler(Pipe):
self.add_patterns(srsly.msgpack_loads(b)) self.add_patterns(srsly.msgpack_loads(b))
deserialize = { deserialize = {
"vocab": lambda b: self.vocab.from_bytes(b), "vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
"patterns": load_patterns, "patterns": load_patterns,
} }
util.from_bytes(bytes_data, deserialize, exclude) util.from_bytes(bytes_data, deserialize, exclude)
@ -313,7 +313,7 @@ class AttributeRuler(Pipe):
DOCS: https://spacy.io/api/attributeruler#to_disk DOCS: https://spacy.io/api/attributeruler#to_disk
""" """
serialize = { serialize = {
"vocab": lambda p: self.vocab.to_disk(p), "vocab": lambda p: self.vocab.to_disk(p, exclude=exclude),
"patterns": lambda p: srsly.write_msgpack(p, self.patterns), "patterns": lambda p: srsly.write_msgpack(p, self.patterns),
} }
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
@ -334,7 +334,7 @@ class AttributeRuler(Pipe):
self.add_patterns(srsly.read_msgpack(p)) self.add_patterns(srsly.read_msgpack(p))
deserialize = { deserialize = {
"vocab": lambda p: self.vocab.from_disk(p), "vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
"patterns": load_patterns, "patterns": load_patterns,
} }
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)

View File

@ -412,7 +412,7 @@ class EntityLinker(TrainablePipe):
serialize = {} serialize = {}
if hasattr(self, "cfg") and self.cfg is not None: if hasattr(self, "cfg") and self.cfg is not None:
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
serialize["vocab"] = self.vocab.to_bytes serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serialize["kb"] = self.kb.to_bytes serialize["kb"] = self.kb.to_bytes
serialize["model"] = self.model.to_bytes serialize["model"] = self.model.to_bytes
return util.to_bytes(serialize, exclude) return util.to_bytes(serialize, exclude)
@ -436,7 +436,7 @@ class EntityLinker(TrainablePipe):
deserialize = {} deserialize = {}
if hasattr(self, "cfg") and self.cfg is not None: if hasattr(self, "cfg") and self.cfg is not None:
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b) deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
deserialize["kb"] = lambda b: self.kb.from_bytes(b) deserialize["kb"] = lambda b: self.kb.from_bytes(b)
deserialize["model"] = load_model deserialize["model"] = load_model
util.from_bytes(bytes_data, deserialize, exclude) util.from_bytes(bytes_data, deserialize, exclude)
@ -453,7 +453,7 @@ class EntityLinker(TrainablePipe):
DOCS: https://spacy.io/api/entitylinker#to_disk DOCS: https://spacy.io/api/entitylinker#to_disk
""" """
serialize = {} serialize = {}
serialize["vocab"] = lambda p: self.vocab.to_disk(p) serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["kb"] = lambda p: self.kb.to_disk(p) serialize["kb"] = lambda p: self.kb.to_disk(p)
serialize["model"] = lambda p: self.model.to_disk(p) serialize["model"] = lambda p: self.model.to_disk(p)
@ -480,6 +480,7 @@ class EntityLinker(TrainablePipe):
deserialize = {} deserialize = {}
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p)) deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
deserialize["kb"] = lambda p: self.kb.from_disk(p) deserialize["kb"] = lambda p: self.kb.from_disk(p)
deserialize["model"] = load_model deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)

View File

@ -269,7 +269,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#to_disk DOCS: https://spacy.io/api/lemmatizer#to_disk
""" """
serialize = {} serialize = {}
serialize["vocab"] = lambda p: self.vocab.to_disk(p) serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
serialize["lookups"] = lambda p: self.lookups.to_disk(p) serialize["lookups"] = lambda p: self.lookups.to_disk(p)
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
@ -285,7 +285,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#from_disk DOCS: https://spacy.io/api/lemmatizer#from_disk
""" """
deserialize = {} deserialize = {}
deserialize["vocab"] = lambda p: self.vocab.from_disk(p) deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
deserialize["lookups"] = lambda p: self.lookups.from_disk(p) deserialize["lookups"] = lambda p: self.lookups.from_disk(p)
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)
self._validate_tables() self._validate_tables()
@ -300,7 +300,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#to_bytes DOCS: https://spacy.io/api/lemmatizer#to_bytes
""" """
serialize = {} serialize = {}
serialize["vocab"] = self.vocab.to_bytes serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serialize["lookups"] = self.lookups.to_bytes serialize["lookups"] = self.lookups.to_bytes
return util.to_bytes(serialize, exclude) return util.to_bytes(serialize, exclude)
@ -316,7 +316,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#from_bytes DOCS: https://spacy.io/api/lemmatizer#from_bytes
""" """
deserialize = {} deserialize = {}
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b) deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
deserialize["lookups"] = lambda b: self.lookups.from_bytes(b) deserialize["lookups"] = lambda b: self.lookups.from_bytes(b)
util.from_bytes(bytes_data, deserialize, exclude) util.from_bytes(bytes_data, deserialize, exclude)
self._validate_tables() self._validate_tables()

View File

@ -273,7 +273,7 @@ cdef class TrainablePipe(Pipe):
serialize = {} serialize = {}
if hasattr(self, "cfg") and self.cfg is not None: if hasattr(self, "cfg") and self.cfg is not None:
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg) serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
serialize["vocab"] = self.vocab.to_bytes serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serialize["model"] = self.model.to_bytes serialize["model"] = self.model.to_bytes
return util.to_bytes(serialize, exclude) return util.to_bytes(serialize, exclude)
@ -296,7 +296,7 @@ cdef class TrainablePipe(Pipe):
deserialize = {} deserialize = {}
if hasattr(self, "cfg") and self.cfg is not None: if hasattr(self, "cfg") and self.cfg is not None:
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b)) deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b) deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
deserialize["model"] = load_model deserialize["model"] = load_model
util.from_bytes(bytes_data, deserialize, exclude) util.from_bytes(bytes_data, deserialize, exclude)
return self return self
@ -313,7 +313,7 @@ cdef class TrainablePipe(Pipe):
serialize = {} serialize = {}
if hasattr(self, "cfg") and self.cfg is not None: if hasattr(self, "cfg") and self.cfg is not None:
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg) serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["vocab"] = lambda p: self.vocab.to_disk(p) serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
serialize["model"] = lambda p: self.model.to_disk(p) serialize["model"] = lambda p: self.model.to_disk(p)
util.to_disk(path, serialize, exclude) util.to_disk(path, serialize, exclude)
@ -338,7 +338,7 @@ cdef class TrainablePipe(Pipe):
deserialize = {} deserialize = {}
if hasattr(self, "cfg") and self.cfg is not None: if hasattr(self, "cfg") and self.cfg is not None:
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p)) deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
deserialize["vocab"] = lambda p: self.vocab.from_disk(p) deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
deserialize["model"] = load_model deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude) util.from_disk(path, deserialize, exclude)
return self return self

View File

@ -569,7 +569,7 @@ cdef class Parser(TrainablePipe):
def to_disk(self, path, exclude=tuple()): def to_disk(self, path, exclude=tuple()):
serializers = { serializers = {
"model": lambda p: (self.model.to_disk(p) if self.model is not True else True), "model": lambda p: (self.model.to_disk(p) if self.model is not True else True),
"vocab": lambda p: self.vocab.to_disk(p), "vocab": lambda p: self.vocab.to_disk(p, exclude=exclude),
"moves": lambda p: self.moves.to_disk(p, exclude=["strings"]), "moves": lambda p: self.moves.to_disk(p, exclude=["strings"]),
"cfg": lambda p: srsly.write_json(p, self.cfg) "cfg": lambda p: srsly.write_json(p, self.cfg)
} }
@ -577,7 +577,7 @@ cdef class Parser(TrainablePipe):
def from_disk(self, path, exclude=tuple()): def from_disk(self, path, exclude=tuple()):
deserializers = { deserializers = {
"vocab": lambda p: self.vocab.from_disk(p), "vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
"moves": lambda p: self.moves.from_disk(p, exclude=["strings"]), "moves": lambda p: self.moves.from_disk(p, exclude=["strings"]),
"cfg": lambda p: self.cfg.update(srsly.read_json(p)), "cfg": lambda p: self.cfg.update(srsly.read_json(p)),
"model": lambda p: None, "model": lambda p: None,
@ -597,7 +597,7 @@ cdef class Parser(TrainablePipe):
def to_bytes(self, exclude=tuple()): def to_bytes(self, exclude=tuple()):
serializers = { serializers = {
"model": lambda: (self.model.to_bytes()), "model": lambda: (self.model.to_bytes()),
"vocab": lambda: self.vocab.to_bytes(), "vocab": lambda: self.vocab.to_bytes(exclude=exclude),
"moves": lambda: self.moves.to_bytes(exclude=["strings"]), "moves": lambda: self.moves.to_bytes(exclude=["strings"]),
"cfg": lambda: srsly.json_dumps(self.cfg, indent=2, sort_keys=True) "cfg": lambda: srsly.json_dumps(self.cfg, indent=2, sort_keys=True)
} }
@ -605,7 +605,7 @@ cdef class Parser(TrainablePipe):
def from_bytes(self, bytes_data, exclude=tuple()): def from_bytes(self, bytes_data, exclude=tuple()):
deserializers = { deserializers = {
"vocab": lambda b: self.vocab.from_bytes(b), "vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
"moves": lambda b: self.moves.from_bytes(b, exclude=["strings"]), "moves": lambda b: self.moves.from_bytes(b, exclude=["strings"]),
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)), "cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
"model": lambda b: None, "model": lambda b: None,

View File

@ -1,5 +1,5 @@
import pytest import pytest
from spacy import registry, Vocab from spacy import registry, Vocab, load
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
from spacy.pipeline import TextCategorizer, SentenceRecognizer, TrainablePipe from spacy.pipeline import TextCategorizer, SentenceRecognizer, TrainablePipe
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
@ -268,3 +268,21 @@ def test_serialize_custom_trainable_pipe():
pipe.to_disk(d) pipe.to_disk(d)
new_pipe = CustomPipe(Vocab(), Linear()).from_disk(d) new_pipe = CustomPipe(Vocab(), Linear()).from_disk(d)
assert new_pipe.to_bytes() == pipe_bytes assert new_pipe.to_bytes() == pipe_bytes
def test_load_without_strings():
nlp = spacy.blank("en")
orig_strings_length = len(nlp.vocab.strings)
word = "unlikely_word_" * 20
nlp.vocab.strings.add(word)
assert len(nlp.vocab.strings) == orig_strings_length + 1
with make_tempdir() as d:
nlp.to_disk(d)
# reload with strings
reloaded_nlp = load(d)
assert len(nlp.vocab.strings) == len(reloaded_nlp.vocab.strings)
assert word in reloaded_nlp.vocab.strings
# reload without strings
reloaded_nlp = load(d, exclude=["strings"])
assert orig_strings_length == len(reloaded_nlp.vocab.strings)
assert word not in reloaded_nlp.vocab.strings

View File

@ -765,7 +765,7 @@ cdef class Tokenizer:
DOCS: https://spacy.io/api/tokenizer#to_bytes DOCS: https://spacy.io/api/tokenizer#to_bytes
""" """
serializers = { serializers = {
"vocab": lambda: self.vocab.to_bytes(), "vocab": lambda: self.vocab.to_bytes(exclude=exclude),
"prefix_search": lambda: _get_regex_pattern(self.prefix_search), "prefix_search": lambda: _get_regex_pattern(self.prefix_search),
"suffix_search": lambda: _get_regex_pattern(self.suffix_search), "suffix_search": lambda: _get_regex_pattern(self.suffix_search),
"infix_finditer": lambda: _get_regex_pattern(self.infix_finditer), "infix_finditer": lambda: _get_regex_pattern(self.infix_finditer),
@ -786,7 +786,7 @@ cdef class Tokenizer:
""" """
data = {} data = {}
deserializers = { deserializers = {
"vocab": lambda b: self.vocab.from_bytes(b), "vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
"prefix_search": lambda b: data.setdefault("prefix_search", b), "prefix_search": lambda b: data.setdefault("prefix_search", b),
"suffix_search": lambda b: data.setdefault("suffix_search", b), "suffix_search": lambda b: data.setdefault("suffix_search", b),
"infix_finditer": lambda b: data.setdefault("infix_finditer", b), "infix_finditer": lambda b: data.setdefault("infix_finditer", b),