diff --git a/spacy/cli/templates/quickstart_training.jinja b/spacy/cli/templates/quickstart_training.jinja index 0db4c8a59..2c7ce024b 100644 --- a/spacy/cli/templates/quickstart_training.jinja +++ b/spacy/cli/templates/quickstart_training.jinja @@ -93,6 +93,29 @@ grad_factor = 1.0 @layers = "reduce_mean.v1" {% endif -%} +{% if "textcat" in components %} +[components.textcat] +factory = "textcat" + +{% if optimize == "accuracy" %} +[components.textcat.model] +@architectures = "spacy.TextCatEnsemble.v1" +exclusive_classes = false +width = 64 +conv_depth = 2 +embed_size = 2000 +window_size = 1 +ngram_size = 1 +nO = null + +{% else -%} +[components.textcat.model] +@architectures = "spacy.TextCatBOW.v1" +exclusive_classes = false +ngram_size = 1 +{%- endif %} +{%- endif %} + {# NON-TRANSFORMER PIPELINE #} {% else -%} @@ -167,10 +190,33 @@ nO = null @architectures = "spacy.Tok2VecListener.v1" width = ${components.tok2vec.model.encode.width} {% endif %} + +{% if "textcat" in components %} +[components.textcat] +factory = "textcat" + +{% if optimize == "accuracy" %} +[components.textcat.model] +@architectures = "spacy.TextCatEnsemble.v1" +exclusive_classes = false +width = 64 +conv_depth = 2 +embed_size = 2000 +window_size = 1 +ngram_size = 1 +nO = null + +{% else -%} +[components.textcat.model] +@architectures = "spacy.TextCatBOW.v1" +exclusive_classes = false +ngram_size = 1 +{%- endif %} +{%- endif %} {% endif %} {% for pipe in components %} -{% if pipe not in ["tagger", "parser", "ner"] %} +{% if pipe not in ["tagger", "parser", "ner", "textcat"] %} {# Other components defined by the user: we just assume they're factories #} [components.{{ pipe }}] factory = "{{ pipe }}"