Fix numpy floating values in meta.json for serialization

This commit is contained in:
Matthew Honnibal 2024-09-14 13:13:05 +02:00
parent a9ed8bb401
commit 57cbac78f4

View File

@ -33,6 +33,7 @@ from typing import (
import srsly import srsly
from cymem.cymem import Pool from cymem.cymem import Pool
from thinc.api import Config, CupyOps, Optimizer, get_current_ops from thinc.api import Config, CupyOps, Optimizer, get_current_ops
from thinc.util import convert_recursive
from . import about, ty, util from . import about, ty, util
from .compat import Literal from .compat import Literal
@ -2141,7 +2142,7 @@ class Language:
serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr] serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr]
p, exclude=["vocab"] p, exclude=["vocab"]
) )
serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta) serializers["meta.json"] = lambda p: srsly.write_json(p, _replace_numpy_floats(self.meta))
serializers["config.cfg"] = lambda p: self.config.to_disk(p) serializers["config.cfg"] = lambda p: self.config.to_disk(p)
for name, proc in self._components: for name, proc in self._components:
if name in exclude: if name in exclude:
@ -2255,7 +2256,7 @@ class Language:
serializers: Dict[str, Callable[[], bytes]] = {} serializers: Dict[str, Callable[[], bytes]] = {}
serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude) serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr] serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr]
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta) serializers["meta.json"] = lambda: srsly.json_dumps(_replace_numpy_floats(self.meta))
serializers["config.cfg"] = lambda: self.config.to_bytes() serializers["config.cfg"] = lambda: self.config.to_bytes()
for name, proc in self._components: for name, proc in self._components:
if name in exclude: if name in exclude:
@ -2306,6 +2307,10 @@ class Language:
return self return self
def _replace_numpy_floats(meta_dict: dict) -> dict:
return convert_recursive(lambda v: isinstance(v, numpy.floating), lambda v: float(v), dict(meta_dict))
@dataclass @dataclass
class FactoryMeta: class FactoryMeta:
"""Dataclass containing information about a component and its defaults """Dataclass containing information about a component and its defaults