Tidy up and auto-format

This commit is contained in:
Ines Montani 2019-09-18 20:27:03 +02:00
parent f2c8b1e362
commit 00a8cbc306
5 changed files with 20 additions and 20 deletions

View File

@ -293,12 +293,13 @@ def debug_data(
"The train and dev labels are not the same. " "The train and dev labels are not the same. "
"Train labels: {}. " "Train labels: {}. "
"Dev labels: {}.".format( "Dev labels: {}.".format(
_format_labels(gold_train_data["cats"]), _format_labels(gold_train_data["cats"]),
_format_labels(gold_dev_data["cats"]), _format_labels(gold_dev_data["cats"]),
) )
) )
if gold_train_data["n_cats_multilabel"] > 0: if gold_train_data["n_cats_multilabel"] > 0:
msg.info("The train data contains instances without " msg.info(
"The train data contains instances without "
"mutually-exclusive classes. Use '--textcat-multilabel' " "mutually-exclusive classes. Use '--textcat-multilabel' "
"when training." "when training."
) )
@ -481,7 +482,6 @@ def debug_data(
) )
) )
msg.divider("Summary") msg.divider("Summary")
good_counts = msg.counts[MESSAGES.GOOD] good_counts = msg.counts[MESSAGES.GOOD]
warn_counts = msg.counts[MESSAGES.WARN] warn_counts = msg.counts[MESSAGES.WARN]

View File

