mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-25 17:36:30 +03:00
Tidy up and auto-format
This commit is contained in:
parent
f2c8b1e362
commit
00a8cbc306
|
@ -295,10 +295,11 @@ def debug_data(
|
|||
"Dev labels: {}.".format(
|
||||
_format_labels(gold_train_data["cats"]),
|
||||
_format_labels(gold_dev_data["cats"]),
|
||||
)
|
||||
)
|
||||
)
|
||||
if gold_train_data["n_cats_multilabel"] > 0:
|
||||
msg.info("The train data contains instances without "
|
||||
msg.info(
|
||||
"The train data contains instances without "
|
||||
"mutually-exclusive classes. Use '--textcat-multilabel' "
|
||||
"when training."
|
||||
)
|
||||
|
@ -481,7 +482,6 @@ def debug_data(
|
|||
)
|
||||
)
|
||||
|
||||
|
||||
msg.divider("Summary")
|
||||
good_counts = msg.counts[MESSAGES.GOOD]
|
||||
warn_counts = msg.counts[MESSAGES.WARN]
|
||||
|
|
|
@ -182,7 +182,7 @@ def train(
|
|||
base_cfg = {
|
||||
"exclusive_classes": textcat_cfg["exclusive_classes"],
|
||||
"architecture": textcat_cfg["architecture"],
|
||||
"positive_label": textcat_cfg["positive_label"]
|
||||
"positive_label": textcat_cfg["positive_label"],
|
||||
}
|
||||
pipe_cfg = {
|
||||
"exclusive_classes": not textcat_multilabel,
|
||||
|
@ -190,12 +190,13 @@ def train(
|
|||
"positive_label": textcat_positive_label,
|
||||
}
|
||||
if base_cfg != pipe_cfg:
|
||||
msg.fail("The base textcat model configuration does"
|
||||
msg.fail(
|
||||
"The base textcat model configuration does"
|
||||
"not match the provided training options. "
|
||||
"Existing cfg: {}, provided cfg: {}".format(
|
||||
base_cfg, pipe_cfg
|
||||
),
|
||||
exits=1
|
||||
exits=1,
|
||||
)
|
||||
else:
|
||||
msg.text("Starting with blank model '{}'".format(lang))
|
||||
|
@ -298,9 +299,10 @@ def train(
|
|||
break
|
||||
if base_model and set(textcat_labels) != train_labels:
|
||||
msg.fail(
|
||||
"Cannot extend textcat model using data with different "
|
||||
"labels. Base model labels: {}, training data labels: "
|
||||
"{}.".format(textcat_labels, list(train_labels)), exits=1
|
||||
"Cannot extend textcat model using data with different "
|
||||
"labels. Base model labels: {}, training data labels: "
|
||||
"{}.".format(textcat_labels, list(train_labels)),
|
||||
exits=1,
|
||||
)
|
||||
if textcat_multilabel:
|
||||
msg.text(
|
||||
|
|
|
@ -201,7 +201,9 @@ _ukrainian = r"а-щюяіїєґА-ЩЮЯІЇЄҐ"
|
|||
_upper = LATIN_UPPER + _russian_upper + _tatar_upper + _greek_upper + _ukrainian_upper
|
||||
_lower = LATIN_LOWER + _russian_lower + _tatar_lower + _greek_lower + _ukrainian_lower
|
||||
|
||||
_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi + _kannada + _tamil + _telugu
|
||||
_uncased = (
|
||||
_bengali + _hebrew + _persian + _sinhala + _hindi + _kannada + _tamil + _telugu
|
||||
)
|
||||
|
||||
ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased)
|
||||
ALPHA_LOWER = group_chars(_lower + _uncased)
|
||||
|
|
|
@ -32,7 +32,6 @@ from .lang.tokenizer_exceptions import TOKEN_MATCH
|
|||
from .lang.tag_map import TAG_MAP
|
||||
from .lang.lex_attrs import LEX_ATTRS, is_stop
|
||||
from .errors import Errors, Warnings, deprecation_warning
|
||||
from .strings import hash_string
|
||||
from . import util
|
||||
from . import about
|
||||
|
||||
|
|
|
@ -1,11 +1,9 @@
|
|||
# coding: utf-8
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import numpy as np
|
||||
from numpy.testing import assert_almost_equal, assert_array_almost_equal
|
||||
import pytest
|
||||
from pytest import approx
|
||||
from spacy.errors import Errors
|
||||
from spacy.gold import GoldParse
|
||||
from spacy.scorer import Scorer, ROCAUCScore
|
||||
from spacy.scorer import _roc_auc_score, _roc_curve
|
||||
|
@ -81,7 +79,7 @@ def test_roc_auc_score():
|
|||
roc_auc = _roc_auc_score(y_true, y_score)
|
||||
assert_array_almost_equal(tpr, [0, 0, 1])
|
||||
assert_array_almost_equal(fpr, [0, 1, 1])
|
||||
assert_almost_equal(roc_auc, 1.)
|
||||
assert_almost_equal(roc_auc, 1.0)
|
||||
|
||||
y_true = [0, 1]
|
||||
y_score = [1, 0]
|
||||
|
@ -89,7 +87,7 @@ def test_roc_auc_score():
|
|||
roc_auc = _roc_auc_score(y_true, y_score)
|
||||
assert_array_almost_equal(tpr, [0, 1, 1])
|
||||
assert_array_almost_equal(fpr, [0, 0, 1])
|
||||
assert_almost_equal(roc_auc, 0.)
|
||||
assert_almost_equal(roc_auc, 0.0)
|
||||
|
||||
y_true = [1, 0]
|
||||
y_score = [1, 1]
|
||||
|
@ -105,7 +103,7 @@ def test_roc_auc_score():
|
|||
roc_auc = _roc_auc_score(y_true, y_score)
|
||||
assert_array_almost_equal(tpr, [0, 0, 1])
|
||||
assert_array_almost_equal(fpr, [0, 1, 1])
|
||||
assert_almost_equal(roc_auc, 1.)
|
||||
assert_almost_equal(roc_auc, 1.0)
|
||||
|
||||
y_true = [1, 0]
|
||||
y_score = [0.5, 0.5]
|
||||
|
@ -113,14 +111,13 @@ def test_roc_auc_score():
|
|||
roc_auc = _roc_auc_score(y_true, y_score)
|
||||
assert_array_almost_equal(tpr, [0, 1])
|
||||
assert_array_almost_equal(fpr, [0, 1])
|
||||
assert_almost_equal(roc_auc, .5)
|
||||
assert_almost_equal(roc_auc, 0.5)
|
||||
|
||||
# same result as above with ROCAUCScore wrapper
|
||||
score = ROCAUCScore()
|
||||
score.score_set(0.5, 1)
|
||||
score.score_set(0.5, 0)
|
||||
assert_almost_equal(score.score, .5)
|
||||
|
||||
assert_almost_equal(score.score, 0.5)
|
||||
|
||||
# check that errors are raised in undefined cases and score is -inf
|
||||
y_true = [0, 0]
|
||||
|
|
Loading…
Reference in New Issue
Block a user