mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-10 19:57:17 +03:00
Pass excludes when serializing vocab (#8824)
* Pass excludes when serializing vocab Additional minor bug fix: * Deserialize vocab in `EntityLinker.from_disk` * Add test for excluding strings on load * Fix formatting
This commit is contained in:
parent
175847f92c
commit
941a591f3c
|
@ -1909,7 +1909,7 @@ class Language:
|
|||
if not hasattr(proc, "to_disk"):
|
||||
continue
|
||||
serializers[name] = lambda p, proc=proc: proc.to_disk(p, exclude=["vocab"])
|
||||
serializers["vocab"] = lambda p: self.vocab.to_disk(p)
|
||||
serializers["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
|
||||
util.to_disk(path, serializers, exclude)
|
||||
|
||||
def from_disk(
|
||||
|
@ -1940,7 +1940,7 @@ class Language:
|
|||
|
||||
def deserialize_vocab(path: Path) -> None:
|
||||
if path.exists():
|
||||
self.vocab.from_disk(path)
|
||||
self.vocab.from_disk(path, exclude=exclude)
|
||||
|
||||
path = util.ensure_path(path)
|
||||
deserializers = {}
|
||||
|
@ -1978,7 +1978,7 @@ class Language:
|
|||
DOCS: https://spacy.io/api/language#to_bytes
|
||||
"""
|
||||
serializers = {}
|
||||
serializers["vocab"] = lambda: self.vocab.to_bytes()
|
||||
serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
|
||||
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"])
|
||||
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
|
||||
serializers["config.cfg"] = lambda: self.config.to_bytes()
|
||||
|
@ -2014,7 +2014,7 @@ class Language:
|
|||
b, interpolate=False
|
||||
)
|
||||
deserializers["meta.json"] = deserialize_meta
|
||||
deserializers["vocab"] = self.vocab.from_bytes
|
||||
deserializers["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
|
||||
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
|
||||
b, exclude=["vocab"]
|
||||
)
|
||||
|
|
|
@ -276,7 +276,7 @@ class AttributeRuler(Pipe):
|
|||
DOCS: https://spacy.io/api/attributeruler#to_bytes
|
||||
"""
|
||||
serialize = {}
|
||||
serialize["vocab"] = self.vocab.to_bytes
|
||||
serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
|
||||
serialize["patterns"] = lambda: srsly.msgpack_dumps(self.patterns)
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
|
@ -296,7 +296,7 @@ class AttributeRuler(Pipe):
|
|||
self.add_patterns(srsly.msgpack_loads(b))
|
||||
|
||||
deserialize = {
|
||||
"vocab": lambda b: self.vocab.from_bytes(b),
|
||||
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
|
||||
"patterns": load_patterns,
|
||||
}
|
||||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
|
@ -313,7 +313,7 @@ class AttributeRuler(Pipe):
|
|||
DOCS: https://spacy.io/api/attributeruler#to_disk
|
||||
"""
|
||||
serialize = {
|
||||
"vocab": lambda p: self.vocab.to_disk(p),
|
||||
"vocab": lambda p: self.vocab.to_disk(p, exclude=exclude),
|
||||
"patterns": lambda p: srsly.write_msgpack(p, self.patterns),
|
||||
}
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
@ -334,7 +334,7 @@ class AttributeRuler(Pipe):
|
|||
self.add_patterns(srsly.read_msgpack(p))
|
||||
|
||||
deserialize = {
|
||||
"vocab": lambda p: self.vocab.from_disk(p),
|
||||
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
|
||||
"patterns": load_patterns,
|
||||
}
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
|
|
|
@ -412,7 +412,7 @@ class EntityLinker(TrainablePipe):
|
|||
serialize = {}
|
||||
if hasattr(self, "cfg") and self.cfg is not None:
|
||||
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
||||
serialize["vocab"] = self.vocab.to_bytes
|
||||
serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
|
||||
serialize["kb"] = self.kb.to_bytes
|
||||
serialize["model"] = self.model.to_bytes
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
@ -436,7 +436,7 @@ class EntityLinker(TrainablePipe):
|
|||
deserialize = {}
|
||||
if hasattr(self, "cfg") and self.cfg is not None:
|
||||
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
||||
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
|
||||
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
|
||||
deserialize["kb"] = lambda b: self.kb.from_bytes(b)
|
||||
deserialize["model"] = load_model
|
||||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
|
@ -453,7 +453,7 @@ class EntityLinker(TrainablePipe):
|
|||
DOCS: https://spacy.io/api/entitylinker#to_disk
|
||||
"""
|
||||
serialize = {}
|
||||
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
|
||||
serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
|
||||
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||
serialize["kb"] = lambda p: self.kb.to_disk(p)
|
||||
serialize["model"] = lambda p: self.model.to_disk(p)
|
||||
|
@ -480,6 +480,7 @@ class EntityLinker(TrainablePipe):
|
|||
|
||||
deserialize = {}
|
||||
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
|
||||
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
|
||||
deserialize["kb"] = lambda p: self.kb.from_disk(p)
|
||||
deserialize["model"] = load_model
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
|
|
|
@ -269,7 +269,7 @@ class Lemmatizer(Pipe):
|
|||
DOCS: https://spacy.io/api/lemmatizer#to_disk
|
||||
"""
|
||||
serialize = {}
|
||||
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
|
||||
serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
|
||||
serialize["lookups"] = lambda p: self.lookups.to_disk(p)
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
|
@ -285,7 +285,7 @@ class Lemmatizer(Pipe):
|
|||
DOCS: https://spacy.io/api/lemmatizer#from_disk
|
||||
"""
|
||||
deserialize = {}
|
||||
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
|
||||
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
|
||||
deserialize["lookups"] = lambda p: self.lookups.from_disk(p)
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
self._validate_tables()
|
||||
|
@ -300,7 +300,7 @@ class Lemmatizer(Pipe):
|
|||
DOCS: https://spacy.io/api/lemmatizer#to_bytes
|
||||
"""
|
||||
serialize = {}
|
||||
serialize["vocab"] = self.vocab.to_bytes
|
||||
serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
|
||||
serialize["lookups"] = self.lookups.to_bytes
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
|
@ -316,7 +316,7 @@ class Lemmatizer(Pipe):
|
|||
DOCS: https://spacy.io/api/lemmatizer#from_bytes
|
||||
"""
|
||||
deserialize = {}
|
||||
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
|
||||
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
|
||||
deserialize["lookups"] = lambda b: self.lookups.from_bytes(b)
|
||||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
self._validate_tables()
|
||||
|
|
|
@ -273,7 +273,7 @@ cdef class TrainablePipe(Pipe):
|
|||
serialize = {}
|
||||
if hasattr(self, "cfg") and self.cfg is not None:
|
||||
serialize["cfg"] = lambda: srsly.json_dumps(self.cfg)
|
||||
serialize["vocab"] = self.vocab.to_bytes
|
||||
serialize["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
|
||||
serialize["model"] = self.model.to_bytes
|
||||
return util.to_bytes(serialize, exclude)
|
||||
|
||||
|
@ -296,7 +296,7 @@ cdef class TrainablePipe(Pipe):
|
|||
deserialize = {}
|
||||
if hasattr(self, "cfg") and self.cfg is not None:
|
||||
deserialize["cfg"] = lambda b: self.cfg.update(srsly.json_loads(b))
|
||||
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b)
|
||||
deserialize["vocab"] = lambda b: self.vocab.from_bytes(b, exclude=exclude)
|
||||
deserialize["model"] = load_model
|
||||
util.from_bytes(bytes_data, deserialize, exclude)
|
||||
return self
|
||||
|
@ -313,7 +313,7 @@ cdef class TrainablePipe(Pipe):
|
|||
serialize = {}
|
||||
if hasattr(self, "cfg") and self.cfg is not None:
|
||||
serialize["cfg"] = lambda p: srsly.write_json(p, self.cfg)
|
||||
serialize["vocab"] = lambda p: self.vocab.to_disk(p)
|
||||
serialize["vocab"] = lambda p: self.vocab.to_disk(p, exclude=exclude)
|
||||
serialize["model"] = lambda p: self.model.to_disk(p)
|
||||
util.to_disk(path, serialize, exclude)
|
||||
|
||||
|
@ -338,7 +338,7 @@ cdef class TrainablePipe(Pipe):
|
|||
deserialize = {}
|
||||
if hasattr(self, "cfg") and self.cfg is not None:
|
||||
deserialize["cfg"] = lambda p: self.cfg.update(deserialize_config(p))
|
||||
deserialize["vocab"] = lambda p: self.vocab.from_disk(p)
|
||||
deserialize["vocab"] = lambda p: self.vocab.from_disk(p, exclude=exclude)
|
||||
deserialize["model"] = load_model
|
||||
util.from_disk(path, deserialize, exclude)
|
||||
return self
|
||||
|
|
|
@ -569,7 +569,7 @@ cdef class Parser(TrainablePipe):
|
|||
def to_disk(self, path, exclude=tuple()):
|
||||
serializers = {
|
||||
"model": lambda p: (self.model.to_disk(p) if self.model is not True else True),
|
||||
"vocab": lambda p: self.vocab.to_disk(p),
|
||||
"vocab": lambda p: self.vocab.to_disk(p, exclude=exclude),
|
||||
"moves": lambda p: self.moves.to_disk(p, exclude=["strings"]),
|
||||
"cfg": lambda p: srsly.write_json(p, self.cfg)
|
||||
}
|
||||
|
@ -577,7 +577,7 @@ cdef class Parser(TrainablePipe):
|
|||
|
||||
def from_disk(self, path, exclude=tuple()):
|
||||
deserializers = {
|
||||
"vocab": lambda p: self.vocab.from_disk(p),
|
||||
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
|
||||
"moves": lambda p: self.moves.from_disk(p, exclude=["strings"]),
|
||||
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
|
||||
"model": lambda p: None,
|
||||
|
@ -597,7 +597,7 @@ cdef class Parser(TrainablePipe):
|
|||
def to_bytes(self, exclude=tuple()):
|
||||
serializers = {
|
||||
"model": lambda: (self.model.to_bytes()),
|
||||
"vocab": lambda: self.vocab.to_bytes(),
|
||||
"vocab": lambda: self.vocab.to_bytes(exclude=exclude),
|
||||
"moves": lambda: self.moves.to_bytes(exclude=["strings"]),
|
||||
"cfg": lambda: srsly.json_dumps(self.cfg, indent=2, sort_keys=True)
|
||||
}
|
||||
|
@ -605,7 +605,7 @@ cdef class Parser(TrainablePipe):
|
|||
|
||||
def from_bytes(self, bytes_data, exclude=tuple()):
|
||||
deserializers = {
|
||||
"vocab": lambda b: self.vocab.from_bytes(b),
|
||||
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
|
||||
"moves": lambda b: self.moves.from_bytes(b, exclude=["strings"]),
|
||||
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
|
||||
"model": lambda b: None,
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
import pytest
|
||||
from spacy import registry, Vocab
|
||||
from spacy import registry, Vocab, load
|
||||
from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
|
||||
from spacy.pipeline import TextCategorizer, SentenceRecognizer, TrainablePipe
|
||||
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
||||
|
@ -268,3 +268,21 @@ def test_serialize_custom_trainable_pipe():
|
|||
pipe.to_disk(d)
|
||||
new_pipe = CustomPipe(Vocab(), Linear()).from_disk(d)
|
||||
assert new_pipe.to_bytes() == pipe_bytes
|
||||
|
||||
|
||||
def test_load_without_strings():
|
||||
nlp = spacy.blank("en")
|
||||
orig_strings_length = len(nlp.vocab.strings)
|
||||
word = "unlikely_word_" * 20
|
||||
nlp.vocab.strings.add(word)
|
||||
assert len(nlp.vocab.strings) == orig_strings_length + 1
|
||||
with make_tempdir() as d:
|
||||
nlp.to_disk(d)
|
||||
# reload with strings
|
||||
reloaded_nlp = load(d)
|
||||
assert len(nlp.vocab.strings) == len(reloaded_nlp.vocab.strings)
|
||||
assert word in reloaded_nlp.vocab.strings
|
||||
# reload without strings
|
||||
reloaded_nlp = load(d, exclude=["strings"])
|
||||
assert orig_strings_length == len(reloaded_nlp.vocab.strings)
|
||||
assert word not in reloaded_nlp.vocab.strings
|
||||
|
|
|
@ -765,7 +765,7 @@ cdef class Tokenizer:
|
|||
DOCS: https://spacy.io/api/tokenizer#to_bytes
|
||||
"""
|
||||
serializers = {
|
||||
"vocab": lambda: self.vocab.to_bytes(),
|
||||
"vocab": lambda: self.vocab.to_bytes(exclude=exclude),
|
||||
"prefix_search": lambda: _get_regex_pattern(self.prefix_search),
|
||||
"suffix_search": lambda: _get_regex_pattern(self.suffix_search),
|
||||
"infix_finditer": lambda: _get_regex_pattern(self.infix_finditer),
|
||||
|
@ -786,7 +786,7 @@ cdef class Tokenizer:
|
|||
"""
|
||||
data = {}
|
||||
deserializers = {
|
||||
"vocab": lambda b: self.vocab.from_bytes(b),
|
||||
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
|
||||
"prefix_search": lambda b: data.setdefault("prefix_search", b),
|
||||
"suffix_search": lambda b: data.setdefault("suffix_search", b),
|
||||
"infix_finditer": lambda b: data.setdefault("infix_finditer", b),
|
||||
|
|
Loading…
Reference in New Issue
Block a user