@ -182,7 +182,7 @@ def train(
base_cfg = { base_cfg = {
"exclusive_classes": textcat_cfg["exclusive_classes"], "exclusive_classes": textcat_cfg["exclusive_classes"],
"architecture": textcat_cfg["architecture"], "architecture": textcat_cfg["architecture"],
"positive_label": textcat_cfg["positive_label"] "positive_label": textcat_cfg["positive_label"],
} }
pipe_cfg = { pipe_cfg = {
"exclusive_classes": not textcat_multilabel, "exclusive_classes": not textcat_multilabel,
@ -190,12 +190,13 @@ def train(
"positive_label": textcat_positive_label, "positive_label": textcat_positive_label,
} }
if base_cfg != pipe_cfg: if base_cfg != pipe_cfg:
msg.fail("The base textcat model configuration does" msg.fail(
"The base textcat model configuration does"
"not match the provided training options. " "not match the provided training options. "
"Existing cfg: {}, provided cfg: {}".format( "Existing cfg: {}, provided cfg: {}".format(
base_cfg, pipe_cfg base_cfg, pipe_cfg
), ),
exits=1 exits=1,
) )
else: else:
msg.text("Starting with blank model '{}'".format(lang)) msg.text("Starting with blank model '{}'".format(lang))
@ -298,9 +299,10 @@ def train(
break break
if base_model and set(textcat_labels) != train_labels: if base_model and set(textcat_labels) != train_labels:
msg.fail( msg.fail(
"Cannot extend textcat model using data with different " "Cannot extend textcat model using data with different "
"labels. Base model labels: {}, training data labels: " "labels. Base model labels: {}, training data labels: "
"{}.".format(textcat_labels, list(train_labels)), exits=1 "{}.".format(textcat_labels, list(train_labels)),
exits=1,
) )
if textcat_multilabel: if textcat_multilabel:
msg.text( msg.text(

View File

@ -201,7 +201,9 @@ _ukrainian = r"а-щюяіїєґА-ЩЮЯІЇЄҐ"
_upper = LATIN_UPPER + _russian_upper + _tatar_upper + _greek_upper + _ukrainian_upper _upper = LATIN_UPPER + _russian_upper + _tatar_upper + _greek_upper + _ukrainian_upper
_lower = LATIN_LOWER + _russian_lower + _tatar_lower + _greek_lower + _ukrainian_lower _lower = LATIN_LOWER + _russian_lower + _tatar_lower + _greek_lower + _ukrainian_lower
_uncased = _bengali + _hebrew + _persian + _sinhala + _hindi + _kannada + _tamil + _telugu _uncased = (
_bengali + _hebrew + _persian + _sinhala + _hindi + _kannada + _tamil + _telugu
)
ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased) ALPHA = group_chars(LATIN + _russian + _tatar + _greek + _ukrainian + _uncased)
ALPHA_LOWER = group_chars(_lower + _uncased) ALPHA_LOWER = group_chars(_lower + _uncased)

View File

@ -32,7 +32,6 @@ from .lang.tokenizer_exceptions import TOKEN_MATCH
from .lang.tag_map import TAG_MAP from .lang.tag_map import TAG_MAP
from .lang.lex_attrs import LEX_ATTRS, is_stop from .lang.lex_attrs import LEX_ATTRS, is_stop
from .errors import Errors, Warnings, deprecation_warning from .errors import Errors, Warnings, deprecation_warning
from .strings import hash_string
from . import util from . import util
from . import about from . import about

View File

@ -1,11 +1,9 @@
# coding: utf-8 # coding: utf-8
from __future__ import unicode_literals from __future__ import unicode_literals
import numpy as np
from numpy.testing import assert_almost_equal, assert_array_almost_equal from numpy.testing import assert_almost_equal, assert_array_almost_equal
import pytest import pytest
from pytest import approx from pytest import approx
from spacy.errors import Errors
from spacy.gold import GoldParse from spacy.gold import GoldParse
from spacy.scorer import Scorer, ROCAUCScore from spacy.scorer import Scorer, ROCAUCScore
from spacy.scorer import _roc_auc_score, _roc_curve from spacy.scorer import _roc_auc_score, _roc_curve
@ -81,7 +79,7 @@ def test_roc_auc_score():
roc_auc = _roc_auc_score(y_true, y_score) roc_auc = _roc_auc_score(y_true, y_score)
assert_array_almost_equal(tpr, [0, 0, 1]) assert_array_almost_equal(tpr, [0, 0, 1])
assert_array_almost_equal(fpr, [0, 1, 1]) assert_array_almost_equal(fpr, [0, 1, 1])
assert_almost_equal(roc_auc, 1.) assert_almost_equal(roc_auc, 1.0)
y_true = [0, 1] y_true = [0, 1]
y_score = [1, 0] y_score = [1, 0]
@ -89,7 +87,7 @@ def test_roc_auc_score():
roc_auc = _roc_auc_score(y_true, y_score) roc_auc = _roc_auc_score(y_true, y_score)
assert_array_almost_equal(tpr, [0, 1, 1]) assert_array_almost_equal(tpr, [0, 1, 1])
assert_array_almost_equal(fpr, [0, 0, 1]) assert_array_almost_equal(fpr, [0, 0, 1])
assert_almost_equal(roc_auc, 0.) assert_almost_equal(roc_auc, 0.0)
y_true = [1, 0] y_true = [1, 0]
y_score = [1, 1] y_score = [1, 1]
@ -105,7 +103,7 @@ def test_roc_auc_score():
roc_auc = _roc_auc_score(y_true, y_score) roc_auc = _roc_auc_score(y_true, y_score)
assert_array_almost_equal(tpr, [0, 0, 1]) assert_array_almost_equal(tpr, [0, 0, 1])
assert_array_almost_equal(fpr, [0, 1, 1]) assert_array_almost_equal(fpr, [0, 1, 1])
assert_almost_equal(roc_auc, 1.) assert_almost_equal(roc_auc, 1.0)
y_true = [1, 0] y_true = [1, 0]
y_score = [0.5, 0.5] y_score = [0.5, 0.5]
@ -113,14 +111,13 @@ def test_roc_auc_score():
roc_auc = _roc_auc_score(y_true, y_score) roc_auc = _roc_auc_score(y_true, y_score)
assert_array_almost_equal(tpr, [0, 1]) assert_array_almost_equal(tpr, [0, 1])
assert_array_almost_equal(fpr, [0, 1]) assert_array_almost_equal(fpr, [0, 1])
assert_almost_equal(roc_auc, .5) assert_almost_equal(roc_auc, 0.5)
# same result as above with ROCAUCScore wrapper # same result as above with ROCAUCScore wrapper
score = ROCAUCScore() score = ROCAUCScore()
score.score_set(0.5, 1) score.score_set(0.5, 1)
score.score_set(0.5, 0) score.score_set(0.5, 0)
assert_almost_equal(score.score, .5) assert_almost_equal(score.score, 0.5)
# check that errors are raised in undefined cases and score is -inf # check that errors are raised in undefined cases and score is -inf
y_true = [0, 0] y_true = [0, 0]