mirror of
https://github.com/explosion/spaCy.git
synced 2024-12-24 17:06:29 +03:00
Tidy up and auto-format [ci skip]
This commit is contained in:
parent
bc0730be3f
commit
2bc31e15c9
|
@ -2,7 +2,6 @@ from typing import Optional
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from wasabi import msg
|
from wasabi import msg
|
||||||
import subprocess
|
import subprocess
|
||||||
import shutil
|
|
||||||
import re
|
import re
|
||||||
|
|
||||||
from ... import about
|
from ... import about
|
||||||
|
|
|
@ -9,7 +9,7 @@ from wasabi import msg
|
||||||
@registry.loggers("spacy.ConsoleLogger.v1")
|
@registry.loggers("spacy.ConsoleLogger.v1")
|
||||||
def console_logger():
|
def console_logger():
|
||||||
def setup_printer(
|
def setup_printer(
|
||||||
nlp: "Language"
|
nlp: "Language",
|
||||||
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
|
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
|
||||||
score_cols = list(nlp.config["training"]["score_weights"])
|
score_cols = list(nlp.config["training"]["score_weights"])
|
||||||
score_widths = [max(len(col), 6) for col in score_cols]
|
score_widths = [max(len(col), 6) for col in score_cols]
|
||||||
|
@ -73,7 +73,7 @@ def wandb_logger(project_name: str, remove_config_values: List[str] = []):
|
||||||
console = console_logger()
|
console = console_logger()
|
||||||
|
|
||||||
def setup_logger(
|
def setup_logger(
|
||||||
nlp: "Language"
|
nlp: "Language",
|
||||||
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
|
) -> Tuple[Callable[[Dict[str, Any]], None], Callable]:
|
||||||
config = nlp.config.interpolate()
|
config = nlp.config.interpolate()
|
||||||
config_dot = util.dict_to_dot(config)
|
config_dot = util.dict_to_dot(config)
|
||||||
|
|
|
@ -242,6 +242,7 @@ class AttributeRuler(Pipe):
|
||||||
|
|
||||||
DOCS: https://spacy.io/api/attributeruler#from_bytes
|
DOCS: https://spacy.io/api/attributeruler#from_bytes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load_patterns(b):
|
def load_patterns(b):
|
||||||
self.add_patterns(srsly.msgpack_loads(b))
|
self.add_patterns(srsly.msgpack_loads(b))
|
||||||
|
|
||||||
|
@ -275,6 +276,7 @@ class AttributeRuler(Pipe):
|
||||||
exclude (Iterable[str]): String names of serialization fields to exclude.
|
exclude (Iterable[str]): String names of serialization fields to exclude.
|
||||||
DOCS: https://spacy.io/api/attributeruler#from_disk
|
DOCS: https://spacy.io/api/attributeruler#from_disk
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def load_patterns(p):
|
def load_patterns(p):
|
||||||
self.add_patterns(srsly.read_msgpack(p))
|
self.add_patterns(srsly.read_msgpack(p))
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@ from typing import Optional, Iterable, Callable, Dict, Iterator, Union, List, Tu
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
import srsly
|
import srsly
|
||||||
import random
|
import random
|
||||||
from thinc.api import CosineDistance, get_array_module, Model, Optimizer, Config
|
from thinc.api import CosineDistance, Model, Optimizer, Config
|
||||||
from thinc.api import set_dropout_rate
|
from thinc.api import set_dropout_rate
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
|
|
|
@ -104,7 +104,11 @@ def test_attributeruler_score(nlp, pattern_dicts):
|
||||||
assert doc[3].lemma_ == "cat"
|
assert doc[3].lemma_ == "cat"
|
||||||
assert doc[3].morph_ == "Case=Nom|Number=Sing"
|
assert doc[3].morph_ == "Case=Nom|Number=Sing"
|
||||||
|
|
||||||
dev_examples = [Example.from_dict(nlp.make_doc("This is a test."), {"lemmas": ["this", "is", "a", "cat", "."]})]
|
dev_examples = [
|
||||||
|
Example.from_dict(
|
||||||
|
nlp.make_doc("This is a test."), {"lemmas": ["this", "is", "a", "cat", "."]}
|
||||||
|
)
|
||||||
|
]
|
||||||
scores = nlp.evaluate(dev_examples)
|
scores = nlp.evaluate(dev_examples)
|
||||||
# "cat" is the only correct lemma
|
# "cat" is the only correct lemma
|
||||||
assert scores["lemma_acc"] == pytest.approx(0.2)
|
assert scores["lemma_acc"] == pytest.approx(0.2)
|
||||||
|
@ -115,20 +119,14 @@ def test_attributeruler_score(nlp, pattern_dicts):
|
||||||
def test_attributeruler_rule_order(nlp):
|
def test_attributeruler_rule_order(nlp):
|
||||||
a = AttributeRuler(nlp.vocab)
|
a = AttributeRuler(nlp.vocab)
|
||||||
patterns = [
|
patterns = [
|
||||||
{
|
{"patterns": [[{"TAG": "VBZ"}]], "attrs": {"POS": "VERB"}},
|
||||||
"patterns": [[{"TAG": "VBZ"}]],
|
{"patterns": [[{"TAG": "VBZ"}]], "attrs": {"POS": "NOUN"}},
|
||||||
"attrs": {"POS": "VERB"},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"patterns": [[{"TAG": "VBZ"}]],
|
|
||||||
"attrs": {"POS": "NOUN"},
|
|
||||||
},
|
|
||||||
]
|
]
|
||||||
a.add_patterns(patterns)
|
a.add_patterns(patterns)
|
||||||
doc = get_doc(
|
doc = get_doc(
|
||||||
nlp.vocab,
|
nlp.vocab,
|
||||||
words=["This", "is", "a", "test", "."],
|
words=["This", "is", "a", "test", "."],
|
||||||
tags=["DT", "VBZ", "DT", "NN", "."]
|
tags=["DT", "VBZ", "DT", "NN", "."],
|
||||||
)
|
)
|
||||||
doc = a(doc)
|
doc = a(doc)
|
||||||
assert doc[1].pos_ == "NOUN"
|
assert doc[1].pos_ == "NOUN"
|
||||||
|
|
|
@ -373,8 +373,7 @@ def test_parse_config_overrides(args, expected):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"args",
|
"args", [["--foo"], ["--x.foo", "bar", "--baz"]],
|
||||||
[["--foo"], ["--x.foo", "bar", "--baz"]],
|
|
||||||
)
|
)
|
||||||
def test_parse_config_overrides_invalid(args):
|
def test_parse_config_overrides_invalid(args):
|
||||||
with pytest.raises(NoSuchOption):
|
with pytest.raises(NoSuchOption):
|
||||||
|
@ -382,8 +381,7 @@ def test_parse_config_overrides_invalid(args):
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"args",
|
"args", [["--x.foo", "bar", "baz"], ["x.foo"]],
|
||||||
[["--x.foo", "bar", "baz"], ["x.foo"]],
|
|
||||||
)
|
)
|
||||||
def test_parse_config_overrides_invalid_2(args):
|
def test_parse_config_overrides_invalid_2(args):
|
||||||
with pytest.raises(SystemExit):
|
with pytest.raises(SystemExit):
|
||||||
|
|
Loading…
Reference in New Issue
Block a user