diff --git a/spacy/ml/models/span_predictor.py b/spacy/ml/models/span_predictor.py index d44e632bd..55a966c1d 100644 --- a/spacy/ml/models/span_predictor.py +++ b/spacy/ml/models/span_predictor.py @@ -13,7 +13,6 @@ from .coref_util import get_sentence_ids @registry.architectures("spacy.SpanPredictor.v1") def build_span_predictor( tok2vec: Model[List[Doc], List[Floats2d]], - tok2vec_size: int = 768, hidden_size: int = 1024, distance_embedding_size: int = 64, conv_channels: int = 4, @@ -24,23 +23,58 @@ def build_span_predictor( # TODO add model return types with Model.define_operators({">>": chain, "&": tuplify}): - span_predictor = PyTorchWrapper( + span_predictor = Model( + "span_predictor", + forward=span_predictor_forward, + init=span_predictor_init, + dims={"nI": nI}, + attrs={ + "distance_embedding_size": distance_embedding_size, + "hidden_size": hidden_size, + "conv_channels": conv_channels, + "window_size": window_size, + "max_distance": max_distance, + "prefix": prefix, + }, + ) + head_info = build_get_head_metadata(prefix) + model = (tok2vec & head_info) >> span_predictor + model.set_ref("span_predictor", span_predictor) + + return model + +def span_predictor_init(model: Model, X=None, Y=None): + if model.layers: + return + + if X is not None and model.has_dim("nI") is None: + model.set_dim("nI", get_width(X)) + + hidden_size = model.attrs["hidden_size"] + distance_embedding_size = model.attrs["distance_embedding_size"] + conv_channels = model.attrs["conv_channels"] + window_size = model.attrs["window_size"] + max_distance = model.attrs["max_distance"] + prefix = model.attrs["prefix"] + + model._layers = [ + PyTorchWrapper( SpanPredictor( - tok2vec_size, + model.get_dim("nI"), hidden_size, distance_embedding_size, conv_channels, window_size, max_distance, + prefix, ), convert_inputs=convert_span_predictor_inputs, ) - # TODO use proper parameter for prefix - head_info = build_get_head_metadata(prefix) - model = (tok2vec & head_info) >> span_predictor - - return model + # TODO maybe we need mixed precision and grad scaling? + ] +def span_predictor_forward(model: Model, X, is_train: bool): + return model.layers[0](X, is_train) def convert_span_predictor_inputs( model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool diff --git a/spacy/pipeline/span_predictor.py b/spacy/pipeline/span_predictor.py index d7e96a4b2..eed6ce9f8 100644 --- a/spacy/pipeline/span_predictor.py +++ b/spacy/pipeline/span_predictor.py @@ -5,6 +5,7 @@ from thinc.types import Floats2d, Floats3d, Ints2d from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy from thinc.api import set_dropout_rate, to_categorical from itertools import islice +import srsly from .trainable_pipe import TrainablePipe from ..language import Language @@ -13,7 +14,7 @@ from ..errors import Errors from ..scorer import Scorer, doc2clusters from ..tokens import Doc from ..vocab import Vocab -from ..util import registry +from ..util import registry, from_bytes, from_disk from ..ml.models.coref_util import ( MentionClusters, @@ -23,7 +24,6 @@ from ..ml.models.coref_util import ( default_span_predictor_config = """ [model] @architectures = "spacy.SpanPredictor.v1" -tok2vec_size = 768 hidden_size = 1024 distance_embedding_size = 64 conv_channels = 4 @@ -274,3 +274,51 @@ class SpanPredictor(TrainablePipe): assert len(X) > 0, Errors.E923.format(name=self.name) self.model.initialize(X=X, Y=Y) + + + def from_bytes(self, bytes_data, *, exclude=tuple()): + deserializers = { + "cfg": lambda b: self.cfg.update(srsly.json_loads(b)), + "vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude), + } + from_bytes(bytes_data, deserializers, exclude) + + self._initialize_from_disk() + + model_deserializers = { + "model": lambda b: self.model.from_bytes(b), + } + from_bytes(bytes_data, model_deserializers, exclude) + + return self + + def from_disk(self, path, exclude=tuple()): + def load_model(p): + try: + with open(p, "rb") as mfile: + self.model.from_bytes(mfile.read()) + except AttributeError: + raise ValueError(Errors.E149) from None + + deserializers = { + "cfg": lambda p: self.cfg.update(srsly.read_json(p)), + "vocab": lambda p: self.vocab.from_disk(p, exclude=exclude), + } + from_disk(path, deserializers, exclude) + + self._initialize_from_disk() + + model_deserializers = { + "model": load_model, + } + from_disk(path, model_deserializers, exclude) + + return self + + def _initialize_from_disk(self): + # The PyTorch model is constructed lazily, so we need to + # explicitly initialize the model before deserialization. + model = self.model.get_ref("span_predictor") + if model.has_dim("nI") is None: + model.set_dim("nI", self.cfg["nI"]) + self.model.initialize()