Tidy up and auto-format

This commit is contained in:
Ines Montani 2019-09-18 20:27:03 +02:00
parent f2c8b1e362
commit 00a8cbc306
5 changed files with 20 additions and 20 deletions

View File

@ -295,10 +295,11 @@ def debug_data(
"Dev labels: {}.".format(
_format_labels(gold_train_data["cats"]),
_format_labels(gold_dev_data["cats"]),
)
)
)
if gold_train_data["n_cats_multilabel"] > 0:
msg.info("The train data contains instances without "
msg.info(
"The train data contains instances without "
"mutually-exclusive classes. Use '--textcat-multilabel' "
"when training."
)
@ -481,7 +482,6 @@ def debug_data(
)
)
msg.divider("Summary")
good_counts = msg.counts[MESSAGES.GOOD]
warn_counts = msg.counts[MESSAGES.WARN]

View File

@ -182,7 +182,7 @@ def train(
base_cfg = {
"exclusive_classes": textcat_cfg["exclusive_classes"],
"architecture": textcat_cfg["architecture"],
"positive_label": textcat_cfg["positive_label"]
"positive_label": textcat_cfg["positive_label"],
}
pipe_cfg = {
"exclusive_classes": not textcat_multilabel,
@ -190,12 +190,13 @@ def train(
"positive_label": textcat_positive_label,
}
if base_cfg != pipe_cfg:
msg.fail("The base textcat model configuration does"
msg.fail(
"The base textcat model configuration does"
"not match the provided training options. "
"Existing cfg: {}, provided cfg: {}".format(
base_cfg, pipe_cfg
),
exits=1
exits=1,
)
else:
msg.text("Starting with blank model '{}'".format(lang))
@ -298,9 +299,10 @@ def train(
break
if base_model and set(textcat_labels) != train_labels:
msg.fail(
"Cannot extend textcat model using data with different "
"labels. Base model labels: {}, training data labels: "
"{}.".format(textcat_labels, list(train_labels)), exits=1
"Cannot extend textcat model using data with different "
"labels. Base model labels: {}, training data labels: "
"{}.".format(textcat_labels, list(train_labels)),
exits=1,
)
if textcat_multilabel:
msg.text(

View File

@ -201,7 +201,9 @@ _ukrainian = r"а-щюяіїєґА-ЩЮЯІЇЄҐ"
_upper = LATIN_UPPER + _russian_upper + _tatar_upper + _greek_upper + _ukrainian_upper
_lower = LATIN_LOWER + _russian_lower + _tatar_lower + _greek_lower + _ukrainian_lower
_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi + _kannada + _tamil + _telugu
_uncased = (
_bengali + _hebrew + _persian + _sinhala + _hindi + _kannada + _tamil + _telugu
)
ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased)
ALPHA_LOWER = group_chars(_lower + _uncased)

View File

@ -32,7 +32,6 @@ from .lang.tokenizer_exceptions import TOKEN_MATCH
from .lang.tag_map import TAG_MAP
from .lang.lex_attrs import LEX_ATTRS, is_stop
from .errors import Errors, Warnings, deprecation_warning
from .strings import hash_string
from . import util
from . import about

View File

@ -1,11 +1,9 @@
# coding: utf-8
from __future__ import unicode_literals
import numpy as np
from numpy.testing import assert_almost_equal, assert_array_almost_equal
import pytest
from pytest import approx
from spacy.errors import Errors
from spacy.gold import GoldParse
from spacy.scorer import Scorer, ROCAUCScore
from spacy.scorer import _roc_auc_score, _roc_curve
@ -81,7 +79,7 @@ def test_roc_auc_score():
roc_auc = _roc_auc_score(y_true, y_score)
assert_array_almost_equal(tpr, [0, 0, 1])
assert_array_almost_equal(fpr, [0, 1, 1])
assert_almost_equal(roc_auc, 1.)
assert_almost_equal(roc_auc, 1.0)
y_true = [0, 1]
y_score = [1, 0]
@ -89,7 +87,7 @@ def test_roc_auc_score():
roc_auc = _roc_auc_score(y_true, y_score)
assert_array_almost_equal(tpr, [0, 1, 1])
assert_array_almost_equal(fpr, [0, 0, 1])
assert_almost_equal(roc_auc, 0.)
assert_almost_equal(roc_auc, 0.0)
y_true = [1, 0]
y_score = [1, 1]
@ -105,7 +103,7 @@ def test_roc_auc_score():
roc_auc = _roc_auc_score(y_true, y_score)
assert_array_almost_equal(tpr, [0, 0, 1])
assert_array_almost_equal(fpr, [0, 1, 1])
assert_almost_equal(roc_auc, 1.)
assert_almost_equal(roc_auc, 1.0)
y_true = [1, 0]
y_score = [0.5, 0.5]
@ -113,14 +111,13 @@ def test_roc_auc_score():
roc_auc = _roc_auc_score(y_true, y_score)
assert_array_almost_equal(tpr, [0, 1])
assert_array_almost_equal(fpr, [0, 1])
assert_almost_equal(roc_auc, .5)
assert_almost_equal(roc_auc, 0.5)
# same result as above with ROCAUCScore wrapper
score = ROCAUCScore()
score.score_set(0.5, 1)
score.score_set(0.5, 0)
assert_almost_equal(score.score, .5)
assert_almost_equal(score.score, 0.5)
# check that errors are raised in undefined cases and score is -inf
y_true = [0, 0]