diff --git a/spacy/ml/models/coref.py b/spacy/ml/models/coref.py index 4967e7f23..377f02236 100644 --- a/spacy/ml/models/coref.py +++ b/spacy/ml/models/coref.py @@ -44,8 +44,9 @@ def build_wl_coref_model( }, ) - coref_model = tok2vec >> coref_clusterer - return coref_model + model = tok2vec >> coref_clusterer + model.set_ref("coref_clusterer", coref_clusterer) + return model def coref_init(model: Model, X=None, Y=None): diff --git a/spacy/pipeline/coref.py b/spacy/pipeline/coref.py index cd07f80e8..ef74a83a4 100644 --- a/spacy/pipeline/coref.py +++ b/spacy/pipeline/coref.py @@ -6,6 +6,7 @@ from thinc.api import Model, Config, Optimizer from thinc.api import set_dropout_rate, to_categorical from itertools import islice from statistics import mean +import srsly from .trainable_pipe import TrainablePipe from ..language import Language @@ -13,7 +14,7 @@ from ..training import Example, validate_examples, validate_get_examples from ..errors import Errors from ..tokens import Doc from ..vocab import Vocab -from ..util import registry +from ..util import registry, from_disk, from_bytes from ..ml.models.coref_util import ( create_gold_scores, @@ -316,3 +317,57 @@ class CoreferenceResolver(TrainablePipe): assert len(X) > 0, Errors.E923.format(name=self.name) self.model.initialize(X=X, Y=Y) + + # Store the input dimensionality. nI and nO are not stored explicitly + # for PyTorch models. This makes it tricky to reconstruct the model + # during deserialization. So, besides storing the labels, we also + # store the number of inputs. + coref_clusterer = self.model.get_ref("coref_clusterer") + self.cfg["nI"] = coref_clusterer.get_dim("nI") + + def from_bytes(self, bytes_data, *, exclude=tuple()): + deserializers = { + "cfg": lambda b: self.cfg.update(srsly.json_loads(b)), + "vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude), + } + from_bytes(bytes_data, deserializers, exclude) + + self._initialize_from_disk() + + model_deserializers = { + "model": lambda b: self.model.from_bytes(b), + } + from_bytes(bytes_data, model_deserializers, exclude) + + return self + + def from_disk(self, path, exclude=tuple()): + def load_model(p): + try: + with open(p, "rb") as mfile: + self.model.from_bytes(mfile.read()) + except AttributeError: + raise ValueError(Errors.E149) from None + + deserializers = { + "cfg": lambda p: self.cfg.update(srsly.read_json(p)), + "vocab": lambda p: self.vocab.from_disk(p, exclude=exclude), + } + from_disk(path, deserializers, exclude) + + self._initialize_from_disk() + + model_deserializers = { + "model": load_model, + } + from_disk(path, model_deserializers, exclude) + + return self + + def _initialize_from_disk(self): + # The PyTorch model is constructed lazily, so we need to + # explicitly initialize the model before deserialization. + model = self.model.get_ref("coref_clusterer") + if model.has_dim("nI") is None: + model.set_dim("nI", self.cfg["nI"]) + self.model.initialize()