from itertools import islice
from typing import Iterable, Optional, Dict, List, Callable, Any

from thinc.api import Model, Config
from thinc.types import Floats2d

from ..language import Language
from import Example, validate_examples, validate_get_examples
from ..errors import Errors
from ..scorer import Scorer
from ..tokens import Doc
from ..vocab import Vocab
from .textcat import TextCategorizer

multi_label_default_config = """
@architectures = "spacy.TextCatEnsemble.v2"

@architectures = "spacy.Tok2Vec.v1"

@architectures = "spacy.MultiHashEmbed.v2"
width = 64
rows = [2000, 2000, 1000, 1000, 1000, 1000]
attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"]
include_static_vectors = false

@architectures = "spacy.MaxoutWindowEncoder.v1"
width = ${model.tok2vec.embed.width}
window_size = 1
maxout_pieces = 3
depth = 2

@architectures = "spacy.TextCatBOW.v2"
exclusive_classes = false
ngram_size = 1
no_output_layer = false
DEFAULT_MULTI_TEXTCAT_MODEL = Config().from_str(multi_label_default_config)["model"]

multi_label_bow_config = """
@architectures = "spacy.TextCatBOW.v2"
exclusive_classes = false
ngram_size = 1
no_output_layer = false

multi_label_cnn_config = """
@architectures = "spacy.TextCatCNN.v2"
exclusive_classes = false

@architectures = "spacy.HashEmbedCNN.v2"
pretrained_vectors = null
width = 96
depth = 4
embed_size = 2000
window_size = 1
maxout_pieces = 3
subword_features = true

    default_config={"threshold": 0.5, "model": DEFAULT_MULTI_TEXTCAT_MODEL},
        "cats_score": 1.0,
        "cats_score_desc": None,
        "cats_micro_p": None,
        "cats_micro_r": None,
        "cats_micro_f": None,
        "cats_macro_p": None,
        "cats_macro_r": None,
        "cats_macro_f": None,
        "cats_macro_auc": None,
        "cats_f_per_type": None,
        "cats_macro_auc_per_type": None,
def make_multilabel_textcat(
    nlp: Language, name: str, model: Model[List[Doc], List[Floats2d]], threshold: float
) -> "TextCategorizer":
    """Create a TextCategorizer component. The text categorizer predicts categories
    over a whole document. It can learn one or more labels, and the labels are considered
    to be non-mutually exclusive, which means that there can be zero or more labels
    per doc).

    model (Model[List[Doc], List[Floats2d]]): A model instance that predicts
        scores for each category.
    threshold (float): Cutoff to consider a prediction "positive".
    return MultiLabel_TextCategorizer(nlp.vocab, model, name, threshold=threshold)

class MultiLabel_TextCategorizer(TextCategorizer):
    """Pipeline component for multi-label text classification.


    def __init__(
        vocab: Vocab,
        model: Model,
        name: str = "textcat_multilabel",
        threshold: float,
    ) -> None:
        """Initialize a text categorizer for multi-label classification.

        vocab (Vocab): The shared vocabulary.
        model (thinc.api.Model): The Thinc Model powering the pipeline component.
        name (str): The component instance name, used to add entries to the
            losses during training.
        threshold (float): Cutoff to consider a prediction "positive".

        self.vocab = vocab
        self.model = model = name
        self._rehearsal_model = None
        cfg = {"labels": [], "threshold": threshold}
        self.cfg = dict(cfg)

    def initialize(  # type: ignore[override]
        get_examples: Callable[[], Iterable[Example]],
        nlp: Optional[Language] = None,
        labels: Optional[Iterable[str]] = None,
        """Initialize the pipe for training, using a representative set
        of data examples.

        get_examples (Callable[[], Iterable[Example]]): Function that
            returns a representative sample of gold-standard Example objects.
        nlp (Language): The current nlp object the component is part of.
        labels: The labels to add to the component, typically generated by the
            `init labels` command. If no labels are provided, the get_examples
            callback is used to extract the labels from the data.

        validate_get_examples(get_examples, "MultiLabel_TextCategorizer.initialize")
        if labels is None:
            for example in get_examples():
                for cat in example.y.cats:
            for label in labels:
        subbatch = list(islice(get_examples(), 10))
        doc_sample = [eg.reference for eg in subbatch]
        label_sample, _ = self._examples_to_truth(subbatch)
        assert len(doc_sample) > 0, Errors.E923.format(
        assert len(label_sample) > 0, Errors.E923.format(
        self.model.initialize(X=doc_sample, Y=label_sample)

    def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]:
        """Score a batch of examples.

        examples (Iterable[Example]): The examples to score.
        RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_cats.

        validate_examples(examples, "MultiLabel_TextCategorizer.score")
        kwargs.setdefault("threshold", self.cfg["threshold"])
        return Scorer.score_cats(

    def _validate_categories(self, examples: Iterable[Example]):
        """This component allows any type of single- or multi-label annotations.
        This method overwrites the more strict one from 'textcat'."""