diff --git a/spacy/language.py b/spacy/language.py index 204508dc4..b0e85446d 100644 --- a/spacy/language.py +++ b/spacy/language.py @@ -31,8 +31,10 @@ from typing import ( overload, ) +import numpy import srsly from thinc.api import Config, CupyOps, Optimizer, get_current_ops +from thinc.util import convert_recursive from . import about, ty, util from .errors import Errors, Warnings @@ -2221,12 +2223,22 @@ class Language: DOCS: https://spacy.io/api/language#to_disk """ + def _dump_meta(path, meta) -> None: + """Helper function for to_disk. It's not entirely clear why this is + necessary: we're seeing numpy.float32 values in the Ukraining trf + model meta, maybe it'll happen elsewhere as well -- but it seems + to originate from the specific transformer code being used by that + model. + """ + meta = convert_recursive(lambda v: isinstance(v, numpy.float32), lambda v: float(v), dict(meta)) + srsly.write_json(path, meta) + path = util.ensure_path(path) serializers = {} serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr] p, exclude=["vocab"] ) - serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta) + serializers["meta.json"] = _dump_meta serializers["config.cfg"] = lambda p: self.config.to_disk(p) for name, proc in self._components: if name in exclude: @@ -2334,10 +2346,20 @@ class Language: DOCS: https://spacy.io/api/language#to_bytes """ + def _dump_meta() -> str: + """Helper function for to_disk. It's not entirely clear why this is + necessary: we're seeing numpy.float32 values in the Ukraining trf + model meta, maybe it'll happen elsewhere as well -- but it seems + to originate from the specific transformer code being used by that + model. + """ + meta = convert_recursive(lambda v: isinstance(v, numpy.float32), lambda v: float(v), dict(self.meta)) + return srsly.json_dumps(meta) + serializers: Dict[str, Callable[[], bytes]] = {} serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude) serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr] - serializers["meta.json"] = lambda: srsly.json_dumps(self.meta) + serializers["meta.json"] = _dump_meta serializers["config.cfg"] = lambda: self.config.to_bytes() for name, proc in self._components: if name in exclude: @@ -2427,7 +2449,6 @@ class DisabledPipes(list): self.nlp.enable_pipe(name) self[:] = [] - def _copy_examples( examples: Iterable[Example], *, copy_x: bool = True, copy_y: bool = False ) -> List[Example]: