mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Fix numpy floats in meta.json
This commit is contained in:
parent
2f1e7ed09a
commit
0576a1ff56
|
@ -9,6 +9,7 @@ from contextlib import ExitStack, contextmanager
|
|||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from itertools import chain, cycle
|
||||
import numpy
|
||||
from pathlib import Path
|
||||
from timeit import default_timer as timer
|
||||
from typing import (
|
||||
|
@ -33,6 +34,7 @@ from typing import (
|
|||
import srsly
|
||||
from cymem.cymem import Pool
|
||||
from thinc.api import Config, CupyOps, Optimizer, get_current_ops
|
||||
from thinc.util import convert_recursive
|
||||
|
||||
from . import about, ty, util
|
||||
from .compat import Literal
|
||||
|
@ -2141,7 +2143,7 @@ class Language:
|
|||
serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr]
|
||||
p, exclude=["vocab"]
|
||||
)
|
||||
serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta)
|
||||
serializers["meta.json"] = lambda p: srsly.write_json(p, _replace_numpy_floats(self.meta))
|
||||
serializers["config.cfg"] = lambda p: self.config.to_disk(p)
|
||||
for name, proc in self._components:
|
||||
if name in exclude:
|
||||
|
@ -2255,7 +2257,7 @@ class Language:
|
|||
serializers: Dict[str, Callable[[], bytes]] = {}
|
||||
serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
|
||||
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr]
|
||||
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
|
||||
serializers["meta.json"] = lambda: srsly.json_dumps(_replace_numpy_floats(self.meta))
|
||||
serializers["config.cfg"] = lambda: self.config.to_bytes()
|
||||
for name, proc in self._components:
|
||||
if name in exclude:
|
||||
|
@ -2306,6 +2308,10 @@ class Language:
|
|||
return self
|
||||
|
||||
|
||||
def _replace_numpy_floats(meta_dict: dict) -> dict:
|
||||
return convert_recursive(lambda v: isinstance(v, numpy.floaty), lambda v: float(v), dict(meta_dict))
|
||||
|
||||
|
||||
@dataclass
|
||||
class FactoryMeta:
|
||||
"""Dataclass containing information about a component and its defaults
|
||||
|
|
Loading…
Reference in New Issue
Block a user