mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 01:16:28 +03:00
multi-label textcat component (#6474)
* multi-label textcat component * formatting * fix comment * cleanup * fix from #6481 * random edit to push the tests * add explicit error when textcat is called with multi-label gold data * fix error nr * small fix
This commit is contained in:
parent
1a77607036
commit
afc5714d32
|
@ -169,7 +169,7 @@ def init_config(
|
||||||
"Hardware": variables["hardware"].upper(),
|
"Hardware": variables["hardware"].upper(),
|
||||||
"Transformer": template_vars.transformer.get("name", False),
|
"Transformer": template_vars.transformer.get("name", False),
|
||||||
}
|
}
|
||||||
msg.info("Generated template specific for your use case")
|
msg.info("Generated config template specific for your use case")
|
||||||
for label, value in use_case.items():
|
for label, value in use_case.items():
|
||||||
msg.text(f"- {label}: {value}")
|
msg.text(f"- {label}: {value}")
|
||||||
with show_validation_error(hint_fill=False):
|
with show_validation_error(hint_fill=False):
|
||||||
|
|
|
@ -149,13 +149,44 @@ grad_factor = 1.0
|
||||||
|
|
||||||
[components.textcat.model.linear_model]
|
[components.textcat.model.linear_model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
exclusive_classes = false
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
|
||||||
{% else -%}
|
{% else -%}
|
||||||
[components.textcat.model]
|
[components.textcat.model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
|
exclusive_classes = true
|
||||||
|
ngram_size = 1
|
||||||
|
no_output_layer = false
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{% if "textcat_multilabel" in components %}
|
||||||
|
[components.textcat_multilabel]
|
||||||
|
factory = "textcat_multilabel"
|
||||||
|
|
||||||
|
{% if optimize == "accuracy" %}
|
||||||
|
[components.textcat_multilabel.model]
|
||||||
|
@architectures = "spacy.TextCatEnsemble.v2"
|
||||||
|
nO = null
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model.tok2vec]
|
||||||
|
@architectures = "spacy-transformers.TransformerListener.v1"
|
||||||
|
grad_factor = 1.0
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model.tok2vec.pooling]
|
||||||
|
@layers = "reduce_mean.v1"
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model.linear_model]
|
||||||
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
|
exclusive_classes = false
|
||||||
|
ngram_size = 1
|
||||||
|
no_output_layer = false
|
||||||
|
|
||||||
|
{% else -%}
|
||||||
|
[components.textcat_multilabel.model]
|
||||||
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
exclusive_classes = false
|
exclusive_classes = false
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
@ -288,13 +319,41 @@ width = ${components.tok2vec.model.encode.width}
|
||||||
|
|
||||||
[components.textcat.model.linear_model]
|
[components.textcat.model.linear_model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
exclusive_classes = false
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
|
||||||
{% else -%}
|
{% else -%}
|
||||||
[components.textcat.model]
|
[components.textcat.model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
|
exclusive_classes = true
|
||||||
|
ngram_size = 1
|
||||||
|
no_output_layer = false
|
||||||
|
{%- endif %}
|
||||||
|
{%- endif %}
|
||||||
|
|
||||||
|
{% if "textcat_multilabel" in components %}
|
||||||
|
[components.textcat_multilabel]
|
||||||
|
factory = "textcat_multilabel"
|
||||||
|
|
||||||
|
{% if optimize == "accuracy" %}
|
||||||
|
[components.textcat_multilabel.model]
|
||||||
|
@architectures = "spacy.TextCatEnsemble.v2"
|
||||||
|
nO = null
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model.linear_model]
|
||||||
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
|
exclusive_classes = false
|
||||||
|
ngram_size = 1
|
||||||
|
no_output_layer = false
|
||||||
|
|
||||||
|
{% else -%}
|
||||||
|
[components.textcat_multilabel.model]
|
||||||
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
exclusive_classes = false
|
exclusive_classes = false
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
|
@ -303,7 +362,7 @@ no_output_layer = false
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
{% for pipe in components %}
|
{% for pipe in components %}
|
||||||
{% if pipe not in ["tagger", "morphologizer", "parser", "ner", "textcat", "entity_linker"] %}
|
{% if pipe not in ["tagger", "morphologizer", "parser", "ner", "textcat", "textcat_multilabel", "entity_linker"] %}
|
||||||
{# Other components defined by the user: we just assume they're factories #}
|
{# Other components defined by the user: we just assume they're factories #}
|
||||||
[components.{{ pipe }}]
|
[components.{{ pipe }}]
|
||||||
factory = "{{ pipe }}"
|
factory = "{{ pipe }}"
|
||||||
|
|
|
@ -463,6 +463,10 @@ class Errors:
|
||||||
"issue tracker: http://github.com/explosion/spaCy/issues")
|
"issue tracker: http://github.com/explosion/spaCy/issues")
|
||||||
|
|
||||||
# TODO: fix numbering after merging develop into master
|
# TODO: fix numbering after merging develop into master
|
||||||
|
E895 = ("The 'textcat' component received gold-standard annotations with "
|
||||||
|
"multiple labels per document. In spaCy 3 you should use the "
|
||||||
|
"'textcat_multilabel' component for this instead. "
|
||||||
|
"Example of an offending annotation: {value}")
|
||||||
E896 = ("There was an error using the static vectors. Ensure that the vectors "
|
E896 = ("There was an error using the static vectors. Ensure that the vectors "
|
||||||
"of the vocab are properly initialized, or set 'include_static_vectors' "
|
"of the vocab are properly initialized, or set 'include_static_vectors' "
|
||||||
"to False.")
|
"to False.")
|
||||||
|
|
|
@ -11,6 +11,7 @@ from .senter import SentenceRecognizer
|
||||||
from .sentencizer import Sentencizer
|
from .sentencizer import Sentencizer
|
||||||
from .tagger import Tagger
|
from .tagger import Tagger
|
||||||
from .textcat import TextCategorizer
|
from .textcat import TextCategorizer
|
||||||
|
from .textcat_multilabel import MultiLabel_TextCategorizer
|
||||||
from .tok2vec import Tok2Vec
|
from .tok2vec import Tok2Vec
|
||||||
from .functions import merge_entities, merge_noun_chunks, merge_subtokens
|
from .functions import merge_entities, merge_noun_chunks, merge_subtokens
|
||||||
|
|
||||||
|
@ -22,13 +23,14 @@ __all__ = [
|
||||||
"EntityRuler",
|
"EntityRuler",
|
||||||
"Morphologizer",
|
"Morphologizer",
|
||||||
"Lemmatizer",
|
"Lemmatizer",
|
||||||
"TrainablePipe",
|
"MultiLabel_TextCategorizer",
|
||||||
"Pipe",
|
"Pipe",
|
||||||
"SentenceRecognizer",
|
"SentenceRecognizer",
|
||||||
"Sentencizer",
|
"Sentencizer",
|
||||||
"Tagger",
|
"Tagger",
|
||||||
"TextCategorizer",
|
"TextCategorizer",
|
||||||
"Tok2Vec",
|
"Tok2Vec",
|
||||||
|
"TrainablePipe",
|
||||||
"merge_entities",
|
"merge_entities",
|
||||||
"merge_noun_chunks",
|
"merge_noun_chunks",
|
||||||
"merge_subtokens",
|
"merge_subtokens",
|
||||||
|
|
|
@ -14,7 +14,7 @@ from ..tokens import Doc
|
||||||
from ..vocab import Vocab
|
from ..vocab import Vocab
|
||||||
|
|
||||||
|
|
||||||
default_model_config = """
|
single_label_default_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.TextCatEnsemble.v2"
|
@architectures = "spacy.TextCatEnsemble.v2"
|
||||||
|
|
||||||
|
@ -37,24 +37,24 @@ depth = 2
|
||||||
|
|
||||||
[model.linear_model]
|
[model.linear_model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
exclusive_classes = false
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
"""
|
"""
|
||||||
DEFAULT_TEXTCAT_MODEL = Config().from_str(default_model_config)["model"]
|
DEFAULT_SINGLE_TEXTCAT_MODEL = Config().from_str(single_label_default_config)["model"]
|
||||||
|
|
||||||
bow_model_config = """
|
single_label_bow_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
exclusive_classes = false
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
"""
|
"""
|
||||||
|
|
||||||
cnn_model_config = """
|
single_label_cnn_config = """
|
||||||
[model]
|
[model]
|
||||||
@architectures = "spacy.TextCatCNN.v1"
|
@architectures = "spacy.TextCatCNN.v1"
|
||||||
exclusive_classes = false
|
exclusive_classes = true
|
||||||
|
|
||||||
[model.tok2vec]
|
[model.tok2vec]
|
||||||
@architectures = "spacy.HashEmbedCNN.v1"
|
@architectures = "spacy.HashEmbedCNN.v1"
|
||||||
|
@ -71,7 +71,7 @@ subword_features = true
|
||||||
@Language.factory(
|
@Language.factory(
|
||||||
"textcat",
|
"textcat",
|
||||||
assigns=["doc.cats"],
|
assigns=["doc.cats"],
|
||||||
default_config={"threshold": 0.5, "model": DEFAULT_TEXTCAT_MODEL},
|
default_config={"threshold": 0.5, "model": DEFAULT_SINGLE_TEXTCAT_MODEL},
|
||||||
default_score_weights={
|
default_score_weights={
|
||||||
"cats_score": 1.0,
|
"cats_score": 1.0,
|
||||||
"cats_score_desc": None,
|
"cats_score_desc": None,
|
||||||
|
@ -103,7 +103,7 @@ def make_textcat(
|
||||||
|
|
||||||
|
|
||||||
class TextCategorizer(TrainablePipe):
|
class TextCategorizer(TrainablePipe):
|
||||||
"""Pipeline component for text classification.
|
"""Pipeline component for single-label text classification.
|
||||||
|
|
||||||
DOCS: https://nightly.spacy.io/api/textcategorizer
|
DOCS: https://nightly.spacy.io/api/textcategorizer
|
||||||
"""
|
"""
|
||||||
|
@ -111,7 +111,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, vocab: Vocab, model: Model, name: str = "textcat", *, threshold: float
|
self, vocab: Vocab, model: Model, name: str = "textcat", *, threshold: float
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Initialize a text categorizer.
|
"""Initialize a text categorizer for single-label classification.
|
||||||
|
|
||||||
vocab (Vocab): The shared vocabulary.
|
vocab (Vocab): The shared vocabulary.
|
||||||
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||||
|
@ -214,6 +214,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
losses = {}
|
losses = {}
|
||||||
losses.setdefault(self.name, 0.0)
|
losses.setdefault(self.name, 0.0)
|
||||||
validate_examples(examples, "TextCategorizer.update")
|
validate_examples(examples, "TextCategorizer.update")
|
||||||
|
self._validate_categories(examples)
|
||||||
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
if not any(len(eg.predicted) if eg.predicted else 0 for eg in examples):
|
||||||
# Handle cases where there are no tokens in any docs.
|
# Handle cases where there are no tokens in any docs.
|
||||||
return losses
|
return losses
|
||||||
|
@ -256,6 +257,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
if self._rehearsal_model is None:
|
if self._rehearsal_model is None:
|
||||||
return losses
|
return losses
|
||||||
validate_examples(examples, "TextCategorizer.rehearse")
|
validate_examples(examples, "TextCategorizer.rehearse")
|
||||||
|
self._validate_categories(examples)
|
||||||
docs = [eg.predicted for eg in examples]
|
docs = [eg.predicted for eg in examples]
|
||||||
if not any(len(doc) for doc in docs):
|
if not any(len(doc) for doc in docs):
|
||||||
# Handle cases where there are no tokens in any docs.
|
# Handle cases where there are no tokens in any docs.
|
||||||
|
@ -296,6 +298,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
DOCS: https://nightly.spacy.io/api/textcategorizer#get_loss
|
DOCS: https://nightly.spacy.io/api/textcategorizer#get_loss
|
||||||
"""
|
"""
|
||||||
validate_examples(examples, "TextCategorizer.get_loss")
|
validate_examples(examples, "TextCategorizer.get_loss")
|
||||||
|
self._validate_categories(examples)
|
||||||
truths, not_missing = self._examples_to_truth(examples)
|
truths, not_missing = self._examples_to_truth(examples)
|
||||||
not_missing = self.model.ops.asarray(not_missing)
|
not_missing = self.model.ops.asarray(not_missing)
|
||||||
d_scores = (scores - truths) / scores.shape[0]
|
d_scores = (scores - truths) / scores.shape[0]
|
||||||
|
@ -341,6 +344,7 @@ class TextCategorizer(TrainablePipe):
|
||||||
DOCS: https://nightly.spacy.io/api/textcategorizer#initialize
|
DOCS: https://nightly.spacy.io/api/textcategorizer#initialize
|
||||||
"""
|
"""
|
||||||
validate_get_examples(get_examples, "TextCategorizer.initialize")
|
validate_get_examples(get_examples, "TextCategorizer.initialize")
|
||||||
|
self._validate_categories(get_examples())
|
||||||
if labels is None:
|
if labels is None:
|
||||||
for example in get_examples():
|
for example in get_examples():
|
||||||
for cat in example.y.cats:
|
for cat in example.y.cats:
|
||||||
|
@ -373,12 +377,20 @@ class TextCategorizer(TrainablePipe):
|
||||||
DOCS: https://nightly.spacy.io/api/textcategorizer#score
|
DOCS: https://nightly.spacy.io/api/textcategorizer#score
|
||||||
"""
|
"""
|
||||||
validate_examples(examples, "TextCategorizer.score")
|
validate_examples(examples, "TextCategorizer.score")
|
||||||
|
self._validate_categories(examples)
|
||||||
return Scorer.score_cats(
|
return Scorer.score_cats(
|
||||||
examples,
|
examples,
|
||||||
"cats",
|
"cats",
|
||||||
labels=self.labels,
|
labels=self.labels,
|
||||||
multi_label=self.model.attrs["multi_label"],
|
multi_label=False,
|
||||||
positive_label=self.cfg["positive_label"],
|
positive_label=self.cfg["positive_label"],
|
||||||
threshold=self.cfg["threshold"],
|
threshold=self.cfg["threshold"],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_categories(self, examples: List[Example]):
|
||||||
|
"""Check whether the provided examples all have single-label cats annotations."""
|
||||||
|
for ex in examples:
|
||||||
|
if list(ex.reference.cats.values()).count(1.0) > 1:
|
||||||
|
raise ValueError(Errors.E895.format(value=ex.reference.cats))
|
||||||
|
|
191
spacy/pipeline/textcat_multilabel.py
Normal file
191
spacy/pipeline/textcat_multilabel.py
Normal file
|
@ -0,0 +1,191 @@
|
||||||
|
from itertools import islice
|
||||||
|
from typing import Iterable, Optional, Dict, List, Callable, Any
|
||||||
|
|
||||||
|
from thinc.api import Model, Config
|
||||||
|
from thinc.types import Floats2d
|
||||||
|
|
||||||
|
from ..language import Language
|
||||||
|
from ..training import Example, validate_examples, validate_get_examples
|
||||||
|
from ..errors import Errors
|
||||||
|
from ..scorer import Scorer
|
||||||
|
from ..tokens import Doc
|
||||||
|
from ..vocab import Vocab
|
||||||
|
from .textcat import TextCategorizer
|
||||||
|
|
||||||
|
|
||||||
|
multi_label_default_config = """
|
||||||
|
[model]
|
||||||
|
@architectures = "spacy.TextCatEnsemble.v2"
|
||||||
|
|
||||||
|
[model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2Vec.v1"
|
||||||
|
|
||||||
|
[model.tok2vec.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v1"
|
||||||
|
width = 64
|
||||||
|
rows = [2000, 2000, 1000, 1000, 1000, 1000]
|
||||||
|
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
|
||||||
|
include_static_vectors = false
|
||||||
|
|
||||||
|
[model.tok2vec.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v1"
|
||||||
|
width = ${model.tok2vec.embed.width}
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
depth = 2
|
||||||
|
|
||||||
|
[model.linear_model]
|
||||||
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
|
exclusive_classes = false
|
||||||
|
ngram_size = 1
|
||||||
|
no_output_layer = false
|
||||||
|
"""
|
||||||
|
DEFAULT_MULTI_TEXTCAT_MODEL = Config().from_str(multi_label_default_config)["model"]
|
||||||
|
|
||||||
|
multi_label_bow_config = """
|
||||||
|
[model]
|
||||||
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
|
exclusive_classes = false
|
||||||
|
ngram_size = 1
|
||||||
|
no_output_layer = false
|
||||||
|
"""
|
||||||
|
|
||||||
|
multi_label_cnn_config = """
|
||||||
|
[model]
|
||||||
|
@architectures = "spacy.TextCatCNN.v1"
|
||||||
|
exclusive_classes = false
|
||||||
|
|
||||||
|
[model.tok2vec]
|
||||||
|
@architectures = "spacy.HashEmbedCNN.v1"
|
||||||
|
pretrained_vectors = null
|
||||||
|
width = 96
|
||||||
|
depth = 4
|
||||||
|
embed_size = 2000
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
subword_features = true
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@Language.factory(
|
||||||
|
"textcat_multilabel",
|
||||||
|
assigns=["doc.cats"],
|
||||||
|
default_config={"threshold": 0.5, "model": DEFAULT_MULTI_TEXTCAT_MODEL},
|
||||||
|
default_score_weights={
|
||||||
|
"cats_score": 1.0,
|
||||||
|
"cats_score_desc": None,
|
||||||
|
"cats_micro_p": None,
|
||||||
|
"cats_micro_r": None,
|
||||||
|
"cats_micro_f": None,
|
||||||
|
"cats_macro_p": None,
|
||||||
|
"cats_macro_r": None,
|
||||||
|
"cats_macro_f": None,
|
||||||
|
"cats_macro_auc": None,
|
||||||
|
"cats_f_per_type": None,
|
||||||
|
"cats_macro_auc_per_type": None,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
def make_multilabel_textcat(
|
||||||
|
nlp: Language, name: str, model: Model[List[Doc], List[Floats2d]], threshold: float
|
||||||
|
) -> "TextCategorizer":
|
||||||
|
"""Create a TextCategorizer compoment. The text categorizer predicts categories
|
||||||
|
over a whole document. It can learn one or more labels, and the labels can
|
||||||
|
be mutually exclusive (i.e. one true label per doc) or non-mutually exclusive
|
||||||
|
(i.e. zero or more labels may be true per doc). The multi-label setting is
|
||||||
|
controlled by the model instance that's provided.
|
||||||
|
|
||||||
|
model (Model[List[Doc], List[Floats2d]]): A model instance that predicts
|
||||||
|
scores for each category.
|
||||||
|
threshold (float): Cutoff to consider a prediction "positive".
|
||||||
|
"""
|
||||||
|
return MultiLabel_TextCategorizer(nlp.vocab, model, name, threshold=threshold)
|
||||||
|
|
||||||
|
|
||||||
|
class MultiLabel_TextCategorizer(TextCategorizer):
|
||||||
|
"""Pipeline component for multi-label text classification.
|
||||||
|
|
||||||
|
DOCS: https://nightly.spacy.io/api/multilabel_textcategorizer
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
vocab: Vocab,
|
||||||
|
model: Model,
|
||||||
|
name: str = "textcat_multilabel",
|
||||||
|
*,
|
||||||
|
threshold: float,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize a text categorizer for multi-label classification.
|
||||||
|
|
||||||
|
vocab (Vocab): The shared vocabulary.
|
||||||
|
model (thinc.api.Model): The Thinc Model powering the pipeline component.
|
||||||
|
name (str): The component instance name, used to add entries to the
|
||||||
|
losses during training.
|
||||||
|
threshold (float): Cutoff to consider a prediction "positive".
|
||||||
|
|
||||||
|
DOCS: https://nightly.spacy.io/api/multilabel_textcategorizer#init
|
||||||
|
"""
|
||||||
|
self.vocab = vocab
|
||||||
|
self.model = model
|
||||||
|
self.name = name
|
||||||
|
self._rehearsal_model = None
|
||||||
|
cfg = {"labels": [], "threshold": threshold}
|
||||||
|
self.cfg = dict(cfg)
|
||||||
|
|
||||||
|
def initialize(
|
||||||
|
self,
|
||||||
|
get_examples: Callable[[], Iterable[Example]],
|
||||||
|
*,
|
||||||
|
nlp: Optional[Language] = None,
|
||||||
|
labels: Optional[Dict] = None,
|
||||||
|
):
|
||||||
|
"""Initialize the pipe for training, using a representative set
|
||||||
|
of data examples.
|
||||||
|
|
||||||
|
get_examples (Callable[[], Iterable[Example]]): Function that
|
||||||
|
returns a representative sample of gold-standard Example objects.
|
||||||
|
nlp (Language): The current nlp object the component is part of.
|
||||||
|
labels: The labels to add to the component, typically generated by the
|
||||||
|
`init labels` command. If no labels are provided, the get_examples
|
||||||
|
callback is used to extract the labels from the data.
|
||||||
|
|
||||||
|
DOCS: https://nightly.spacy.io/api/multilabel_textcategorizer#initialize
|
||||||
|
"""
|
||||||
|
validate_get_examples(get_examples, "MultiLabel_TextCategorizer.initialize")
|
||||||
|
if labels is None:
|
||||||
|
for example in get_examples():
|
||||||
|
for cat in example.y.cats:
|
||||||
|
self.add_label(cat)
|
||||||
|
else:
|
||||||
|
for label in labels:
|
||||||
|
self.add_label(label)
|
||||||
|
subbatch = list(islice(get_examples(), 10))
|
||||||
|
doc_sample = [eg.reference for eg in subbatch]
|
||||||
|
label_sample, _ = self._examples_to_truth(subbatch)
|
||||||
|
self._require_labels()
|
||||||
|
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
|
||||||
|
self.model.initialize(X=doc_sample, Y=label_sample)
|
||||||
|
|
||||||
|
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
|
||||||
|
"""Score a batch of examples.
|
||||||
|
|
||||||
|
examples (Iterable[Example]): The examples to score.
|
||||||
|
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_cats.
|
||||||
|
|
||||||
|
DOCS: https://nightly.spacy.io/api/multilabel_textcategorizer#score
|
||||||
|
"""
|
||||||
|
validate_examples(examples, "MultiLabel_TextCategorizer.score")
|
||||||
|
return Scorer.score_cats(
|
||||||
|
examples,
|
||||||
|
"cats",
|
||||||
|
labels=self.labels,
|
||||||
|
multi_label=True,
|
||||||
|
threshold=self.cfg["threshold"],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _validate_categories(self, examples: List[Example]):
|
||||||
|
"""This component allows any type of single- or multi-label annotations.
|
||||||
|
This method overwrites the more strict one from 'textcat'. """
|
||||||
|
pass
|
|
@ -460,7 +460,7 @@ class Scorer:
|
||||||
gold_label, gold_score = max(gold_cats, key=lambda it: it[1])
|
gold_label, gold_score = max(gold_cats, key=lambda it: it[1])
|
||||||
if gold_score is not None and gold_score > 0:
|
if gold_score is not None and gold_score > 0:
|
||||||
f_per_type[gold_label].fn += 1
|
f_per_type[gold_label].fn += 1
|
||||||
else:
|
elif pred_cats:
|
||||||
pred_label, pred_score = max(pred_cats, key=lambda it: it[1])
|
pred_label, pred_score = max(pred_cats, key=lambda it: it[1])
|
||||||
if pred_score >= threshold:
|
if pred_score >= threshold:
|
||||||
f_per_type[pred_label].fp += 1
|
f_per_type[pred_label].fp += 1
|
||||||
|
|
|
@ -15,15 +15,31 @@ from spacy.training import Example
|
||||||
from ..util import make_tempdir
|
from ..util import make_tempdir
|
||||||
|
|
||||||
|
|
||||||
TRAIN_DATA = [
|
TRAIN_DATA_SINGLE_LABEL = [
|
||||||
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
|
("I'm so happy.", {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}),
|
||||||
("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
|
("I'm so angry", {"cats": {"POSITIVE": 0.0, "NEGATIVE": 1.0}}),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
TRAIN_DATA_MULTI_LABEL = [
|
||||||
|
("I'm angry and confused", {"cats": {"ANGRY": 1.0, "CONFUSED": 1.0, "HAPPY": 0.0}}),
|
||||||
|
("I'm confused but happy", {"cats": {"ANGRY": 0.0, "CONFUSED": 1.0, "HAPPY": 1.0}}),
|
||||||
|
]
|
||||||
|
|
||||||
def make_get_examples(nlp):
|
|
||||||
|
def make_get_examples_single_label(nlp):
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for t in TRAIN_DATA:
|
for t in TRAIN_DATA_SINGLE_LABEL:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
|
||||||
|
def get_examples():
|
||||||
|
return train_examples
|
||||||
|
|
||||||
|
return get_examples
|
||||||
|
|
||||||
|
|
||||||
|
def make_get_examples_multi_label(nlp):
|
||||||
|
train_examples = []
|
||||||
|
for t in TRAIN_DATA_MULTI_LABEL:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
train_examples.append(Example.from_dict(nlp.make_doc(t[0]), t[1]))
|
||||||
|
|
||||||
def get_examples():
|
def get_examples():
|
||||||
|
@ -85,49 +101,75 @@ def test_textcat_learns_multilabel():
|
||||||
assert score > 0.5
|
assert score > 0.5
|
||||||
|
|
||||||
|
|
||||||
def test_label_types():
|
@pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"])
|
||||||
|
def test_label_types(name):
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
textcat = nlp.add_pipe("textcat")
|
textcat = nlp.add_pipe(name)
|
||||||
textcat.add_label("answer")
|
textcat.add_label("answer")
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
textcat.add_label(9)
|
textcat.add_label(9)
|
||||||
|
|
||||||
|
|
||||||
def test_no_label():
|
@pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"])
|
||||||
|
def test_no_label(name):
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
nlp.add_pipe("textcat")
|
nlp.add_pipe(name)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
|
|
||||||
|
|
||||||
def test_implicit_label():
|
@pytest.mark.parametrize(
|
||||||
|
"name,get_examples",
|
||||||
|
[
|
||||||
|
("textcat", make_get_examples_single_label),
|
||||||
|
("textcat_multilabel", make_get_examples_multi_label),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_implicit_label(name, get_examples):
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
nlp.add_pipe("textcat")
|
nlp.add_pipe(name)
|
||||||
nlp.initialize(get_examples=make_get_examples(nlp))
|
nlp.initialize(get_examples=get_examples(nlp))
|
||||||
|
|
||||||
|
|
||||||
def test_no_resize():
|
@pytest.mark.parametrize("name", ["textcat", "textcat_multilabel"])
|
||||||
|
def test_no_resize(name):
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
textcat = nlp.add_pipe("textcat")
|
textcat = nlp.add_pipe(name)
|
||||||
textcat.add_label("POSITIVE")
|
textcat.add_label("POSITIVE")
|
||||||
textcat.add_label("NEGATIVE")
|
textcat.add_label("NEGATIVE")
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
assert textcat.model.get_dim("nO") == 2
|
assert textcat.model.get_dim("nO") >= 2
|
||||||
# this throws an error because the textcat can't be resized after initialization
|
# this throws an error because the textcat can't be resized after initialization
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
textcat.add_label("NEUTRAL")
|
textcat.add_label("NEUTRAL")
|
||||||
|
|
||||||
|
|
||||||
def test_initialize_examples():
|
def test_error_with_multi_labels():
|
||||||
nlp = Language()
|
nlp = Language()
|
||||||
textcat = nlp.add_pipe("textcat")
|
textcat = nlp.add_pipe("textcat")
|
||||||
for text, annotations in TRAIN_DATA:
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA_MULTI_LABEL:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"name,get_examples, train_data",
|
||||||
|
[
|
||||||
|
("textcat", make_get_examples_single_label, TRAIN_DATA_SINGLE_LABEL),
|
||||||
|
("textcat_multilabel", make_get_examples_multi_label, TRAIN_DATA_MULTI_LABEL),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_initialize_examples(name, get_examples, train_data):
|
||||||
|
nlp = Language()
|
||||||
|
textcat = nlp.add_pipe(name)
|
||||||
|
for text, annotations in train_data:
|
||||||
for label, value in annotations.get("cats").items():
|
for label, value in annotations.get("cats").items():
|
||||||
textcat.add_label(label)
|
textcat.add_label(label)
|
||||||
# you shouldn't really call this more than once, but for testing it should be fine
|
# you shouldn't really call this more than once, but for testing it should be fine
|
||||||
nlp.initialize()
|
nlp.initialize()
|
||||||
get_examples = make_get_examples(nlp)
|
nlp.initialize(get_examples=get_examples(nlp))
|
||||||
nlp.initialize(get_examples=get_examples)
|
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
nlp.initialize(get_examples=lambda: None)
|
nlp.initialize(get_examples=lambda: None)
|
||||||
with pytest.raises(TypeError):
|
with pytest.raises(TypeError):
|
||||||
|
@ -138,12 +180,10 @@ def test_overfitting_IO():
|
||||||
# Simple test to try and quickly overfit the single-label textcat component - ensuring the ML models work correctly
|
# Simple test to try and quickly overfit the single-label textcat component - ensuring the ML models work correctly
|
||||||
fix_random_seed(0)
|
fix_random_seed(0)
|
||||||
nlp = English()
|
nlp = English()
|
||||||
nlp.config["initialize"]["components"]["textcat"] = {"positive_label": "POSITIVE"}
|
textcat = nlp.add_pipe("textcat")
|
||||||
# Set exclusive labels
|
|
||||||
config = {"model": {"linear_model": {"exclusive_classes": True}}}
|
|
||||||
textcat = nlp.add_pipe("textcat", config=config)
|
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for text, annotations in TRAIN_DATA:
|
for text, annotations in TRAIN_DATA_SINGLE_LABEL:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
assert textcat.model.get_dim("nO") == 2
|
assert textcat.model.get_dim("nO") == 2
|
||||||
|
@ -172,6 +212,8 @@ def test_overfitting_IO():
|
||||||
# Test scoring
|
# Test scoring
|
||||||
scores = nlp.evaluate(train_examples)
|
scores = nlp.evaluate(train_examples)
|
||||||
assert scores["cats_micro_f"] == 1.0
|
assert scores["cats_micro_f"] == 1.0
|
||||||
|
assert scores["cats_macro_f"] == 1.0
|
||||||
|
assert scores["cats_macro_auc"] == 1.0
|
||||||
assert scores["cats_score"] == 1.0
|
assert scores["cats_score"] == 1.0
|
||||||
assert "cats_score_desc" in scores
|
assert "cats_score_desc" in scores
|
||||||
|
|
||||||
|
@ -192,7 +234,7 @@ def test_overfitting_IO_multi():
|
||||||
config = {"model": {"linear_model": {"exclusive_classes": False}}}
|
config = {"model": {"linear_model": {"exclusive_classes": False}}}
|
||||||
textcat = nlp.add_pipe("textcat", config=config)
|
textcat = nlp.add_pipe("textcat", config=config)
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for text, annotations in TRAIN_DATA:
|
for text, annotations in TRAIN_DATA_MULTI_LABEL:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
assert textcat.model.get_dim("nO") == 2
|
assert textcat.model.get_dim("nO") == 2
|
||||||
|
@ -231,27 +273,75 @@ def test_overfitting_IO_multi():
|
||||||
assert_equal(batch_cats_1, no_batch_cats)
|
assert_equal(batch_cats_1, no_batch_cats)
|
||||||
|
|
||||||
|
|
||||||
|
def test_overfitting_IO_multi():
|
||||||
|
# Simple test to try and quickly overfit the multi-label textcat component - ensuring the ML models work correctly
|
||||||
|
fix_random_seed(0)
|
||||||
|
nlp = English()
|
||||||
|
textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
|
|
||||||
|
train_examples = []
|
||||||
|
for text, annotations in TRAIN_DATA_MULTI_LABEL:
|
||||||
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
|
optimizer = nlp.initialize(get_examples=lambda: train_examples)
|
||||||
|
assert textcat.model.get_dim("nO") == 3
|
||||||
|
|
||||||
|
for i in range(100):
|
||||||
|
losses = {}
|
||||||
|
nlp.update(train_examples, sgd=optimizer, losses=losses)
|
||||||
|
assert losses["textcat_multilabel"] < 0.01
|
||||||
|
|
||||||
|
# test the trained model
|
||||||
|
test_text = "I am confused but happy."
|
||||||
|
doc = nlp(test_text)
|
||||||
|
cats = doc.cats
|
||||||
|
assert cats["HAPPY"] > 0.9
|
||||||
|
assert cats["CONFUSED"] > 0.9
|
||||||
|
|
||||||
|
# Also test the results are still the same after IO
|
||||||
|
with make_tempdir() as tmp_dir:
|
||||||
|
nlp.to_disk(tmp_dir)
|
||||||
|
nlp2 = util.load_model_from_path(tmp_dir)
|
||||||
|
doc2 = nlp2(test_text)
|
||||||
|
cats2 = doc2.cats
|
||||||
|
assert cats2["HAPPY"] > 0.9
|
||||||
|
assert cats2["CONFUSED"] > 0.9
|
||||||
|
|
||||||
|
# Test scoring
|
||||||
|
scores = nlp.evaluate(train_examples)
|
||||||
|
assert scores["cats_micro_f"] == 1.0
|
||||||
|
assert scores["cats_macro_f"] == 1.0
|
||||||
|
assert "cats_score_desc" in scores
|
||||||
|
|
||||||
|
# Make sure that running pipe twice, or comparing to call, always amounts to the same predictions
|
||||||
|
texts = ["Just a sentence.", "I like green eggs.", "I am happy.", "I eat ham."]
|
||||||
|
batch_deps_1 = [doc.cats for doc in nlp.pipe(texts)]
|
||||||
|
batch_deps_2 = [doc.cats for doc in nlp.pipe(texts)]
|
||||||
|
no_batch_deps = [doc.cats for doc in [nlp(text) for text in texts]]
|
||||||
|
assert_equal(batch_deps_1, batch_deps_2)
|
||||||
|
assert_equal(batch_deps_1, no_batch_deps)
|
||||||
|
|
||||||
|
|
||||||
# fmt: off
|
# fmt: off
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"textcat_config",
|
"name,train_data,textcat_config",
|
||||||
[
|
[
|
||||||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False},
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}),
|
||||||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False},
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 4, "no_output_layer": False}),
|
||||||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True},
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 3, "no_output_layer": True}),
|
||||||
{"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True},
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 2, "no_output_layer": True}),
|
||||||
{"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}},
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": False, "ngram_size": 1, "no_output_layer": False}}),
|
||||||
{"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}},
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatEnsemble.v2", "tok2vec": DEFAULT_TOK2VEC_MODEL, "linear_model": {"@architectures": "spacy.TextCatBOW.v1", "exclusive_classes": True, "ngram_size": 5, "no_output_layer": False}}),
|
||||||
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True},
|
("textcat", TRAIN_DATA_SINGLE_LABEL, {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": True}),
|
||||||
{"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False},
|
("textcat_multilabel", TRAIN_DATA_MULTI_LABEL, {"@architectures": "spacy.TextCatCNN.v1", "tok2vec": DEFAULT_TOK2VEC_MODEL, "exclusive_classes": False}),
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
# fmt: on
|
# fmt: on
|
||||||
def test_textcat_configs(textcat_config):
|
def test_textcat_configs(name, train_data, textcat_config):
|
||||||
pipe_config = {"model": textcat_config}
|
pipe_config = {"model": textcat_config}
|
||||||
nlp = English()
|
nlp = English()
|
||||||
textcat = nlp.add_pipe("textcat", config=pipe_config)
|
textcat = nlp.add_pipe(name, config=pipe_config)
|
||||||
train_examples = []
|
train_examples = []
|
||||||
for text, annotations in TRAIN_DATA:
|
for text, annotations in train_data:
|
||||||
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
|
||||||
for label, value in annotations.get("cats").items():
|
for label, value in annotations.get("cats").items():
|
||||||
textcat.add_label(label)
|
textcat.add_label(label)
|
||||||
|
@ -264,15 +354,24 @@ def test_textcat_configs(textcat_config):
|
||||||
def test_positive_class():
|
def test_positive_class():
|
||||||
nlp = English()
|
nlp = English()
|
||||||
textcat = nlp.add_pipe("textcat")
|
textcat = nlp.add_pipe("textcat")
|
||||||
get_examples = make_get_examples(nlp)
|
get_examples = make_get_examples_single_label(nlp)
|
||||||
textcat.initialize(get_examples, labels=["POS", "NEG"], positive_label="POS")
|
textcat.initialize(get_examples, labels=["POS", "NEG"], positive_label="POS")
|
||||||
assert textcat.labels == ("POS", "NEG")
|
assert textcat.labels == ("POS", "NEG")
|
||||||
|
assert textcat.cfg["positive_label"] == "POS"
|
||||||
|
|
||||||
|
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
|
||||||
|
get_examples = make_get_examples_multi_label(nlp)
|
||||||
|
with pytest.raises(TypeError):
|
||||||
|
textcat_multilabel.initialize(get_examples, labels=["POS", "NEG"], positive_label="POS")
|
||||||
|
textcat_multilabel.initialize(get_examples, labels=["FICTION", "DRAMA"])
|
||||||
|
assert textcat_multilabel.labels == ("FICTION", "DRAMA")
|
||||||
|
assert "positive_label" not in textcat_multilabel.cfg
|
||||||
|
|
||||||
|
|
||||||
def test_positive_class_not_present():
|
def test_positive_class_not_present():
|
||||||
nlp = English()
|
nlp = English()
|
||||||
textcat = nlp.add_pipe("textcat")
|
textcat = nlp.add_pipe("textcat")
|
||||||
get_examples = make_get_examples(nlp)
|
get_examples = make_get_examples_single_label(nlp)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
textcat.initialize(get_examples, labels=["SOME", "THING"], positive_label="POS")
|
textcat.initialize(get_examples, labels=["SOME", "THING"], positive_label="POS")
|
||||||
|
|
||||||
|
@ -280,11 +379,9 @@ def test_positive_class_not_present():
|
||||||
def test_positive_class_not_binary():
|
def test_positive_class_not_binary():
|
||||||
nlp = English()
|
nlp = English()
|
||||||
textcat = nlp.add_pipe("textcat")
|
textcat = nlp.add_pipe("textcat")
|
||||||
get_examples = make_get_examples(nlp)
|
get_examples = make_get_examples_multi_label(nlp)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
textcat.initialize(
|
textcat.initialize(get_examples, labels=["SOME", "THING", "POS"], positive_label="POS")
|
||||||
get_examples, labels=["SOME", "THING", "POS"], positive_label="POS"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def test_textcat_evaluation():
|
def test_textcat_evaluation():
|
||||||
|
|
|
@ -2,8 +2,11 @@ import pytest
|
||||||
from thinc.api import Config, fix_random_seed
|
from thinc.api import Config, fix_random_seed
|
||||||
|
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from spacy.pipeline.textcat import default_model_config, bow_model_config
|
from spacy.pipeline.textcat import single_label_default_config, single_label_bow_config
|
||||||
from spacy.pipeline.textcat import cnn_model_config
|
from spacy.pipeline.textcat import single_label_cnn_config
|
||||||
|
from spacy.pipeline.textcat_multilabel import multi_label_default_config
|
||||||
|
from spacy.pipeline.textcat_multilabel import multi_label_bow_config
|
||||||
|
from spacy.pipeline.textcat_multilabel import multi_label_cnn_config
|
||||||
from spacy.tokens import Span
|
from spacy.tokens import Span
|
||||||
from spacy import displacy
|
from spacy import displacy
|
||||||
from spacy.pipeline import merge_entities
|
from spacy.pipeline import merge_entities
|
||||||
|
@ -11,7 +14,15 @@ from spacy.training import Example
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"textcat_config", [default_model_config, bow_model_config, cnn_model_config]
|
"textcat_config",
|
||||||
|
[
|
||||||
|
single_label_default_config,
|
||||||
|
single_label_bow_config,
|
||||||
|
single_label_cnn_config,
|
||||||
|
multi_label_default_config,
|
||||||
|
multi_label_bow_config,
|
||||||
|
multi_label_cnn_config,
|
||||||
|
],
|
||||||
)
|
)
|
||||||
def test_issue5551(textcat_config):
|
def test_issue5551(textcat_config):
|
||||||
"""Test that after fixing the random seed, the results of the pipeline are truly identical"""
|
"""Test that after fixing the random seed, the results of the pipeline are truly identical"""
|
||||||
|
|
|
@ -4,7 +4,7 @@ from spacy.pipeline import Tagger, DependencyParser, EntityRecognizer
|
||||||
from spacy.pipeline import TextCategorizer, SentenceRecognizer, TrainablePipe
|
from spacy.pipeline import TextCategorizer, SentenceRecognizer, TrainablePipe
|
||||||
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
from spacy.pipeline.dep_parser import DEFAULT_PARSER_MODEL
|
||||||
from spacy.pipeline.tagger import DEFAULT_TAGGER_MODEL
|
from spacy.pipeline.tagger import DEFAULT_TAGGER_MODEL
|
||||||
from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
|
from spacy.pipeline.textcat import DEFAULT_SINGLE_TEXTCAT_MODEL
|
||||||
from spacy.pipeline.senter import DEFAULT_SENTER_MODEL
|
from spacy.pipeline.senter import DEFAULT_SENTER_MODEL
|
||||||
from spacy.lang.en import English
|
from spacy.lang.en import English
|
||||||
from thinc.api import Linear
|
from thinc.api import Linear
|
||||||
|
@ -188,7 +188,7 @@ def test_serialize_tagger_strings(en_vocab, de_vocab, taggers):
|
||||||
|
|
||||||
def test_serialize_textcat_empty(en_vocab):
|
def test_serialize_textcat_empty(en_vocab):
|
||||||
# See issue #1105
|
# See issue #1105
|
||||||
cfg = {"model": DEFAULT_TEXTCAT_MODEL}
|
cfg = {"model": DEFAULT_SINGLE_TEXTCAT_MODEL}
|
||||||
model = registry.resolve(cfg, validate=True)["model"]
|
model = registry.resolve(cfg, validate=True)["model"]
|
||||||
textcat = TextCategorizer(en_vocab, model, threshold=0.5)
|
textcat = TextCategorizer(en_vocab, model, threshold=0.5)
|
||||||
textcat.to_bytes(exclude=["vocab"])
|
textcat.to_bytes(exclude=["vocab"])
|
||||||
|
|
|
@ -51,7 +51,7 @@ def test_readers():
|
||||||
for example in train_corpus(nlp):
|
for example in train_corpus(nlp):
|
||||||
nlp.update([example], sgd=optimizer)
|
nlp.update([example], sgd=optimizer)
|
||||||
scores = nlp.evaluate(list(dev_corpus(nlp)))
|
scores = nlp.evaluate(list(dev_corpus(nlp)))
|
||||||
assert scores["cats_score"] == 0.0
|
assert scores["cats_macro_auc"] == 0.0
|
||||||
# ensure the pipeline runs
|
# ensure the pipeline runs
|
||||||
doc = nlp("Quick test")
|
doc = nlp("Quick test")
|
||||||
assert doc.cats
|
assert doc.cats
|
||||||
|
|
|
@ -94,7 +94,7 @@ Defines the `nlp` object, its tokenizer and
|
||||||
>
|
>
|
||||||
> [components.textcat.model]
|
> [components.textcat.model]
|
||||||
> @architectures = "spacy.TextCatBOW.v1"
|
> @architectures = "spacy.TextCatBOW.v1"
|
||||||
> exclusive_classes = false
|
> exclusive_classes = true
|
||||||
> ngram_size = 1
|
> ngram_size = 1
|
||||||
> no_output_layer = false
|
> no_output_layer = false
|
||||||
> ```
|
> ```
|
||||||
|
|
454
website/docs/api/multilabel_textcategorizer.md
Normal file
454
website/docs/api/multilabel_textcategorizer.md
Normal file
|
@ -0,0 +1,454 @@
|
||||||
|
---
|
||||||
|
title: Multi-label TextCategorizer
|
||||||
|
tag: class
|
||||||
|
source: spacy/pipeline/textcat_multilabel.py
|
||||||
|
new: 3
|
||||||
|
teaser: 'Pipeline component for multi-label text classification'
|
||||||
|
api_base_class: /api/pipe
|
||||||
|
api_string_name: textcat_multilabel
|
||||||
|
api_trainable: true
|
||||||
|
---
|
||||||
|
|
||||||
|
The text categorizer predicts **categories over a whole document**. It
|
||||||
|
learns non-mutually exclusive labels, which means that zero or more labels
|
||||||
|
may be true per document.
|
||||||
|
|
||||||
|
## Config and implementation {#config}
|
||||||
|
|
||||||
|
The default config is defined by the pipeline component factory and describes
|
||||||
|
how the component should be configured. You can override its settings via the
|
||||||
|
`config` argument on [`nlp.add_pipe`](/api/language#add_pipe) or in your
|
||||||
|
[`config.cfg` for training](/usage/training#config). See the
|
||||||
|
[model architectures](/api/architectures) documentation for details on the
|
||||||
|
architectures and their arguments and hyperparameters.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> from spacy.pipeline.textcat_multilabel import DEFAULT_MULTI_TEXTCAT_MODEL
|
||||||
|
> config = {
|
||||||
|
> "threshold": 0.5,
|
||||||
|
> "model": DEFAULT_MULTI_TEXTCAT_MODEL,
|
||||||
|
> }
|
||||||
|
> nlp.add_pipe("textcat_multilabel", config=config)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Setting | Description |
|
||||||
|
| ----------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||||
|
| `model` | A model instance that predicts scores for each category. Defaults to [TextCatEnsemble](/api/architectures#TextCatEnsemble). ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||||
|
|
||||||
|
```python
|
||||||
|
%%GITHUB_SPACY/spacy/pipeline/textcat_multilabel.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.\_\_init\_\_ {#init tag="method"}
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> # Construction via add_pipe with default model
|
||||||
|
> textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
|
>
|
||||||
|
> # Construction via add_pipe with custom model
|
||||||
|
> config = {"model": {"@architectures": "my_textcat"}}
|
||||||
|
> parser = nlp.add_pipe("textcat_multilabel", config=config)
|
||||||
|
>
|
||||||
|
> # Construction from class
|
||||||
|
> from spacy.pipeline import MultiLabel_TextCategorizer
|
||||||
|
> textcat = MultiLabel_TextCategorizer(nlp.vocab, model, threshold=0.5)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
Create a new pipeline instance. In your application, you would normally use a
|
||||||
|
shortcut for this and instantiate the component using its string name and
|
||||||
|
[`nlp.add_pipe`](/api/language#create_pipe).
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------- | -------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `vocab` | The shared vocabulary. ~~Vocab~~ |
|
||||||
|
| `model` | The Thinc [`Model`](https://thinc.ai/docs/api-model) powering the pipeline component. ~~Model[List[Doc], List[Floats2d]]~~ |
|
||||||
|
| `name` | String name of the component instance. Used to add entries to the `losses` during training. ~~str~~ |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| `threshold` | Cutoff to consider a prediction "positive", relevant when printing accuracy results. ~~float~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.\_\_call\_\_ {#call tag="method"}
|
||||||
|
|
||||||
|
Apply the pipe to one document. The document is modified in place, and returned.
|
||||||
|
This usually happens under the hood when the `nlp` object is called on a text
|
||||||
|
and all pipeline components are applied to the `Doc` in order. Both
|
||||||
|
[`__call__`](/api/multilabel_textcategorizer#call) and [`pipe`](/api/multilabel_textcategorizer#pipe)
|
||||||
|
delegate to the [`predict`](/api/multilabel_textcategorizer#predict) and
|
||||||
|
[`set_annotations`](/api/multilabel_textcategorizer#set_annotations) methods.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> doc = nlp("This is a sentence.")
|
||||||
|
> textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
|
> # This usually happens under the hood
|
||||||
|
> processed = textcat(doc)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | -------------------------------- |
|
||||||
|
| `doc` | The document to process. ~~Doc~~ |
|
||||||
|
| **RETURNS** | The processed document. ~~Doc~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.pipe {#pipe tag="method"}
|
||||||
|
|
||||||
|
Apply the pipe to a stream of documents. This usually happens under the hood
|
||||||
|
when the `nlp` object is called on a text and all pipeline components are
|
||||||
|
applied to the `Doc` in order. Both [`__call__`](/api/multilabel_textcategorizer#call) and
|
||||||
|
[`pipe`](/api/multilabel_textcategorizer#pipe) delegate to the
|
||||||
|
[`predict`](/api/multilabel_textcategorizer#predict) and
|
||||||
|
[`set_annotations`](/api/multilabel_textcategorizer#set_annotations) methods.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
|
> for doc in textcat.pipe(docs, batch_size=50):
|
||||||
|
> pass
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------- | ------------------------------------------------------------- |
|
||||||
|
| `stream` | A stream of documents. ~~Iterable[Doc]~~ |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| `batch_size` | The number of documents to buffer. Defaults to `128`. ~~int~~ |
|
||||||
|
| **YIELDS** | The processed documents in order. ~~Doc~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.initialize {#initialize tag="method" new="3"}
|
||||||
|
|
||||||
|
Initialize the component for training. `get_examples` should be a function that
|
||||||
|
returns an iterable of [`Example`](/api/example) objects. The data examples are
|
||||||
|
used to **initialize the model** of the component and can either be the full
|
||||||
|
training data or a representative sample. Initialization includes validating the
|
||||||
|
network,
|
||||||
|
[inferring missing shapes](https://thinc.ai/docs/usage-models#validation) and
|
||||||
|
setting up the label scheme based on the data. This method is typically called
|
||||||
|
by [`Language.initialize`](/api/language#initialize) and lets you customize
|
||||||
|
arguments it receives via the
|
||||||
|
[`[initialize.components]`](/api/data-formats#config-initialize) block in the
|
||||||
|
config.
|
||||||
|
|
||||||
|
<Infobox variant="warning" title="Changed in v3.0" id="begin_training">
|
||||||
|
|
||||||
|
This method was previously called `begin_training`.
|
||||||
|
|
||||||
|
</Infobox>
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
|
> textcat.initialize(lambda: [], nlp=nlp)
|
||||||
|
> ```
|
||||||
|
>
|
||||||
|
> ```ini
|
||||||
|
> ### config.cfg
|
||||||
|
> [initialize.components.textcat_multilabel]
|
||||||
|
>
|
||||||
|
> [initialize.components.textcat_multilabel.labels]
|
||||||
|
> @readers = "spacy.read_labels.v1"
|
||||||
|
> path = "corpus/labels/textcat.json
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ---------------- | ---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `get_examples` | Function that returns gold-standard annotations in the form of [`Example`](/api/example) objects. ~~Callable[[], Iterable[Example]]~~ |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| `nlp` | The current `nlp` object. Defaults to `None`. ~~Optional[Language]~~ |
|
||||||
|
| `labels` | The label information to add to the component, as provided by the [`label_data`](#label_data) property after initialization. To generate a reusable JSON file from your data, you should run the [`init labels`](/api/cli#init-labels) command. If no labels are provided, the `get_examples` callback is used to extract the labels from the data, which may be a lot slower. ~~Optional[Iterable[str]]~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.predict {#predict tag="method"}
|
||||||
|
|
||||||
|
Apply the component's model to a batch of [`Doc`](/api/doc) objects without
|
||||||
|
modifying them.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
|
> scores = textcat.predict([doc1, doc2])
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ------------------------------------------- |
|
||||||
|
| `docs` | The documents to predict. ~~Iterable[Doc]~~ |
|
||||||
|
| **RETURNS** | The model's prediction for each document. |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.set_annotations {#set_annotations tag="method"}
|
||||||
|
|
||||||
|
Modify a batch of [`Doc`](/api/doc) objects using pre-computed scores.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
|
> scores = textcat.predict(docs)
|
||||||
|
> textcat.set_annotations(docs, scores)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------- | --------------------------------------------------------- |
|
||||||
|
| `docs` | The documents to modify. ~~Iterable[Doc]~~ |
|
||||||
|
| `scores` | The scores to set, produced by `MultiLabel_TextCategorizer.predict`. |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.update {#update tag="method"}
|
||||||
|
|
||||||
|
Learn from a batch of [`Example`](/api/example) objects containing the
|
||||||
|
predictions and gold-standard annotations, and update the component's model.
|
||||||
|
Delegates to [`predict`](/api/multilabel_textcategorizer#predict) and
|
||||||
|
[`get_loss`](/api/multilabel_textcategorizer#get_loss).
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
|
> optimizer = nlp.initialize()
|
||||||
|
> losses = textcat.update(examples, sgd=optimizer)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------------- | ---------------------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| `drop` | The dropout rate. ~~float~~ |
|
||||||
|
| `set_annotations` | Whether or not to update the `Example` objects with the predictions, delegating to [`set_annotations`](#set_annotations). ~~bool~~ |
|
||||||
|
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||||
|
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||||
|
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.rehearse {#rehearse tag="method,experimental" new="3"}
|
||||||
|
|
||||||
|
Perform a "rehearsal" update from a batch of data. Rehearsal updates teach the
|
||||||
|
current model to make predictions similar to an initial model to try to address
|
||||||
|
the "catastrophic forgetting" problem. This feature is experimental.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
|
> optimizer = nlp.resume_training()
|
||||||
|
> losses = textcat.rehearse(examples, sgd=optimizer)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------- | ------------------------------------------------------------------------------------------------------------------------ |
|
||||||
|
| `examples` | A batch of [`Example`](/api/example) objects to learn from. ~~Iterable[Example]~~ |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| `drop` | The dropout rate. ~~float~~ |
|
||||||
|
| `sgd` | An optimizer. Will be created via [`create_optimizer`](#create_optimizer) if not set. ~~Optional[Optimizer]~~ |
|
||||||
|
| `losses` | Optional record of the loss during training. Updated using the component name as the key. ~~Optional[Dict[str, float]]~~ |
|
||||||
|
| **RETURNS** | The updated `losses` dictionary. ~~Dict[str, float]~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.get_loss {#get_loss tag="method"}
|
||||||
|
|
||||||
|
Find the loss and gradient of loss for the batch of documents and their
|
||||||
|
predicted scores.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat = nlp.add_pipe("textcat_multilabel")
|
||||||
|
> scores = textcat.predict([eg.predicted for eg in examples])
|
||||||
|
> loss, d_loss = textcat.get_loss(examples, scores)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | --------------------------------------------------------------------------- |
|
||||||
|
| `examples` | The batch of examples. ~~Iterable[Example]~~ |
|
||||||
|
| `scores` | Scores representing the model's predictions. |
|
||||||
|
| **RETURNS** | The loss and the gradient, i.e. `(loss, gradient)`. ~~Tuple[float, float]~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.score {#score tag="method" new="3"}
|
||||||
|
|
||||||
|
Score a batch of examples.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> scores = textcat.score(examples)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ---------------- | -------------------------------------------------------------------------------------------------------------------- |
|
||||||
|
| `examples` | The examples to score. ~~Iterable[Example]~~ |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| **RETURNS** | The scores, produced by [`Scorer.score_cats`](/api/scorer#score_cats). ~~Dict[str, Union[float, Dict[str, float]]]~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.create_optimizer {#create_optimizer tag="method"}
|
||||||
|
|
||||||
|
Create an optimizer for the pipeline component.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat = nlp.add_pipe("textcat")
|
||||||
|
> optimizer = textcat.create_optimizer()
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ---------------------------- |
|
||||||
|
| **RETURNS** | The optimizer. ~~Optimizer~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.use_params {#use_params tag="method, contextmanager"}
|
||||||
|
|
||||||
|
Modify the pipe's model to use the given parameter values.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat = nlp.add_pipe("textcat")
|
||||||
|
> with textcat.use_params(optimizer.averages):
|
||||||
|
> textcat.to_disk("/best_model")
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------- | -------------------------------------------------- |
|
||||||
|
| `params` | The parameter values to use in the model. ~~dict~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.add_label {#add_label tag="method"}
|
||||||
|
|
||||||
|
Add a new label to the pipe. Raises an error if the output dimension is already
|
||||||
|
set, or if the model has already been fully [initialized](#initialize). Note
|
||||||
|
that you don't have to call this method if you provide a **representative data
|
||||||
|
sample** to the [`initialize`](#initialize) method. In this case, all labels
|
||||||
|
found in the sample will be automatically added to the model, and the output
|
||||||
|
dimension will be [inferred](/usage/layers-architectures#thinc-shape-inference)
|
||||||
|
automatically.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat = nlp.add_pipe("textcat")
|
||||||
|
> textcat.add_label("MY_LABEL")
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ----------------------------------------------------------- |
|
||||||
|
| `label` | The label to add. ~~str~~ |
|
||||||
|
| **RETURNS** | `0` if the label is already present, otherwise `1`. ~~int~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.to_disk {#to_disk tag="method"}
|
||||||
|
|
||||||
|
Serialize the pipe to disk.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat = nlp.add_pipe("textcat")
|
||||||
|
> textcat.to_disk("/path/to/textcat")
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------- | ------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||||
|
| `path` | A path to a directory, which will be created if it doesn't exist. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.from_disk {#from_disk tag="method"}
|
||||||
|
|
||||||
|
Load the pipe from disk. Modifies the object in place and returns it.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat = nlp.add_pipe("textcat")
|
||||||
|
> textcat.from_disk("/path/to/textcat")
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------- | ----------------------------------------------------------------------------------------------- |
|
||||||
|
| `path` | A path to a directory. Paths may be either strings or `Path`-like objects. ~~Union[str, Path]~~ |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||||
|
| **RETURNS** | The modified `MultiLabel_TextCategorizer` object. ~~MultiLabel_TextCategorizer~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.to_bytes {#to_bytes tag="method"}
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat = nlp.add_pipe("textcat")
|
||||||
|
> textcat_bytes = textcat.to_bytes()
|
||||||
|
> ```
|
||||||
|
|
||||||
|
Serialize the pipe to a bytestring.
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------- | ------------------------------------------------------------------------------------------- |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||||
|
| **RETURNS** | The serialized form of the `MultiLabel_TextCategorizer` object. ~~bytes~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.from_bytes {#from_bytes tag="method"}
|
||||||
|
|
||||||
|
Load the pipe from a bytestring. Modifies the object in place and returns it.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat_bytes = textcat.to_bytes()
|
||||||
|
> textcat = nlp.add_pipe("textcat")
|
||||||
|
> textcat.from_bytes(textcat_bytes)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| -------------- | ------------------------------------------------------------------------------------------- |
|
||||||
|
| `bytes_data` | The data to load from. ~~bytes~~ |
|
||||||
|
| _keyword-only_ | |
|
||||||
|
| `exclude` | String names of [serialization fields](#serialization-fields) to exclude. ~~Iterable[str]~~ |
|
||||||
|
| **RETURNS** | The `MultiLabel_TextCategorizer` object. ~~MultiLabel_TextCategorizer~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.labels {#labels tag="property"}
|
||||||
|
|
||||||
|
The labels currently added to the component.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> textcat.add_label("MY_LABEL")
|
||||||
|
> assert "MY_LABEL" in textcat.labels
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ------------------------------------------------------ |
|
||||||
|
| **RETURNS** | The labels added to the component. ~~Tuple[str, ...]~~ |
|
||||||
|
|
||||||
|
## MultiLabel_TextCategorizer.label_data {#label_data tag="property" new="3"}
|
||||||
|
|
||||||
|
The labels currently added to the component and their internal meta information.
|
||||||
|
This is the data generated by [`init labels`](/api/cli#init-labels) and used by
|
||||||
|
[`MultiLabel_TextCategorizer.initialize`](/api/multilabel_textcategorizer#initialize) to initialize
|
||||||
|
the model with a pre-defined label set.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> labels = textcat.label_data
|
||||||
|
> textcat.initialize(lambda: [], nlp=nlp, labels=labels)
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ----------- | ---------------------------------------------------------- |
|
||||||
|
| **RETURNS** | The label data added to the component. ~~Tuple[str, ...]~~ |
|
||||||
|
|
||||||
|
## Serialization fields {#serialization-fields}
|
||||||
|
|
||||||
|
During serialization, spaCy will export several data fields used to restore
|
||||||
|
different aspects of the object. If needed, you can exclude them from
|
||||||
|
serialization by passing in the string names via the `exclude` argument.
|
||||||
|
|
||||||
|
> #### Example
|
||||||
|
>
|
||||||
|
> ```python
|
||||||
|
> data = textcat.to_disk("/path", exclude=["vocab"])
|
||||||
|
> ```
|
||||||
|
|
||||||
|
| Name | Description |
|
||||||
|
| ------- | -------------------------------------------------------------- |
|
||||||
|
| `vocab` | The shared [`Vocab`](/api/vocab). |
|
||||||
|
| `cfg` | The config file. You usually don't want to exclude this. |
|
||||||
|
| `model` | The binary model data. You usually don't want to exclude this. |
|
|
@ -3,17 +3,15 @@ title: TextCategorizer
|
||||||
tag: class
|
tag: class
|
||||||
source: spacy/pipeline/textcat.py
|
source: spacy/pipeline/textcat.py
|
||||||
new: 2
|
new: 2
|
||||||
teaser: 'Pipeline component for text classification'
|
teaser: 'Pipeline component for single-label text classification'
|
||||||
api_base_class: /api/pipe
|
api_base_class: /api/pipe
|
||||||
api_string_name: textcat
|
api_string_name: textcat
|
||||||
api_trainable: true
|
api_trainable: true
|
||||||
---
|
---
|
||||||
|
|
||||||
The text categorizer predicts **categories over a whole document**. It can learn
|
The text categorizer predicts **categories over a whole document**. It can learn
|
||||||
one or more labels, and the labels can be mutually exclusive (i.e. one true
|
one or more labels, and the labels are mutually exclusive - there is exactly one
|
||||||
label per document) or non-mutually exclusive (i.e. zero or more labels may be
|
true label per document.
|
||||||
true per document). The multi-label setting is controlled by the model instance
|
|
||||||
that's provided.
|
|
||||||
|
|
||||||
## Config and implementation {#config}
|
## Config and implementation {#config}
|
||||||
|
|
||||||
|
@ -27,10 +25,10 @@ architectures and their arguments and hyperparameters.
|
||||||
> #### Example
|
> #### Example
|
||||||
>
|
>
|
||||||
> ```python
|
> ```python
|
||||||
> from spacy.pipeline.textcat import DEFAULT_TEXTCAT_MODEL
|
> from spacy.pipeline.textcat import DEFAULT_SINGLE_TEXTCAT_MODEL
|
||||||
> config = {
|
> config = {
|
||||||
> "threshold": 0.5,
|
> "threshold": 0.5,
|
||||||
> "model": DEFAULT_TEXTCAT_MODEL,
|
> "model": DEFAULT_SINGLE_TEXTCAT_MODEL,
|
||||||
> }
|
> }
|
||||||
> nlp.add_pipe("textcat", config=config)
|
> nlp.add_pipe("textcat", config=config)
|
||||||
> ```
|
> ```
|
||||||
|
@ -280,7 +278,6 @@ Score a batch of examples.
|
||||||
| ---------------- | -------------------------------------------------------------------------------------------------------------------- |
|
| ---------------- | -------------------------------------------------------------------------------------------------------------------- |
|
||||||
| `examples` | The examples to score. ~~Iterable[Example]~~ |
|
| `examples` | The examples to score. ~~Iterable[Example]~~ |
|
||||||
| _keyword-only_ | |
|
| _keyword-only_ | |
|
||||||
| `positive_label` | Optional positive label. ~~Optional[str]~~ |
|
|
||||||
| **RETURNS** | The scores, produced by [`Scorer.score_cats`](/api/scorer#score_cats). ~~Dict[str, Union[float, Dict[str, float]]]~~ |
|
| **RETURNS** | The scores, produced by [`Scorer.score_cats`](/api/scorer#score_cats). ~~Dict[str, Union[float, Dict[str, float]]]~~ |
|
||||||
|
|
||||||
## TextCategorizer.create_optimizer {#create_optimizer tag="method"}
|
## TextCategorizer.create_optimizer {#create_optimizer tag="method"}
|
||||||
|
|
|
@ -152,7 +152,7 @@ depth = 2
|
||||||
|
|
||||||
[components.textcat.model.linear_model]
|
[components.textcat.model.linear_model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
exclusive_classes = false
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
```
|
```
|
||||||
|
@ -170,7 +170,7 @@ labels = []
|
||||||
|
|
||||||
[components.textcat.model]
|
[components.textcat.model]
|
||||||
@architectures = "spacy.TextCatBOW.v1"
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
exclusive_classes = false
|
exclusive_classes = true
|
||||||
ngram_size = 1
|
ngram_size = 1
|
||||||
no_output_layer = false
|
no_output_layer = false
|
||||||
nO = null
|
nO = null
|
||||||
|
|
Loading…
Reference in New Issue
Block a user