Fix meta writing for numpy conversion

This commit is contained in:
Matthew Honnibal 2024-10-02 01:10:04 +02:00
parent 75d097155d
commit acb44f8e73

View File

@ -2222,22 +2222,12 @@ class Language:
DOCS: https://spacy.io/api/language#to_disk DOCS: https://spacy.io/api/language#to_disk
""" """
def _dump_meta(path) -> None:
"""Helper function for to_disk. It's not entirely clear why this is
necessary: we're seeing numpy.float32 values in the Ukraining trf
model meta, maybe it'll happen elsewhere as well -- but it seems
to originate from the specific transformer code being used by that
model.
"""
meta = convert_recursive(lambda v: isinstance(v, numpy.floating), lambda v: float(v), dict(self.meta))
srsly.write_json(path, meta)
path = util.ensure_path(path) path = util.ensure_path(path)
serializers = {} serializers = {}
serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr] serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr]
p, exclude=["vocab"] p, exclude=["vocab"]
) )
serializers["meta.json"] = _dump_meta serializers["meta.json"] = lambda p: srsly.write_json(p, _replace_numpy_floats(self.meta))
serializers["config.cfg"] = lambda p: self.config.to_disk(p) serializers["config.cfg"] = lambda p: self.config.to_disk(p)
for name, proc in self._components: for name, proc in self._components:
if name in exclude: if name in exclude:
@ -2345,20 +2335,10 @@ class Language:
DOCS: https://spacy.io/api/language#to_bytes DOCS: https://spacy.io/api/language#to_bytes
""" """
def _dump_meta() -> str:
"""Helper function for to_disk. It's not entirely clear why this is
necessary: we're seeing numpy.float32 values in the Ukraining trf
model meta, maybe it'll happen elsewhere as well -- but it seems
to originate from the specific transformer code being used by that
model.
"""
meta = convert_recursive(lambda v: isinstance(v, numpy.floating), lambda v: float(v), dict(self.meta))
return srsly.json_dumps(meta)
serializers: Dict[str, Callable[[], bytes]] = {} serializers: Dict[str, Callable[[], bytes]] = {}
serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude) serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr] serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr]
serializers["meta.json"] = _dump_meta # type: ignore serializers["meta.json"] = lambda: srsly.json_dumps(_replace_numpy_floats(self.meta))
serializers["config.cfg"] = lambda: self.config.to_bytes() serializers["config.cfg"] = lambda: self.config.to_bytes()
for name, proc in self._components: for name, proc in self._components:
if name in exclude: if name in exclude:
@ -2552,4 +2532,14 @@ class _WorkDoneSentinel:
pass pass
def _replace_numpy_floats(container: dict) -> dict:
"""Helper function for to_disk. It's not entirely clear why this is
necessary: we're seeing numpy.float32 values in the Ukraining trf
model meta, maybe it'll happen elsewhere as well -- but it seems
to originate from the specific transformer code being used by that
model.
"""
return convert_recursive(lambda v: isinstance(v, numpy.floating), lambda v: float(v), container)
_WORK_DONE_SENTINEL = _WorkDoneSentinel() _WORK_DONE_SENTINEL = _WorkDoneSentinel()