mirror of
https://github.com/explosion/spaCy.git
synced 2024-11-14 21:57:15 +03:00
Merge pull request #5514 from adrianeboyd/bugfix/load-vector-name
Improve vector name loading from model meta
This commit is contained in:
commit
e7ac12b598
|
@ -934,15 +934,26 @@ class Language(object):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#from_disk
|
DOCS: https://spacy.io/api/language#from_disk
|
||||||
"""
|
"""
|
||||||
|
def deserialize_meta(path):
|
||||||
|
if path.exists():
|
||||||
|
data = srsly.read_json(path)
|
||||||
|
self.meta.update(data)
|
||||||
|
# self.meta always overrides meta["vectors"] with the metadata
|
||||||
|
# from self.vocab.vectors, so set the name directly
|
||||||
|
self.vocab.vectors.name = data.get("vectors", {}).get("name")
|
||||||
|
|
||||||
|
def deserialize_vocab(path):
|
||||||
|
if path.exists():
|
||||||
|
self.vocab.from_disk(path)
|
||||||
|
_fix_pretrained_vectors_name(self)
|
||||||
|
|
||||||
if disable is not None:
|
if disable is not None:
|
||||||
warnings.warn(Warnings.W014, DeprecationWarning)
|
warnings.warn(Warnings.W014, DeprecationWarning)
|
||||||
exclude = disable
|
exclude = disable
|
||||||
path = util.ensure_path(path)
|
path = util.ensure_path(path)
|
||||||
deserializers = OrderedDict()
|
deserializers = OrderedDict()
|
||||||
deserializers["meta.json"] = lambda p: self.meta.update(srsly.read_json(p))
|
deserializers["meta.json"] = deserialize_meta
|
||||||
deserializers["vocab"] = lambda p: self.vocab.from_disk(
|
deserializers["vocab"] = deserialize_vocab
|
||||||
p
|
|
||||||
) and _fix_pretrained_vectors_name(self)
|
|
||||||
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
|
deserializers["tokenizer"] = lambda p: self.tokenizer.from_disk(
|
||||||
p, exclude=["vocab"]
|
p, exclude=["vocab"]
|
||||||
)
|
)
|
||||||
|
@ -996,14 +1007,23 @@ class Language(object):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/language#from_bytes
|
DOCS: https://spacy.io/api/language#from_bytes
|
||||||
"""
|
"""
|
||||||
|
def deserialize_meta(b):
|
||||||
|
data = srsly.json_loads(b)
|
||||||
|
self.meta.update(data)
|
||||||
|
# self.meta always overrides meta["vectors"] with the metadata
|
||||||
|
# from self.vocab.vectors, so set the name directly
|
||||||
|
self.vocab.vectors.name = data.get("vectors", {}).get("name")
|
||||||
|
|
||||||
|
def deserialize_vocab(b):
|
||||||
|
self.vocab.from_bytes(b)
|
||||||
|
_fix_pretrained_vectors_name(self)
|
||||||
|
|
||||||
if disable is not None:
|
if disable is not None:
|
||||||
warnings.warn(Warnings.W014, DeprecationWarning)
|
warnings.warn(Warnings.W014, DeprecationWarning)
|
||||||
exclude = disable
|
exclude = disable
|
||||||
deserializers = OrderedDict()
|
deserializers = OrderedDict()
|
||||||
deserializers["meta.json"] = lambda b: self.meta.update(srsly.json_loads(b))
|
deserializers["meta.json"] = deserialize_meta
|
||||||
deserializers["vocab"] = lambda b: self.vocab.from_bytes(
|
deserializers["vocab"] = deserialize_vocab
|
||||||
b
|
|
||||||
) and _fix_pretrained_vectors_name(self)
|
|
||||||
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
|
deserializers["tokenizer"] = lambda b: self.tokenizer.from_bytes(
|
||||||
b, exclude=["vocab"]
|
b, exclude=["vocab"]
|
||||||
)
|
)
|
||||||
|
@ -1069,7 +1089,7 @@ class component(object):
|
||||||
def _fix_pretrained_vectors_name(nlp):
|
def _fix_pretrained_vectors_name(nlp):
|
||||||
# TODO: Replace this once we handle vectors consistently as static
|
# TODO: Replace this once we handle vectors consistently as static
|
||||||
# data
|
# data
|
||||||
if "vectors" in nlp.meta and nlp.meta["vectors"].get("name"):
|
if "vectors" in nlp.meta and "name" in nlp.meta["vectors"]:
|
||||||
nlp.vocab.vectors.name = nlp.meta["vectors"]["name"]
|
nlp.vocab.vectors.name = nlp.meta["vectors"]["name"]
|
||||||
elif not nlp.vocab.vectors.size:
|
elif not nlp.vocab.vectors.size:
|
||||||
nlp.vocab.vectors.name = None
|
nlp.vocab.vectors.name = None
|
||||||
|
|
Loading…
Reference in New Issue
Block a user