TextCat updates and fixes (#6263)

* small fix in example imports

* throw error when train_corpus or dev_corpus is not a string

* small fix in custom logger example

* limit macro_auc to labels with 2 annotations

* fix typo

* also create parents of output_dir if need be

* update documentation of textcat scores

* refactor TextCatEnsemble

* fix tests for new AUC definition

* bump to 3.0.0a42

* update docs

* rename to spacy.TextCatEnsemble.v2

* spacy.TextCatEnsemble.v1 in legacy

* cleanup

* small fix

* update to 3.0.0rc2

* fix import that got lost in merge

* cursed IDE

* fix two typos
This commit is contained in:
Sofie Van Landeghem 2020-10-18 14:50:41 +02:00 committed by GitHub
parent e2f3c4e12d
commit 75a202ce65
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 235 additions and 127 deletions

View File

@ -1,6 +1,6 @@
# fmt: off
__title__ = "spacy-nightly"
__version__ = "3.0.0rc1"
__version__ = "3.0.0rc2"
__download_url__ = "https://github.com/explosion/spacy-models/releases/download"
__compatibility__ = "https://raw.githubusercontent.com/explosion/spacy-models/master/compatibility.json"
__projects__ = "https://github.com/explosion/projects"

View File

@ -100,7 +100,7 @@ def init_labels_cli(
extract the labels."""
util.logger.setLevel(logging.DEBUG if verbose else logging.INFO)
if not output_path.exists():
output_path.mkdir()
output_path.mkdir(parents=True)
overrides = parse_config_overrides(ctx.args)
import_code(code_path)
setup_gpu(use_gpu)

View File

@ -136,15 +136,19 @@ factory = "textcat"
{% if optimize == "accuracy" %}
[components.textcat.model]
@architectures = "spacy.TextCatEnsemble.v1"
exclusive_classes = false
width = 64
conv_depth = 2
embed_size = 2000
window_size = 1
ngram_size = 1
@architectures = "spacy.TextCatEnsemble.v2"
nO = null
[components.textcat.model.tok2vec]
@architectures = "spacy-transformers.TransformerListener.v1"
grad_factor = 1.0
[components.textcat.model.linear_model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false
ngram_size = 1
no_output_layer = false
{% else -%}
[components.textcat.model]
@architectures = "spacy.TextCatBOW.v1"
@ -271,15 +275,19 @@ factory = "textcat"
{% if optimize == "accuracy" %}
[components.textcat.model]
@architectures = "spacy.TextCatEnsemble.v1"
exclusive_classes = false
width = 64
conv_depth = 2
embed_size = 2000
window_size = 1
ngram_size = 1
@architectures = "spacy.TextCatEnsemble.v2"
nO = null
[components.textcat.model.tok2vec]
@architectures = "spacy.Tok2VecListener.v1"
width = ${components.tok2vec.model.encode.width}
[components.textcat.model.linear_model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false
ngram_size = 1
no_output_layer = false
{% else -%}
[components.textcat.model]
@architectures = "spacy.TextCatBOW.v1"

View File

@ -44,7 +44,7 @@ def train_cli(
if not config_path or not config_path.exists():
msg.fail("Config file not found", config_path, exits=1)
if output_path is not None and not output_path.exists():
output_path.mkdir()
output_path.mkdir(parents=True)
msg.good(f"Created output directory: {output_path}")
overrides = parse_config_overrides(ctx.args)
import_code(code_path)

View File

@ -398,8 +398,8 @@ class Errors:
E163 = ("cumsum was found to be unstable: its last element does not "
"correspond to sum")
E164 = ("x is neither increasing nor decreasing: {x}.")
E165 = ("Only one class present in y_true. ROC AUC score is not defined in "
"that case.")
E165 = ("Only one class present in the gold labels: {label}. "
"ROC AUC score is not defined in that case.")
E166 = ("Can only merge DocBins with the same value for '{param}'.\n"
"Current DocBin: {current}\nOther DocBin: {other}")
E169 = ("Can't find module: {module}")
@ -456,6 +456,8 @@ class Errors:
"issue tracker: http://github.com/explosion/spaCy/issues")
# TODO: fix numbering after merging develop into master
E897 = ("Field '{field}' should be a dot-notation string referring to the "
"relevant section in the config, but found type {type} instead.")
E898 = ("Can't serialize trainable pipe '{name}': the `model` attribute "
"is not set or None. If you've implemented a custom component, make "
"sure to store the component model as `self.model` in your "

View File

@ -1,4 +1,6 @@
from typing import Optional
from typing import Optional, List
from thinc.types import Floats2d
from thinc.api import Model, reduce_mean, Linear, list2ragged, Logistic
from thinc.api import chain, concatenate, clone, Dropout, ParametricAttention
from thinc.api import SparseLinear, Softmax, softmax_activation, Maxout, reduce_sum
@ -10,12 +12,13 @@ from ...util import registry
from ..extract_ngrams import extract_ngrams
from ..staticvectors import StaticVectors
from ..featureextractor import FeatureExtractor
from ...tokens import Doc
@registry.architectures.register("spacy.TextCatCNN.v1")
def build_simple_cnn_text_classifier(
tok2vec: Model, exclusive_classes: bool, nO: Optional[int] = None
) -> Model:
) -> Model[List[Doc], Floats2d]:
"""
Build a simple CNN text classifier, given a token-to-vector model as inputs.
If exclusive_classes=True, a softmax non-linearity is applied, so that the
@ -23,15 +26,14 @@ def build_simple_cnn_text_classifier(
is applied instead, so that outputs are in the range [0, 1].
"""
with Model.define_operators({">>": chain}):
cnn = tok2vec >> list2ragged() >> reduce_mean()
if exclusive_classes:
output_layer = Softmax(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
model = tok2vec >> list2ragged() >> reduce_mean() >> output_layer
model = cnn >> output_layer
model.set_ref("output_layer", output_layer)
else:
linear_layer = Linear(nO=nO, nI=tok2vec.maybe_get_dim("nO"))
model = (
tok2vec >> list2ragged() >> reduce_mean() >> linear_layer >> Logistic()
)
model = cnn >> linear_layer >> Logistic()
model.set_ref("output_layer", linear_layer)
model.set_ref("tok2vec", tok2vec)
model.set_dim("nO", nO)
@ -45,8 +47,7 @@ def build_bow_text_classifier(
ngram_size: int,
no_output_layer: bool,
nO: Optional[int] = None,
) -> Model:
# Don't document this yet, I'm not sure it's right.
) -> Model[List[Doc], Floats2d]:
with Model.define_operators({">>": chain}):
sparse_linear = SparseLinear(nO)
model = extract_ngrams(ngram_size, attr=ORTH) >> sparse_linear
@ -59,6 +60,39 @@ def build_bow_text_classifier(
return model
@registry.architectures.register("spacy.TextCatEnsemble.v2")
def build_text_classifier(
tok2vec: Model[List[Doc], List[Floats2d]],
linear_model: Model[List[Doc], Floats2d],
nO: Optional[int] = None,
) -> Model[List[Doc], Floats2d]:
exclusive_classes = not linear_model.attrs["multi_label"]
with Model.define_operators({">>": chain, "|": concatenate}):
width = tok2vec.get_dim("nO")
cnn_model = (
tok2vec
>> list2ragged()
>> ParametricAttention(width) # TODO: benchmark performance difference of this layer
>> reduce_sum()
>> residual(Maxout(nO=width, nI=width))
>> Linear(nO=nO, nI=width)
>> Dropout(0.0)
)
nO_double = nO * 2 if nO else None
if exclusive_classes:
output_layer = Softmax(nO=nO, nI=nO_double)
else:
output_layer = Linear(nO=nO, nI=nO_double) >> Dropout(0.0) >> Logistic()
model = (linear_model | cnn_model) >> output_layer
model.set_ref("tok2vec", tok2vec)
if model.has_dim("nO") is not False:
model.set_dim("nO", nO)
model.set_ref("output_layer", linear_model.get_ref("output_layer"))
model.attrs["multi_label"] = not exclusive_classes
return model
# TODO: move to legacy
@registry.architectures.register("spacy.TextCatEnsemble.v1")
def build_text_classifier(
width: int,
@ -158,11 +192,8 @@ def build_text_classifier(
@registry.architectures.register("spacy.TextCatLowData.v1")
def build_text_classifier_lowdata(
width: int,
pretrained_vectors: Optional[bool],
dropout: Optional[float],
nO: Optional[int] = None,
) -> Model:
width: int, dropout: Optional[float], nO: Optional[int] = None
) -> Model[List[Doc], Floats2d]:
# Don't document this yet, I'm not sure it's right.
# Note, before v.3, this was the default if setting "low_data" and "pretrained_dims"
with Model.define_operators({">>": chain, "**": clone}):

View File

@ -106,7 +106,7 @@ def MultiHashEmbed(
) -> Model[List[Doc], List[Floats2d]]:
"""Construct an embedding layer that separately embeds a number of lexical
attributes using hash embedding, concatenates the results, and passes it
through a feed-forward subnetwork to build a mixed representations.
through a feed-forward subnetwork to build a mixed representation.
The features used can be configured with the 'attrs' argument. The suggested
attributes are NORM, PREFIX, SUFFIX and SHAPE. This lets the model take into

View File

@ -16,15 +16,30 @@ from ..vocab import Vocab
default_model_config = """
[model]
@architectures = "spacy.TextCatEnsemble.v1"
exclusive_classes = false
pretrained_vectors = null
@architectures = "spacy.TextCatEnsemble.v2"
[model.tok2vec]
@architectures = "spacy.Tok2Vec.v1"
[model.tok2vec.embed]
@architectures = "spacy.MultiHashEmbed.v1"
width = 64
conv_depth = 2
embed_size = 2000
rows = [2000, 2000, 1000, 1000, 1000, 1000]
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
include_static_vectors = false
[model.tok2vec.encode]
@architectures = "spacy.MaxoutWindowEncoder.v1"
width = ${model.tok2vec.embed.width}
window_size = 1
maxout_pieces = 3
depth = 2
[model.linear_model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false
ngram_size = 1
dropout = null
no_output_layer = false
"""
DEFAULT_TEXTCAT_MODEL = Config().from_str(default_model_config)["model"]
@ -60,9 +75,11 @@ subword_features = true
default_score_weights={
"cats_score": 1.0,
"cats_score_desc": None,
"cats_p": None,
"cats_r": None,
"cats_f": None,
"cats_micro_p": None,
"cats_micro_r": None,
"cats_micro_f": None,
"cats_macro_p": None,
"cats_macro_r": None,
"cats_macro_f": None,
"cats_macro_auc": None,
"cats_f_per_type": None,

View File

@ -59,7 +59,9 @@ class PRFScore:
class ROCAUCScore:
"""An AUC ROC score."""
"""An AUC ROC score. This is only defined for binary classification.
Use the method is_binary before calculating the score, otherwise it
may throw an error."""
def __init__(self) -> None:
self.golds = []
@ -71,16 +73,16 @@ class ROCAUCScore:
self.cands.append(cand)
self.golds.append(gold)
def is_binary(self):
return len(np.unique(self.golds)) == 2
@property
def score(self):
if not self.is_binary():
raise ValueError(Errors.E165.format(label=set(self.golds)))
if len(self.golds) == self.saved_score_at_len:
return self.saved_score
try:
self.saved_score = _roc_auc_score(self.golds, self.cands)
# catch ValueError: Only one class present in y_true.
# ROC AUC score is not defined in that case.
except ValueError:
self.saved_score = -float("inf")
self.saved_score = _roc_auc_score(self.golds, self.cands)
self.saved_score_at_len = len(self.golds)
return self.saved_score
@ -362,9 +364,13 @@ class Scorer:
for all:
attr_score (one of attr_micro_f / attr_macro_f / attr_macro_auc),
attr_score_desc (text description of the overall score),
attr_micro_p,
attr_micro_r,
attr_micro_f,
attr_macro_p,
attr_macro_r,
attr_macro_f,
attr_auc,
attr_macro_auc,
attr_f_per_type,
attr_auc_per_type
@ -431,7 +437,9 @@ class Scorer:
macro_p = sum(prf.precision for prf in f_per_type.values()) / n_cats
macro_r = sum(prf.recall for prf in f_per_type.values()) / n_cats
macro_f = sum(prf.fscore for prf in f_per_type.values()) / n_cats
macro_auc = sum(auc.score for auc in auc_per_type.values()) / n_cats
# Limit macro_auc to those labels with gold annotations,
# but still divide by all cats to avoid artificial boosting of datasets with missing labels
macro_auc = sum(auc.score if auc.is_binary() else 0.0 for auc in auc_per_type.values()) / n_cats
results = {
f"{attr}_score": None,
f"{attr}_score_desc": None,
@ -443,7 +451,7 @@ class Scorer:
f"{attr}_macro_f": macro_f,
f"{attr}_macro_auc": macro_auc,
f"{attr}_f_per_type": {k: v.to_dict() for k, v in f_per_type.items()},
f"{attr}_auc_per_type": {k: v.score for k, v in auc_per_type.items()},
f"{attr}_auc_per_type": {k: v.score if v.is_binary() else None for k, v in auc_per_type.items()},
}
if len(labels) == 2 and not multi_label and positive_label:
positive_label_f = results[f"{attr}_f_per_type"][positive_label]["f"]
@ -726,7 +734,7 @@ def _roc_auc_score(y_true, y_score):
<https://www.ncbi.nlm.nih.gov/pubmed/2668680>`_
"""
if len(np.unique(y_true)) != 2:
raise ValueError(Errors.E165)
raise ValueError(Errors.E165.format(label=np.unique(y_true)))
fpr, tpr, _ = _roc_curve(y_true, y_score)
return _auc(fpr, tpr)

View File

@ -2,6 +2,7 @@ import pytest
from spacy.language import Language
from spacy.lang.en import English
from spacy.lang.de import German
from spacy.pipeline.tok2vec import DEFAULT_TOK2VEC_MODEL
from spacy.tokens import Doc
from spacy.util import registry, SimpleFrozenDict, combine_score_weights
from thinc.api import Model, Linear, ConfigValidationError
@ -156,15 +157,10 @@ def test_pipe_class_component_model():
name = "test_class_component_model"
default_config = {
"model": {
"@architectures": "spacy.TextCatEnsemble.v1",
"exclusive_classes": False,
"pretrained_vectors": None,
"width": 64,
"embed_size": 2000,
"window_size": 1,
"conv_depth": 2,
"ngram_size": 1,
"dropout": None,
"@architectures": "spacy.TextCatEnsemble.v2",
"tok2vec": DEFAULT_TOK2VEC_MODEL,
"linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1,
"no_output_layer": False},
},
"value1": 10,
}

View File

@ -140,7 +140,7 @@ def test_overfitting_IO():
nlp = English()
nlp.config["initialize"]["components"]["textcat"] = {"positive_label": "POSITIVE"}
# Set exclusive labels
config = {"model": {"exclusive_classes": True}}
config = {"model": {"linear_model": {"exclusive_classes": True}}}
textcat = nlp.add_pipe("textcat", config=config)
train_examples = []
for text, annotations in TRAIN_DATA:
@ -192,9 +192,8 @@ def test_overfitting_IO():
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False},
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True},
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True},
{"@architectures": "spacy.TextCatEnsemble.v1", "exclusive_classes": False, "ngram_size": 1, "pretrained_vectors": False, "width": 64, "conv_depth": 2, "embed_size": 2000, "window_size": 2, "dropout": None},
{"@architectures": "spacy.TextCatEnsemble.v1", "exclusive_classes": True, "ngram_size": 5, "pretrained_vectors": False, "width": 128, "conv_depth": 2, "embed_size": 2000, "window_size": 1, "dropout": None},
{"@architectures": "spacy.TextCatEnsemble.v1", "exclusive_classes": True, "ngram_size": 2, "pretrained_vectors": False, "width": 32, "conv_depth": 3, "embed_size": 500, "window_size": 3, "dropout": None},
{"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}},
{"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}},
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True},
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False},
],

View File

@ -4,32 +4,23 @@ from thinc.api import fix_random_seed, Adam, set_dropout_rate
from numpy.testing import assert_array_equal
import numpy
from spacy.ml.models import build_Tok2Vec_model, MultiHashEmbed, MaxoutWindowEncoder
from spacy.ml.models import build_text_classifier, build_simple_cnn_text_classifier
from spacy.ml.models import build_bow_text_classifier, build_simple_cnn_text_classifier
from spacy.ml.staticvectors import StaticVectors
from spacy.lang.en import English
from spacy.lang.en.examples import sentences as EN_SENTENCES
def get_textcat_kwargs():
def get_textcat_bow_kwargs():
return {
"width": 64,
"embed_size": 2000,
"pretrained_vectors": None,
"exclusive_classes": False,
"exclusive_classes": True,
"ngram_size": 1,
"window_size": 1,
"conv_depth": 2,
"dropout": None,
"nO": 7,
"no_output_layer": False,
"nO": 34,
}
def get_textcat_cnn_kwargs():
return {
"tok2vec": test_tok2vec(),
"exclusive_classes": False,
"nO": 13,
}
return {"tok2vec": test_tok2vec(), "exclusive_classes": False, "nO": 13}
def get_all_params(model):
@ -105,7 +96,7 @@ def test_multi_hash_embed():
"seed,model_func,kwargs",
[
(0, build_Tok2Vec_model, get_tok2vec_kwargs()),
(0, build_text_classifier, get_textcat_kwargs()),
(0, build_bow_text_classifier, get_textcat_bow_kwargs()),
(0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs()),
],
)
@ -125,7 +116,7 @@ def test_models_initialize_consistently(seed, model_func, kwargs):
"seed,model_func,kwargs,get_X",
[
(0, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs),
(0, build_text_classifier, get_textcat_kwargs(), get_docs),
(0, build_bow_text_classifier, get_textcat_bow_kwargs(), get_docs),
(0, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs),
],
)
@ -160,7 +151,7 @@ def test_models_predict_consistently(seed, model_func, kwargs, get_X):
"seed,dropout,model_func,kwargs,get_X",
[
(0, 0.2, build_Tok2Vec_model, get_tok2vec_kwargs(), get_docs),
(0, 0.2, build_text_classifier, get_textcat_kwargs(), get_docs),
(0, 0.2, build_bow_text_classifier, get_textcat_bow_kwargs(), get_docs),
(0, 0.2, build_simple_cnn_text_classifier, get_textcat_cnn_kwargs(), get_docs),
],
)

View File

@ -334,7 +334,8 @@ def test_roc_auc_score():
score = ROCAUCScore()
score.score_set(0.25, 0)
score.score_set(0.75, 0)
assert score.score == -float("inf")
with pytest.raises(ValueError):
s = score.score
y_true = [1, 1]
y_score = [0.25, 0.75]
@ -344,4 +345,5 @@ def test_roc_auc_score():
score = ROCAUCScore()
score.score_set(0.25, 1)
score.score_set(0.75, 1)
assert score.score == -float("inf")
with pytest.raises(ValueError):
s = score.score

View File

@ -51,7 +51,7 @@ def test_readers():
for example in train_corpus(nlp):
nlp.update([example], sgd=optimizer)
scores = nlp.evaluate(list(dev_corpus(nlp)))
assert scores["cats_score"]
assert scores["cats_score"] == 0.0
# ensure the pipeline runs
doc = nlp("Quick test")
assert doc.cats

View File

@ -36,6 +36,10 @@ def init_nlp(config: Config, *, use_gpu: int = -1) -> "Language":
# Resolve all training-relevant sections using the filled nlp config
T = registry.resolve(config["training"], schema=ConfigSchemaTraining)
dot_names = [T["train_corpus"], T["dev_corpus"]]
if not isinstance(T["train_corpus"], str):
raise ConfigValidationError(desc=Errors.E897.format(field="training.train_corpus", type=type(T["train_corpus"])))
if not isinstance(T["dev_corpus"], str):
raise ConfigValidationError(desc=Errors.E897.format(field="training.dev_corpus", type=type(T["dev_corpus"])))
train_corpus, dev_corpus = resolve_dot_names(config, dot_names)
optimizer = T["optimizer"]
# Components that shouldn't be updated during training

View File

@ -143,7 +143,7 @@ argument that connects to the shared `tok2vec` component in the pipeline.
Construct an embedding layer that separately embeds a number of lexical
attributes using hash embedding, concatenates the results, and passes it through
a feed-forward subnetwork to build a mixed representations. The features used
a feed-forward subnetwork to build a mixed representation. The features used
can be configured with the `attrs` argument. The suggested attributes are
`NORM`, `PREFIX`, `SUFFIX` and `SHAPE`. This lets the model take into account
some subword information, without construction a fully character-based
@ -516,26 +516,54 @@ several different built-in architectures. It is recommended to experiment with
different architectures and settings to determine what works best on your
specific data and challenge.
### spacy.TextCatEnsemble.v1 {#TextCatEnsemble}
### spacy.TextCatEnsemble.v2 {#TextCatEnsemble}
> #### Example Config
>
> ```ini
> [model]
> @architectures = "spacy.TextCatEnsemble.v1"
> exclusive_classes = false
> pretrained_vectors = null
> width = 64
> embed_size = 2000
> conv_depth = 2
> window_size = 1
> ngram_size = 1
> dropout = null
> @architectures = "spacy.TextCatEnsemble.v2"
> nO = null
>
> [model.linear_model]
> @architectures = "spacy.TextCatBOW.v1"
> exclusive_classes = true
> ngram_size = 1
> no_output_layer = false
>
> [model.tok2vec]
> @architectures = "spacy.Tok2Vec.v1"
>
> [model.tok2vec.embed]
> @architectures = "spacy.MultiHashEmbed.v1"
> width = 64
> rows = [2000, 2000, 1000, 1000, 1000, 1000]
> attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
> include_static_vectors = false
>
> [model.tok2vec.encode]
> @architectures = "spacy.MaxoutWindowEncoder.v1"
> width = ${model.tok2vec.embed.width}
> window_size = 1
> maxout_pieces = 3
> depth = 2
> ```
Stacked ensemble of a bag-of-words model and a neural network model. The neural
network has an internal CNN Tok2Vec layer and uses attention.
Stacked ensemble of a linear bag-of-words model and a neural network model. The
neural network is built upon a Tok2Vec layer and uses attention. The setting for
whether or not this model should cater for multi-label classification, is taken
from the linear model, where it is stored in `model.attrs["multi_label"]`.
| Name | Description |
| -------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| `linear_model` | The linear bag-of-words model. ~~Model[List[Doc], Floats2d]~~ |
| `tok2vec` | The `tok2vec` layer to build the neural network upon. ~~Model[List[Doc], List[Floats2d]]~~ |
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
<Accordion title="spacy.TextCatEnsemble.v1 definition" spaced>
The v1 was functionally similar, but used an internal `tok2vec` instead of taking it as argument.
| Name | Description |
| -------------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
@ -550,6 +578,8 @@ network has an internal CNN Tok2Vec layer and uses attention.
| `nO` | Output dimension, determined by the number of different labels. If not set, the [`TextCategorizer`](/api/textcategorizer) component will set it when `initialize` is called. ~~Optional[int]~~ |
| **CREATES** | The model using the architecture. ~~Model[List[Doc], Floats2d]~~ |
</Accordion>
### spacy.TextCatCNN.v1 {#TextCatCNN}
> #### Example Config

View File

@ -174,15 +174,25 @@ Calculate the UAS, LAS, and LAS per type scores for dependency parses.
## Scorer.score_cats {#score_cats tag="staticmethod" new="3"}
Calculate PRF and ROC AUC scores for a doc-level attribute that is a dict
containing scores for each label like `Doc.cats`. The reported overall score
depends on the scorer settings:
containing scores for each label like `Doc.cats`. The returned dictionary
contains the following scores:
1. **all:** `{attr}_score` (one of `{attr}_f` / `{attr}_macro_f` /
`{attr}_macro_auc`), `{attr}_score_desc` (text description of the overall
score), `{attr}_f_per_type`, `{attr}_auc_per_type`
2. **binary exclusive with positive label:** `{attr}_p`, `{attr}_r`, `{attr}_f`
3. **3+ exclusive classes**, macro-averaged F-score: `{attr}_macro_f`;
4. **multilabel**, macro-averaged AUC: `{attr}_macro_auc`
- `{attr}_micro_p`, `{attr}_micro_r` and `{attr}_micro_f`: each instance across
each label is weighted equally
- `{attr}_macro_p`, `{attr}_macro_r` and `{attr}_macro_f`: the average values
across evaluations per label
- `{attr}_f_per_type` and `{attr}_auc_per_type`: each contains a dictionary of
scores, keyed by label
- A final `{attr}_score` and corresponding `{attr}_score_desc` (text
description)
The reported `{attr}_score` depends on the classification properties:
- **binary exclusive with positive label:** `{attr}_score` is set to the F-score
of the positive label
- **3+ exclusive classes**, macro-averaged F-score:
`{attr}_score = {attr}_macro_f`
- **multilabel**, macro-averaged AUC: `{attr}_score = {attr}_macro_auc`
> #### Example
>

View File

@ -130,16 +130,31 @@ factory = "textcat"
labels = []
[components.textcat.model]
@architectures = "spacy.TextCatEnsemble.v1"
exclusive_classes = false
pretrained_vectors = null
width = 64
conv_depth = 2
embed_size = 2000
window_size = 1
ngram_size = 1
dropout = 0
@architectures = "spacy.TextCatEnsemble.v2"
nO = null
[components.textcat.model.tok2vec]
@architectures = "spacy.Tok2Vec.v1"
[components.textcat.model.tok2vec.embed]
@architectures = "spacy.MultiHashEmbed.v1"
width = 64
rows = [2000, 2000, 1000, 1000, 1000, 1000]
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
include_static_vectors = false
[components.textcat.model.tok2vec.encode]
@architectures = "spacy.MaxoutWindowEncoder.v1"
width = ${components.textcat.model.tok2vec.embed.width}
window_size = 1
maxout_pieces = 3
depth = 2
[components.textcat.model.linear_model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false
ngram_size = 1
no_output_layer = false
```
spaCy has two additional built-in `textcat` architectures, and you can easily

View File

@ -1244,15 +1244,10 @@ labels = []
# This function is created and then passed to the "textcat" component as
# the argument "model"
[components.textcat.model]
@architectures = "spacy.TextCatEnsemble.v1"
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false
pretrained_vectors = null
width = 64
conv_depth = 2
embed_size = 2000
window_size = 1
ngram_size = 1
dropout = null
no_output_layer = false
[components.other_textcat]
factory = "textcat"

View File

@ -717,7 +717,7 @@ tabular results to a file:
```python
### functions.py
import sys
from typing import IO, Tuple, Callable, Dict, Any
from typing import IO, Tuple, Callable, Dict, Any, Optional
import spacy
from spacy import Language
from pathlib import Path
@ -729,7 +729,7 @@ def custom_logger(log_path):
stdout: IO=sys.stdout,
stderr: IO=sys.stderr
) -> Tuple[Callable, Callable]:
stdout.write(f"Logging to {log_path}\n")
stdout.write(f"Logging to {log_path}\\n")
log_file = Path(log_path).open("w", encoding="utf8")
log_file.write("step\\t")
log_file.write("score\\t")