diff --git a/pyproject.toml b/pyproject.toml
index 7a0e34376..882b31162 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -5,7 +5,7 @@ requires = [
     "cymem>=2.0.2,<2.1.0",
     "preshed>=3.0.2,<3.1.0",
     "murmurhash>=0.28.0,<1.1.0",
-    "thinc>=8.0.3,<8.1.0",
+    "thinc>=8.0.4,<8.1.0",
     "blis>=0.4.0,<0.8.0",
     "pathy",
     "numpy>=1.15.0",
diff --git a/requirements.txt b/requirements.txt
index 46337389c..9837933ab 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,8 +1,8 @@
 # Our libraries
-spacy-legacy>=3.0.5,<3.1.0
+spacy-legacy>=3.0.6,<3.1.0
 cymem>=2.0.2,<2.1.0
 preshed>=3.0.2,<3.1.0
-thinc>=8.0.3,<8.1.0
+thinc>=8.0.4,<8.1.0
 blis>=0.4.0,<0.8.0
 ml_datasets>=0.2.0,<0.3.0
 murmurhash>=0.28.0,<1.1.0
diff --git a/setup.cfg b/setup.cfg
index 99bae6ac8..37c432205 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -37,14 +37,14 @@ setup_requires =
     cymem>=2.0.2,<2.1.0
     preshed>=3.0.2,<3.1.0
     murmurhash>=0.28.0,<1.1.0
-    thinc>=8.0.3,<8.1.0
+    thinc>=8.0.4,<8.1.0
 install_requires =
     # Our libraries
-    spacy-legacy>=3.0.5,<3.1.0
+    spacy-legacy>=3.0.6,<3.1.0
     murmurhash>=0.28.0,<1.1.0
     cymem>=2.0.2,<2.1.0
     preshed>=3.0.2,<3.1.0
-    thinc>=8.0.3,<8.1.0
+    thinc>=8.0.4,<8.1.0
     blis>=0.4.0,<0.8.0
     wasabi>=0.8.1,<1.1.0
     srsly>=2.4.1,<3.0.0
diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja
index 0d422318b..3f8b3e1cc 100644
--- a/spacy/cli/templates/quickstart_training.jinja
+++ b/spacy/cli/templates/quickstart_training.jinja
@@ -151,14 +151,14 @@ grad_factor = 1.0
 @layers = "reduce_mean.v1"
 
 [components.textcat.model.linear_model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = true
 ngram_size = 1
 no_output_layer = false
 
 {% else -%}
 [components.textcat.model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = true
 ngram_size = 1
 no_output_layer = false
@@ -182,14 +182,14 @@ grad_factor = 1.0
 @layers = "reduce_mean.v1"
 
 [components.textcat_multilabel.model.linear_model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = false
 ngram_size = 1
 no_output_layer = false
 
 {% else -%}
 [components.textcat_multilabel.model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = false
 ngram_size = 1
 no_output_layer = false
@@ -316,14 +316,14 @@ nO = null
 width = ${components.tok2vec.model.encode.width}
 
 [components.textcat.model.linear_model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = true
 ngram_size = 1
 no_output_layer = false
 
 {% else -%}
 [components.textcat.model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = true
 ngram_size = 1
 no_output_layer = false
@@ -344,14 +344,14 @@ nO = null
 width = ${components.tok2vec.model.encode.width}
 
 [components.textcat_multilabel.model.linear_model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = false
 ngram_size = 1
 no_output_layer = false
 
 {% else -%}
 [components.textcat_multilabel.model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = false
 ngram_size = 1
 no_output_layer = false
diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py
index a1855c5a0..e3f6e944a 100644
--- a/spacy/ml/models/textcat.py
+++ b/spacy/ml/models/textcat.py
@@ -1,11 +1,13 @@
+from functools import partial
 from typing import Optional, List
 
 from thinc.types import Floats2d
 from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic
 from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention
 from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
-from thinc.api import with_cpu, Relu, residual, LayerNorm
+from thinc.api import with_cpu, Relu, residual, LayerNorm, resizable
 from thinc.layers.chain import init as init_chain
+from thinc.layers.resizable import resize_model, resize_linear_weighted
 
 from ...attrs import ORTH
 from ...util import registry
@@ -15,7 +17,10 @@ from ...tokens import Doc
 from .tok2vec import get_tok2vec_width
 
 
-@registry.architectures("spacy.TextCatCNN.v1")
+NEG_VALUE = -5000
+
+
+@registry.architectures("spacy.TextCatCNN.v2")
 def build_simple_cnn_text_classifier(
     tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
 ) -> Model[List[Doc], Floats2d]:
@@ -25,38 +30,75 @@ def build_simple_cnn_text_classifier(
     outputs sum to 1. If exclusive_classes=False, a logistic non-linearity
     is applied instead, so that outputs are in the range [0, 1].
     """
+    fill_defaults = {"b": 0, "W": 0}
     with Model.define_operators({">>": chain}):
         cnn = tok2vec >> list2ragged() >> reduce_mean()
+        nI = tok2vec.maybe_get_dim("nO")
         if exclusive_classes:
-            output_layer = Softmax(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
-            model = cnn >> output_layer
-            model.set_ref("output_layer", output_layer)
+            output_layer = Softmax(nO=nO, nI=nI)
+            fill_defaults["b"] = NEG_VALUE
+            resizable_layer = resizable(
+                output_layer,
+                resize_layer=partial(
+                    resize_linear_weighted, fill_defaults=fill_defaults
+                ),
+            )
+            model = cnn >> resizable_layer
         else:
-            linear_layer = Linear(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
-            model = cnn >> linear_layer >> Logistic()
-            model.set_ref("output_layer", linear_layer)
+            output_layer = Linear(nO=nO, nI=nI)
+            resizable_layer = resizable(
+                output_layer,
+                resize_layer=partial(
+                    resize_linear_weighted, fill_defaults=fill_defaults
+                ),
+            )
+            model = cnn >> resizable_layer >> Logistic()
+        model.set_ref("output_layer", output_layer)
+        model.attrs["resize_output"] = partial(
+            resize_and_set_ref,
+            resizable_layer=resizable_layer,
+        )
     model.set_ref("tok2vec", tok2vec)
     model.set_dim("nO", nO)
     model.attrs["multi_label"] = not exclusive_classes
     return model
 
 
-@registry.architectures("spacy.TextCatBOW.v1")
+def resize_and_set_ref(model, new_nO, resizable_layer):
+    resizable_layer = resize_model(resizable_layer, new_nO)
+    model.set_ref("output_layer", resizable_layer.layers[0])
+    model.set_dim("nO", new_nO, force=True)
+    return model
+
+
+@registry.architectures("spacy.TextCatBOW.v2")
 def build_bow_text_classifier(
     exclusive_classes: bool,
     ngram_size: int,
     no_output_layer: bool,
     nO: Optional[int] = None,
 ) -> Model[List[Doc], Floats2d]:
+    fill_defaults = {"b": 0, "W": 0}
     with Model.define_operators({">>": chain}):
-        sparse_linear = SparseLinear(nO)
-        model = extract_ngrams(ngram_size, attr=ORTH) >> sparse_linear
-        model = with_cpu(model, model.ops)
+        sparse_linear = SparseLinear(nO=nO)
+        output_layer = None
         if not no_output_layer:
+            fill_defaults["b"] = NEG_VALUE
             output_layer = softmax_activation() if exclusive_classes else Logistic()
+        resizable_layer = resizable(
+            sparse_linear,
+            resize_layer=partial(resize_linear_weighted, fill_defaults=fill_defaults),
+        )
+        model = extract_ngrams(ngram_size, attr=ORTH) >> resizable_layer
+        model = with_cpu(model, model.ops)
+        if output_layer:
             model = model >> with_cpu(output_layer, output_layer.ops)
+    model.set_dim("nO", nO)
     model.set_ref("output_layer", sparse_linear)
     model.attrs["multi_label"] = not exclusive_classes
+    model.attrs["resize_output"] = partial(
+        resize_and_set_ref, resizable_layer=resizable_layer
+    )
     return model
 
 
@@ -69,9 +111,7 @@ def build_text_classifier_v2(
     exclusive_classes = not linear_model.attrs["multi_label"]
     with Model.define_operators({">>": chain, "|": concatenate}):
         width = tok2vec.maybe_get_dim("nO")
-        attention_layer = ParametricAttention(
-            width
-        )  # TODO: benchmark performance difference of this layer
+        attention_layer = ParametricAttention(width)
         maxout_layer = Maxout(nO=width, nI=width)
         norm_layer = LayerNorm(nI=width)
         cnn_model = (
diff --git a/spacy/ml/tb_framework.py b/spacy/ml/tb_framework.py
index 4ab5830cd..e7e5561af 100644
--- a/spacy/ml/tb_framework.py
+++ b/spacy/ml/tb_framework.py
@@ -15,7 +15,7 @@ def TransitionModel(
     return Model(
         name="parser_model",
         forward=forward,
-        dims={"nI": tok2vec.get_dim("nI") if tok2vec.has_dim("nI") else None},
+        dims={"nI": tok2vec.maybe_get_dim("nI")},
         layers=[tok2vec, lower, upper],
         refs={"tok2vec": tok2vec, "lower": lower, "upper": upper},
         init=init,
diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py
index 1d652a483..0d3bbdf35 100644
--- a/spacy/pipeline/textcat.py
+++ b/spacy/pipeline/textcat.py
@@ -35,7 +35,7 @@ maxout_pieces = 3
 depth = 2
 
 [model.linear_model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = true
 ngram_size = 1
 no_output_layer = false
@@ -44,7 +44,7 @@ DEFAULT_SINGLE_TEXTCAT_MODEL = Config().from_str(single_label_default_config)["m
 
 single_label_bow_config = """
 [model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = true
 ngram_size = 1
 no_output_layer = false
@@ -52,7 +52,7 @@ no_output_layer = false
 
 single_label_cnn_config = """
 [model]
-@architectures = "spacy.TextCatCNN.v1"
+@architectures = "spacy.TextCatCNN.v2"
 exclusive_classes = true
 
 [model.tok2vec]
@@ -298,6 +298,8 @@ class TextCategorizer(TrainablePipe):
             return 0
         self._allow_extra_label()
         self.cfg["labels"].append(label)
+        if self.model and "resize_output" in self.model.attrs:
+            self.model = self.model.attrs["resize_output"](self.model, len(self.cfg["labels"]))
         self.vocab.strings.add(label)
         return 1
 
diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py
index 7267735b4..ba36881af 100644
--- a/spacy/pipeline/textcat_multilabel.py
+++ b/spacy/pipeline/textcat_multilabel.py
@@ -35,7 +35,7 @@ maxout_pieces = 3
 depth = 2
 
 [model.linear_model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = false
 ngram_size = 1
 no_output_layer = false
@@ -44,7 +44,7 @@ DEFAULT_MULTI_TEXTCAT_MODEL = Config().from_str(multi_label_default_config)["mod
 
 multi_label_bow_config = """
 [model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = false
 ngram_size = 1
 no_output_layer = false
@@ -52,7 +52,7 @@ no_output_layer = false
 
 multi_label_cnn_config = """
 [model]
-@architectures = "spacy.TextCatCNN.v1"
+@architectures = "spacy.TextCatCNN.v2"
 exclusive_classes = false
 
 [model.tok2vec]
diff --git a/spacy/pipeline/trainable_pipe.pyx b/spacy/pipeline/trainable_pipe.pyx
index fe51f38e5..926e92e91 100644
--- a/spacy/pipeline/trainable_pipe.pyx
+++ b/spacy/pipeline/trainable_pipe.pyx
@@ -213,7 +213,12 @@ cdef class TrainablePipe(Pipe):
 
     def _allow_extra_label(self) -> None:
         """Raise an error if the component can not add any more labels."""
-        if self.model.has_dim("nO") and self.model.get_dim("nO") == len(self.labels):
+        nO = None
+        if self.model.has_dim("nO"):
+            nO = self.model.get_dim("nO")
+        elif self.model.has_ref("output_layer") and self.model.get_ref("output_layer").has_dim("nO"):
+            nO = self.model.get_ref("output_layer").get_dim("nO")
+        if nO is not None and nO == len(self.labels):
             if not self.is_resizable:
                 raise ValueError(Errors.E922.format(name=self.name, nO=self.model.get_dim("nO")))
 
diff --git a/spacy/tests/pipeline/test_pipe_factories.py b/spacy/tests/pipeline/test_pipe_factories.py
index b99e9a863..5a5ca140c 100644
--- a/spacy/tests/pipeline/test_pipe_factories.py
+++ b/spacy/tests/pipeline/test_pipe_factories.py
@@ -160,7 +160,7 @@ def test_pipe_class_component_model():
             "@architectures": "spacy.TextCatEnsemble.v2",
             "tok2vec": DEFAULT_TOK2VEC_MODEL,
             "linear_model": {
-                "@architectures": "spacy.TextCatBOW.v1",
+                "@architectures": "spacy.TextCatBOW.v2",
                 "exclusive_classes": False,
                 "ngram_size": 1,
                 "no_output_layer": False,
diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py
index 43dfff147..6f1d22eba 100644
--- a/spacy/tests/pipeline/test_textcat.py
+++ b/spacy/tests/pipeline/test_textcat.py
@@ -131,19 +131,129 @@ def test_implicit_label(name, get_examples):
     nlp.initialize(get_examples=get_examples(nlp))
 
 
-@pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"])
-def test_no_resize(name):
+#fmt: off
+@pytest.mark.parametrize(
+    "name,textcat_config",
+    [
+        # BOW
+        ("textcat", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
+        ("textcat", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
+        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
+        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
+        # ENSEMBLE
+        ("textcat", {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}}),
+        ("textcat", {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}}),
+        ("textcat_multilabel", {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}}),
+        ("textcat_multilabel", {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}}),
+        # CNN
+        ("textcat", {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
+        ("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
+    ],
+)
+#fmt: on
+def test_no_resize(name, textcat_config):
+    """The old textcat architectures weren't resizable"""
     nlp = Language()
-    textcat = nlp.add_pipe(name)
+    pipe_config = {"model": textcat_config}
+    textcat = nlp.add_pipe(name, config=pipe_config)
     textcat.add_label("POSITIVE")
     textcat.add_label("NEGATIVE")
     nlp.initialize()
-    assert textcat.model.get_dim("nO") >= 2
+    assert textcat.model.maybe_get_dim("nO") in [2, None]
     # this throws an error because the textcat can't be resized after initialization
     with pytest.raises(ValueError):
         textcat.add_label("NEUTRAL")
 
 
+#fmt: off
+@pytest.mark.parametrize(
+    "name,textcat_config",
+    [
+        # BOW
+        ("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
+        ("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
+        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
+        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
+        # CNN
+        ("textcat", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
+        ("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
+    ],
+)
+#fmt: on
+def test_resize(name, textcat_config):
+    """The new textcat architectures are resizable"""
+    nlp = Language()
+    pipe_config = {"model": textcat_config}
+    textcat = nlp.add_pipe(name, config=pipe_config)
+    textcat.add_label("POSITIVE")
+    textcat.add_label("NEGATIVE")
+    assert textcat.model.maybe_get_dim("nO") in [2, None]
+    nlp.initialize()
+    assert textcat.model.maybe_get_dim("nO") in [2, None]
+    textcat.add_label("NEUTRAL")
+    assert textcat.model.maybe_get_dim("nO") in [3, None]
+
+
+#fmt: off
+@pytest.mark.parametrize(
+    "name,textcat_config",
+    [
+        # BOW
+        ("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
+        ("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
+        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
+        ("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
+        # CNN
+        ("textcat", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
+        ("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
+    ],
+)
+#fmt: on
+def test_resize_same_results(name, textcat_config):
+    # Ensure that the resized textcat classifiers still produce the same results for old labels
+    fix_random_seed(0)
+    nlp = English()
+    pipe_config = {"model": textcat_config}
+    textcat = nlp.add_pipe(name, config=pipe_config)
+
+    train_examples = []
+    for text, annotations in TRAIN_DATA_SINGLE_LABEL:
+        train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
+    optimizer = nlp.initialize(get_examples=lambda: train_examples)
+    assert textcat.model.maybe_get_dim("nO") in [2, None]
+
+    for i in range(5):
+        losses = {}
+        nlp.update(train_examples, sgd=optimizer, losses=losses)
+
+    # test the trained model before resizing
+    test_text = "I am happy."
+    doc = nlp(test_text)
+    assert len(doc.cats) == 2
+    pos_pred = doc.cats["POSITIVE"]
+    neg_pred = doc.cats["NEGATIVE"]
+
+    # test the trained model again after resizing
+    textcat.add_label("NEUTRAL")
+    doc = nlp(test_text)
+    assert len(doc.cats) == 3
+    assert doc.cats["POSITIVE"] == pos_pred
+    assert doc.cats["NEGATIVE"] == neg_pred
+    assert doc.cats["NEUTRAL"] <= 1
+
+    for i in range(5):
+        losses = {}
+        nlp.update(train_examples, sgd=optimizer, losses=losses)
+
+    # test the trained model again after training further with new label
+    doc = nlp(test_text)
+    assert len(doc.cats) == 3
+    assert doc.cats["POSITIVE"] != pos_pred
+    assert doc.cats["NEGATIVE"] != neg_pred
+    for cat in doc.cats:
+        assert doc.cats[cat] <= 1
+
+
 def test_error_with_multi_labels():
     nlp = Language()
     nlp.add_pipe("textcat")
@@ -286,14 +396,14 @@ def test_overfitting_IO_multi():
 @pytest.mark.parametrize(
     "name,train_data,textcat_config",
     [
-        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}),
-        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False}),
-        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True}),
-        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True}),
-        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}}),
-        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}}),
-        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
-        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
+        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}),
+        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False}),
+        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True}),
+        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True}),
+        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}}),
+        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}}),
+        ("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
+        ("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
     ],
 )
 # fmt: on
