mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-11 04:08:09 +03:00
Merge pull request #5514 from adrianeboyd/bugfix/load-vector-name
Improve vector name loading from model meta
This commit is contained in:
commit
e7ac12b598
|
@ -934,15 +934,26 @@ class Language(object):
|
|||
|
||||
DOCS: https://spacy.io/api/language#from_disk
|
||||
"""
|
||||
def deserialize_meta(path):
|
||||
if path.exists():
|
||||
data = srsly.read_json(path)
|
||||
self.meta.update(data)
|
||||
# self.meta always overrides meta["vectors"] with the metadata
|
||||
# from self.vocab.vectors, so set the name directly
|
||||
self.vocab.vectors.name = data.get("vectors", {}).get("name")
|
||||
|
||||
def deserialize_vocab(path):
|
||||
if path.exists():
|
||||
self.vocab.from_disk(path)
|
||||
_fix_pretrained_vectors_name(self)
|
||||
|
||||
if disable is not None:
|
||||
warnings.warn(Warnings.W014, DeprecationWarning)
|
||||
exclude = disable
|
||||
path = util.ensure_path(path)
|
||||
deserializers = OrderedDict()
|
||||
deserializers["meta.json"] = lambda p: self.meta.update(srsly.read_json(p))
|
||||
deserializers["vocab"] = lambda p: self.vocab.from_disk(
|
||||
p
|
||||
) and _fix_pretrained_vectors_name(self)
|
||||
deserializers["meta.json"] = deserialize_meta
|
||||
deserializers["vocab"] = deserialize_vocab
|
||||
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
|
||||
p, exclude=["vocab"]
|
||||
)
|
||||
|
@ -996,14 +1007,23 @@ class Language(object):
|
|||
|
||||
DOCS: https://spacy.io/api/language#from_bytes
|
||||
"""
|
||||
def deserialize_meta(b):
|
||||
data = srsly.json_loads(b)
|
||||
self.meta.update(data)
|
||||
# self.meta always overrides meta["vectors"] with the metadata
|
||||
# from self.vocab.vectors, so set the name directly
|
||||
self.vocab.vectors.name = data.get("vectors", {}).get("name")
|
||||
|
||||
def deserialize_vocab(b):
|
||||
self.vocab.from_bytes(b)
|
||||
_fix_pretrained_vectors_name(self)
|
||||
|
||||
if disable is not None:
|
||||
warnings.warn(Warnings.W014, DeprecationWarning)
|
||||
exclude = disable
|
||||
deserializers = OrderedDict()
|
||||
deserializers["meta.json"] = lambda b: self.meta.update(srsly.json_loads(b))
|
||||
deserializers["vocab"] = lambda b: self.vocab.from_bytes(
|
||||
b
|
||||
) and _fix_pretrained_vectors_name(self)
|
||||
deserializers["meta.json"] = deserialize_meta
|
||||
deserializers["vocab"] = deserialize_vocab
|
||||
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
|
||||
b, exclude=["vocab"]
|
||||
)
|
||||
|
@ -1069,7 +1089,7 @@ class component(object):
|
|||
def _fix_pretrained_vectors_name(nlp):
|
||||
# TODO: Replace this once we handle vectors consistently as static
|
||||
# data
|
||||
if "vectors" in nlp.meta and nlp.meta["vectors"].get("name"):
|
||||
if "vectors" in nlp.meta and "name" in nlp.meta["vectors"]:
|
||||
nlp.vocab.vectors.name = nlp.meta["vectors"]["name"]
|
||||
elif not nlp.vocab.vectors.size:
|
||||
nlp.vocab.vectors.name = None
|
||||
|
|
Loading…
Reference in New Issue
Block a user