Rename coref params

This commit is contained in:
Paul O'Leary McCann 2022-05-16 16:50:10 +09:00
parent 13481fbcc2
commit 2e8f0e9168
3 changed files with 58 additions and 49 deletions

View File

@ -14,14 +14,13 @@ from .coref_util import add_dummy
@registry.architectures("spacy.Coref.v1")
def build_wl_coref_model(
tok2vec: Model[List[Doc], List[Floats2d]],
embedding_size: int = 20,
distance_embedding_size: int = 20,
hidden_size: int = 1024,
n_hidden_layers: int = 1, # TODO rename to "depth"?
depth: int = 1,
dropout: float = 0.3,
# pairs to keep per mention after rough scoring
rough_k: int = 50,
# TODO is this not a training loop setting?
a_scoring_batch_size: int = 512,
antecedent_limit: int = 50,
antecedent_batch_size: int = 512,
):
# TODO add model return types
# TODO fix this
@ -35,12 +34,12 @@ def build_wl_coref_model(
coref_scorer = PyTorchWrapper(
CorefScorer(
dim,
embedding_size,
distance_embedding_size,
hidden_size,
n_hidden_layers,
depth,
dropout,
rough_k,
a_scoring_batch_size,
antecedent_limit,
antecedent_batch_size,
),
convert_inputs=convert_coref_scorer_inputs,
convert_outputs=convert_coref_scorer_outputs,
@ -99,7 +98,7 @@ class CorefScorer(torch.nn.Module):
dist_emb_size: int,
hidden_size: int,
n_layers: int,
dropout_rate: float,
dropout: float,
roughk: int,
batch_size: int,
):
@ -109,31 +108,31 @@ class CorefScorer(torch.nn.Module):
dist_emb_size: Size of the distance embeddings.
hidden_size: Size of the coreference candidate embeddings.
n_layers: Numbers of layers in the AnaphoricityScorer.
dropout_rate: Dropout probability to apply across all modules.
dropout: Dropout probability to apply across all modules.
roughk: Number of candidates the RoughScorer returns.
batch_size: Internal batch-size for the more expensive scorer.
"""
self.dropout = torch.nn.Dropout(dropout_rate)
self.dropout = torch.nn.Dropout(dropout)
self.batch_size = batch_size
# Modules
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer(
pair_emb,
hidden_size,
n_layers,
dropout_rate
dropout
)
self.lstm = torch.nn.LSTM(
input_size=dim,
hidden_size=dim,
batch_first=True,
)
self.rough_scorer = RoughScorer(dim, dropout_rate, roughk)
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout_rate)
self.rough_scorer = RoughScorer(dim, dropout, roughk)
self.pw = DistancePairwiseEncoder(dist_emb_size, dropout)
pair_emb = dim * 3 + self.pw.shape
self.a_scorer = AnaphoricityScorer(
pair_emb, hidden_size, n_layers, dropout_rate
pair_emb, hidden_size, n_layers, dropout
)
def forward(
@ -190,18 +189,18 @@ class CorefScorer(torch.nn.Module):
class AnaphoricityScorer(torch.nn.Module):
"""Calculates anaphoricity scores by passing the inputs into a FFNN"""
def __init__(self, in_features: int, hidden_size, n_hidden_layers, dropout_rate):
def __init__(self, in_features: int, hidden_size, depth, dropout):
super().__init__()
hidden_size = hidden_size
if not n_hidden_layers:
if not depth:
hidden_size = in_features
layers = []
for i in range(n_hidden_layers):
for i in range(depth):
layers.extend(
[
torch.nn.Linear(hidden_size if i else in_features, hidden_size),
torch.nn.LeakyReLU(),
torch.nn.Dropout(dropout_rate),
torch.nn.Dropout(dropout),
]
)
self.hidden = torch.nn.Sequential(*layers)
@ -243,7 +242,7 @@ class AnaphoricityScorer(torch.nn.Module):
def _ffnn(self, x: torch.Tensor) -> torch.Tensor:
"""
x: tensor of shape (batch_size x roughk x n_features
returns: tensor of shape (batch_size x rough_k)
returns: tensor of shape (batch_size x antecedent_limit)
"""
x = self.out(self.hidden(x))
return x.squeeze(2)
@ -289,11 +288,11 @@ class RoughScorer(torch.nn.Module):
steps to reduce computational cost.
"""
def __init__(self, features: int, dropout_rate: float, rough_k: float):
def __init__(self, features: int, dropout: float, antecedent_limit: int):
super().__init__()
self.dropout = torch.nn.Dropout(dropout_rate)
self.dropout = torch.nn.Dropout(dropout)
self.bilinear = torch.nn.Linear(features, features)
self.k = rough_k
self.k = antecedent_limit
def forward(
self, # type: ignore # pylint: disable=arguments-differ #35566 in pytorch
@ -317,7 +316,7 @@ class RoughScorer(torch.nn.Module):
class DistancePairwiseEncoder(torch.nn.Module):
def __init__(self, embedding_size, dropout_rate):
def __init__(self, distance_embedding_size, dropout):
"""
Takes the top_indices indicating, which is a ranked
list for each word and its most likely corresponding
@ -325,15 +324,15 @@ class DistancePairwiseEncoder(torch.nn.Module):
up a distance embedding from a table, where the distance
corresponds to the log-distance.
embedding_size: int,
distance_embedding_size: int,
Dimensionality of the distance-embeddings table.
dropout_rate: float,
dropout: float,
Dropout probability.
"""
super().__init__()
emb_size = embedding_size
emb_size = distance_embedding_size
self.distance_emb = torch.nn.Embedding(9, emb_size)
self.dropout = torch.nn.Dropout(dropout_rate)
self.dropout = torch.nn.Dropout(dropout)
self.shape = emb_size
def forward(

View File

@ -31,13 +31,12 @@ from ..coref_scorer import Evaluator, get_cluster_info, lea
default_config = """
[model]
@architectures = "spacy.Coref.v1"
embedding_size = 20
distance_embedding_size = 20
hidden_size = 1024
n_hidden_layers = 1
depth = 1
dropout = 0.3
rough_k = 50
a_scoring_batch_size = 512
sp_embedding_size = 64
antecedent_limit = 50
antecedent_batch_size = 512
[model.tok2vec]
@architectures = "spacy.Tok2Vec.v2"

View File

@ -939,12 +939,12 @@ performance if working with only token-level clusters is acceptable.
>
> [model]
> @architectures = "spacy.Coref.v1"
> embedding_size = 20
> distance_embedding_size = 20
> dropout = 0.3
> hidden_size = 1024
> n_hidden_layers = 2
> rough_k = 50
> a_scoring_batch_size = 512
> depth = 2
> antecedent_limit = 50
> antecedent_batch_size = 512
>
> [model.tok2vec]
> @architectures = "spacy-transformers.TransformerListener.v1"
@ -956,14 +956,14 @@ performance if working with only token-level clusters is acceptable.
The `Coref` model architecture is a Thinc `Model`.
| Name | Description |
| ---------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
| `embedding_size` | ~~int~~ |
| `distance_embedding_size` | A representation of the distance between candidates. ~~int~~ |
| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ |
| `hidden_size` | Size of the main internal layers. ~~int~~ |
| `n_hidden_layers` | Depth of the internal network. ~~int~~ |
| `rough_k` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ |
| `a_scoring_batch_size` | Internal batch size. ~~int~~ |
| `depth` | Depth of the internal network. ~~int~~ |
| `antecedent_limit` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ |
| `antecedent_batch_size` | Internal batch size. ~~int~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
### spacy.SpanPredictor.v1 {#SpanPredictor}
@ -985,3 +985,14 @@ The `Coref` model architecture is a Thinc `Model`.
> ```
The `SpanPredictor` model architecture is a Thinc `Model`.
| Name | Description |
| ------------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
| `distance_embedding_size` | A representation of the distance between two candidates. ~~int~~ |
| `dropout` | The dropout to use internally. Unlike some Thinc models, this has separate dropout for the internal PyTorch layers. ~~float~~ |
| `hidden_size` | Size of the main internal layers. ~~int~~ |
| `depth` | Depth of the internal network. ~~int~~ |
| `antecedent_limit` | How many candidate antecedents to keep after rough scoring. This has a significant effect on memory usage. Typical values would be 50 to 200, or higher for very long documents. ~~int~~ |
| `antecedent_batch_size` | Internal batch size. ~~int~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], TupleFloats2d]~~ |