diff --git a/spacy/tests/test_misc.py b/spacy/tests/test_misc.py
index b38a50f71..b851641d9 100644
--- a/spacy/tests/test_misc.py
+++ b/spacy/tests/test_misc.py
@@ -297,7 +297,7 @@ def test_util_dot_section():
     factory = "textcat"
 
     [components.textcat.model]
-    @architectures = "spacy.TextCatBOW.v1"
+    @architectures = "spacy.TextCatBOW.v2"
     exclusive_classes = true
     ngram_size = 1
     no_output_layer = false
diff --git a/website/docs/api/architectures.md b/website/docs/api/architectures.md
index e09352ec9..4923ce18f 100644
--- a/website/docs/api/architectures.md
+++ b/website/docs/api/architectures.md
@@ -611,7 +611,7 @@ single-label use-cases where `exclusive_classes = true`, while the
 > nO = null
 >
 > [model.linear_model]
-> @architectures = "spacy.TextCatBOW.v1"
+> @architectures = "spacy.TextCatBOW.v2"
 > exclusive_classes = true
 > ngram_size = 1
 > no_output_layer = false
@@ -666,13 +666,13 @@ taking it as argument:
 
 
 
-### spacy.TextCatCNN.v1 {#TextCatCNN}
+### spacy.TextCatCNN.v2 {#TextCatCNN}
 
 > #### Example Config
 >
 > ```ini
 > [model]
-> @architectures = "spacy.TextCatCNN.v1"
+> @architectures = "spacy.TextCatCNN.v2"
 > exclusive_classes = false
 > nO = null
 >
@@ -698,13 +698,20 @@ architecture is usually less accurate than the ensemble, but runs faster.
 | `nO`                | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
 | **CREATES**         | The model using the architecture. ~~Model[List[Doc], Floats2d]~~                                                                                                                               |
 
-### spacy.TextCatBOW.v1 {#TextCatBOW}
+
+
+[TextCatCNN.v1](/api/legacy#TextCatCNN_v1) had the exact same signature, but was not yet resizable. 
+Since v2, new labels can be added to this component, even after training.
+
+
+
+### spacy.TextCatBOW.v2 {#TextCatBOW}
 
 > #### Example Config
 >
 > ```ini
 > [model]
