From acb44f8e73ef04b4b019637d5e72e6ad92508e73 Mon Sep 17 00:00:00 2001 From: Matthew Honnibal Date: Wed, 2 Oct 2024 01:10:04 +0200 Subject: [PATCH] Fix meta writing for numpy conversion --- spacy/language.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) diff --git a/spacy/language.py b/spacy/language.py index eba54dae2..098b83e5a 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -2222,22 +2222,12 @@ class Language: DOCS: https://spacy.io/api/language#to_disk """ - def _dump_meta(path) -> None: - """Helper function for to_disk. It's not entirely clear why this is - necessary: we're seeing numpy.float32 values in the Ukraining trf - model meta, maybe it'll happen elsewhere as well -- but it seems - to originate from the specific transformer code being used by that - model. - """ - meta = convert_recursive(lambda v: isinstance(v, numpy.floating), lambda v: float(v), dict(self.meta)) - srsly.write_json(path, meta) - path = util.ensure_path(path) serializers = {} serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr] p, exclude=["vocab"] ) - serializers["meta.json"] = _dump_meta + serializers["meta.json"] = lambda p: srsly.write_json(p, _replace_numpy_floats(self.meta)) serializers["config.cfg"] = lambda p: self.config.to_disk(p) for name, proc in self._components: if name in exclude: @@ -2345,20 +2335,10 @@ class Language: DOCS: https://spacy.io/api/language#to_bytes """ - def _dump_meta() -> str: - """Helper function for to_disk. It's not entirely clear why this is - necessary: we're seeing numpy.float32 values in the Ukraining trf - model meta, maybe it'll happen elsewhere as well -- but it seems - to originate from the specific transformer code being used by that - model. - """ - meta = convert_recursive(lambda v: isinstance(v, numpy.floating), lambda v: float(v), dict(self.meta)) - return srsly.json_dumps(meta) - serializers: Dict[str, Callable[[], bytes]] = {} serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude) serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr] - serializers["meta.json"] = _dump_meta # type: ignore + serializers["meta.json"] = lambda: srsly.json_dumps(_replace_numpy_floats(self.meta)) serializers["config.cfg"] = lambda: self.config.to_bytes() for name, proc in self._components: if name in exclude: @@ -2552,4 +2532,14 @@ class _WorkDoneSentinel: pass +def _replace_numpy_floats(container: dict) -> dict: + """Helper function for to_disk. It's not entirely clear why this is + necessary: we're seeing numpy.float32 values in the Ukraining trf + model meta, maybe it'll happen elsewhere as well -- but it seems + to originate from the specific transformer code being used by that + model. + """ + return convert_recursive(lambda v: isinstance(v, numpy.floating), lambda v: float(v), container) + + _WORK_DONE_SENTINEL = _WorkDoneSentinel()