Rename coref params

This commit is contained in:
Paul O'Leary McCann 2022-05-16 16:50:10 +09:00
parent 13481fbcc2
commit 2e8f0e9168
3 changed files with 58 additions and 49 deletions

View File

@ -14,14 +14,13 @@ from .coref_util import add_dummy
@registry.architectures("spacy.Coref.v1") @registry.architectures("spacy.Coref.v1")
def build_wl_coref_model( def build_wl_coref_model(
tok2vec: Model[List[Doc], List[Floats2d]], tok2vec: Model[List[Doc], List[Floats2d]],
embedding_size: int = 20, distance_embedding_size: int = 20,
hidden_size: int = 1024, hidden_size: int = 1024,
n_hidden_layers: int = 1, # TODO rename to "depth"? depth: int = 1,
dropout: float = 0.3, dropout: float = 0.3,
# pairs to keep per mention after rough scoring # pairs to keep per mention after rough scoring
rough_k: int = 50, antecedent_limit: int = 50,
# TODO is this not a training loop setting? antecedent_batch_size: int = 512,
a_scoring_batch_size: int = 512,
): ):
# TODO add model return types # TODO add model return types
# TODO fix this # TODO fix this
@ -35,12 +34,12 @@ def build_wl_coref_model(
coref_scorer = PyTorchWrapper( coref_scorer = PyTorchWrapper(
CorefScorer( CorefScorer(
dim, dim,
embedding_size, distance_embedding_size,
hidden_size, hidden_size,
n_hidden_layers, depth,
dropout, dropout,
rough_k, antecedent_limit,
a_scoring_batch_size, antecedent_batch_size,
), ),
convert_inputs=convert_coref_scorer_inputs, convert_inputs=convert_coref_scorer_inputs,
convert_outputs=convert_coref_scorer_outputs, convert_outputs=convert_coref_scorer_outputs,
@ -99,7 +98,7 @@ class CorefScorer(torch.nn.Module):
dist_emb_size: int, dist_emb_size: int,
hidden_size: int, hidden_size: int,
n_layers: int, n_layers: int,
dropout_rate: float, dropout: float,
roughk: int, roughk: int,
batch_size: int, batch_size: int,
): ):
@ -109,31 +108,31 @@ class CorefScorer(torch.nn.Module):
dist_emb_size: Size of the distance embeddings. dist_emb_size: Size of the distance embeddings.
hidden_size: Size of the coreference candidate embeddings. hidden_size: Size of the coreference candidate embeddings.
n_layers: Numbers of layers in the AnaphoricityScorer. n_layers: Numbers of layers in the AnaphoricityScorer.
dropout_rate: Dropout probability to apply across all modules. dropout: Dropout probability to apply across all modules.
roughk: Number of candidates the RoughScorer returns. roughk: Number of candidates the RoughScorer returns.
batch_size: Internal batch-size for the more expensive scorer. batch_size: Internal batch-size for the more expensive scorer.
""" """
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout)
self.batch_size = batch_size self.batch_size = batch_size
# Modules # Modules
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate) self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer( self.a_scorer = AnaphoricityScorer(
pair_emb, pair_emb,
hidden_size, hidden_size,
n_layers, n_layers,
dropout_rate dropout
) )
self.lstm = torch.nn.LSTM( self.lstm = torch.nn.LSTM(
input_size=dim, input_size=dim,
hidden_size=dim, hidden_size=dim,
batch_first=True, batch_first=True,
) )
self.rough_scorer = RoughScorer(dim, dropout_rate, roughk) self.rough_scorer = RoughScorer(dim, dropout, roughk)
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate) self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer( self.a_scorer = AnaphoricityScorer(
pair_emb, hidden_size, n_layers, dropout_rate pair_emb, hidden_size, n_layers, dropout
) )
def forward( def forward(
@ -190,18 +189,18 @@ class CorefScorer(torch.nn.Module):
class AnaphoricityScorer(torch.nn.Module): class AnaphoricityScorer(torch.nn.Module):
"""Calculates anaphoricity scores by passing the inputs into a FFNN""" """Calculates anaphoricity scores by passing the inputs into a FFNN"""
def __init__(self, in_features: int, hidden_size, n_hidden_layers, dropout_rate): def __init__(self, in_features: int, hidden_size, depth, dropout):
super().__init__() super().__init__()
hidden_size = hidden_size hidden_size = hidden_size
if not n_hidden_layers: if not depth:
hidden_size = in_features hidden_size = in_features
layers = [] layers = []
for i in range(n_hidden_layers): for i in range(depth):
layers.extend( layers.extend(
[ [
torch.nn.Linear(hidden_size if i else in_features, hidden_size), torch.nn.Linear(hidden_size if i else in_features, hidden_size),
torch.nn.LeakyReLU(), torch.nn.LeakyReLU(),
torch.nn.Dropout(dropout_rate), torch.nn.Dropout(dropout),
] ]
) )
self.hidden = torch.nn.Sequential(*layers) self.hidden = torch.nn.Sequential(*layers)
@ -243,7 +242,7 @@ class AnaphoricityScorer(torch.nn.Module):
def _ffnn(self, x: torch.Tensor) -> torch.Tensor: def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
""" """
x: tensor of shape (batch_size x roughk x n_features x: tensor of shape (batch_size x roughk x n_features
returns: tensor of shape (batch_size x rough_k) returns: tensor of shape (batch_size x antecedent_limit)
""" """
x = self.out(self.hidden(x)) x = self.out(self.hidden(x))
return x.squeeze(2) return x.squeeze(2)
@ -289,11 +288,11 @@ class RoughScorer(torch.nn.Module):
steps to reduce computational cost. steps to reduce computational cost.
""" """
def __init__(self, features: int, dropout_rate: float, rough_k: float): def __init__(self, features: int, dropout: float, antecedent_limit: int):
super().__init__() super().__init__()
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout)
self.bilinear = torch.nn.Linear(features, features) self.bilinear = torch.nn.Linear(features, features)
self.k = rough_k self.k = antecedent_limit
def forward( def forward(
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
@ -317,7 +316,7 @@ class RoughScorer(torch.nn.Module):
class DistancePairwiseEncoder(torch.nn.Module): class DistancePairwiseEncoder(torch.nn.Module):
def __init__(self, embedding_size, dropout_rate): def __init__(self, distance_embedding_size, dropout):
""" """
Takes the top_indices indicating, which is a ranked Takes the top_indices indicating, which is a ranked
list for each word and its most likely corresponding list for each word and its most likely corresponding
@ -325,15 +324,15 @@ class DistancePairwiseEncoder(torch.nn.Module):
up a distance embedding from a table, where the distance up a distance embedding from a table, where the distance
corresponds to the log-distance. corresponds to the log-distance.
embedding_size: int, distance_embedding_size: int,
Dimensionality of the distance-embeddings table. Dimensionality of the distance-embeddings table.
dropout_rate: float, dropout: float,
Dropout probability. Dropout probability.
""" """
super().__init__() super().__init__()
emb_size = embedding_size emb_size = distance_embedding_size
self.distance_emb = torch.nn.Embedding(9, emb_size) self.distance_emb = torch.nn.Embedding(9, emb_size)
self.dropout = torch.nn.Dropout(dropout_rate) self.dropout = torch.nn.Dropout(dropout)
self.shape = emb_size self.shape = emb_size
def forward( def forward(

View File

@ -31,13 +31,12 @@ from ..coref_scorer import Evaluator, get_cluster_info, lea
default_config = """ default_config = """
[model] [model]
@architectures = "spacy.Coref.v1" @architectures = "spacy.Coref.v1"
embedding_size = 20 distance_embedding_size = 20
hidden_size = 1024 hidden_size = 1024
n_hidden_layers = 1 depth = 1
dropout = 0.3 dropout = 0.3
rough_k = 50 antecedent_limit = 50
a_scoring_batch_size = 512 antecedent_batch_size = 512
sp_embedding_size = 64
[model.tok2vec] [model.tok2vec]
@architectures = "spacy.Tok2Vec.v2" @architectures = "spacy.Tok2Vec.v2"

View File

@ -939,12 +939,12 @@ performance if working with only token-level clusters is acceptable.
> >
> [model] > [model]
> @architectures = "spacy.Coref.v1" > @architectures = "spacy.Coref.v1"
> embedding_size = 20 > distance_embedding_size = 20
> dropout = 0.3 > dropout = 0.3
> hidden_size = 1024 > hidden_size = 1024
> n_hidden_layers = 2 > depth = 2
> rough_k = 50 > antecedent_limit = 50
> a_scoring_batch_size = 512 > antecedent_batch_size = 512
> >
> [model.tok2vec] > [model.tok2vec]
> @architectures = "spacy-transformers.TransformerListener.v1" > @architectures = "spacy-transformers.TransformerListener.v1"
@ -955,16 +955,16 @@ performance if working with only token-level clusters is acceptable.
The `Coref` model architecture is a Thinc `Model`. The `Coref` model architecture is a Thinc `Model`.
| Name | Description | | Name | Description |
| ---------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | | ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ | | `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
| `embedding_size` | ~~int~~ | | `distance_embedding_size` | A representation of the distance between candidates. ~~int~~ |
| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ | | `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ |
| `hidden_size` | Size of the main internal layers. ~~int~~ | | `hidden_size` | Size of the main internal layers. ~~int~~ |
| `n_hidden_layers` | Depth of the internal network. ~~int~~ | | `depth` | Depth of the internal network. ~~int~~ |
| `rough_k` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ | | `antecedent_limit` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ |
| `a_scoring_batch_size` | Internal batch size. ~~int~~ | | `antecedent_batch_size` | Internal batch size. ~~int~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ | | **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
### spacy.SpanPredictor.v1 {#SpanPredictor} ### spacy.SpanPredictor.v1 {#SpanPredictor}
@ -985,3 +985,14 @@ The `Coref` model architecture is a Thinc `Model`.
> ``` > ```
The `SpanPredictor` model architecture is a Thinc `Model`. The `SpanPredictor` model architecture is a Thinc `Model`.
| Name | Description |
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ |
| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ |
| `hidden_size` | Size of the main internal layers. ~~int~~ |
| `depth` | Depth of the internal network. ~~int~~ |
| `antecedent_limit` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ |
| `antecedent_batch_size` | Internal batch size. ~~int~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ |