diff --git a/spacy/language.py b/spacy/language.py index eba54dae2..098b83e5a 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -2222,22 +2222,12 @@ class Language: DOCS: https://spacy.io/api/language#to_disk """ - def _dump_meta(path) -> None: - """Helper function for to_disk. It's not entirely clear why this is - necessary: we're seeing numpy.float32 values in the Ukraining trf - model meta, maybe it'll happen elsewhere as well -- but it seems - to originate from the specific transformer code being used by that - model. - """ - meta = convert_recursive(lambda v: isinstance(v, numpy.floating), lambda v: float(v), dict(self.meta)) - srsly.write_json(path, meta) - path = util.ensure_path(path) serializers = {} serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr] p, exclude=["vocab"] ) - serializers["meta.json"] = _dump_meta + serializers["meta.json"] = lambda p: srsly.write_json(p, _replace_numpy_floats(self.meta)) serializers["config.cfg"] = lambda p: self.config.to_disk(p) for name, proc in self._components: if name in exclude: @@ -2345,20 +2335,10 @@ class Language: DOCS: https://spacy.io/api/language#to_bytes """ - def _dump_meta() -> str: - """Helper function for to_disk. It's not entirely clear why this is - necessary: we're seeing numpy.float32 values in the Ukraining trf - model meta, maybe it'll happen elsewhere as well -- but it seems - to originate from the specific transformer code being used by that - model. - """ - meta = convert_recursive(lambda v: isinstance(v, numpy.floating), lambda v: float(v), dict(self.meta)) - return srsly.json_dumps(meta) - serializers: Dict[str, Callable[[], bytes]] = {} serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude) serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr] - serializers["meta.json"] = _dump_meta # type: ignore + serializers["meta.json"] = lambda: srsly.json_dumps(_replace_numpy_floats(self.meta)) serializers["config.cfg"] = lambda: self.config.to_bytes() for name, proc in self._components: if name in exclude: @@ -2552,4 +2532,14 @@ class _WorkDoneSentinel: pass +def _replace_numpy_floats(container: dict) -> dict: + """Helper function for to_disk. It's not entirely clear why this is + necessary: we're seeing numpy.float32 values in the Ukraining trf + model meta, maybe it'll happen elsewhere as well -- but it seems + to originate from the specific transformer code being used by that + model. + """ + return convert_recursive(lambda v: isinstance(v, numpy.floating), lambda v: float(v), container) + + _WORK_DONE_SENTINEL = _WorkDoneSentinel()