Pass excludes when serializing vocab (#8824)

* Pass excludes when serializing vocab

Additional minor bug fix:

* Deserialize vocab in `EntityLinker.from_disk`

* Add test for excluding strings on load

* Fix formatting
This commit is contained in:
Adriane Boyd 2021-08-03 14:42:44 +02:00 committed by GitHub
parent 175847f92c
commit 941a591f3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 45 additions and 26 deletions

View File

@ -1909,7 +1909,7 @@ class Language:
if not hasattr(proc, "to_disk"):
continue
serializers[name] = lambda p, proc=proc: proc.to_disk(p, exclude=["vocab"])
serializers["vocab"] = lambda p: self.vocab.to_disk(p)
serializers["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
util.to_disk(path, serializers, exclude)
def from_disk(
@ -1940,7 +1940,7 @@ class Language:
def deserialize_vocab(path: Path) -> None:
if path.exists():
self.vocab.from_disk(path)
self.vocab.from_disk(path, exclude=exclude)
path = util.ensure_path(path)
deserializers = {}
@ -1978,7 +1978,7 @@ class Language:
DOCS: https://spacy.io/api/language#to_bytes
"""
serializers = {}
serializers["vocab"] = lambda: self.vocab.to_bytes()
serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
serializers["config.cfg"] = lambda: self.config.to_bytes()
@ -2014,7 +2014,7 @@ class Language:
b, interpolate=False
)
deserializers["meta.json"] = deserialize_meta
deserializers["vocab"] = self.vocab.from_bytes
deserializers["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
b, exclude=["vocab"]
)

View File

@ -276,7 +276,7 @@ class AttributeRuler(Pipe):
DOCS: https://spacy.io/api/attributeruler#to_bytes
"""
serialize = {}
serialize["vocab"] = self.vocab.to_bytes
serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serialize["patterns"] = lambda: srsly.msgpack_dumps(self.patterns)
return util.to_bytes(serialize, exclude)
@ -296,7 +296,7 @@ class AttributeRuler(Pipe):
self.add_patterns(srsly.msgpack_loads(b))
deserialize = {
"vocab": lambda b: self.vocab.from_bytes(b),
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
"patterns": load_patterns,
}
util.from_bytes(bytes_data, deserialize, exclude)
@ -313,7 +313,7 @@ class AttributeRuler(Pipe):
DOCS: https://spacy.io/api/attributeruler#to_disk
"""
serialize = {
"vocab": lambda p: self.vocab.to_disk(p),
"vocab": lambda p: self.vocab.to_disk(p, exclude=exclude),
"patterns": lambda p: srsly.write_msgpack(p, self.patterns),
}
util.to_disk(path, serialize, exclude)
@ -334,7 +334,7 @@ class AttributeRuler(Pipe):
self.add_patterns(srsly.read_msgpack(p))
deserialize = {
"vocab": lambda p: self.vocab.from_disk(p),
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
"patterns": load_patterns,
}
util.from_disk(path, deserialize, exclude)

View File

@ -412,7 +412,7 @@ class EntityLinker(TrainablePipe):
serialize = {}
if hasattr(self, "cfg") and self.cfg is not None:
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
serialize["vocab"] = self.vocab.to_bytes
serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serialize["kb"] = self.kb.to_bytes
serialize["model"] = self.model.to_bytes
return util.to_bytes(serialize, exclude)
@ -436,7 +436,7 @@ class EntityLinker(TrainablePipe):
deserialize = {}
if hasattr(self, "cfg") and self.cfg is not None:
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
deserialize["kb"] = lambda b: self.kb.from_bytes(b)
deserialize["model"] = load_model
util.from_bytes(bytes_data, deserialize, exclude)
@ -453,7 +453,7 @@ class EntityLinker(TrainablePipe):
DOCS: https://spacy.io/api/entitylinker#to_disk
"""
serialize = {}
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["kb"] = lambda p: self.kb.to_disk(p)
serialize["model"] = lambda p: self.model.to_disk(p)
@ -480,6 +480,7 @@ class EntityLinker(TrainablePipe):
deserialize = {}
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
deserialize["kb"] = lambda p: self.kb.from_disk(p)
deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude)

View File

@ -269,7 +269,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#to_disk
"""
serialize = {}
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
serialize["lookups"] = lambda p: self.lookups.to_disk(p)
util.to_disk(path, serialize, exclude)
@ -285,7 +285,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#from_disk
"""
deserialize = {}
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
deserialize["lookups"] = lambda p: self.lookups.from_disk(p)
util.from_disk(path, deserialize, exclude)
self._validate_tables()
@ -300,7 +300,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#to_bytes
"""
serialize = {}
serialize["vocab"] = self.vocab.to_bytes
serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serialize["lookups"] = self.lookups.to_bytes
return util.to_bytes(serialize, exclude)
@ -316,7 +316,7 @@ class Lemmatizer(Pipe):
DOCS: https://spacy.io/api/lemmatizer#from_bytes
"""
deserialize = {}
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
deserialize["lookups"] = lambda b: self.lookups.from_bytes(b)
util.from_bytes(bytes_data, deserialize, exclude)
self._validate_tables()

View File

@ -273,7 +273,7 @@ cdef class TrainablePipe(Pipe):
serialize = {}
if hasattr(self, "cfg") and self.cfg is not None:
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
serialize["vocab"] = self.vocab.to_bytes
serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serialize["model"] = self.model.to_bytes
return util.to_bytes(serialize, exclude)
@ -296,7 +296,7 @@ cdef class TrainablePipe(Pipe):
deserialize = {}
if hasattr(self, "cfg") and self.cfg is not None:
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
deserialize["model"] = load_model
util.from_bytes(bytes_data, deserialize, exclude)
return self
@ -313,7 +313,7 @@ cdef class TrainablePipe(Pipe):
serialize = {}
if hasattr(self, "cfg") and self.cfg is not None:
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
serialize["model"] = lambda p: self.model.to_disk(p)
util.to_disk(path, serialize, exclude)
@ -338,7 +338,7 @@ cdef class TrainablePipe(Pipe):
deserialize = {}
if hasattr(self, "cfg") and self.cfg is not None:
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
deserialize["model"] = load_model
util.from_disk(path, deserialize, exclude)
return self

View File

@ -569,7 +569,7 @@ cdef class Parser(TrainablePipe):
def to_disk(self, path, exclude=tuple()):
serializers = {
"model": lambda p: (self.model.to_disk(p) if self.model is not True else True),
"vocab": lambda p: self.vocab.to_disk(p),
"vocab": lambda p: self.vocab.to_disk(p, exclude=exclude),
"moves": lambda p: self.moves.to_disk(p, exclude=["strings"]),
"cfg": lambda p: srsly.write_json(p, self.cfg)
}
@ -577,7 +577,7 @@ cdef class Parser(TrainablePipe):
def from_disk(self, path, exclude=tuple()):
deserializers = {
"vocab": lambda p: self.vocab.from_disk(p),
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
"moves": lambda p: self.moves.from_disk(p, exclude=["strings"]),
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
"model": lambda p: None,
@ -597,7 +597,7 @@ cdef class Parser(TrainablePipe):
def to_bytes(self, exclude=tuple()):
serializers = {
"model": lambda: (self.model.to_bytes()),
"vocab": lambda: self.vocab.to_bytes(),
"vocab": lambda: self.vocab.to_bytes(exclude=exclude),
"moves": lambda: self.moves.to_bytes(exclude=["strings"]),
"cfg": lambda: srsly.json_dumps(self.cfg, indent=2, sort_keys=True)
}
@ -605,7 +605,7 @@ cdef class Parser(TrainablePipe):
def from_bytes(self, bytes_data, exclude=tuple()):
deserializers = {
"vocab": lambda b: self.vocab.from_bytes(b),
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
"moves": lambda b: self.moves.from_bytes(b, exclude=["strings"]),
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
"model": lambda b: None,

View File

@ -1,5 +1,5 @@
import pytest
from spacy import registry, Vocab
from spacy import registry, Vocab, load
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
from spacy.pipeline import TextCategorizer, SentenceRecognizer, TrainablePipe
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
@ -268,3 +268,21 @@ def test_serialize_custom_trainable_pipe():
pipe.to_disk(d)
new_pipe = CustomPipe(Vocab(), Linear()).from_disk(d)
assert new_pipe.to_bytes() == pipe_bytes
def test_load_without_strings():
nlp = spacy.blank("en")
orig_strings_length = len(nlp.vocab.strings)
word = "unlikely_word_" * 20
nlp.vocab.strings.add(word)
assert len(nlp.vocab.strings) == orig_strings_length + 1
with make_tempdir() as d:
nlp.to_disk(d)
# reload with strings
reloaded_nlp = load(d)
assert len(nlp.vocab.strings) == len(reloaded_nlp.vocab.strings)
assert word in reloaded_nlp.vocab.strings
# reload without strings
reloaded_nlp = load(d, exclude=["strings"])
assert orig_strings_length == len(reloaded_nlp.vocab.strings)
assert word not in reloaded_nlp.vocab.strings

View File

@ -765,7 +765,7 @@ cdef class Tokenizer:
DOCS: https://spacy.io/api/tokenizer#to_bytes
"""
serializers = {
"vocab": lambda: self.vocab.to_bytes(),
"vocab": lambda: self.vocab.to_bytes(exclude=exclude),
"prefix_search": lambda: _get_regex_pattern(self.prefix_search),
"suffix_search": lambda: _get_regex_pattern(self.suffix_search),
"infix_finditer": lambda: _get_regex_pattern(self.infix_finditer),
@ -786,7 +786,7 @@ cdef class Tokenizer:
"""
data = {}
deserializers = {
"vocab": lambda b: self.vocab.from_bytes(b),
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
"prefix_search": lambda b: data.setdefault("prefix_search", b),
"suffix_search": lambda b: data.setdefault("suffix_search", b),
"infix_finditer": lambda b: data.setdefault("infix_finditer", b),