mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 09:26:27 +03:00
Resizable textcat (#7862)
* implement textcat resizing for TextCatCNN * resizing textcat in-place * simplify code * ensure predictions for old textcat labels remain the same after resizing (WIP) * fix for softmax * store softmax as attr * fix ensemble weight copy and cleanup * restructure slightly * adjust documentation, update tests and quickstart templates to use latest versions * extend unit test slightly * revert unnecessary edits * fix typo * ensemble architecture won't be resizable for now * use resizable layer (WIP) * revert using resizable layer * resizable container while avoid shape inference trouble * cleanup * ensure model continues training after resizing * use fill_b parameter * use fill_defaults * resize_layer callback * format * bump thinc to 8.0.4 * bump spacy-legacy to 3.0.6
This commit is contained in:
parent
19521d525b
commit
e796aab4b3
|
@ -5,7 +5,7 @@ requires = [
|
||||||
"cymem>=2.0.2,<2.1.0",
|
"cymem>=2.0.2,<2.1.0",
|
||||||
"preshed>=3.0.2,<3.1.0",
|
"preshed>=3.0.2,<3.1.0",
|
||||||
"murmurhash>=0.28.0,<1.1.0",
|
"murmurhash>=0.28.0,<1.1.0",
|
||||||
"thinc>=8.0.3,<8.1.0",
|
"thinc>=8.0.4,<8.1.0",
|
||||||
"blis>=0.4.0,<0.8.0",
|
"blis>=0.4.0,<0.8.0",
|
||||||
"pathy",
|
"pathy",
|
||||||
"numpy>=1.15.0",
|
"numpy>=1.15.0",
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
# Our libraries
|
# Our libraries
|
||||||
spacy-legacy>=3.0.5,<3.1.0
|
spacy-legacy>=3.0.6,<3.1.0
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
thinc>=8.0.3,<8.1.0
|
thinc>=8.0.4,<8.1.0
|
||||||
blis>=0.4.0,<0.8.0
|
blis>=0.4.0,<0.8.0
|
||||||
ml_datasets>=0.2.0,<0.3.0
|
ml_datasets>=0.2.0,<0.3.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
|
|
|
@ -37,14 +37,14 @@ setup_requires =
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
thinc>=8.0.3,<8.1.0
|
thinc>=8.0.4,<8.1.0
|
||||||
install_requires =
|
install_requires =
|
||||||
# Our libraries
|
# Our libraries
|
||||||
spacy-legacy>=3.0.5,<3.1.0
|
spacy-legacy>=3.0.6,<3.1.0
|
||||||
murmurhash>=0.28.0,<1.1.0
|
murmurhash>=0.28.0,<1.1.0
|
||||||
cymem>=2.0.2,<2.1.0
|
cymem>=2.0.2,<2.1.0
|
||||||
preshed>=3.0.2,<3.1.0
|
preshed>=3.0.2,<3.1.0
|
||||||
thinc>=8.0.3,<8.1.0
|
thinc>=8.0.4,<8.1.0
|
||||||
blis>=0.4.0,<0.8.0
|
blis>=0.4.0,<0.8.0
|
||||||
wasabi>=0.8.1,<1.1.0
|
wasabi>=0.8.1,<1.1.0
|
||||||
srsly>=2.4.1,<3.0.0
|
srsly>=2.4.1,<3.0.0
|
||||||
|
|
|
@ -151,14 +151,14 @@ grad_factor = 1.0
|
||||||
@layers = "reduce_mean.v1"
|
@layers = "reduce_mean.v1"
|
||||||
|
|
||||||
[components.textcat.model.linear_model]
|
[components.textcat.model.linear_model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = true
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
|
||||||
{% else -%}
|
{% else -%}
|
||||||
[components.textcat.model]
|
[components.textcat.model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = true
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
@ -182,14 +182,14 @@ grad_factor = 1.0
|
||||||
@layers = "reduce_mean.v1"
|
@layers = "reduce_mean.v1"
|
||||||
|
|
||||||
[components.textcat_multilabel.model.linear_model]
|
[components.textcat_multilabel.model.linear_model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = false
|
exclusive_classes = false
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
|
||||||
{% else -%}
|
{% else -%}
|
||||||
[components.textcat_multilabel.model]
|
[components.textcat_multilabel.model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = false
|
exclusive_classes = false
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
@ -316,14 +316,14 @@ nO = null
|
||||||
width = ${components.tok2vec.model.encode.width}
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
|
||||||
[components.textcat.model.linear_model]
|
[components.textcat.model.linear_model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = true
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
|
||||||
{% else -%}
|
{% else -%}
|
||||||
[components.textcat.model]
|
[components.textcat.model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = true
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
@ -344,14 +344,14 @@ nO = null
|
||||||
width = ${components.tok2vec.model.encode.width}
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
|
||||||
[components.textcat_multilabel.model.linear_model]
|
[components.textcat_multilabel.model.linear_model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = false
|
exclusive_classes = false
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
|
||||||
{% else -%}
|
{% else -%}
|
||||||
[components.textcat_multilabel.model]
|
[components.textcat_multilabel.model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = false
|
exclusive_classes = false
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
|
|
@ -1,11 +1,13 @@
|
||||||
|
from functools import partial
|
||||||
from typing import Optional, List
|
from typing import Optional, List
|
||||||
|
|
||||||
from thinc.types import Floats2d
|
from thinc.types import Floats2d
|
||||||
from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic
|
from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic
|
||||||
from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention
|
from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention
|
||||||
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
|
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
|
||||||
from thinc.api import with_cpu, Relu, residual, LayerNorm
|
from thinc.api import with_cpu, Relu, residual, LayerNorm, resizable
|
||||||
from thinc.layers.chain import init as init_chain
|
from thinc.layers.chain import init as init_chain
|
||||||
|
from thinc.layers.resizable import resize_model, resize_linear_weighted
|
||||||
|
|
||||||
from ...attrs import ORTH
|
from ...attrs import ORTH
|
||||||
from ...util import registry
|
from ...util import registry
|
||||||
|
@ -15,7 +17,10 @@ from ...tokens import Doc
|
||||||
from .tok2vec import get_tok2vec_width
|
from .tok2vec import get_tok2vec_width
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.TextCatCNN.v1")
|
NEG_VALUE = -5000
|
||||||
|
|
||||||
|
|
||||||
|
@registry.architectures("spacy.TextCatCNN.v2")
|
||||||
def build_simple_cnn_text_classifier(
|
def build_simple_cnn_text_classifier(
|
||||||
tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
|
tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
|
||||||
) -> Model[List[Doc], Floats2d]:
|
) -> Model[List[Doc], Floats2d]:
|
||||||
|
@ -25,38 +30,75 @@ def build_simple_cnn_text_classifier(
|
||||||
outputs sum to 1. If exclusive_classes=False, a logistic non-linearity
|
outputs sum to 1. If exclusive_classes=False, a logistic non-linearity
|
||||||
is applied instead, so that outputs are in the range [0, 1].
|
is applied instead, so that outputs are in the range [0, 1].
|
||||||
"""
|
"""
|
||||||
|
fill_defaults = {"b": 0, "W": 0}
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
cnn = tok2vec >> list2ragged() >> reduce_mean()
|
cnn = tok2vec >> list2ragged() >> reduce_mean()
|
||||||
|
nI = tok2vec.maybe_get_dim("nO")
|
||||||
if exclusive_classes:
|
if exclusive_classes:
|
||||||
output_layer = Softmax(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
|
output_layer = Softmax(nO=nO, nI=nI)
|
||||||
model = cnn >> output_layer
|
fill_defaults["b"] = NEG_VALUE
|
||||||
model.set_ref("output_layer", output_layer)
|
resizable_layer = resizable(
|
||||||
|
output_layer,
|
||||||
|
resize_layer=partial(
|
||||||
|
resize_linear_weighted, fill_defaults=fill_defaults
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model = cnn >> resizable_layer
|
||||||
else:
|
else:
|
||||||
linear_layer = Linear(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
|
output_layer = Linear(nO=nO, nI=nI)
|
||||||
model = cnn >> linear_layer >> Logistic()
|
resizable_layer = resizable(
|
||||||
model.set_ref("output_layer", linear_layer)
|
output_layer,
|
||||||
|
resize_layer=partial(
|
||||||
|
resize_linear_weighted, fill_defaults=fill_defaults
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model = cnn >> resizable_layer >> Logistic()
|
||||||
|
model.set_ref("output_layer", output_layer)
|
||||||
|
model.attrs["resize_output"] = partial(
|
||||||
|
resize_and_set_ref,
|
||||||
|
resizable_layer=resizable_layer,
|
||||||
|
)
|
||||||
model.set_ref("tok2vec", tok2vec)
|
model.set_ref("tok2vec", tok2vec)
|
||||||
model.set_dim("nO", nO)
|
model.set_dim("nO", nO)
|
||||||
model.attrs["multi_label"] = not exclusive_classes
|
model.attrs["multi_label"] = not exclusive_classes
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@registry.architectures("spacy.TextCatBOW.v1")
|
def resize_and_set_ref(model, new_nO, resizable_layer):
|
||||||
|
resizable_layer = resize_model(resizable_layer, new_nO)
|
||||||
|
model.set_ref("output_layer", resizable_layer.layers[0])
|
||||||
|
model.set_dim("nO", new_nO, force=True)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@registry.architectures("spacy.TextCatBOW.v2")
|
||||||
def build_bow_text_classifier(
|
def build_bow_text_classifier(
|
||||||
exclusive_classes: bool,
|
exclusive_classes: bool,
|
||||||
ngram_size: int,
|
ngram_size: int,
|
||||||
no_output_layer: bool,
|
no_output_layer: bool,
|
||||||
nO: Optional[int] = None,
|
nO: Optional[int] = None,
|
||||||
) -> Model[List[Doc], Floats2d]:
|
) -> Model[List[Doc], Floats2d]:
|
||||||
|
fill_defaults = {"b": 0, "W": 0}
|
||||||
with Model.define_operators({">>": chain}):
|
with Model.define_operators({">>": chain}):
|
||||||
sparse_linear = SparseLinear(nO)
|
sparse_linear = SparseLinear(nO=nO)
|
||||||
model = extract_ngrams(ngram_size, attr=ORTH) >> sparse_linear
|
output_layer = None
|
||||||
model = with_cpu(model, model.ops)
|
|
||||||
if not no_output_layer:
|
if not no_output_layer:
|
||||||
|
fill_defaults["b"] = NEG_VALUE
|
||||||
output_layer = softmax_activation() if exclusive_classes else Logistic()
|
output_layer = softmax_activation() if exclusive_classes else Logistic()
|
||||||
|
resizable_layer = resizable(
|
||||||
|
sparse_linear,
|
||||||
|
resize_layer=partial(resize_linear_weighted, fill_defaults=fill_defaults),
|
||||||
|
)
|
||||||
|
model = extract_ngrams(ngram_size, attr=ORTH) >> resizable_layer
|
||||||
|
model = with_cpu(model, model.ops)
|
||||||
|
if output_layer:
|
||||||
model = model >> with_cpu(output_layer, output_layer.ops)
|
model = model >> with_cpu(output_layer, output_layer.ops)
|
||||||
|
model.set_dim("nO", nO)
|
||||||
model.set_ref("output_layer", sparse_linear)
|
model.set_ref("output_layer", sparse_linear)
|
||||||
model.attrs["multi_label"] = not exclusive_classes
|
model.attrs["multi_label"] = not exclusive_classes
|
||||||
|
model.attrs["resize_output"] = partial(
|
||||||
|
resize_and_set_ref, resizable_layer=resizable_layer
|
||||||
|
)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -69,9 +111,7 @@ def build_text_classifier_v2(
|
||||||
exclusive_classes = not linear_model.attrs["multi_label"]
|
exclusive_classes = not linear_model.attrs["multi_label"]
|
||||||
with Model.define_operators({">>": chain, "|": concatenate}):
|
with Model.define_operators({">>": chain, "|": concatenate}):
|
||||||
width = tok2vec.maybe_get_dim("nO")
|
width = tok2vec.maybe_get_dim("nO")
|
||||||
attention_layer = ParametricAttention(
|
attention_layer = ParametricAttention(width)
|
||||||
width
|
|
||||||
) # TODO: benchmark performance difference of this layer
|
|
||||||
maxout_layer = Maxout(nO=width, nI=width)
|
maxout_layer = Maxout(nO=width, nI=width)
|
||||||
norm_layer = LayerNorm(nI=width)
|
norm_layer = LayerNorm(nI=width)
|
||||||
cnn_model = (
|
cnn_model = (
|
||||||
|
|
|
@ -15,7 +15,7 @@ def TransitionModel(
|
||||||
return Model(
|
return Model(
|
||||||
name="parser_model",
|
name="parser_model",
|
||||||
forward=forward,
|
forward=forward,
|
||||||
dims={"nI": tok2vec.get_dim("nI") if tok2vec.has_dim("nI") else None},
|
dims={"nI": tok2vec.maybe_get_dim("nI")},
|
||||||
layers=[tok2vec, lower, upper],
|
layers=[tok2vec, lower, upper],
|
||||||
refs={"tok2vec": tok2vec, "lower": lower, "upper": upper},
|
refs={"tok2vec": tok2vec, "lower": lower, "upper": upper},
|
||||||
init=init,
|
init=init,
|
||||||
|
|
|
@ -35,7 +35,7 @@ maxout_pieces = 3
|
||||||
depth = 2
|
depth = 2
|
||||||
|
|
||||||
[model.linear_model]
|
[model.linear_model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = true
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
@ -44,7 +44,7 @@ DEFAULT_SINGLE_TEXTCAT_MODEL = Config().from_str(single_label_default_config)["m
|
||||||
|
|
||||||
single_label_bow_config = """
|
single_label_bow_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = true
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
@ -52,7 +52,7 @@ no_output_layer = false
|
||||||
|
|
||||||
single_label_cnn_config = """
|
single_label_cnn_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.TextCatCNN.v1"
|
@architectures = "spacy.TextCatCNN.v2"
|
||||||
exclusive_classes = true
|
exclusive_classes = true
|
||||||
|
|
||||||
[model.tok2vec]
|
[model.tok2vec]
|
||||||
|
@ -298,6 +298,8 @@ class TextCategorizer(TrainablePipe):
|
||||||
return 0
|
return 0
|
||||||
self._allow_extra_label()
|
self._allow_extra_label()
|
||||||
self.cfg["labels"].append(label)
|
self.cfg["labels"].append(label)
|
||||||
|
if self.model and "resize_output" in self.model.attrs:
|
||||||
|
self.model = self.model.attrs["resize_output"](self.model, len(self.cfg["labels"]))
|
||||||
self.vocab.strings.add(label)
|
self.vocab.strings.add(label)
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ maxout_pieces = 3
|
||||||
depth = 2
|
depth = 2
|
||||||
|
|
||||||
[model.linear_model]
|
[model.linear_model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = false
|
exclusive_classes = false
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
@ -44,7 +44,7 @@ DEFAULT_MULTI_TEXTCAT_MODEL = Config().from_str(multi_label_default_config)["mod
|
||||||
|
|
||||||
multi_label_bow_config = """
|
multi_label_bow_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = false
|
exclusive_classes = false
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
@ -52,7 +52,7 @@ no_output_layer = false
|
||||||
|
|
||||||
multi_label_cnn_config = """
|
multi_label_cnn_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.TextCatCNN.v1"
|
@architectures = "spacy.TextCatCNN.v2"
|
||||||
exclusive_classes = false
|
exclusive_classes = false
|
||||||
|
|
||||||
[model.tok2vec]
|
[model.tok2vec]
|
||||||
|
|
|
@ -213,7 +213,12 @@ cdef class TrainablePipe(Pipe):
|
||||||
|
|
||||||
def _allow_extra_label(self) -> None:
|
def _allow_extra_label(self) -> None:
|
||||||
"""Raise an error if the component can not add any more labels."""
|
"""Raise an error if the component can not add any more labels."""
|
||||||
if self.model.has_dim("nO") and self.model.get_dim("nO") == len(self.labels):
|
nO = None
|
||||||
|
if self.model.has_dim("nO"):
|
||||||
|
nO = self.model.get_dim("nO")
|
||||||
|
elif self.model.has_ref("output_layer") and self.model.get_ref("output_layer").has_dim("nO"):
|
||||||
|
nO = self.model.get_ref("output_layer").get_dim("nO")
|
||||||
|
if nO is not None and nO == len(self.labels):
|
||||||
if not self.is_resizable:
|
if not self.is_resizable:
|
||||||
raise ValueError(Errors.E922.format(name=self.name, nO=self.model.get_dim("nO")))
|
raise ValueError(Errors.E922.format(name=self.name, nO=self.model.get_dim("nO")))
|
||||||
|
|
||||||
|
|
|
@ -160,7 +160,7 @@ def test_pipe_class_component_model():
|
||||||
"@architectures": "spacy.TextCatEnsemble.v2",
|
"@architectures": "spacy.TextCatEnsemble.v2",
|
||||||
"tok2vec": DEFAULT_TOK2VEC_MODEL,
|
"tok2vec": DEFAULT_TOK2VEC_MODEL,
|
||||||
"linear_model": {
|
"linear_model": {
|
||||||
"@architectures": "spacy.TextCatBOW.v1",
|
"@architectures": "spacy.TextCatBOW.v2",
|
||||||
"exclusive_classes": False,
|
"exclusive_classes": False,
|
||||||
"ngram_size": 1,
|
"ngram_size": 1,
|
||||||
"no_output_layer": False,
|
"no_output_layer": False,
|
||||||
|
|
|
@ -131,19 +131,129 @@ def test_implicit_label(name, get_examples):
|
||||||
nlp.initialize(get_examples=get_examples(nlp))
|
nlp.initialize(get_examples=get_examples(nlp))
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"])
|
#fmt: off
|
||||||
def test_no_resize(name):
|
@pytest.mark.parametrize(
|
||||||
|
"name,textcat_config",
|
||||||
|
[
|
||||||
|
# BOW
|
||||||
|
("textcat", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
|
||||||
|
("textcat", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
|
||||||
|
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
|
||||||
|
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
|
||||||
|
# ENSEMBLE
|
||||||
|
("textcat", {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}}),
|
||||||
|
("textcat", {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}}),
|
||||||
|
("textcat_multilabel", {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}}),
|
||||||
|
("textcat_multilabel", {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}}),
|
||||||
|
# CNN
|
||||||
|
("textcat", {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
||||||
|
("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
#fmt: on
|
||||||
|
def test_no_resize(name, textcat_config):
|
||||||
|
"""The old textcat architectures weren't resizable"""
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
textcat = nlp.add_pipe(name)
|
pipe_config = {"model": textcat_config}
|
||||||
|
textcat = nlp.add_pipe(name, config=pipe_config)
|
||||||
textcat.add_label("POSITIVE")
|
textcat.add_label("POSITIVE")
|
||||||
textcat.add_label("NEGATIVE")
|
textcat.add_label("NEGATIVE")
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert textcat.model.get_dim("nO") >= 2
|
assert textcat.model.maybe_get_dim("nO") in [2, None]
|
||||||
# this throws an error because the textcat can't be resized after initialization
|
# this throws an error because the textcat can't be resized after initialization
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
textcat.add_label("NEUTRAL")
|
textcat.add_label("NEUTRAL")
|
||||||
|
|
||||||
|
|
||||||
|
#fmt: off
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"name,textcat_config",
|
||||||
|
[
|
||||||
|
# BOW
|
||||||
|
("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
|
||||||
|
("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
|
||||||
|
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
|
||||||
|
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
|
||||||
|
# CNN
|
||||||
|
("textcat", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
||||||
|
("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
#fmt: on
|
||||||
|
def test_resize(name, textcat_config):
|
||||||
|
"""The new textcat architectures are resizable"""
|
||||||
|
nlp = Language()
|
||||||
|
pipe_config = {"model": textcat_config}
|
||||||
|
textcat = nlp.add_pipe(name, config=pipe_config)
|
||||||
|
textcat.add_label("POSITIVE")
|
||||||
|
textcat.add_label("NEGATIVE")
|
||||||
|
assert textcat.model.maybe_get_dim("nO") in [2, None]
|
||||||
|
nlp.initialize()
|
||||||
|
assert textcat.model.maybe_get_dim("nO") in [2, None]
|
||||||
|
textcat.add_label("NEUTRAL")
|
||||||
|
assert textcat.model.maybe_get_dim("nO") in [3, None]
|
||||||
|
|
||||||
|
|
||||||
|
#fmt: off
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"name,textcat_config",
|
||||||
|
[
|
||||||
|
# BOW
|
||||||
|
("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": False, "ngram_size": 3}),
|
||||||
|
("textcat", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "no_output_layer": True, "ngram_size": 3}),
|
||||||
|
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": False, "ngram_size": 3}),
|
||||||
|
("textcat_multilabel", {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "no_output_layer": True, "ngram_size": 3}),
|
||||||
|
# CNN
|
||||||
|
("textcat", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
||||||
|
("textcat_multilabel", {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
#fmt: on
|
||||||
|
def test_resize_same_results(name, textcat_config):
|
||||||
|
# Ensure that the resized textcat classifiers still produce the same results for old labels
|
||||||
|
fix_random_seed(0)
|
||||||
|
nlp = English()
|
||||||
|
pipe_config = {"model": textcat_config}
|
||||||
|
textcat = nlp.add_pipe(name, config=pipe_config)
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA_SINGLE_LABEL:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
assert textcat.model.maybe_get_dim("nO") in [2, None]
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
|
||||||
|
# test the trained model before resizing
|
||||||
|
test_text = "I am happy."
|
||||||
|
doc = nlp(test_text)
|
||||||
|
assert len(doc.cats) == 2
|
||||||
|
pos_pred = doc.cats["POSITIVE"]
|
||||||
|
neg_pred = doc.cats["NEGATIVE"]
|
||||||
|
|
||||||
|
# test the trained model again after resizing
|
||||||
|
textcat.add_label("NEUTRAL")
|
||||||
|
doc = nlp(test_text)
|
||||||
|
assert len(doc.cats) == 3
|
||||||
|
assert doc.cats["POSITIVE"] == pos_pred
|
||||||
|
assert doc.cats["NEGATIVE"] == neg_pred
|
||||||
|
assert doc.cats["NEUTRAL"] <= 1
|
||||||
|
|
||||||
|
for i in range(5):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
|
||||||
|
# test the trained model again after training further with new label
|
||||||
|
doc = nlp(test_text)
|
||||||
|
assert len(doc.cats) == 3
|
||||||
|
assert doc.cats["POSITIVE"] != pos_pred
|
||||||
|
assert doc.cats["NEGATIVE"] != neg_pred
|
||||||
|
for cat in doc.cats:
|
||||||
|
assert doc.cats[cat] <= 1
|
||||||
|
|
||||||
|
|
||||||
def test_error_with_multi_labels():
|
def test_error_with_multi_labels():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
nlp.add_pipe("textcat")
|
nlp.add_pipe("textcat")
|
||||||
|
@ -286,14 +396,14 @@ def test_overfitting_IO_multi():
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"name,train_data,textcat_config",
|
"name,train_data,textcat_config",
|
||||||
[
|
[
|
||||||
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}),
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}),
|
||||||
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False}),
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False}),
|
||||||
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True}),
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True}),
|
||||||
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True}),
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True}),
|
||||||
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}}),
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}}),
|
||||||
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}}),
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v2", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}}),
|
||||||
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
||||||
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
|
|
|
@ -297,7 +297,7 @@ def test_util_dot_section():
|
||||||
factory = "textcat"
|
factory = "textcat"
|
||||||
|
|
||||||
[components.textcat.model]
|
[components.textcat.model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = true
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
|
|
@ -611,7 +611,7 @@ single-label use-cases where `exclusive_classes = true`, while the
|
||||||
> nO = null
|
> nO = null
|
||||||
>
|
>
|
||||||
> [model.linear_model]
|
> [model.linear_model]
|
||||||
> @architectures = "spacy.TextCatBOW.v1"
|
> @architectures = "spacy.TextCatBOW.v2"
|
||||||
> exclusive_classes = true
|
> exclusive_classes = true
|
||||||
> ngram_size = 1
|
> ngram_size = 1
|
||||||
> no_output_layer = false
|
> no_output_layer = false
|
||||||
|
@ -666,13 +666,13 @@ taking it as argument:
|
||||||
|
|
||||||
</Accordion>
|
</Accordion>
|
||||||
|
|
||||||
### spacy.TextCatCNN.v1 {#TextCatCNN}
|
### spacy.TextCatCNN.v2 {#TextCatCNN}
|
||||||
|
|
||||||
> #### Example Config
|
> #### Example Config
|
||||||
>
|
>
|
||||||
> ```ini
|
> ```ini
|
||||||
> [model]
|
> [model]
|
||||||
> @architectures = "spacy.TextCatCNN.v1"
|
> @architectures = "spacy.TextCatCNN.v2"
|
||||||
> exclusive_classes = false
|
> exclusive_classes = false
|
||||||
> nO = null
|
> nO = null
|
||||||
>
|
>
|
||||||
|
@ -698,13 +698,20 @@ architecture is usually less accurate than the ensemble, but runs faster.
|
||||||
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
|
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
|
||||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
|
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
|
||||||
|
|
||||||
### spacy.TextCatBOW.v1 {#TextCatBOW}
|
<Accordion title="spacy.TextCatCNN.v1 definition" spaced>
|
||||||
|
|
||||||
|
[TextCatCNN.v1](/api/legacy#TextCatCNN_v1) had the exact same signature, but was not yet resizable.
|
||||||
|
Since v2, new labels can be added to this component, even after training.
|
||||||
|
|
||||||
|
</Accordion>
|
||||||
|
|
||||||
|
### spacy.TextCatBOW.v2 {#TextCatBOW}
|
||||||
|
|
||||||
> #### Example Config
|
> #### Example Config
|
||||||
>
|
>
|
||||||
> ```ini
|
> ```ini
|
||||||
> [model]
|
> [model]
|
||||||
> @architectures = "spacy.TextCatBOW.v1"
|
> @architectures = "spacy.TextCatBOW.v2"
|
||||||
> exclusive_classes = false
|
> exclusive_classes = false
|
||||||
> ngram_size = 1
|
> ngram_size = 1
|
||||||
> no_output_layer = false
|
> no_output_layer = false
|
||||||
|
@ -722,6 +729,13 @@ the others, but may not be as accurate, especially if texts are short.
|
||||||
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
|
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
|
||||||
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
|
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
|
||||||
|
|
||||||
|
<Accordion title="spacy.TextCatBOW.v1 definition" spaced>
|
||||||
|
|
||||||
|
[TextCatBOW.v1](/api/legacy#TextCatBOW_v1) had the exact same signature, but was not yet resizable.
|
||||||
|
Since v2, new labels can be added to this component, even after training.
|
||||||
|
|
||||||
|
</Accordion>
|
||||||
|
|
||||||
## Entity linking architectures {#entitylinker source="spacy/ml/models/entity_linker.py"}
|
## Entity linking architectures {#entitylinker source="spacy/ml/models/entity_linker.py"}
|
||||||
|
|
||||||
An [`EntityLinker`](/api/entitylinker) component disambiguates textual mentions
|
An [`EntityLinker`](/api/entitylinker) component disambiguates textual mentions
|
||||||
|
|
|
@ -93,7 +93,7 @@ Defines the `nlp` object, its tokenizer and
|
||||||
> labels = ["POSITIVE", "NEGATIVE"]
|
> labels = ["POSITIVE", "NEGATIVE"]
|
||||||
>
|
>
|
||||||
> [components.textcat.model]
|
> [components.textcat.model]
|
||||||
> @architectures = "spacy.TextCatBOW.v1"
|
> @architectures = "spacy.TextCatBOW.v2"
|
||||||
> exclusive_classes = true
|
> exclusive_classes = true
|
||||||
> ngram_size = 1
|
> ngram_size = 1
|
||||||
> no_output_layer = false
|
> no_output_layer = false
|
||||||
|
|
|
@ -176,6 +176,68 @@ added to an existing vectors table. See more details in
|
||||||
|
|
||||||
</Infobox>
|
</Infobox>
|
||||||
|
|
||||||
|
### spacy.TextCatCNN.v1 {#TextCatCNN_v1}
|
||||||
|
|
||||||
|
Since `spacy.TextCatCNN.v2`, this architecture has become resizable, which means that you can add
|
||||||
|
labels to a previously trained textcat. `TextCatCNN` v1 did not yet support that.
|
||||||
|
|
||||||
|
> #### Example Config
|
||||||
|
>
|
||||||
|
> ```ini
|
||||||
|
> [model]
|
||||||
|
> @architectures = "spacy.TextCatCNN.v1"
|
||||||
|
> exclusive_classes = false
|
||||||
|
> nO = null
|
||||||
|
>
|
||||||
|
> [model.tok2vec]
|
||||||
|
> @architectures = "spacy.HashEmbedCNN.v1"
|
||||||
|
> pretrained_vectors = null
|
||||||
|
> width = 96
|
||||||
|
> depth = 4
|
||||||
|
> embed_size = 2000
|
||||||
|
> window_size = 1
|
||||||
|
> maxout_pieces = 3
|
||||||
|
> subword_features = true
|
||||||
|
> ```
|
||||||
|
|
||||||
|
A neural network model where token vectors are calculated using a CNN. The
|
||||||
|
vectors are mean pooled and used as features in a feed-forward network. This
|
||||||
|
architecture is usually less accurate than the ensemble, but runs faster.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
|
||||||
|
| `tok2vec` | The [`tok2vec`](#tok2vec) layer of the model. ~~Model~~ |
|
||||||
|
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
|
||||||
|
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
|
||||||
|
|
||||||
|
### spacy.TextCatBOW.v1 {#TextCatBOW_v1}
|
||||||
|
|
||||||
|
Since `spacy.TextCatBOW.v2`, this architecture has become resizable, which means that you can add
|
||||||
|
labels to a previously trained textcat. `TextCatBOW` v1 did not yet support that.
|
||||||
|
|
||||||
|
> #### Example Config
|
||||||
|
>
|
||||||
|
> ```ini
|
||||||
|
> [model]
|
||||||
|
> @architectures = "spacy.TextCatBOW.v1"
|
||||||
|
> exclusive_classes = false
|
||||||
|
> ngram_size = 1
|
||||||
|
> no_output_layer = false
|
||||||
|
> nO = null
|
||||||
|
> ```
|
||||||
|
|
||||||
|
An n-gram "bag-of-words" model. This architecture should run much faster than
|
||||||
|
the others, but may not be as accurate, especially if texts are short.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `exclusive_classes` | Whether or not categories are mutually exclusive. ~~bool~~ |
|
||||||
|
| `ngram_size` | Determines the maximum length of the n-grams in the BOW model. For instance, `ngram_size=3` would give unigram, trigram and bigram features. ~~int~~ |
|
||||||
|
| `no_output_layer` | Whether or not to add an output layer to the model (`Softmax` activation if `exclusive_classes` is `True`, else `Logistic`). ~~bool~~ |
|
||||||
|
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
|
||||||
|
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
|
||||||
|
|
||||||
## Loggers {#loggers}
|
## Loggers {#loggers}
|
||||||
|
|
||||||
These functions are available from `@spacy.registry.loggers`.
|
These functions are available from `@spacy.registry.loggers`.
|
||||||
|
|
|
@ -151,7 +151,7 @@ maxout_pieces = 3
|
||||||
depth = 2
|
depth = 2
|
||||||
|
|
||||||
[components.textcat.model.linear_model]
|
[components.textcat.model.linear_model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = true
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
@ -169,7 +169,7 @@ factory = "textcat"
|
||||||
labels = []
|
labels = []
|
||||||
|
|
||||||
[components.textcat.model]
|
[components.textcat.model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = true
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
|
|
@ -1324,7 +1324,7 @@ labels = []
|
||||||
# This function is created and then passed to the "textcat" component as
|
# This function is created and then passed to the "textcat" component as
|
||||||
# the argument "model"
|
# the argument "model"
|
||||||
[components.textcat.model]
|
[components.textcat.model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
exclusive_classes = true
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
|
Loading…
Reference in New Issue
Block a user