from itertools import islice from typing import Iterable, Optional, Dict, List, Callable, Any from thinc.api import Model, Config from thinc.types import Floats2d from ..language import Language from ..training import Example, validate_examples, validate_get_examples from ..errors import Errors from ..scorer import Scorer from ..tokens import Doc from ..vocab import Vocab from .textcat import TextCategorizer multi_label_default_config = """ [model] @architectures = "spacy.TextCatEnsemble.v2" [model.tok2vec] @architectures = "spacy.Tok2Vec.v1" [model.tok2vec.embed] @architectures = "spacy.MultiHashEmbed.v1" width = 64 rows = [2000, 2000, 1000, 1000, 1000, 1000] attrs = ["ORTH", "LOWER", "PREFIX", "SUFFIX", "SHAPE", "ID"] include_static_vectors = false [model.tok2vec.encode] @architectures = "spacy.MaxoutWindowEncoder.v1" width = ${model.tok2vec.embed.width} window_size = 1 maxout_pieces = 3 depth = 2 [model.linear_model] @architectures = "spacy.TextCatBOW.v1" exclusive_classes = false ngram_size = 1 no_output_layer = false """ DEFAULT_MULTI_TEXTCAT_MODEL = Config().from_str(multi_label_default_config)["model"] multi_label_bow_config = """ [model] @architectures = "spacy.TextCatBOW.v1" exclusive_classes = false ngram_size = 1 no_output_layer = false """ multi_label_cnn_config = """ [model] @architectures = "spacy.TextCatCNN.v1" exclusive_classes = false [model.tok2vec] @architectures = "spacy.HashEmbedCNN.v1" pretrained_vectors = null width = 96 depth = 4 embed_size = 2000 window_size = 1 maxout_pieces = 3 subword_features = true """ @Language.factory( "textcat_multilabel", assigns=["doc.cats"], default_config={"threshold": 0.5, "model": DEFAULT_MULTI_TEXTCAT_MODEL}, default_score_weights={ "cats_score": 1.0, "cats_score_desc": None, "cats_micro_p": None, "cats_micro_r": None, "cats_micro_f": None, "cats_macro_p": None, "cats_macro_r": None, "cats_macro_f": None, "cats_macro_auc": None, "cats_f_per_type": None, "cats_macro_auc_per_type": None, }, ) def make_multilabel_textcat( nlp: Language, name: str, model: Model[List[Doc], List[Floats2d]], threshold: float ) -> "TextCategorizer": """Create a TextCategorizer compoment. The text categorizer predicts categories over a whole document. It can learn one or more labels, and the labels can be mutually exclusive (i.e. one true label per doc) or non-mutually exclusive (i.e. zero or more labels may be true per doc). The multi-label setting is controlled by the model instance that's provided. model (Model[List[Doc], List[Floats2d]]): A model instance that predicts scores for each category. threshold (float): Cutoff to consider a prediction "positive". """ return MultiLabel_TextCategorizer(nlp.vocab, model, name, threshold=threshold) class MultiLabel_TextCategorizer(TextCategorizer): """Pipeline component for multi-label text classification. DOCS: https://nightly.spacy.io/api/multilabel_textcategorizer """ def __init__( self, vocab: Vocab, model: Model, name: str = "textcat_multilabel", *, threshold: float, ) -> None: """Initialize a text categorizer for multi-label classification. vocab (Vocab): The shared vocabulary. model (thinc.api.Model): The Thinc Model powering the pipeline component. name (str): The component instance name, used to add entries to the losses during training. threshold (float): Cutoff to consider a prediction "positive". DOCS: https://nightly.spacy.io/api/multilabel_textcategorizer#init """ self.vocab = vocab self.model = model self.name = name self._rehearsal_model = None cfg = {"labels": [], "threshold": threshold} self.cfg = dict(cfg) def initialize( self, get_examples: Callable[[], Iterable[Example]], *, nlp: Optional[Language] = None, labels: Optional[Dict] = None, ): """Initialize the pipe for training, using a representative set of data examples. get_examples (Callable[[], Iterable[Example]]): Function that returns a representative sample of gold-standard Example objects. nlp (Language): The current nlp object the component is part of. labels: The labels to add to the component, typically generated by the `init labels` command. If no labels are provided, the get_examples callback is used to extract the labels from the data. DOCS: https://nightly.spacy.io/api/multilabel_textcategorizer#initialize """ validate_get_examples(get_examples, "MultiLabel_TextCategorizer.initialize") if labels is None: for example in get_examples(): for cat in example.y.cats: self.add_label(cat) else: for label in labels: self.add_label(label) subbatch = list(islice(get_examples(), 10)) doc_sample = [eg.reference for eg in subbatch] label_sample, _ = self._examples_to_truth(subbatch) self._require_labels() assert len(doc_sample) > 0, Errors.E923.format(name=self.name) assert len(label_sample) > 0, Errors.E923.format(name=self.name) self.model.initialize(X=doc_sample, Y=label_sample) def score(self, examples: Iterable[Example], **kwargs) -> Dict[str, Any]: """Score a batch of examples. examples (Iterable[Example]): The examples to score. RETURNS (Dict[str, Any]): The scores, produced by Scorer.score_cats. DOCS: https://nightly.spacy.io/api/multilabel_textcategorizer#score """ validate_examples(examples, "MultiLabel_TextCategorizer.score") return Scorer.score_cats( examples, "cats", labels=self.labels, multi_label=True, threshold=self.cfg["threshold"], **kwargs, ) def _validate_categories(self, examples: List[Example]): """This component allows any type of single- or multi-label annotations. This method overwrites the more strict one from 'textcat'.""" pass