mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-26 01:46:28 +03:00
Tidy up and auto-format
This commit is contained in:
parent
f2c8b1e362
commit
00a8cbc306
|
@ -298,7 +298,8 @@ def debug_data(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if gold_train_data["n_cats_multilabel"] > 0:
|
if gold_train_data["n_cats_multilabel"] > 0:
|
||||||
msg.info("The train data contains instances without "
|
msg.info(
|
||||||
|
"The train data contains instances without "
|
||||||
"mutually-exclusive classes. Use '--textcat-multilabel' "
|
"mutually-exclusive classes. Use '--textcat-multilabel' "
|
||||||
"when training."
|
"when training."
|
||||||
)
|
)
|
||||||
|
@ -481,7 +482,6 @@ def debug_data(
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
msg.divider("Summary")
|
msg.divider("Summary")
|
||||||
good_counts = msg.counts[MESSAGES.GOOD]
|
good_counts = msg.counts[MESSAGES.GOOD]
|
||||||
warn_counts = msg.counts[MESSAGES.WARN]
|
warn_counts = msg.counts[MESSAGES.WARN]
|
||||||
|
|
|
@ -182,7 +182,7 @@ def train(
|
||||||
base_cfg = {
|
base_cfg = {
|
||||||
"exclusive_classes": textcat_cfg["exclusive_classes"],
|
"exclusive_classes": textcat_cfg["exclusive_classes"],
|
||||||
"architecture": textcat_cfg["architecture"],
|
"architecture": textcat_cfg["architecture"],
|
||||||
"positive_label": textcat_cfg["positive_label"]
|
"positive_label": textcat_cfg["positive_label"],
|
||||||
}
|
}
|
||||||
pipe_cfg = {
|
pipe_cfg = {
|
||||||
"exclusive_classes": not textcat_multilabel,
|
"exclusive_classes": not textcat_multilabel,
|
||||||
|
@ -190,12 +190,13 @@ def train(
|
||||||
"positive_label": textcat_positive_label,
|
"positive_label": textcat_positive_label,
|
||||||
}
|
}
|
||||||
if base_cfg != pipe_cfg:
|
if base_cfg != pipe_cfg:
|
||||||
msg.fail("The base textcat model configuration does"
|
msg.fail(
|
||||||
|
"The base textcat model configuration does"
|
||||||
"not match the provided training options. "
|
"not match the provided training options. "
|
||||||
"Existing cfg: {}, provided cfg: {}".format(
|
"Existing cfg: {}, provided cfg: {}".format(
|
||||||
base_cfg, pipe_cfg
|
base_cfg, pipe_cfg
|
||||||
),
|
),
|
||||||
exits=1
|
exits=1,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
msg.text("Starting with blank model '{}'".format(lang))
|
msg.text("Starting with blank model '{}'".format(lang))
|
||||||
|
@ -300,7 +301,8 @@ def train(
|
||||||
msg.fail(
|
msg.fail(
|
||||||
"Cannot extend textcat model using data with different "
|
"Cannot extend textcat model using data with different "
|
||||||
"labels. Base model labels: {}, training data labels: "
|
"labels. Base model labels: {}, training data labels: "
|
||||||
"{}.".format(textcat_labels, list(train_labels)), exits=1
|
"{}.".format(textcat_labels, list(train_labels)),
|
||||||
|
exits=1,
|
||||||
)
|
)
|
||||||
if textcat_multilabel:
|
if textcat_multilabel:
|
||||||
msg.text(
|
msg.text(
|
||||||
|
|
|
@ -201,7 +201,9 @@ _ukrainian = r"а-щюяіїєґА-ЩЮЯІЇЄҐ"
|
||||||
_upper = LATIN_UPPER + _russian_upper + _tatar_upper + _greek_upper + _ukrainian_upper
|
_upper = LATIN_UPPER + _russian_upper + _tatar_upper + _greek_upper + _ukrainian_upper
|
||||||
_lower = LATIN_LOWER + _russian_lower + _tatar_lower + _greek_lower + _ukrainian_lower
|
_lower = LATIN_LOWER + _russian_lower + _tatar_lower + _greek_lower + _ukrainian_lower
|
||||||
|
|
||||||
_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi + _kannada + _tamil + _telugu
|
_uncased = (
|
||||||
|
_bengali + _hebrew + _persian + _sinhala + _hindi + _kannada + _tamil + _telugu
|
||||||
|
)
|
||||||
|
|
||||||
ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased)
|
ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased)
|
||||||
ALPHA_LOWER = group_chars(_lower + _uncased)
|
ALPHA_LOWER = group_chars(_lower + _uncased)
|
||||||
|
|
|
@ -32,7 +32,6 @@ from .lang.tokenizer_exceptions import TOKEN_MATCH
|
||||||
from .lang.tag_map import TAG_MAP
|
from .lang.tag_map import TAG_MAP
|
||||||
from .lang.lex_attrs import LEX_ATTRS, is_stop
|
from .lang.lex_attrs import LEX_ATTRS, is_stop
|
||||||
from .errors import Errors, Warnings, deprecation_warning
|
from .errors import Errors, Warnings, deprecation_warning
|
||||||
from .strings import hash_string
|
|
||||||
from . import util
|
from . import util
|
||||||
from . import about
|
from . import about
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,9 @@
|
||||||
# coding: utf-8
|
# coding: utf-8
|
||||||
from __future__ import unicode_literals
|
from __future__ import unicode_literals
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
from numpy.testing import assert_almost_equal, assert_array_almost_equal
|
from numpy.testing import assert_almost_equal, assert_array_almost_equal
|
||||||
import pytest
|
import pytest
|
||||||
from pytest import approx
|
from pytest import approx
|
||||||
from spacy.errors import Errors
|
|
||||||
from spacy.gold import GoldParse
|
from spacy.gold import GoldParse
|
||||||
from spacy.scorer import Scorer, ROCAUCScore
|
from spacy.scorer import Scorer, ROCAUCScore
|
||||||
from spacy.scorer import _roc_auc_score, _roc_curve
|
from spacy.scorer import _roc_auc_score, _roc_curve
|
||||||
|
@ -81,7 +79,7 @@ def test_roc_auc_score():
|
||||||
roc_auc = _roc_auc_score(y_true, y_score)
|
roc_auc = _roc_auc_score(y_true, y_score)
|
||||||
assert_array_almost_equal(tpr, [0, 0, 1])
|
assert_array_almost_equal(tpr, [0, 0, 1])
|
||||||
assert_array_almost_equal(fpr, [0, 1, 1])
|
assert_array_almost_equal(fpr, [0, 1, 1])
|
||||||
assert_almost_equal(roc_auc, 1.)
|
assert_almost_equal(roc_auc, 1.0)
|
||||||
|
|
||||||
y_true = [0, 1]
|
y_true = [0, 1]
|
||||||
y_score = [1, 0]
|
y_score = [1, 0]
|
||||||
|
@ -89,7 +87,7 @@ def test_roc_auc_score():
|
||||||
roc_auc = _roc_auc_score(y_true, y_score)
|
roc_auc = _roc_auc_score(y_true, y_score)
|
||||||
assert_array_almost_equal(tpr, [0, 1, 1])
|
assert_array_almost_equal(tpr, [0, 1, 1])
|
||||||
assert_array_almost_equal(fpr, [0, 0, 1])
|
assert_array_almost_equal(fpr, [0, 0, 1])
|
||||||
assert_almost_equal(roc_auc, 0.)
|
assert_almost_equal(roc_auc, 0.0)
|
||||||
|
|
||||||
y_true = [1, 0]
|
y_true = [1, 0]
|
||||||
y_score = [1, 1]
|
y_score = [1, 1]
|
||||||
|
@ -105,7 +103,7 @@ def test_roc_auc_score():
|
||||||
roc_auc = _roc_auc_score(y_true, y_score)
|
roc_auc = _roc_auc_score(y_true, y_score)
|
||||||
assert_array_almost_equal(tpr, [0, 0, 1])
|
assert_array_almost_equal(tpr, [0, 0, 1])
|
||||||
assert_array_almost_equal(fpr, [0, 1, 1])
|
assert_array_almost_equal(fpr, [0, 1, 1])
|
||||||
assert_almost_equal(roc_auc, 1.)
|
assert_almost_equal(roc_auc, 1.0)
|
||||||
|
|
||||||
y_true = [1, 0]
|
y_true = [1, 0]
|
||||||
y_score = [0.5, 0.5]
|
y_score = [0.5, 0.5]
|
||||||
|
@ -113,14 +111,13 @@ def test_roc_auc_score():
|
||||||
roc_auc = _roc_auc_score(y_true, y_score)
|
roc_auc = _roc_auc_score(y_true, y_score)
|
||||||
assert_array_almost_equal(tpr, [0, 1])
|
assert_array_almost_equal(tpr, [0, 1])
|
||||||
assert_array_almost_equal(fpr, [0, 1])
|
assert_array_almost_equal(fpr, [0, 1])
|
||||||
assert_almost_equal(roc_auc, .5)
|
assert_almost_equal(roc_auc, 0.5)
|
||||||
|
|
||||||
# same result as above with ROCAUCScore wrapper
|
# same result as above with ROCAUCScore wrapper
|
||||||
score = ROCAUCScore()
|
score = ROCAUCScore()
|
||||||
score.score_set(0.5, 1)
|
score.score_set(0.5, 1)
|
||||||
score.score_set(0.5, 0)
|
score.score_set(0.5, 0)
|
||||||
assert_almost_equal(score.score, .5)
|
assert_almost_equal(score.score, 0.5)
|
||||||
|
|
||||||
|
|
||||||
# check that errors are raised in undefined cases and score is -inf
|
# check that errors are raised in undefined cases and score is -inf
|
||||||
y_true = [0, 0]
|
y_true = [0, 0]
|
||||||
|
|
Loading…
Reference in New Issue
Block a user