Merge pull request #5514 from adrianeboyd/bugfix/load-vector-name

Improve vector name loading from model meta
This commit is contained in:
Matthew Honnibal 2020-05-27 20:39:23 +02:00 committed by GitHub
commit e7ac12b598
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -934,15 +934,26 @@ class Language(object):
DOCS: https://spacy.io/api/language#from_disk
"""
def deserialize_meta(path):
if path.exists():
data = srsly.read_json(path)
self.meta.update(data)
# self.meta always overrides meta["vectors"] with the metadata
# from self.vocab.vectors, so set the name directly
self.vocab.vectors.name = data.get("vectors", {}).get("name")
def deserialize_vocab(path):
if path.exists():
self.vocab.from_disk(path)
_fix_pretrained_vectors_name(self)
if disable is not None:
warnings.warn(Warnings.W014, DeprecationWarning)
exclude = disable
path = util.ensure_path(path)
deserializers = OrderedDict()
deserializers["meta.json"] = lambda p: self.meta.update(srsly.read_json(p))
deserializers["vocab"] = lambda p: self.vocab.from_disk(
p
) and _fix_pretrained_vectors_name(self)
deserializers["meta.json"] = deserialize_meta
deserializers["vocab"] = deserialize_vocab
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
p, exclude=["vocab"]
)
@ -996,14 +1007,23 @@ class Language(object):
DOCS: https://spacy.io/api/language#from_bytes
"""
def deserialize_meta(b):
data = srsly.json_loads(b)
self.meta.update(data)
# self.meta always overrides meta["vectors"] with the metadata
# from self.vocab.vectors, so set the name directly
self.vocab.vectors.name = data.get("vectors", {}).get("name")
def deserialize_vocab(b):
self.vocab.from_bytes(b)
_fix_pretrained_vectors_name(self)
if disable is not None:
warnings.warn(Warnings.W014, DeprecationWarning)
exclude = disable
deserializers = OrderedDict()
deserializers["meta.json"] = lambda b: self.meta.update(srsly.json_loads(b))
deserializers["vocab"] = lambda b: self.vocab.from_bytes(
b
) and _fix_pretrained_vectors_name(self)
deserializers["meta.json"] = deserialize_meta
deserializers["vocab"] = deserialize_vocab
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
b, exclude=["vocab"]
)
@ -1069,7 +1089,7 @@ class component(object):
def _fix_pretrained_vectors_name(nlp):
# TODO: Replace this once we handle vectors consistently as static
# data
if "vectors" in nlp.meta and nlp.meta["vectors"].get("name"):
if "vectors" in nlp.meta and "name" in nlp.meta["vectors"]:
nlp.vocab.vectors.name = nlp.meta["vectors"]["name"]
elif not nlp.vocab.vectors.size:
nlp.vocab.vectors.name = None