It works!

Was missing the serialization-related code from biaffine.
This commit is contained in:
Paul O'Leary McCann 2022-07-06 18:58:22 +09:00
parent ba1bf8ae72
commit bd17c38b74
2 changed files with 59 additions and 3 deletions

View File

@ -44,8 +44,9 @@ def build_wl_coref_model(
}, },
) )
coref_model = tok2vec >> coref_clusterer model = tok2vec >> coref_clusterer
return coref_model model.set_ref("coref_clusterer", coref_clusterer)
return model
def coref_init(model: Model, X=None, Y=None): def coref_init(model: Model, X=None, Y=None):

View File

@ -6,6 +6,7 @@ from thinc.api import Model, Config, Optimizer
from thinc.api import set_dropout_rate, to_categorical from thinc.api import set_dropout_rate, to_categorical
from itertools import islice from itertools import islice
from statistics import mean from statistics import mean
import srsly
from .trainable_pipe import TrainablePipe from .trainable_pipe import TrainablePipe
from ..language import Language from ..language import Language
@ -13,7 +14,7 @@ from ..training import Example, validate_examples, validate_get_examples
from ..errors import Errors from ..errors import Errors
from ..tokens import Doc from ..tokens import Doc
from ..vocab import Vocab from ..vocab import Vocab
from ..util import registry from ..util import registry, from_disk, from_bytes
from ..ml.models.coref_util import ( from ..ml.models.coref_util import (
create_gold_scores, create_gold_scores,
@ -316,3 +317,57 @@ class CoreferenceResolver(TrainablePipe):
assert len(X) > 0, Errors.E923.format(name=self.name) assert len(X) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=X, Y=Y) self.model.initialize(X=X, Y=Y)
# Store the input dimensionality. nI and nO are not stored explicitly
# for PyTorch models. This makes it tricky to reconstruct the model
# during deserialization. So, besides storing the labels, we also
# store the number of inputs.
coref_clusterer = self.model.get_ref("coref_clusterer")
self.cfg["nI"] = coref_clusterer.get_dim("nI")
def from_bytes(self, bytes_data, *, exclude=tuple()):
deserializers = {
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
}
from_bytes(bytes_data, deserializers, exclude)
self._initialize_from_disk()
model_deserializers = {
"model": lambda b: self.model.from_bytes(b),
}
from_bytes(bytes_data, model_deserializers, exclude)
return self
def from_disk(self, path, exclude=tuple()):
def load_model(p):
try:
with open(p, "rb") as mfile:
self.model.from_bytes(mfile.read())
except AttributeError:
raise ValueError(Errors.E149) from None
deserializers = {
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
}
from_disk(path, deserializers, exclude)
self._initialize_from_disk()
model_deserializers = {
"model": load_model,
}
from_disk(path, model_deserializers, exclude)
return self
def _initialize_from_disk(self):
# The PyTorch model is constructed lazily, so we need to
# explicitly initialize the model before deserialization.
model = self.model.get_ref("coref_clusterer")
if model.has_dim("nI") is None:
model.set_dim("nI", self.cfg["nI"])
self.model.initialize()