mirror of
https://github.com/explosion/spaCy.git
synced 2025-08-07 13:44:55 +03:00
Move pipeline definitions into tests
This commit is contained in:
parent
38abc802a6
commit
ae919e9907
|
@ -58,29 +58,6 @@ def nlp():
|
||||||
return nlp
|
return nlp
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def nlp_tcm():
|
|
||||||
nlp = Language(Vocab())
|
|
||||||
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
|
|
||||||
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
|
|
||||||
textcat_multilabel.add_label(label)
|
|
||||||
nlp.initialize()
|
|
||||||
return nlp
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def nlp_tc_tcm():
|
|
||||||
nlp = Language(Vocab())
|
|
||||||
textcat = nlp.add_pipe("textcat")
|
|
||||||
for label in ("POSITIVE", "NEGATIVE"):
|
|
||||||
textcat.add_label(label)
|
|
||||||
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
|
|
||||||
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
|
|
||||||
textcat_multilabel.add_label(label)
|
|
||||||
nlp.initialize()
|
|
||||||
return nlp
|
|
||||||
|
|
||||||
|
|
||||||
def test_language_update(nlp):
|
def test_language_update(nlp):
|
||||||
text = "hello world"
|
text = "hello world"
|
||||||
annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
|
annots = {"cats": {"POSITIVE": 1.0, "NEGATIVE": 0.0}}
|
||||||
|
@ -149,14 +126,19 @@ def test_evaluate_no_pipe(nlp):
|
||||||
nlp.evaluate([Example.from_dict(doc, annots)])
|
nlp.evaluate([Example.from_dict(doc, annots)])
|
||||||
|
|
||||||
|
|
||||||
def test_evaluate_textcat_multilabel(nlp_tcm):
|
def test_evaluate_textcat_multilabel(en_vocab):
|
||||||
"""Test that evaluate works with a multilabel textcat pipe."""
|
"""Test that evaluate works with a multilabel textcat pipe."""
|
||||||
text = "hello world"
|
nlp = Language(en_vocab)
|
||||||
|
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
|
||||||
|
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
|
||||||
|
textcat_multilabel.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
|
||||||
annots = {"cats": {"FEATURE": 1.0, "QUESTION": 1.0}}
|
annots = {"cats": {"FEATURE": 1.0, "QUESTION": 1.0}}
|
||||||
doc = Doc(nlp_tcm.vocab, words=text.split(" "))
|
doc = nlp.make_doc("hello world")
|
||||||
example = Example.from_dict(doc, annots)
|
example = Example.from_dict(doc, annots)
|
||||||
scores = nlp_tcm.evaluate([example])
|
scores = nlp.evaluate([example])
|
||||||
labels = nlp_tcm.get_pipe("textcat_multilabel").labels
|
labels = nlp.get_pipe("textcat_multilabel").labels
|
||||||
for label in labels:
|
for label in labels:
|
||||||
assert scores["cats_f_per_type"].get(label) is not None
|
assert scores["cats_f_per_type"].get(label) is not None
|
||||||
for key in example.reference.cats.keys():
|
for key in example.reference.cats.keys():
|
||||||
|
@ -164,10 +146,18 @@ def test_evaluate_textcat_multilabel(nlp_tcm):
|
||||||
assert scores["cats_f_per_type"].get(key) is None
|
assert scores["cats_f_per_type"].get(key) is None
|
||||||
|
|
||||||
|
|
||||||
def test_evaluate_multiple_textcat(nlp_tc_tcm):
|
def test_evaluate_multiple_textcat_final(en_vocab):
|
||||||
"""Test that evaluate evaluates the final textcat component in a pipeline
|
"""Test that evaluate evaluates the final textcat component in a pipeline
|
||||||
with more than one textcat or textcat_multilabel."""
|
with more than one textcat or textcat_multilabel."""
|
||||||
text = "hello world"
|
nlp = Language(en_vocab)
|
||||||
|
textcat = nlp.add_pipe("textcat")
|
||||||
|
for label in ("POSITIVE", "NEGATIVE"):
|
||||||
|
textcat.add_label(label)
|
||||||
|
textcat_multilabel = nlp.add_pipe("textcat_multilabel")
|
||||||
|
for label in ("FEATURE", "REQUEST", "BUG", "QUESTION"):
|
||||||
|
textcat_multilabel.add_label(label)
|
||||||
|
nlp.initialize()
|
||||||
|
|
||||||
annots = {
|
annots = {
|
||||||
"cats": {
|
"cats": {
|
||||||
"POSITIVE": 1.0,
|
"POSITIVE": 1.0,
|
||||||
|
@ -178,11 +168,11 @@ def test_evaluate_multiple_textcat(nlp_tc_tcm):
|
||||||
"NEGATIVE": 0.0,
|
"NEGATIVE": 0.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
doc = Doc(nlp_tc_tcm.vocab, words=text.split(" "))
|
doc = nlp.make_doc("hello world")
|
||||||
example = Example.from_dict(doc, annots)
|
example = Example.from_dict(doc, annots)
|
||||||
scores = nlp_tc_tcm.evaluate([example])
|
scores = nlp.evaluate([example])
|
||||||
# get the labels from the final pipe
|
# get the labels from the final pipe
|
||||||
labels = nlp_tc_tcm.get_pipe(nlp_tc_tcm.pipe_names[-1]).labels
|
labels = nlp.get_pipe(nlp.pipe_names[-1]).labels
|
||||||
for label in labels:
|
for label in labels:
|
||||||
assert scores["cats_f_per_type"].get(label) is not None
|
assert scores["cats_f_per_type"].get(label) is not None
|
||||||
for key in example.reference.cats.keys():
|
for key in example.reference.cats.keys():
|
||||||
|
|
Loading…
Reference in New Issue
Block a user