From 25de2a2191c168ce133d922c4e2e041684431228 Mon Sep 17 00:00:00 2001 From: Adriane Boyd Date: Wed, 27 May 2020 14:48:54 +0200 Subject: [PATCH] Improve vector name loading from model meta --- spacy/language.py | 38 +++++++++++++++++++++++++++++--------- 1 file changed, 29 insertions(+), 9 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index 53a788f2a..2058def8a 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -934,15 +934,26 @@ class Language(object): DOCS: https://spacy.io/api/language#from_disk """ + def deserialize_meta(path): + if path.exists(): + data = srsly.read_json(path) + self.meta.update(data) + # self.meta always overrides meta["vectors"] with the metadata + # from self.vocab.vectors, so set the name directly + self.vocab.vectors.name = data.get("vectors", {}).get("name") + + def deserialize_vocab(path): + if path.exists(): + self.vocab.from_disk(path) + _fix_pretrained_vectors_name(self) + if disable is not None: warnings.warn(Warnings.W014, DeprecationWarning) exclude = disable path = util.ensure_path(path) deserializers = OrderedDict() - deserializers["meta.json"] = lambda p: self.meta.update(srsly.read_json(p)) - deserializers["vocab"] = lambda p: self.vocab.from_disk( - p - ) and _fix_pretrained_vectors_name(self) + deserializers["meta.json"] = deserialize_meta + deserializers["vocab"] = deserialize_vocab deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk( p, exclude=["vocab"] ) @@ -996,14 +1007,23 @@ class Language(object): DOCS: https://spacy.io/api/language#from_bytes """ + def deserialize_meta(b): + data = srsly.json_loads(b) + self.meta.update(data) + # self.meta always overrides meta["vectors"] with the metadata + # from self.vocab.vectors, so set the name directly + self.vocab.vectors.name = data.get("vectors", {}).get("name") + + def deserialize_vocab(b): + self.vocab.from_bytes(b) + _fix_pretrained_vectors_name(self) + if disable is not None: warnings.warn(Warnings.W014, DeprecationWarning) exclude = disable deserializers = OrderedDict() - deserializers["meta.json"] = lambda b: self.meta.update(srsly.json_loads(b)) - deserializers["vocab"] = lambda b: self.vocab.from_bytes( - b - ) and _fix_pretrained_vectors_name(self) + deserializers["meta.json"] = deserialize_meta + deserializers["vocab"] = deserialize_vocab deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes( b, exclude=["vocab"] ) @@ -1069,7 +1089,7 @@ class component(object): def _fix_pretrained_vectors_name(nlp): # TODO: Replace this once we handle vectors consistently as static # data - if "vectors" in nlp.meta and nlp.meta["vectors"].get("name"): + if "vectors" in nlp.meta and "name" in nlp.meta["vectors"]: nlp.vocab.vectors.name = nlp.meta["vectors"]["name"] elif not nlp.vocab.vectors.size: nlp.vocab.vectors.name = None