Tidy up and auto-format [ci skip]

This commit is contained in:
Ines Montani 2020-08-29 13:01:10 +02:00
parent bc0730be3f
commit 2bc31e15c9
6 changed files with 15 additions and 18 deletions

View File

@ -2,7 +2,6 @@ from typing import Optional
from pathlib import Path
from wasabi import msg
import subprocess
import shutil
import re
from ... import about

View File

@ -9,7 +9,7 @@ from wasabi import msg
@registry.loggers("spacy.ConsoleLogger.v1")
def console_logger():
def setup_printer(
nlp: "Language"
nlp: "Language",
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
score_cols = list(nlp.config["training"]["score_weights"])
score_widths = [max(len(col), 6) for col in score_cols]
@ -73,7 +73,7 @@ def wandb_logger(project_name: str, remove_config_values: List[str] = []):
console = console_logger()
def setup_logger(
nlp: "Language"
nlp: "Language",
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
config = nlp.config.interpolate()
config_dot = util.dict_to_dot(config)

View File

@ -242,6 +242,7 @@ class AttributeRuler(Pipe):
DOCS: https://spacy.io/api/attributeruler#from_bytes
"""
def load_patterns(b):
self.add_patterns(srsly.msgpack_loads(b))
@ -275,6 +276,7 @@ class AttributeRuler(Pipe):
exclude (Iterable[str]): String names of serialization fields to exclude.
DOCS: https://spacy.io/api/attributeruler#from_disk
"""
def load_patterns(p):
self.add_patterns(srsly.read_msgpack(p))

View File

@ -2,7 +2,7 @@ from typing import Optional, Iterable, Callable, Dict, Iterator, Union, List, Tu
from pathlib import Path
import srsly
import random
from thinc.api import CosineDistance, get_array_module, Model, Optimizer, Config
from thinc.api import CosineDistance, Model, Optimizer, Config
from thinc.api import set_dropout_rate
import warnings

View File

@ -104,7 +104,11 @@ def test_attributeruler_score(nlp, pattern_dicts):
assert doc[3].lemma_ == "cat"
assert doc[3].morph_ == "Case=Nom|Number=Sing"
dev_examples = [Example.from_dict(nlp.make_doc("This is a test."), {"lemmas": ["this", "is", "a", "cat", "."]})]
dev_examples = [
Example.from_dict(
nlp.make_doc("This is a test."), {"lemmas": ["this", "is", "a", "cat", "."]}
)
]
scores = nlp.evaluate(dev_examples)
# "cat" is the only correct lemma
assert scores["lemma_acc"] == pytest.approx(0.2)
@ -115,20 +119,14 @@ def test_attributeruler_score(nlp, pattern_dicts):
def test_attributeruler_rule_order(nlp):
a = AttributeRuler(nlp.vocab)
patterns = [
{
"patterns": [[{"TAG": "VBZ"}]],
"attrs": {"POS": "VERB"},
},
{
"patterns": [[{"TAG": "VBZ"}]],
"attrs": {"POS": "NOUN"},
},
{"patterns": [[{"TAG": "VBZ"}]], "attrs": {"POS": "VERB"}},
{"patterns": [[{"TAG": "VBZ"}]], "attrs": {"POS": "NOUN"}},
]
a.add_patterns(patterns)
doc = get_doc(
nlp.vocab,
words=["This", "is", "a", "test", "."],
tags=["DT", "VBZ", "DT", "NN", "."]
tags=["DT", "VBZ", "DT", "NN", "."],
)
doc = a(doc)
assert doc[1].pos_ == "NOUN"

View File

@ -373,8 +373,7 @@ def test_parse_config_overrides(args, expected):
@pytest.mark.parametrize(
"args",
[["--foo"], ["--x.foo", "bar", "--baz"]],
"args", [["--foo"], ["--x.foo", "bar", "--baz"]],
)
def test_parse_config_overrides_invalid(args):
with pytest.raises(NoSuchOption):
@ -382,8 +381,7 @@ def test_parse_config_overrides_invalid(args):
@pytest.mark.parametrize(
"args",
[["--x.foo", "bar", "baz"], ["x.foo"]],
"args", [["--x.foo", "bar", "baz"], ["x.foo"]],
)
def test_parse_config_overrides_invalid_2(args):
with pytest.raises(SystemExit):