Tidy up and auto-format [ci skip]

This commit is contained in:
Ines Montani 2020-08-29 13:01:10 +02:00
parent bc0730be3f
commit 2bc31e15c9
6 changed files with 15 additions and 18 deletions

View File

@ -2,7 +2,6 @@ from typing import Optional
from pathlib import Path from pathlib import Path
from wasabi import msg from wasabi import msg
import subprocess import subprocess
import shutil
import re import re
from ... import about from ... import about

View File

@ -9,7 +9,7 @@ from wasabi import msg
@registry.loggers("spacy.ConsoleLogger.v1") @registry.loggers("spacy.ConsoleLogger.v1")
def console_logger(): def console_logger():
def setup_printer( def setup_printer(
nlp: "Language" nlp: "Language",
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]: ) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
score_cols = list(nlp.config["training"]["score_weights"]) score_cols = list(nlp.config["training"]["score_weights"])
score_widths = [max(len(col), 6) for col in score_cols] score_widths = [max(len(col), 6) for col in score_cols]
@ -73,7 +73,7 @@ def wandb_logger(project_name: str, remove_config_values: List[str] = []):
console = console_logger() console = console_logger()
def setup_logger( def setup_logger(
nlp: "Language" nlp: "Language",
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]: ) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
config = nlp.config.interpolate() config = nlp.config.interpolate()
config_dot = util.dict_to_dot(config) config_dot = util.dict_to_dot(config)

View File

@ -242,6 +242,7 @@ class AttributeRuler(Pipe):
DOCS: https://spacy.io/api/attributeruler#from_bytes DOCS: https://spacy.io/api/attributeruler#from_bytes
""" """
def load_patterns(b): def load_patterns(b):
self.add_patterns(srsly.msgpack_loads(b)) self.add_patterns(srsly.msgpack_loads(b))
@ -275,6 +276,7 @@ class AttributeRuler(Pipe):
exclude (Iterable[str]): String names of serialization fields to exclude. exclude (Iterable[str]): String names of serialization fields to exclude.
DOCS: https://spacy.io/api/attributeruler#from_disk DOCS: https://spacy.io/api/attributeruler#from_disk
""" """
def load_patterns(p): def load_patterns(p):
self.add_patterns(srsly.read_msgpack(p)) self.add_patterns(srsly.read_msgpack(p))

View File

@ -2,7 +2,7 @@ from typing import Optional, Iterable, Callable, Dict, Iterator, Union, List, Tu
from pathlib import Path from pathlib import Path
import srsly import srsly
import random import random
from thinc.api import CosineDistance, get_array_module, Model, Optimizer, Config from thinc.api import CosineDistance, Model, Optimizer, Config
from thinc.api import set_dropout_rate from thinc.api import set_dropout_rate
import warnings import warnings

View File

@ -104,7 +104,11 @@ def test_attributeruler_score(nlp, pattern_dicts):
assert doc[3].lemma_ == "cat" assert doc[3].lemma_ == "cat"
assert doc[3].morph_ == "Case=Nom|Number=Sing" assert doc[3].morph_ == "Case=Nom|Number=Sing"
dev_examples = [Example.from_dict(nlp.make_doc("This is a test."), {"lemmas": ["this", "is", "a", "cat", "."]})] dev_examples = [
Example.from_dict(
nlp.make_doc("This is a test."), {"lemmas": ["this", "is", "a", "cat", "."]}
)
]
scores = nlp.evaluate(dev_examples) scores = nlp.evaluate(dev_examples)
# "cat" is the only correct lemma # "cat" is the only correct lemma
assert scores["lemma_acc"] == pytest.approx(0.2) assert scores["lemma_acc"] == pytest.approx(0.2)
@ -115,20 +119,14 @@ def test_attributeruler_score(nlp, pattern_dicts):
def test_attributeruler_rule_order(nlp): def test_attributeruler_rule_order(nlp):
a = AttributeRuler(nlp.vocab) a = AttributeRuler(nlp.vocab)
patterns = [ patterns = [
{ {"patterns": [[{"TAG": "VBZ"}]], "attrs": {"POS": "VERB"}},
"patterns": [[{"TAG": "VBZ"}]], {"patterns": [[{"TAG": "VBZ"}]], "attrs": {"POS": "NOUN"}},
"attrs": {"POS": "VERB"},
},
{
"patterns": [[{"TAG": "VBZ"}]],
"attrs": {"POS": "NOUN"},
},
] ]
a.add_patterns(patterns) a.add_patterns(patterns)
doc = get_doc( doc = get_doc(
nlp.vocab, nlp.vocab,
words=["This", "is", "a", "test", "."], words=["This", "is", "a", "test", "."],
tags=["DT", "VBZ", "DT", "NN", "."] tags=["DT", "VBZ", "DT", "NN", "."],
) )
doc = a(doc) doc = a(doc)
assert doc[1].pos_ == "NOUN" assert doc[1].pos_ == "NOUN"

View File

@ -373,8 +373,7 @@ def test_parse_config_overrides(args, expected):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"args", "args", [["--foo"], ["--x.foo", "bar", "--baz"]],
[["--foo"], ["--x.foo", "bar", "--baz"]],
) )
def test_parse_config_overrides_invalid(args): def test_parse_config_overrides_invalid(args):
with pytest.raises(NoSuchOption): with pytest.raises(NoSuchOption):
@ -382,8 +381,7 @@ def test_parse_config_overrides_invalid(args):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"args", "args", [["--x.foo", "bar", "baz"], ["x.foo"]],
[["--x.foo", "bar", "baz"], ["x.foo"]],
) )
def test_parse_config_overrides_invalid_2(args): def test_parse_config_overrides_invalid_2(args):
with pytest.raises(SystemExit): with pytest.raises(SystemExit):