First take at dimension inference

This follows the pattern used in the Biaffine Parser, which uses an init
function to get the size only after the tok2vec is available.

This works at first, but serialization fails with an error.
This commit is contained in:
Paul O'Leary McCann 2022-07-06 18:40:05 +09:00
parent c59aeeb0ae
commit ba1bf8ae72

View File

@ -1,6 +1,6 @@
from typing import List, Tuple from typing import List, Tuple
from thinc.api import Model, chain from thinc.api import Model, chain, get_width
from thinc.api import PyTorchWrapper, ArgsKwargs from thinc.api import PyTorchWrapper, ArgsKwargs
from thinc.types import Floats2d from thinc.types import Floats2d
from thinc.util import torch, xp2torch, torch2xp from thinc.util import torch, xp2torch, torch2xp
@ -25,12 +25,48 @@ def build_wl_coref_model(
tok2vec_size: int = 768, # tok2vec size tok2vec_size: int = 768, # tok2vec size
): ):
# TODO add model return types # TODO add model return types
# dim = tok2vec.maybe_get_dim("n0")
nI = None
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
coref_clusterer = PyTorchWrapper( coref_clusterer = Model(
"coref_clusterer",
forward=coref_forward,
init=coref_init,
dims={"nI": nI},
attrs={
"distance_embedding_size": distance_embedding_size,
"hidden_size": hidden_size,
"depth": depth,
"dropout": dropout,
"antecedent_limit": antecedent_limit,
"antecedent_batch_size": antecedent_batch_size,
},
)
coref_model = tok2vec >> coref_clusterer
return coref_model
def coref_init(model: Model, X=None, Y=None):
if model.layers:
return
if X is not None and model.has_dim("nI") is None:
model.set_dim("nI", get_width(X))
hidden_size = model.attrs["hidden_size"]
depth = model.attrs["depth"]
dropout = model.attrs["dropout"]
antecedent_limit = model.attrs["antecedent_limit"]
antecedent_batch_size = model.attrs["antecedent_batch_size"]
distance_embedding_size = model.attrs["distance_embedding_size"]
PyTorchWrapper = registry.get("layers", "PyTorchWrapper.v2")
model._layers = [
PyTorchWrapper(
CorefClusterer( CorefClusterer(
tok2vec_size, model.get_dim("nI"),
distance_embedding_size, distance_embedding_size,
hidden_size, hidden_size,
depth, depth,
@ -41,13 +77,15 @@ def build_wl_coref_model(
convert_inputs=convert_coref_clusterer_inputs, convert_inputs=convert_coref_clusterer_inputs,
convert_outputs=convert_coref_clusterer_outputs, convert_outputs=convert_coref_clusterer_outputs,
) )
coref_model = tok2vec >> coref_clusterer # TODO maybe we need mixed precision and grad scaling?
return coref_model ]
def convert_coref_clusterer_inputs( def coref_forward(model: Model, X, is_train: bool):
model: Model, X: List[Floats2d], is_train: bool return model.layers[0](X, is_train)
):
def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bool):
# The input here is List[Floats2d], one for each doc # The input here is List[Floats2d], one for each doc
# just use the first # just use the first
# TODO real batching # TODO real batching
@ -63,9 +101,7 @@ def convert_coref_clusterer_inputs(
return ArgsKwargs(args=(word_features,), kwargs={}), backprop return ArgsKwargs(args=(word_features,), kwargs={}), backprop
def convert_coref_clusterer_outputs( def convert_coref_clusterer_outputs(model: Model, inputs_outputs, is_train: bool):
model: Model, inputs_outputs, is_train: bool
):
_, outputs = inputs_outputs _, outputs = inputs_outputs
scores, indices = outputs scores, indices = outputs
@ -115,9 +151,7 @@ class CorefClusterer(torch.nn.Module):
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout) self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer( self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
pair_emb, hidden_size, n_layers, dropout
)
self.lstm = torch.nn.LSTM( self.lstm = torch.nn.LSTM(
input_size=dim, input_size=dim,
hidden_size=dim, hidden_size=dim,