mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-04 20:30:24 +03:00
Add failing unit tests
This commit is contained in:
parent
ef9e504eac
commit
6c8358dcfc
|
@ -1,6 +1,8 @@
|
||||||
import pytest
|
import pytest
|
||||||
import spacy
|
import spacy
|
||||||
|
|
||||||
|
from thinc.api import Config
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
from spacy.training import Example
|
from spacy.training import Example
|
||||||
|
|
||||||
|
@ -148,6 +150,157 @@ REHEARSE_DATA = [
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
TEXTCAT_MULTILABEL_LISTENER_CONFIG = """
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["tok2vec","textcat_multilabel"]
|
||||||
|
disabled = []
|
||||||
|
before_creation = null
|
||||||
|
after_creation = null
|
||||||
|
after_pipeline_creation = null
|
||||||
|
batch_size = 1000
|
||||||
|
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.textcat_multilabel]
|
||||||
|
factory = "textcat_multilabel"
|
||||||
|
threshold = 0.5
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model]
|
||||||
|
@architectures = "spacy.TextCatEnsemble.v2"
|
||||||
|
nO = null
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model.linear_model]
|
||||||
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
|
exclusive_classes = false
|
||||||
|
ngram_size = 1
|
||||||
|
no_output_layer = false
|
||||||
|
|
||||||
|
[components.textcat_multilabel.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = 64
|
||||||
|
upstream = "*"
|
||||||
|
|
||||||
|
[components.tok2vec]
|
||||||
|
factory = "tok2vec"
|
||||||
|
|
||||||
|
[components.tok2vec.model]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
|
||||||
|
[components.tok2vec.model.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v2"
|
||||||
|
width = 64
|
||||||
|
attrs = ["ORTH", "SHAPE"]
|
||||||
|
rows = [5000, 2500]
|
||||||
|
include_static_vectors = true
|
||||||
|
|
||||||
|
[components.tok2vec.model.encode]
|
||||||
|
@architectures = "spacy.MishWindowEncoder.v2"
|
||||||
|
width = 64
|
||||||
|
depth = 4
|
||||||
|
window_size = 1
|
||||||
|
"""
|
||||||
|
|
||||||
|
TEXTCAT_LISTENER_CONFIG = """
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["tok2vec","textcat"]
|
||||||
|
batch_size = 1000
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.tok2vec]
|
||||||
|
factory = "tok2vec"
|
||||||
|
|
||||||
|
[components.tok2vec.model]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
|
||||||
|
[components.tok2vec.model.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v2"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
attrs = ["NORM", "PREFIX", "SUFFIX", "SHAPE"]
|
||||||
|
rows = [5000, 1000, 2500, 2500]
|
||||||
|
include_static_vectors = true
|
||||||
|
|
||||||
|
[components.tok2vec.model.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
|
width = 256
|
||||||
|
depth = 8
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
|
||||||
|
[components.textcat]
|
||||||
|
factory = "textcat"
|
||||||
|
|
||||||
|
[components.textcat.model]
|
||||||
|
@architectures = "spacy.TextCatEnsemble.v2"
|
||||||
|
nO = null
|
||||||
|
|
||||||
|
[components.textcat.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2VecListener.v1"
|
||||||
|
width = ${components.tok2vec.model.encode.width}
|
||||||
|
|
||||||
|
[components.textcat.model.linear_model]
|
||||||
|
@architectures = "spacy.TextCatBOW.v2"
|
||||||
|
exclusive_classes = true
|
||||||
|
ngram_size = 1
|
||||||
|
no_output_layer = false
|
||||||
|
"""
|
||||||
|
|
||||||
|
TEXTCAT_CONFIG = """
|
||||||
|
[nlp]
|
||||||
|
lang = "en"
|
||||||
|
pipeline = ["textcat"]
|
||||||
|
disabled = []
|
||||||
|
before_creation = null
|
||||||
|
after_creation = null
|
||||||
|
after_pipeline_creation = null
|
||||||
|
batch_size = 1000
|
||||||
|
tokenizer = {"@tokenizers":"spacy.Tokenizer.v1"}
|
||||||
|
|
||||||
|
[components]
|
||||||
|
|
||||||
|
[components.textcat]
|
||||||
|
factory = "textcat"
|
||||||
|
threshold = 0.5
|
||||||
|
|
||||||
|
[components.textcat.model]
|
||||||
|
@architectures = "spacy.TextCatEnsemble.v2"
|
||||||
|
nO = null
|
||||||
|
|
||||||
|
[components.textcat.model.linear_model]
|
||||||
|
@architectures = "spacy.TextCatBOW.v1"
|
||||||
|
exclusive_classes = true
|
||||||
|
ngram_size = 1
|
||||||
|
no_output_layer = false
|
||||||
|
nO = null
|
||||||
|
|
||||||
|
[components.textcat.model.tok2vec]
|
||||||
|
@architectures = "spacy.Tok2Vec.v2"
|
||||||
|
|
||||||
|
[components.textcat.model.tok2vec.embed]
|
||||||
|
@architectures = "spacy.MultiHashEmbed.v1"
|
||||||
|
width = 64
|
||||||
|
rows = [2000,2000,1000,1000,1000,1000]
|
||||||
|
attrs = ["ORTH","LOWER","PREFIX","SUFFIX","SHAPE","ID"]
|
||||||
|
include_static_vectors = false
|
||||||
|
|
||||||
|
[components.textcat.model.tok2vec.encode]
|
||||||
|
@architectures = "spacy.MaxoutWindowEncoder.v2"
|
||||||
|
width = 64
|
||||||
|
window_size = 1
|
||||||
|
maxout_pieces = 3
|
||||||
|
depth = 2
|
||||||
|
"""
|
||||||
|
|
||||||
|
TEXTCAT_EXAMPLE_TEXTS = [
|
||||||
|
("This is a sentence for LABEL_A.", {"cats": {"LABEL_A": 1, "LABEL_B": 0}}),
|
||||||
|
("A sentence for the label LABEL_B.", {"cats": {"LABEL_A": 0, "LABEL_B": 1}}),
|
||||||
|
]
|
||||||
|
|
||||||
|
TEXTCAT_LABELS = ["LABEL_A", "LABEL_B"]
|
||||||
|
|
||||||
|
|
||||||
def _add_ner_label(ner, data):
|
def _add_ner_label(ner, data):
|
||||||
for _, annotations in data:
|
for _, annotations in data:
|
||||||
|
@ -209,3 +362,63 @@ def test_rehearse(component):
|
||||||
nlp.add_pipe(component)
|
nlp.add_pipe(component)
|
||||||
nlp = _optimize(nlp, component, TRAIN_DATA, False)
|
nlp = _optimize(nlp, component, TRAIN_DATA, False)
|
||||||
_optimize(nlp, component, REHEARSE_DATA, True)
|
_optimize(nlp, component, REHEARSE_DATA, True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.issue(12044)
|
||||||
|
def test_rehearse_textcat_multilabel_listener():
|
||||||
|
"""Test nlp.rehearse on a textcat_multilabel pipeline with a tok2vec listener"""
|
||||||
|
config = Config().from_str(TEXTCAT_MULTILABEL_LISTENER_CONFIG)
|
||||||
|
nlp = spacy.blank("en").from_config(config)
|
||||||
|
textcat_multilabel = nlp.get_pipe("textcat_multilabel")
|
||||||
|
for label in TEXTCAT_LABELS:
|
||||||
|
textcat_multilabel.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
|
||||||
|
examples = []
|
||||||
|
for example in TEXTCAT_EXAMPLE_TEXTS:
|
||||||
|
example = Example.from_dict(nlp.make_doc(example[0]), example[1])
|
||||||
|
examples.append(example)
|
||||||
|
nlp.update([example])
|
||||||
|
|
||||||
|
optimizer = nlp.resume_training()
|
||||||
|
nlp.rehearse(examples, sgd=optimizer)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.issue(12044)
|
||||||
|
def test_rehearse_textcat_listener():
|
||||||
|
"""Test nlp.rehearse on a textcat pipeline with a tok2vec listener"""
|
||||||
|
config = Config().from_str(TEXTCAT_LISTENER_CONFIG)
|
||||||
|
nlp = spacy.blank("en").from_config(config)
|
||||||
|
textcat = nlp.get_pipe("textcat")
|
||||||
|
for label in TEXTCAT_LABELS:
|
||||||
|
textcat.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
|
||||||
|
examples = []
|
||||||
|
for example in TEXTCAT_EXAMPLE_TEXTS:
|
||||||
|
example = Example.from_dict(nlp.make_doc(example[0]), example[1])
|
||||||
|
examples.append(example)
|
||||||
|
nlp.update([example])
|
||||||
|
|
||||||
|
optimizer = nlp.resume_training()
|
||||||
|
nlp.rehearse(examples, sgd=optimizer)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.issue(12044)
|
||||||
|
def test_rehearse_textcat():
|
||||||
|
"""Test nlp.rehearse on a textcat pipeline with an inline tok2vec component"""
|
||||||
|
config = Config().from_str(TEXTCAT_CONFIG)
|
||||||
|
nlp = spacy.blank("en").from_config(config)
|
||||||
|
textcat = nlp.get_pipe("textcat")
|
||||||
|
for label in TEXTCAT_LABELS:
|
||||||
|
textcat.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
|
||||||
|
examples = []
|
||||||
|
for example in TEXTCAT_EXAMPLE_TEXTS:
|
||||||
|
example = Example.from_dict(nlp.make_doc(example[0]), example[1])
|
||||||
|
examples.append(example)
|
||||||
|
nlp.update([example])
|
||||||
|
|
||||||
|
optimizer = nlp.resume_training()
|
||||||
|
nlp.rehearse(examples, sgd=optimizer)
|
||||||
|
|
Loading…
Reference in New Issue
Block a user