spaCy/spacy/pipeline/textcat_multilabel.py
Sofie Van Landeghem 932887b950
textcat scoring fix and multi_label docs (#6974)
* add multi-label textcat to menu

* add infobox on textcat API

* add info to v3 migration guide

* small edits

* further fixes in doc strings

* add infobox to textcat architectures

* add textcat_multilabel to overview of built-in components

* spelling

* fix unrelated warn msg

* Add textcat_multilabel to quickstart [ci skip]

* remove separate documentation page for multilabel_textcategorizer

* small edits

* positive label clarification

* avoid duplicating information in self.cfg and fix textcat.score

* fix multilabel textcat too

* revert threshold to storage in cfg

* revert threshold stuff for multi-textcat

Co-authored-by: Ines Montani <ines@ines.io>
2021-03-09 23:04:22 +11:00

191 lines
6.0 KiB
Python

from itertools import islice
from typing import Iterable, Optional, Dict, List, Callable, Any
from thinc.api import Model, Config
from thinc.types import Floats2d
from ..language import Language
from ..training import Example, validate_examples, validate_get_examples
from ..errors import Errors
from ..scorer import Scorer
from ..tokens import Doc
from ..vocab import Vocab
from .textcat import TextCategorizer
multi_label_default_config = """
[model]
@architectures = "spacy.TextCatEnsemble.v2"
[model.tok2vec]
@architectures = "spacy.Tok2Vec.v1"
[model.tok2vec.embed]
@architectures = "spacy.MultiHashEmbed.v1"
width = 64
rows = [2000, 2000, 1000, 1000, 1000, 1000]
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
include_static_vectors = false
[model.tok2vec.encode]
@architectures = "spacy.MaxoutWindowEncoder.v1"
width = ${model.tok2vec.embed.width}
window_size = 1
maxout_pieces = 3
depth = 2
[model.linear_model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false
ngram_size = 1
no_output_layer = false
"""
DEFAULT_MULTI_TEXTCAT_MODEL = Config().from_str(multi_label_default_config)["model"]
multi_label_bow_config = """
[model]
@architectures = "spacy.TextCatBOW.v1"
exclusive_classes = false
ngram_size = 1
no_output_layer = false
"""
multi_label_cnn_config = """
[model]
@architectures = "spacy.TextCatCNN.v1"
exclusive_classes = false
[model.tok2vec]
@architectures = "spacy.HashEmbedCNN.v1"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true
"""
@Language.factory(
"textcat_multilabel",
assigns=["doc.cats"],
default_config={"threshold": 0.5, "model": DEFAULT_MULTI_TEXTCAT_MODEL},
default_score_weights={
"cats_score": 1.0,
"cats_score_desc": None,
"cats_micro_p": None,
"cats_micro_r": None,
"cats_micro_f": None,
"cats_macro_p": None,
"cats_macro_r": None,
"cats_macro_f": None,
"cats_macro_auc": None,
"cats_f_per_type": None,
"cats_macro_auc_per_type": None,
},
)
def make_multilabel_textcat(
nlp: Language, name: str, model: Model[List[Doc], List[Floats2d]], threshold: float
) -> "TextCategorizer":
"""Create a TextCategorizer component. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels are considered
to be non-mutually exclusive, which means that there can be zero or more labels
per doc).
model (Model[List[Doc], List[Floats2d]]): A model instance that predicts
scores for each category.
threshold (float): Cutoff to consider a prediction "positive".
"""
return MultiLabel_TextCategorizer(nlp.vocab, model, name, threshold=threshold)
class MultiLabel_TextCategorizer(TextCategorizer):
"""Pipeline component for multi-label text classification.
DOCS: https://spacy.io/api/textcategorizer
"""
def __init__(
self,
vocab: Vocab,
model: Model,
name: str = "textcat_multilabel",
*,
threshold: float,
) -> None:
"""Initialize a text categorizer for multi-label classification.
vocab (Vocab): The shared vocabulary.
model (thinc.api.Model): The Thinc Model powering the pipeline component.
name (str): The component instance name, used to add entries to the
losses during training.
threshold (float): Cutoff to consider a prediction "positive".
DOCS: https://spacy.io/api/textcategorizer#init
"""
self.vocab = vocab
self.model = model
self.name = name
self._rehearsal_model = None
cfg = {"labels": [], "threshold": threshold}
self.cfg = dict(cfg)
def initialize(
self,
get_examples: Callable[[], Iterable[Example]],
*,
nlp: Optional[Language] = None,
labels: Optional[Iterable[str]] = None,
):
"""Initialize the pipe for training, using a representative set
of data examples.
get_examples (Callable[[], Iterable[Example]]): Function that
returns a representative sample of gold-standard Example objects.
nlp (Language): The current nlp object the component is part of.
labels: The labels to add to the component, typically generated by the
`init labels` command. If no labels are provided, the get_examples
callback is used to extract the labels from the data.
DOCS: https://spacy.io/api/textcategorizer#initialize
"""
validate_get_examples(get_examples, "MultiLabel_TextCategorizer.initialize")
if labels is None:
for example in get_examples():
for cat in example.y.cats:
self.add_label(cat)
else:
for label in labels:
self.add_label(label)
subbatch = list(islice(get_examples(), 10))
doc_sample = [eg.reference for eg in subbatch]
label_sample, _ = self._examples_to_truth(subbatch)
self._require_labels()
assert len(doc_sample) > 0, Errors.E923.format(name=self.name)
assert len(label_sample) > 0, Errors.E923.format(name=self.name)
self.model.initialize(X=doc_sample, Y=label_sample)
def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
"""Score a batch of examples.
examples (Iterable[Example]): The examples to score.
RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_cats.
DOCS: https://spacy.io/api/textcategorizer#score
"""
validate_examples(examples, "MultiLabel_TextCategorizer.score")
kwargs.setdefault("threshold", self.cfg["threshold"])
return Scorer.score_cats(
examples,
"cats",
labels=self.labels,
multi_label=True,
**kwargs,
)
def _validate_categories(self, examples: List[Example]):
"""This component allows any type of single- or multi-label annotations.
This method overwrites the more strict one from 'textcat'."""
pass