Tidy up and auto-format

This commit is contained in:
Ines Montani 2020-10-03 17:20:18 +02:00
parent 7c4ab7e82c
commit 3bc3c05fcc
14 changed files with 40 additions and 26 deletions

View File

@ -171,7 +171,7 @@ def debug_data(
n_missing_vectors = sum(gold_train_data["words_missing_vectors"].values())
msg.warn(
"{} words in training data without vectors ({:0.2f}%)".format(
n_missing_vectors, n_missing_vectors / gold_train_data["n_words"],
n_missing_vectors, n_missing_vectors / gold_train_data["n_words"]
),
)
msg.text(

View File

@ -8,7 +8,6 @@ from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS
from .lemmatizer import PolishLemmatizer
from ..tokenizer_exceptions import BASE_EXCEPTIONS
from ...lookups import Lookups
from ...language import Language

View File

@ -47,7 +47,7 @@ class Segmenter(str, Enum):
@registry.tokenizers("spacy.zh.ChineseTokenizer")
def create_chinese_tokenizer(segmenter: Segmenter = Segmenter.char,):
def create_chinese_tokenizer(segmenter: Segmenter = Segmenter.char):
def chinese_tokenizer_factory(nlp):
return ChineseTokenizer(nlp, segmenter=segmenter)

View File

@ -165,8 +165,12 @@ def MultiHashEmbed(
@registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed(
width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool,
feature: Union[int, str]="LOWER"
width: int,
rows: int,
nM: int,
nC: int,
also_use_static_vectors: bool,
feature: Union[int, str] = "LOWER",
) -> Model[List[Doc], List[Floats2d]]:
"""Construct an embedded representation based on character embeddings, using
a feed-forward network. A fixed number of UTF-8 byte characters are used for

View File

@ -70,7 +70,7 @@ subword_features = true
},
)
def make_textcat(
nlp: Language, name: str, model: Model[List[Doc], List[Floats2d]], threshold: float,
nlp: Language, name: str, model: Model[List[Doc], List[Floats2d]], threshold: float
) -> "TextCategorizer":
"""Create a TextCategorizer compoment. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels can

View File

@ -294,7 +294,8 @@ def zh_tokenizer_pkuseg():
"segmenter": "pkuseg",
}
},
"initialize": {"tokenizer": {
"initialize": {
"tokenizer": {
"pkuseg_model": "default",
}
},

View File

@ -5,12 +5,14 @@ import pytest
def i_has(en_tokenizer):
doc = en_tokenizer("I has")
doc[0].set_morph({"PronType": "prs"})
doc[1].set_morph({
"VerbForm": "fin",
"Tense": "pres",
"Number": "sing",
"Person": "three",
})
doc[1].set_morph(
{
"VerbForm": "fin",
"Tense": "pres",
"Number": "sing",
"Person": "three",
}
)
return doc

View File

@ -34,7 +34,8 @@ def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg):
"segmenter": "pkuseg",
}
},
"initialize": {"tokenizer": {
"initialize": {
"tokenizer": {
"pkuseg_model": "medicine",
}
},

View File

@ -139,7 +139,8 @@ def test_overfitting_IO():
nlp = English()
nlp.config["initialize"]["components"]["textcat"] = {"positive_label": "POSITIVE"}
# Set exclusive labels
textcat = nlp.add_pipe("textcat", config={"model": {"exclusive_classes": True}},)
config = {"model": {"exclusive_classes": True}}
textcat = nlp.add_pipe("textcat", config=config)
train_examples = []
for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
@ -226,7 +227,9 @@ def test_positive_class_not_binary():
textcat = nlp.add_pipe("textcat")
get_examples = make_get_examples(nlp)
with pytest.raises(ValueError):
textcat.initialize(get_examples, labels=["SOME", "THING", "POS"], positive_label="POS")
textcat.initialize(
get_examples, labels=["SOME", "THING", "POS"], positive_label="POS"
)
def test_textcat_evaluation():

View File

@ -92,7 +92,13 @@ def test_serialize_doc_bin_unknown_spaces(en_vocab):
@pytest.mark.parametrize(
"writer_flag,reader_flag,reader_value", [(True, True, "bar"), (True, False, "bar"), (False, True, "nothing"), (False, False, "nothing")]
"writer_flag,reader_flag,reader_value",
[
(True, True, "bar"),
(True, False, "bar"),
(False, True, "nothing"),
(False, False, "nothing"),
],
)
def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_value):
"""Test that custom extensions are correctly serialized in DocBin."""

View File

@ -158,7 +158,7 @@ def test_las_per_type(en_vocab):
examples = []
for input_, annot in test_las_apple:
doc = Doc(
en_vocab, words=input_.split(" "), heads=annot["heads"], deps=annot["deps"],
en_vocab, words=input_.split(" "), heads=annot["heads"], deps=annot["deps"]
)
gold = {"heads": annot["heads"], "deps": annot["deps"]}
doc[0].dep_ = "compound"
@ -182,9 +182,7 @@ def test_ner_per_type(en_vocab):
examples = []
for input_, annot in test_ner_cardinal:
doc = Doc(
en_vocab,
words=input_.split(" "),
ents=["B-CARDINAL", "O", "B-CARDINAL"],
en_vocab, words=input_.split(" "), ents=["B-CARDINAL", "O", "B-CARDINAL"]
)
entities = offsets_to_biluo_tags(doc, annot["entities"])
example = Example.from_dict(doc, {"entities": entities})

View File

@ -30,7 +30,7 @@ class OrthVariants(BaseModel):
@registry.augmenters("spacy.orth_variants.v1")
def create_orth_variants_augmenter(
level: float, lower: float, orth_variants: OrthVariants,
level: float, lower: float, orth_variants: OrthVariants
) -> Callable[["Language", Example], Iterator[Example]]:
"""Create a data augmentation callback that uses orth-variant replacement.
The callback can be added to a corpus or other data iterator during training.

View File

@ -21,8 +21,8 @@ def train(
output_path: Optional[Path] = None,
*,
use_gpu: int = -1,
stdout: IO=sys.stdout,
stderr: IO=sys.stderr
stdout: IO = sys.stdout,
stderr: IO = sys.stderr,
) -> None:
"""Train a pipeline.
@ -34,7 +34,7 @@ def train(
printing, set to io.StringIO.
stderr (file): A second file-like object to write output messages. To disable
printing, set to io.StringIO.
RETURNS (Path / None): The path to the final exported model.
"""
# We use no_print here so we can respect the stdout/stderr options.

View File

@ -16,7 +16,7 @@ from .errors import Errors
from .attrs import intify_attrs, NORM, IS_STOP
from .vectors import Vectors
from .util import registry
from .lookups import Lookups, load_lookups
from .lookups import Lookups
from . import util
from .lang.norm_exceptions import BASE_NORMS
from .lang.lex_attrs import LEX_ATTRS, is_stop, get_lang