diff --git a/spacy/tests/training/test_rehearse.py b/spacy/tests/training/test_rehearse.py index 5ac7fc217..432047a73 100644 --- a/spacy/tests/training/test_rehearse.py +++ b/spacy/tests/training/test_rehearse.py @@ -1,6 +1,8 @@ import pytest import spacy +from thinc.api import Config + from typing import List from spacy.training import Example @@ -148,6 +150,157 @@ REHEARSE_DATA = [ ), ] +TEXTCAT_MULTILABEL_LISTENER_CONFIG = """ +[nlp] +lang = "en" +pipeline = ["tok2vec","textcat_multilabel"] +disabled = [] +before_creation = null +after_creation = null +after_pipeline_creation = null +batch_size = 1000 +tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"} + +[components] + +[components.textcat_multilabel] +factory = "textcat_multilabel" +threshold = 0.5 + +[components.textcat_multilabel.model] +@architectures = "spacy.TextCatEnsemble.v2" +nO = null + +[components.textcat_multilabel.model.linear_model] +@architectures = "spacy.TextCatBOW.v2" +exclusive_classes = false +ngram_size = 1 +no_output_layer = false + +[components.textcat_multilabel.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = 64 +upstream = "*" + +[components.tok2vec] +factory = "tok2vec" + +[components.tok2vec.model] +@architectures = "spacy.Tok2Vec.v2" + +[components.tok2vec.model.embed] +@architectures = "spacy.MultiHashEmbed.v2" +width = 64 +attrs = ["ORTH", "SHAPE"] +rows = [5000, 2500] +include_static_vectors = true + +[components.tok2vec.model.encode] +@architectures = "spacy.MishWindowEncoder.v2" +width = 64 +depth = 4 +window_size = 1 +""" + +TEXTCAT_LISTENER_CONFIG = """ +[nlp] +lang = "en" +pipeline = ["tok2vec","textcat"] +batch_size = 1000 + +[components] + +[components.tok2vec] +factory = "tok2vec" + +[components.tok2vec.model] +@architectures = "spacy.Tok2Vec.v2" + +[components.tok2vec.model.embed] +@architectures = "spacy.MultiHashEmbed.v2" +width = ${components.tok2vec.model.encode.width} +attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"] +rows = [5000, 1000, 2500, 2500] +include_static_vectors = true + +[components.tok2vec.model.encode] +@architectures = "spacy.MaxoutWindowEncoder.v2" +width = 256 +depth = 8 +window_size = 1 +maxout_pieces = 3 + +[components.textcat] +factory = "textcat" + +[components.textcat.model] +@architectures = "spacy.TextCatEnsemble.v2" +nO = null + +[components.textcat.model.tok2vec] +@architectures = "spacy.Tok2VecListener.v1" +width = ${components.tok2vec.model.encode.width} + +[components.textcat.model.linear_model] +@architectures = "spacy.TextCatBOW.v2" +exclusive_classes = true +ngram_size = 1 +no_output_layer = false +""" + +TEXTCAT_CONFIG = """ +[nlp] +lang = "en" +pipeline = ["textcat"] +disabled = [] +before_creation = null +after_creation = null +after_pipeline_creation = null +batch_size = 1000 +tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"} + +[components] + +[components.textcat] +factory = "textcat" +threshold = 0.5 + +[components.textcat.model] +@architectures = "spacy.TextCatEnsemble.v2" +nO = null + +[components.textcat.model.linear_model] +@architectures = "spacy.TextCatBOW.v1" +exclusive_classes = true +ngram_size = 1 +no_output_layer = false +nO = null + +[components.textcat.model.tok2vec] +@architectures = "spacy.Tok2Vec.v2" + +[components.textcat.model.tok2vec.embed] +@architectures = "spacy.MultiHashEmbed.v1" +width = 64 +rows = [2000,2000,1000,1000,1000,1000] +attrs = ["ORTH","LOWER","PREFIX","SUFFIX","SHAPE","ID"] +include_static_vectors = false + +[components.textcat.model.tok2vec.encode] +@architectures = "spacy.MaxoutWindowEncoder.v2" +width = 64 +window_size = 1 +maxout_pieces = 3 +depth = 2 +""" + +TEXTCAT_EXAMPLE_TEXTS = [ + ("This is a sentence for LABEL_A.", {"cats": {"LABEL_A": 1, "LABEL_B": 0}}), + ("A sentence for the label LABEL_B.", {"cats": {"LABEL_A": 0, "LABEL_B": 1}}), +] + +TEXTCAT_LABELS = ["LABEL_A", "LABEL_B"] + def _add_ner_label(ner, data): for _, annotations in data: @@ -209,3 +362,63 @@ def test_rehearse(component): nlp.add_pipe(component) nlp = _optimize(nlp, component, TRAIN_DATA, False) _optimize(nlp, component, REHEARSE_DATA, True) + + +@pytest.mark.issue(12044) +def test_rehearse_textcat_multilabel_listener(): + """Test nlp.rehearse on a textcat_multilabel pipeline with a tok2vec listener""" + config = Config().from_str(TEXTCAT_MULTILABEL_LISTENER_CONFIG) + nlp = spacy.blank("en").from_config(config) + textcat_multilabel = nlp.get_pipe("textcat_multilabel") + for label in TEXTCAT_LABELS: + textcat_multilabel.add_label(label) + nlp.initialize() + + examples = [] + for example in TEXTCAT_EXAMPLE_TEXTS: + example = Example.from_dict(nlp.make_doc(example[0]), example[1]) + examples.append(example) + nlp.update([example]) + + optimizer = nlp.resume_training() + nlp.rehearse(examples, sgd=optimizer) + + +@pytest.mark.issue(12044) +def test_rehearse_textcat_listener(): + """Test nlp.rehearse on a textcat pipeline with a tok2vec listener""" + config = Config().from_str(TEXTCAT_LISTENER_CONFIG) + nlp = spacy.blank("en").from_config(config) + textcat = nlp.get_pipe("textcat") + for label in TEXTCAT_LABELS: + textcat.add_label(label) + nlp.initialize() + + examples = [] + for example in TEXTCAT_EXAMPLE_TEXTS: + example = Example.from_dict(nlp.make_doc(example[0]), example[1]) + examples.append(example) + nlp.update([example]) + + optimizer = nlp.resume_training() + nlp.rehearse(examples, sgd=optimizer) + + +@pytest.mark.issue(12044) +def test_rehearse_textcat(): + """Test nlp.rehearse on a textcat pipeline with an inline tok2vec component""" + config = Config().from_str(TEXTCAT_CONFIG) + nlp = spacy.blank("en").from_config(config) + textcat = nlp.get_pipe("textcat") + for label in TEXTCAT_LABELS: + textcat.add_label(label) + nlp.initialize() + + examples = [] + for example in TEXTCAT_EXAMPLE_TEXTS: + example = Example.from_dict(nlp.make_doc(example[0]), example[1]) + examples.append(example) + nlp.update([example]) + + optimizer = nlp.resume_training() + nlp.rehearse(examples, sgd=optimizer)