Do dimension inference in span predictor

This commit is contained in:
Paul O'Leary McCann 2022-07-06 19:22:37 +09:00
parent b59b924e49
commit b0800ea855
2 changed files with 92 additions and 10 deletions

View File

@ -13,7 +13,6 @@ from .coref_util import get_sentence_ids
@registry.architectures("spacy.SpanPredictor.v1") @registry.architectures("spacy.SpanPredictor.v1")
def build_span_predictor( def build_span_predictor(
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],
tok2vec_size: int = 768,
hidden_size: int = 1024, hidden_size: int = 1024,
distance_embedding_size: int = 64, distance_embedding_size: int = 64,
conv_channels: int = 4, conv_channels: int = 4,
@ -24,23 +23,58 @@ def build_span_predictor(
# TODO add model return types # TODO add model return types
with Model.define_operators({">>": chain, "&": tuplify}): with Model.define_operators({">>": chain, "&": tuplify}):
span_predictor = PyTorchWrapper( span_predictor = Model(
"span_predictor",
forward=span_predictor_forward,
init=span_predictor_init,
dims={"nI": nI},
attrs={
"distance_embedding_size": distance_embedding_size,
"hidden_size": hidden_size,
"conv_channels": conv_channels,
"window_size": window_size,
"max_distance": max_distance,
"prefix": prefix,
},
)
head_info = build_get_head_metadata(prefix)
model = (tok2vec & head_info) >> span_predictor
model.set_ref("span_predictor", span_predictor)
return model
def span_predictor_init(model: Model, X=None, Y=None):
if model.layers:
return
if X is not None and model.has_dim("nI") is None:
model.set_dim("nI", get_width(X))
hidden_size = model.attrs["hidden_size"]
distance_embedding_size = model.attrs["distance_embedding_size"]
conv_channels = model.attrs["conv_channels"]
window_size = model.attrs["window_size"]
max_distance = model.attrs["max_distance"]
prefix = model.attrs["prefix"]
model._layers = [
PyTorchWrapper(
SpanPredictor( SpanPredictor(
tok2vec_size, model.get_dim("nI"),
hidden_size, hidden_size,
distance_embedding_size, distance_embedding_size,
conv_channels, conv_channels,
window_size, window_size,
max_distance, max_distance,
prefix,
), ),
convert_inputs=convert_span_predictor_inputs, convert_inputs=convert_span_predictor_inputs,
) )
# TODO use proper parameter for prefix # TODO maybe we need mixed precision and grad scaling?
head_info = build_get_head_metadata(prefix) ]
model = (tok2vec & head_info) >> span_predictor
return model
def span_predictor_forward(model: Model, X, is_train: bool):
return model.layers[0](X, is_train)
def convert_span_predictor_inputs( def convert_span_predictor_inputs(
model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool model: Model, X: Tuple[List[Floats2d], Tuple[List[Ints1d], List[Ints1d]]], is_train: bool

View File

@ -5,6 +5,7 @@ from thinc.types import Floats2d, Floats3d, Ints2d
from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy from thinc.api import Model, Config, Optimizer, CategoricalCrossentropy
from thinc.api import set_dropout_rate, to_categorical from thinc.api import set_dropout_rate, to_categorical
from itertools import islice from itertools import islice
import srsly
from .trainable_pipe import TrainablePipe from .trainable_pipe import TrainablePipe
from ..language import Language from ..language import Language
@ -13,7 +14,7 @@ from ..errors import Errors
from ..scorer import Scorer, doc2clusters from ..scorer import Scorer, doc2clusters
from ..tokens import Doc from ..tokens import Doc
from ..vocab import Vocab from ..vocab import Vocab
from ..util import registry from ..util import registry, from_bytes, from_disk
from ..ml.models.coref_util import ( from ..ml.models.coref_util import (
MentionClusters, MentionClusters,
@ -23,7 +24,6 @@ from ..ml.models.coref_util import (
default_span_predictor_config = """ default_span_predictor_config = """
[model] [model]
@architectures = "spacy.SpanPredictor.v1" @architectures = "spacy.SpanPredictor.v1"
tok2vec_size = 768
hidden_size = 1024 hidden_size = 1024
distance_embedding_size = 64 distance_embedding_size = 64
conv_channels = 4 conv_channels = 4
@ -274,3 +274,51 @@ class SpanPredictor(TrainablePipe):
assert len(X) > 0, Errors.E923.format(name=self.name) assert len(X) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=X, Y=Y) self.model.initialize(X=X, Y=Y)
def from_bytes(self, bytes_data, *, exclude=tuple()):
deserializers = {
"cfg": lambda b: self.cfg.update(srsly.json_loads(b)),
"vocab": lambda b: self.vocab.from_bytes(b, exclude=exclude),
}
from_bytes(bytes_data, deserializers, exclude)
self._initialize_from_disk()
model_deserializers = {
"model": lambda b: self.model.from_bytes(b),
}
from_bytes(bytes_data, model_deserializers, exclude)
return self
def from_disk(self, path, exclude=tuple()):
def load_model(p):
try:
with open(p, "rb") as mfile:
self.model.from_bytes(mfile.read())
except AttributeError:
raise ValueError(Errors.E149) from None
deserializers = {
"cfg": lambda p: self.cfg.update(srsly.read_json(p)),
"vocab": lambda p: self.vocab.from_disk(p, exclude=exclude),
}
from_disk(path, deserializers, exclude)
self._initialize_from_disk()
model_deserializers = {
"model": load_model,
}
from_disk(path, model_deserializers, exclude)
return self
def _initialize_from_disk(self):
# The PyTorch model is constructed lazily, so we need to
# explicitly initialize the model before deserialization.
model = self.model.get_ref("span_predictor")
if model.has_dim("nI") is None:
model.set_dim("nI", self.cfg["nI"])
self.model.initialize()