Fix serialization for uk trf model

This commit is contained in:
Matthew Honnibal 2024-09-06 22:23:25 +02:00
parent 2a37f97365
commit 3bc5846e83

View File

@ -31,8 +31,10 @@ from typing import (
overload,
)
import numpy
import srsly
from thinc.api import Config, CupyOps, Optimizer, get_current_ops
from thinc.util import convert_recursive
from . import about, ty, util
from .errors import Errors, Warnings
@ -2221,12 +2223,22 @@ class Language:
DOCS: https://spacy.io/api/language#to_disk
"""
def _dump_meta(path, meta) -> None:
"""Helper function for to_disk. It's not entirely clear why this is
necessary: we're seeing numpy.float32 values in the Ukraining trf
model meta, maybe it'll happen elsewhere as well -- but it seems
to originate from the specific transformer code being used by that
model.
"""
meta = convert_recursive(lambda v: isinstance(v, numpy.float32), lambda v: float(v), dict(meta))
srsly.write_json(path, meta)
path = util.ensure_path(path)
serializers = {}
serializers["tokenizer"] = lambda p: self.tokenizer.to_disk( # type: ignore[union-attr]
p, exclude=["vocab"]
)
serializers["meta.json"] = lambda p: srsly.write_json(p, self.meta)
serializers["meta.json"] = _dump_meta
serializers["config.cfg"] = lambda p: self.config.to_disk(p)
for name, proc in self._components:
if name in exclude:
@ -2334,10 +2346,20 @@ class Language:
DOCS: https://spacy.io/api/language#to_bytes
"""
def _dump_meta() -> str:
"""Helper function for to_disk. It's not entirely clear why this is
necessary: we're seeing numpy.float32 values in the Ukraining trf
model meta, maybe it'll happen elsewhere as well -- but it seems
to originate from the specific transformer code being used by that
model.
"""
meta = convert_recursive(lambda v: isinstance(v, numpy.float32), lambda v: float(v), dict(self.meta))
return srsly.json_dumps(meta)
serializers: Dict[str, Callable[[], bytes]] = {}
serializers["vocab"] = lambda: self.vocab.to_bytes(exclude=exclude)
serializers["tokenizer"] = lambda: self.tokenizer.to_bytes(exclude=["vocab"]) # type: ignore[union-attr]
serializers["meta.json"] = lambda: srsly.json_dumps(self.meta)
serializers["meta.json"] = _dump_meta
serializers["config.cfg"] = lambda: self.config.to_bytes()
for name, proc in self._components:
if name in exclude:
@ -2427,7 +2449,6 @@ class DisabledPipes(list):
self.nlp.enable_pipe(name)
self[:] = []
def _copy_examples(
examples: Iterable[Example], *, copy_x: bool = True, copy_y: bool = False
) -> List[Example]: