diff --git a/website/docs/api/architectures.mdx b/website/docs/api/architectures.mdx
index f2cc61fb0..2839a8576 100644
--- a/website/docs/api/architectures.mdx
+++ b/website/docs/api/architectures.mdx
@@ -495,16 +495,24 @@ pre-trained model. The
[`init fill-curated-transformer`](/api/cli#init-fill-curated-transformer) CLI
command can be used to automatically fill in these values.
-### spacy-curated-transformers.AlbertTransformer.v1
+### spacy-curated-transformers.AlbertTransformer.v2
Construct an ALBERT transformer model.
+
+
+`v2` of this model added the `dtype` argument to support other PyTorch data types besides `float32`.
+
+
+
+
| Name | Description |
| ------------------------------ | ---------------------------------------------------------------------------------------- |
| `vocab_size` | Vocabulary size. ~~int~~ |
| `with_spans` | Callback that constructs a span generator model. ~~Callable~~ |
| `piece_encoder` | The piece encoder to segment input tokens. ~~Model~~ |
| `attention_probs_dropout_prob` | Dropout probability of the self-attention layers. ~~float~~ |
+| `dtype` | Torch data type (e.g. `"float32"`). ~~str~~ |
| `embedding_width` | Width of the embedding representations. ~~int~~ |
| `hidden_act` | Activation used by the point-wise feed-forward layers. ~~str~~ |
| `hidden_dropout_prob` | Dropout probability of the point-wise feed-forward and embedding layers. ~~float~~ |
@@ -522,16 +530,23 @@ Construct an ALBERT transformer model.
| `grad_scaler_config` | Configuration passed to the PyTorch gradient scaler. ~~dict~~ |
| **CREATES** | The model using the architecture ~~Model~~ |
-### spacy-curated-transformers.BertTransformer.v1
+### spacy-curated-transformers.BertTransformer.v2
Construct a BERT transformer model.
+
+
+`v2` of this model added the `dtype` argument to support other PyTorch data types besides `float32`.
+
+
+
| Name | Description |
| ------------------------------ | ---------------------------------------------------------------------------------------- |
| `vocab_size` | Vocabulary size. ~~int~~ |
| `with_spans` | Callback that constructs a span generator model. ~~Callable~~ |
| `piece_encoder` | The piece encoder to segment input tokens. ~~Model~~ |
| `attention_probs_dropout_prob` | Dropout probability of the self-attention layers. ~~float~~ |
+| `dtype` | Torch data type (e.g. `"float32"`). ~~str~~ |
| `hidden_act` | Activation used by the point-wise feed-forward layers. ~~str~~ |
| `hidden_dropout_prob` | Dropout probability of the point-wise feed-forward and embedding layers. ~~float~~ |
| `hidden_width` | Width of the final representations. ~~int~~ |
@@ -547,16 +562,23 @@ Construct a BERT transformer model.
| `grad_scaler_config` | Configuration passed to the PyTorch gradient scaler. ~~dict~~ |
| **CREATES** | The model using the architecture ~~Model~~ |
-### spacy-curated-transformers.CamembertTransformer.v1
+### spacy-curated-transformers.CamembertTransformer.v2
Construct a CamemBERT transformer model.
+
+
+`v2` of this model added the `dtype` argument to support other PyTorch data types besides `float32`.
+
+
+
| Name | Description |
| ------------------------------ | ---------------------------------------------------------------------------------------- |
| `vocab_size` | Vocabulary size. ~~int~~ |
| `with_spans` | Callback that constructs a span generator model. ~~Callable~~ |
| `piece_encoder` | The piece encoder to segment input tokens. ~~Model~~ |
| `attention_probs_dropout_prob` | Dropout probability of the self-attention layers. ~~float~~ |
+| `dtype` | Torch data type (e.g. `"float32"`). ~~str~~ |
| `hidden_act` | Activation used by the point-wise feed-forward layers. ~~str~~ |
| `hidden_dropout_prob` | Dropout probability of the point-wise feed-forward and embedding layers. ~~float~~ |
| `hidden_width` | Width of the final representations. ~~int~~ |
@@ -572,16 +594,23 @@ Construct a CamemBERT transformer model.
| `grad_scaler_config` | Configuration passed to the PyTorch gradient scaler. ~~dict~~ |
| **CREATES** | The model using the architecture ~~Model~~ |
-### spacy-curated-transformers.RobertaTransformer.v1
+### spacy-curated-transformers.RobertaTransformer.v2
Construct a RoBERTa transformer model.
+
+
+`v2` of this model added the `dtype` argument to support other PyTorch data types besides `float32`.
+
+
+
| Name | Description |
| ------------------------------ | ---------------------------------------------------------------------------------------- |
| `vocab_size` | Vocabulary size. ~~int~~ |
| `with_spans` | Callback that constructs a span generator model. ~~Callable~~ |
| `piece_encoder` | The piece encoder to segment input tokens. ~~Model~~ |
| `attention_probs_dropout_prob` | Dropout probability of the self-attention layers. ~~float~~ |
+| `dtype` | Torch data type (e.g. `"float32"`). ~~str~~ |
| `hidden_act` | Activation used by the point-wise feed-forward layers. ~~str~~ |
| `hidden_dropout_prob` | Dropout probability of the point-wise feed-forward and embedding layers. ~~float~~ |
| `hidden_width` | Width of the final representations. ~~int~~ |
@@ -597,16 +626,23 @@ Construct a RoBERTa transformer model.
| `grad_scaler_config` | Configuration passed to the PyTorch gradient scaler. ~~dict~~ |
| **CREATES** | The model using the architecture ~~Model~~ |
-### spacy-curated-transformers.XlmrTransformer.v1
+### spacy-curated-transformers.XlmrTransformer.v2
Construct a XLM-RoBERTa transformer model.
+
+
+`v2` of this model added the `dtype` argument to support other PyTorch data types besides `float32`.
+
+
+
| Name | Description |
| ------------------------------ | ---------------------------------------------------------------------------------------- |
| `vocab_size` | Vocabulary size. ~~int~~ |
| `with_spans` | Callback that constructs a span generator model. ~~Callable~~ |
| `piece_encoder` | The piece encoder to segment input tokens. ~~Model~~ |
| `attention_probs_dropout_prob` | Dropout probability of the self-attention layers. ~~float~~ |
+| `dtype` | Torch data type (e.g. `"float32"`). ~~str~~ |
| `hidden_act` | Activation used by the point-wise feed-forward layers. ~~str~~ |
| `hidden_dropout_prob` | Dropout probability of the point-wise feed-forward and embedding layers. ~~float~~ |
| `hidden_width` | Width of the final representations. ~~int~~ |
diff --git a/website/docs/api/curatedtransformer.mdx b/website/docs/api/curatedtransformer.mdx
index 3e63ef7c2..43ad4e5c8 100644
--- a/website/docs/api/curatedtransformer.mdx
+++ b/website/docs/api/curatedtransformer.mdx
@@ -91,12 +91,12 @@ https://github.com/explosion/spacy-curated-transformers/blob/main/spacy_curated_
> # Construction via add_pipe with custom config
> config = {
> "model": {
-> "@architectures": "spacy-curated-transformers.XlmrTransformer.v1",
+> "@architectures": "spacy-curated-transformers.XlmrTransformer.v2",
> "vocab_size": 250002,
> "num_hidden_layers": 12,
> "hidden_width": 768,
> "piece_encoder": {
-> "@architectures": "spacy-curated-transformers.XlmrSentencepieceEncoder.v1"
+> "@architectures": "spacy-curated-transformers.XlmrSentencepieceEncoder.v2"
> }
> }
> }
@@ -503,14 +503,24 @@ from a corresponding HuggingFace model.
| `name` | Name of the HuggingFace model. ~~str~~ |
| `revision` | Name of the model revision/branch. ~~str~~ |
-### PyTorchCheckpointLoader.v1 {id="pytorch_checkpoint_loader",tag="registered_function"}
+### PyTorchCheckpointLoader.v2 {id="pytorch_checkpoint_loader",tag="registered_function"}
Construct a callback that initializes a supported transformer model with weights
-from a PyTorch checkpoint.
+from a PyTorch checkpoint. The given directory must contain PyTorch and/or
+Safetensors checkpoints. Sharded checkpoints are also supported.
-| Name | Description |
-| ------ | ---------------------------------------- |
-| `path` | Path to the PyTorch checkpoint. ~~Path~~ |
+
+
+`PyTorchCheckpointLoader.v1` required specifying the path to the checkpoint
+itself rather than the directory holding the checkpoint.
+`PyTorchCheckpointLoader.v1` is deprecated, but still provided for compatibility
+with older configurations.
+
+
+
+| Name | Description |
+| ------ | -------------------------------------------------- |
+| `path` | Path to the PyTorch checkpoint directory. ~~Path~~ |
## Tokenizer Loaders
@@ -578,3 +588,35 @@ catastrophic forgetting during fine-tuning.
| Name | Description |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `target_pipes` | A dictionary whose keys and values correspond to the names of Transformer components and the training step at which they should be unfrozen respectively. ~~Dict[str, int]~~ |
+
+## Learning rate schedules
+
+### transformer_discriminative.v1 {id="transformer_discriminative",tag="registered_function",version="4"}
+
+> #### Example config
+>
+> ```ini
+> [training.optimizer.learn_rate]
+> @schedules = "spacy-curated-transformers.transformer_discriminative.v1"
+>
+> [training.optimizer.learn_rate.default_schedule]
+> @schedules = "warmup_linear.v1"
+> warmup_steps = 250
+> total_steps = 20000
+> initial_rate = 1e-3
+>
+> [training.optimizer.learn_rate.transformer_schedule]
+> @schedules = "warmup_linear.v1"
+> warmup_steps = 1000
+> total_steps = 20000
+> initial_rate = 5e-5
+> ```
+
+Construct a discriminative learning rate schedule for transformers. This is a
+compound schedule that allows you to use different schedules for transformer
+parameters (`transformer_schedule`) and other parameters (`default_schedule`).
+
+| Name | Description |
+| ---------------------- | -------------------------------------------------------------------------- |
+| `default_schedule` | Learning rate schedule to use for non-transformer parameters. ~~Schedule~~ |
+| `transformer_schedule` | Learning rate schedule to use for transformer parameters. ~~Schedule~~ |