Fix coref size inference (#10916)

* Add explicit tok2vec_size parameter in clusterer

* Add tok2vec size to span predictor config

* Minor fixes
This commit is contained in:
Paul O'Leary McCann 2022-06-08 20:03:41 +09:00 committed by GitHub
parent aa2eb2789c
commit 196886bbca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 14 additions and 21 deletions

View File

@ -19,19 +19,15 @@ def build_wl_coref_model(
# pairs to keep per mention after rough scoring # pairs to keep per mention after rough scoring
antecedent_limit: int = 50, antecedent_limit: int = 50,
antecedent_batch_size: int = 512, antecedent_batch_size: int = 512,
tok2vec_size: int = 768, # tok2vec size
): ):
# TODO add model return types # TODO add model return types
# TODO fix this # dim = tok2vec.maybe_get_dim("n0")
try:
dim = tok2vec.get_dim("nO")
except ValueError:
# happens with transformer listener
dim = 768
with Model.define_operators({">>": chain}): with Model.define_operators({">>": chain}):
coref_clusterer = PyTorchWrapper( coref_clusterer = PyTorchWrapper(
CorefClusterer( CorefClusterer(
dim, tok2vec_size,
distance_embedding_size, distance_embedding_size,
hidden_size, hidden_size,
depth, depth,
@ -56,7 +52,7 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
def backprop(args: ArgsKwargs) -> List[Floats2d]: def backprop(args: ArgsKwargs) -> List[Floats2d]:
# convert to xp and wrap in list # convert to xp and wrap in list
gradients = torch2xp(args.args[0]) gradients = torch2xp(args.args[0])
assert isinstance(gradients, Floats2d) # assert isinstance(gradients, Floats2d)
return [gradients] return [gradients]
return ArgsKwargs(args=(word_features,), kwargs={}), backprop return ArgsKwargs(args=(word_features,), kwargs={}), backprop
@ -89,7 +85,7 @@ class CorefClusterer(torch.nn.Module):
def __init__( def __init__(
self, self,
dim: int, # tok2vec size dim: int,
dist_emb_size: int, dist_emb_size: int,
hidden_size: int, hidden_size: int,
n_layers: int, n_layers: int,
@ -109,19 +105,19 @@ class CorefClusterer(torch.nn.Module):
""" """
self.dropout = torch.nn.Dropout(dropout) self.dropout = torch.nn.Dropout(dropout)
self.batch_size = batch_size self.batch_size = batch_size
# Modules
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout) self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout) self.a_scorer = AnaphoricityScorer(
pair_emb, hidden_size, n_layers, dropout
)
self.lstm = torch.nn.LSTM( self.lstm = torch.nn.LSTM(
input_size=dim, input_size=dim,
hidden_size=dim, hidden_size=dim,
batch_first=True, batch_first=True,
) )
self.rough_scorer = RoughScorer(dim, dropout, roughk) self.rough_scorer = RoughScorer(dim, dropout, roughk)
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
""" """

View File

@ -13,6 +13,7 @@ from .coref_util import get_sentence_ids
@registry.architectures("spacy.SpanPredictor.v1") @registry.architectures("spacy.SpanPredictor.v1")
def build_span_predictor( def build_span_predictor(
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],
tok2vec_size: int = 768,
hidden_size: int = 1024, hidden_size: int = 1024,
distance_embedding_size: int = 64, distance_embedding_size: int = 64,
conv_channels: int = 4, conv_channels: int = 4,
@ -21,17 +22,11 @@ def build_span_predictor(
prefix: str = "coref_head_clusters", prefix: str = "coref_head_clusters",
): ):
# TODO add model return types # TODO add model return types
# TODO fix this
try:
dim = tok2vec.get_dim("nO")
except ValueError:
# happens with transformer listener
dim = 768
with Model.define_operators({">>": chain, "&": tuplify}): with Model.define_operators({">>": chain, "&": tuplify}):
span_predictor = PyTorchWrapper( span_predictor = PyTorchWrapper(
SpanPredictor( SpanPredictor(
dim, tok2vec_size,
hidden_size, hidden_size,
distance_embedding_size, distance_embedding_size,
conv_channels, conv_channels,

View File

@ -30,6 +30,7 @@ from ..coref_scorer import Evaluator, get_cluster_info, lea
default_config = """ default_config = """
[model] [model]
@architectures = "spacy.Coref.v1" @architectures = "spacy.Coref.v1"
tok2vec_size = 768
distance_embedding_size = 20 distance_embedding_size = 20
hidden_size = 1024 hidden_size = 1024
depth = 1 depth = 1

View File

@ -24,6 +24,7 @@ from ..ml.models.coref_util import (
default_span_predictor_config = """ default_span_predictor_config = """
[model] [model]
@architectures = "spacy.SpanPredictor.v1" @architectures = "spacy.SpanPredictor.v1"
tok2vec_size = 768
hidden_size = 1024 hidden_size = 1024
distance_embedding_size = 64 distance_embedding_size = 64
conv_channels = 4 conv_channels = 4