mirror of
https://github.com/explosion/spaCy.git
synced 2025-07-18 12:12:20 +03:00
Fix coref size inference (#10916)
* Add explicit tok2vec_size parameter in clusterer * Add tok2vec size to span predictor config * Minor fixes
This commit is contained in:
parent
aa2eb2789c
commit
196886bbca
|
@ -19,19 +19,15 @@ def build_wl_coref_model(
|
|||
# pairs to keep per mention after rough scoring
|
||||
antecedent_limit: int = 50,
|
||||
antecedent_batch_size: int = 512,
|
||||
tok2vec_size: int = 768, # tok2vec size
|
||||
):
|
||||
# TODO add model return types
|
||||
# TODO fix this
|
||||
try:
|
||||
dim = tok2vec.get_dim("nO")
|
||||
except ValueError:
|
||||
# happens with transformer listener
|
||||
dim = 768
|
||||
# dim = tok2vec.maybe_get_dim("n0")
|
||||
|
||||
with Model.define_operators({">>": chain}):
|
||||
coref_clusterer = PyTorchWrapper(
|
||||
CorefClusterer(
|
||||
dim,
|
||||
tok2vec_size,
|
||||
distance_embedding_size,
|
||||
hidden_size,
|
||||
depth,
|
||||
|
@ -56,7 +52,7 @@ def convert_coref_clusterer_inputs(model: Model, X: List[Floats2d], is_train: bo
|
|||
def backprop(args: ArgsKwargs) -> List[Floats2d]:
|
||||
# convert to xp and wrap in list
|
||||
gradients = torch2xp(args.args[0])
|
||||
assert isinstance(gradients, Floats2d)
|
||||
# assert isinstance(gradients, Floats2d)
|
||||
return [gradients]
|
||||
|
||||
return ArgsKwargs(args=(word_features,), kwargs={}), backprop
|
||||
|
@ -89,7 +85,7 @@ class CorefClusterer(torch.nn.Module):
|
|||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int, # tok2vec size
|
||||
dim: int,
|
||||
dist_emb_size: int,
|
||||
hidden_size: int,
|
||||
n_layers: int,
|
||||
|
@ -109,19 +105,19 @@ class CorefClusterer(torch.nn.Module):
|
|||
"""
|
||||
self.dropout = torch.nn.Dropout(dropout)
|
||||
self.batch_size = batch_size
|
||||
# Modules
|
||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
||||
|
||||
pair_emb = dim * 3 + self.pw.shape
|
||||
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
|
||||
self.a_scorer = AnaphoricityScorer(
|
||||
pair_emb, hidden_size, n_layers, dropout
|
||||
)
|
||||
self.lstm = torch.nn.LSTM(
|
||||
input_size=dim,
|
||||
hidden_size=dim,
|
||||
batch_first=True,
|
||||
)
|
||||
|
||||
self.rough_scorer = RoughScorer(dim, dropout, roughk)
|
||||
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
|
||||
pair_emb = dim * 3 + self.pw.shape
|
||||
self.a_scorer = AnaphoricityScorer(pair_emb, hidden_size, n_layers, dropout)
|
||||
|
||||
def forward(self, word_features: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
|
|
|
@ -13,6 +13,7 @@ from .coref_util import get_sentence_ids
|
|||
@registry.architectures("spacy.SpanPredictor.v1")
|
||||
def build_span_predictor(
|
||||
tok2vec: Model[List[Doc], List[Floats2d]],
|
||||
tok2vec_size: int = 768,
|
||||
hidden_size: int = 1024,
|
||||
distance_embedding_size: int = 64,
|
||||
conv_channels: int = 4,
|
||||
|
@ -21,17 +22,11 @@ def build_span_predictor(
|
|||
prefix: str = "coref_head_clusters",
|
||||
):
|
||||
# TODO add model return types
|
||||
# TODO fix this
|
||||
try:
|
||||
dim = tok2vec.get_dim("nO")
|
||||
except ValueError:
|
||||
# happens with transformer listener
|
||||
dim = 768
|
||||
|
||||
with Model.define_operators({">>": chain, "&": tuplify}):
|
||||
span_predictor = PyTorchWrapper(
|
||||
SpanPredictor(
|
||||
dim,
|
||||
tok2vec_size,
|
||||
hidden_size,
|
||||
distance_embedding_size,
|
||||
conv_channels,
|
||||
|
|
|
@ -30,6 +30,7 @@ from ..coref_scorer import Evaluator, get_cluster_info, lea
|
|||
default_config = """
|
||||
[model]
|
||||
@architectures = "spacy.Coref.v1"
|
||||
tok2vec_size = 768
|
||||
distance_embedding_size = 20
|
||||
hidden_size = 1024
|
||||
depth = 1
|
||||
|
|
|
@ -24,6 +24,7 @@ from ..ml.models.coref_util import (
|
|||
default_span_predictor_config = """
|
||||
[model]
|
||||
@architectures = "spacy.SpanPredictor.v1"
|
||||
tok2vec_size = 768
|
||||
hidden_size = 1024
|
||||
distance_embedding_size = 64
|
||||
conv_channels = 4
|
||||
|
|
Loading…
Reference in New Issue
Block a user