mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 04:02:20 +03:00
It works!
Was missing the serialization-related code from biaffine.
This commit is contained in:
parent
ba1bf8ae72
commit
bd17c38b74
|
@ -44,8 +44,9 @@ def build_wl_coref_model(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
coref_model = tok2vec >> coref_clusterer
|
model = tok2vec >> coref_clusterer
|
||||||
return coref_model
|
model.set_ref("coref_clusterer", coref_clusterer)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
def coref_init(model: Model, X=None, Y=None):
|
def coref_init(model: Model, X=None, Y=None):
|
||||||
|
|
|
@ -6,6 +6,7 @@ from thinc.api import Model, Config, Optimizer
|
||||||
from thinc.api import set_dropout_rate, to_categorical
|
from thinc.api import set_dropout_rate, to_categorical
|
||||||
from itertools import islice
|
from itertools import islice
|
||||||
from statistics import mean
|
from statistics import mean
|
||||||
|
import srsly
|
||||||
|
|
||||||
from .trainable_pipe import TrainablePipe
|
from .trainable_pipe import TrainablePipe
|
||||||
from ..language import Language
|
from ..language import Language
|
||||||
|
@ -13,7 +14,7 @@ from ..training import Example, validate_examples, validate_get_examples
|
||||||
from ..errors import Errors
|
from ..errors import Errors
|
||||||
from ..tokens import Doc
|
from ..tokens import Doc
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
from ..util import registry
|
from ..util import registry, from_disk, from_bytes
|
||||||
|
|
||||||
from ..ml.models.coref_util import (
|
from ..ml.models.coref_util import (
|
||||||
create_gold_scores,
|
create_gold_scores,
|
||||||
|
@ -316,3 +317,57 @@ class CoreferenceResolver(TrainablePipe):
|
||||||
|
|
||||||
assert len(X) > 0, Errors.E923.format(name=self.name)
|
assert len(X) > 0, Errors.E923.format(name=self.name)
|
||||||
self.model.initialize(X=X, Y=Y)
|
self.model.initialize(X=X, Y=Y)
|
||||||
|
|
||||||
|
# Store the input dimensionality. nI and nO are not stored explicitly
|
||||||
|
# for PyTorch models. This makes it tricky to reconstruct the model
|
||||||
|
# during deserialization. So, besides storing the labels, we also
|
||||||
|
# store the number of inputs.
|
||||||
|
coref_clusterer = self.model.get_ref("coref_clusterer")
|
||||||
|
self.cfg["nI"] = coref_clusterer.get_dim("nI")
|
||||||
|
|
||||||
|
def from_bytes(self, bytes_data, *, exclude=tuple()):
|
||||||
|
deserializers = {
|
||||||
|
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
|
||||||
|
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
|
||||||
|
}
|
||||||
|
from_bytes(bytes_data, deserializers, exclude)
|
||||||
|
|
||||||
|
self._initialize_from_disk()
|
||||||
|
|
||||||
|
model_deserializers = {
|
||||||
|
"model": lambda b: self.model.from_bytes(b),
|
||||||
|
}
|
||||||
|
from_bytes(bytes_data, model_deserializers, exclude)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def from_disk(self, path, exclude=tuple()):
|
||||||
|
def load_model(p):
|
||||||
|
try:
|
||||||
|
with open(p, "rb") as mfile:
|
||||||
|
self.model.from_bytes(mfile.read())
|
||||||
|
except AttributeError:
|
||||||
|
raise ValueError(Errors.E149) from None
|
||||||
|
|
||||||
|
deserializers = {
|
||||||
|
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
|
||||||
|
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
|
||||||
|
}
|
||||||
|
from_disk(path, deserializers, exclude)
|
||||||
|
|
||||||
|
self._initialize_from_disk()
|
||||||
|
|
||||||
|
model_deserializers = {
|
||||||
|
"model": load_model,
|
||||||
|
}
|
||||||
|
from_disk(path, model_deserializers, exclude)
|
||||||
|
|
||||||
|
return self
|
||||||
|
|
||||||
|
def _initialize_from_disk(self):
|
||||||
|
# The PyTorch model is constructed lazily, so we need to
|
||||||
|
# explicitly initialize the model before deserialization.
|
||||||
|
model = self.model.get_ref("coref_clusterer")
|
||||||
|
if model.has_dim("nI") is None:
|
||||||
|
model.set_dim("nI", self.cfg["nI"])
|
||||||
|
self.model.initialize()
|
||||||
|
|
Loading…
Reference in New Issue
Block a user