-> @architectures = "spacy.TextCatBOW.v1"
+> @architectures = "spacy.TextCatBOW.v2"
 > exclusive_classes = false
 > ngram_size = 1
 > no_output_layer = false
@@ -722,6 +729,13 @@ the others, but may not be as accurate, especially if texts are short.
 | `nO`                | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
 | **CREATES**         | The model using the architecture. ~~Model[List[Doc], Floats2d]~~                                                                                                                               |
 
+
+
+[TextCatBOW.v1](/api/legacy#TextCatBOW_v1) had the exact same signature, but was not yet resizable. 
+Since v2, new labels can be added to this component, even after training.
+
+
+
 ## Entity linking architectures {#entitylinker source="spacy/ml/models/entity_linker.py"}
 
 An [`EntityLinker`](/api/entitylinker) component disambiguates textual mentions
diff --git a/website/docs/api/data-formats.md b/website/docs/api/data-formats.md
index 2b1c3480c..4ca5fb24d 100644
--- a/website/docs/api/data-formats.md
+++ b/website/docs/api/data-formats.md
@@ -93,7 +93,7 @@ Defines the `nlp` object, its tokenizer and
 > labels = ["POSITIVE", "NEGATIVE"]
 >
 > [components.textcat.model]
-> @architectures = "spacy.TextCatBOW.v1"
+> @architectures = "spacy.TextCatBOW.v2"
 > exclusive_classes = true
 > ngram_size = 1
 > no_output_layer = false
diff --git a/website/docs/api/legacy.md b/website/docs/api/legacy.md
index 96bc199bf..563d5aea8 100644
--- a/website/docs/api/legacy.md
+++ b/website/docs/api/legacy.md
@@ -176,6 +176,68 @@ added to an existing vectors table. See more details in
 
 
 
+### spacy.TextCatCNN.v1 {#TextCatCNN_v1}
+
+Since `spacy.TextCatCNN.v2`, this architecture has become resizable, which means that you can add 
+labels to a previously trained textcat. `TextCatCNN` v1 did not yet support that.
+
+> #### Example Config
+>
+> ```ini
+> [model]
+> @architectures = "spacy.TextCatCNN.v1"
+> exclusive_classes = false
+> nO = null
+>
+> [model.tok2vec]
+> @architectures = "spacy.HashEmbedCNN.v1"
+> pretrained_vectors = null
+> width = 96
+> depth = 4
+> embed_size = 2000
+> window_size = 1
+> maxout_pieces = 3
+> subword_features = true
+> ```
+
+A neural network model where token vectors are calculated using a CNN. The
+vectors are mean pooled and used as features in a feed-forward network. This
+architecture is usually less accurate than the ensemble, but runs faster.
+
+| Name                | Description                                                                                                                                                                                    |
+| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~                                                                                                                                     |
+| `tok2vec`           | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~                                                                                                                                        |
+| `nO`                | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
+| **CREATES**         | The model using the architecture. ~~Model[List[Doc], Floats2d]~~                                                                                                                               |
+
+### spacy.TextCatBOW.v1 {#TextCatBOW_v1}
+
+Since `spacy.TextCatBOW.v2`, this architecture has become resizable, which means that you can add 
+labels to a previously trained textcat. `TextCatBOW` v1 did not yet support that.
+
+> #### Example Config
+>
+> ```ini
+> [model]
+> @architectures = "spacy.TextCatBOW.v1"
+> exclusive_classes = false
+> ngram_size = 1
+> no_output_layer = false
+> nO = null
+> ```
+
+An n-gram "bag-of-words" model. This architecture should run much faster than
+the others, but may not be as accurate, especially if texts are short.
+
+| Name                | Description                                                                                                                                                                                    |
+| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~                                                                                                                                     |
+| `ngram_size`        | Determines the maximum length of the n-grams in the BOW model. For instance, `ngram_size=3` would give unigram, trigram and bigram features. ~~int~~                                           |
+| `no_output_layer`   | Whether or not to add an output layer to the model (`Softmax` activation if `exclusive_classes` is `True`, else `Logistic`). ~~bool~~                                                          |
+| `nO`                | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
+| **CREATES**         | The model using the architecture. ~~Model[List[Doc], Floats2d]~~                                                                                                                               |
+
 ## Loggers {#loggers}
 
 These functions are available from `@spacy.registry.loggers`.
diff --git a/website/docs/usage/layers-architectures.md b/website/docs/usage/layers-architectures.md
index 8fe2cf489..17043d599 100644
--- a/website/docs/usage/layers-architectures.md
+++ b/website/docs/usage/layers-architectures.md
@@ -151,7 +151,7 @@ maxout_pieces = 3
 depth = 2
 
 [components.textcat.model.linear_model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = true
 ngram_size = 1
 no_output_layer = false
@@ -169,7 +169,7 @@ factory = "textcat"
 labels = []
 
 [components.textcat.model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = true
 ngram_size = 1
 no_output_layer = false
diff --git a/website/docs/usage/processing-pipelines.md b/website/docs/usage/processing-pipelines.md
index bde3ab84f..87feee54a 100644
--- a/website/docs/usage/processing-pipelines.md
+++ b/website/docs/usage/processing-pipelines.md
@@ -1324,7 +1324,7 @@ labels = []
 # This function is created and then passed to the "textcat" component as
 # the argument "model"
 [components.textcat.model]
-@architectures = "spacy.TextCatBOW.v1"
+@architectures = "spacy.TextCatBOW.v2"
 exclusive_classes = true
 ngram_size = 1
 no_output_layer = false