Tidy up and auto-format

This commit is contained in:
Ines Montani 2020-10-03 17:20:18 +02:00
parent 7c4ab7e82c
commit 3bc3c05fcc
14 changed files with 40 additions and 26 deletions

View File

@ -171,7 +171,7 @@ def debug_data(
n_missing_vectors = sum(gold_train_data["words_missing_vectors"].values()) n_missing_vectors = sum(gold_train_data["words_missing_vectors"].values())
msg.warn( msg.warn(
"{} words in training data without vectors ({:0.2f}%)".format( "{} words in training data without vectors ({:0.2f}%)".format(
n_missing_vectors, n_missing_vectors / gold_train_data["n_words"], n_missing_vectors, n_missing_vectors / gold_train_data["n_words"]
), ),
) )
msg.text( msg.text(

View File

@ -8,7 +8,6 @@ from .stop_words import STOP_WORDS
from .lex_attrs import LEX_ATTRS from .lex_attrs import LEX_ATTRS
from .lemmatizer import PolishLemmatizer from .lemmatizer import PolishLemmatizer
from ..tokenizer_exceptions import BASE_EXCEPTIONS from ..tokenizer_exceptions import BASE_EXCEPTIONS
from ...lookups import Lookups
from ...language import Language from ...language import Language

View File

@ -47,7 +47,7 @@ class Segmenter(str, Enum):
@registry.tokenizers("spacy.zh.ChineseTokenizer") @registry.tokenizers("spacy.zh.ChineseTokenizer")
def create_chinese_tokenizer(segmenter: Segmenter = Segmenter.char,): def create_chinese_tokenizer(segmenter: Segmenter = Segmenter.char):
def chinese_tokenizer_factory(nlp): def chinese_tokenizer_factory(nlp):
return ChineseTokenizer(nlp, segmenter=segmenter) return ChineseTokenizer(nlp, segmenter=segmenter)

View File

@ -165,8 +165,12 @@ def MultiHashEmbed(
@registry.architectures.register("spacy.CharacterEmbed.v1") @registry.architectures.register("spacy.CharacterEmbed.v1")
def CharacterEmbed( def CharacterEmbed(
width: int, rows: int, nM: int, nC: int, also_use_static_vectors: bool, width: int,
feature: Union[int, str]="LOWER" rows: int,
nM: int,
nC: int,
also_use_static_vectors: bool,
feature: Union[int, str] = "LOWER",
) -> Model[List[Doc], List[Floats2d]]: ) -> Model[List[Doc], List[Floats2d]]:
"""Construct an embedded representation based on character embeddings, using """Construct an embedded representation based on character embeddings, using
a feed-forward network. A fixed number of UTF-8 byte characters are used for a feed-forward network. A fixed number of UTF-8 byte characters are used for

View File

@ -70,7 +70,7 @@ subword_features = true
}, },
) )
def make_textcat( def make_textcat(
nlp: Language, name: str, model: Model[List[Doc], List[Floats2d]], threshold: float, nlp: Language, name: str, model: Model[List[Doc], List[Floats2d]], threshold: float
) -> "TextCategorizer": ) -> "TextCategorizer":
"""Create a TextCategorizer compoment. The text categorizer predicts categories """Create a TextCategorizer compoment. The text categorizer predicts categories
over a whole document. It can learn one or more labels, and the labels can over a whole document. It can learn one or more labels, and the labels can

View File

@ -294,7 +294,8 @@ def zh_tokenizer_pkuseg():
"segmenter": "pkuseg", "segmenter": "pkuseg",
} }
}, },
"initialize": {"tokenizer": { "initialize": {
"tokenizer": {
"pkuseg_model": "default", "pkuseg_model": "default",
} }
}, },

View File

@ -5,12 +5,14 @@ import pytest
def i_has(en_tokenizer): def i_has(en_tokenizer):
doc = en_tokenizer("I has") doc = en_tokenizer("I has")
doc[0].set_morph({"PronType": "prs"}) doc[0].set_morph({"PronType": "prs"})
doc[1].set_morph({ doc[1].set_morph(
"VerbForm": "fin", {
"Tense": "pres", "VerbForm": "fin",
"Number": "sing", "Tense": "pres",
"Person": "three", "Number": "sing",
}) "Person": "three",
}
)
return doc return doc

View File

@ -34,7 +34,8 @@ def test_zh_tokenizer_serialize_pkuseg_with_processors(zh_tokenizer_pkuseg):
"segmenter": "pkuseg", "segmenter": "pkuseg",
} }
}, },
"initialize": {"tokenizer": { "initialize": {
"tokenizer": {
"pkuseg_model": "medicine", "pkuseg_model": "medicine",
} }
}, },

View File

@ -139,7 +139,8 @@ def test_overfitting_IO():
nlp = English() nlp = English()
nlp.config["initialize"]["components"]["textcat"] = {"positive_label": "POSITIVE"} nlp.config["initialize"]["components"]["textcat"] = {"positive_label": "POSITIVE"}
# Set exclusive labels # Set exclusive labels
textcat = nlp.add_pipe("textcat", config={"model": {"exclusive_classes": True}},) config = {"model": {"exclusive_classes": True}}
textcat = nlp.add_pipe("textcat", config=config)
train_examples = [] train_examples = []
for text, annotations in TRAIN_DATA: for text, annotations in TRAIN_DATA:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations)) train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
@ -226,7 +227,9 @@ def test_positive_class_not_binary():
textcat = nlp.add_pipe("textcat") textcat = nlp.add_pipe("textcat")
get_examples = make_get_examples(nlp) get_examples = make_get_examples(nlp)
with pytest.raises(ValueError): with pytest.raises(ValueError):
textcat.initialize(get_examples, labels=["SOME", "THING", "POS"], positive_label="POS") textcat.initialize(
get_examples, labels=["SOME", "THING", "POS"], positive_label="POS"
)
def test_textcat_evaluation(): def test_textcat_evaluation():

View File

@ -92,7 +92,13 @@ def test_serialize_doc_bin_unknown_spaces(en_vocab):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"writer_flag,reader_flag,reader_value", [(True, True, "bar"), (True, False, "bar"), (False, True, "nothing"), (False, False, "nothing")] "writer_flag,reader_flag,reader_value",
[
(True, True, "bar"),
(True, False, "bar"),
(False, True, "nothing"),
(False, False, "nothing"),
],
) )
def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_value): def test_serialize_custom_extension(en_vocab, writer_flag, reader_flag, reader_value):
"""Test that custom extensions are correctly serialized in DocBin.""" """Test that custom extensions are correctly serialized in DocBin."""

View File

@ -158,7 +158,7 @@ def test_las_per_type(en_vocab):
examples = [] examples = []
for input_, annot in test_las_apple: for input_, annot in test_las_apple:
doc = Doc( doc = Doc(
en_vocab, words=input_.split(" "), heads=annot["heads"], deps=annot["deps"], en_vocab, words=input_.split(" "), heads=annot["heads"], deps=annot["deps"]
) )
gold = {"heads": annot["heads"], "deps": annot["deps"]} gold = {"heads": annot["heads"], "deps": annot["deps"]}
doc[0].dep_ = "compound" doc[0].dep_ = "compound"
@ -182,9 +182,7 @@ def test_ner_per_type(en_vocab):
examples = [] examples = []
for input_, annot in test_ner_cardinal: for input_, annot in test_ner_cardinal:
doc = Doc( doc = Doc(
en_vocab, en_vocab, words=input_.split(" "), ents=["B-CARDINAL", "O", "B-CARDINAL"]
words=input_.split(" "),
ents=["B-CARDINAL", "O", "B-CARDINAL"],
) )
entities = offsets_to_biluo_tags(doc, annot["entities"]) entities = offsets_to_biluo_tags(doc, annot["entities"])
example = Example.from_dict(doc, {"entities": entities}) example = Example.from_dict(doc, {"entities": entities})

View File

@ -30,7 +30,7 @@ class OrthVariants(BaseModel):
@registry.augmenters("spacy.orth_variants.v1") @registry.augmenters("spacy.orth_variants.v1")
def create_orth_variants_augmenter( def create_orth_variants_augmenter(
level: float, lower: float, orth_variants: OrthVariants, level: float, lower: float, orth_variants: OrthVariants
) -> Callable[["Language", Example], Iterator[Example]]: ) -> Callable[["Language", Example], Iterator[Example]]:
"""Create a data augmentation callback that uses orth-variant replacement. """Create a data augmentation callback that uses orth-variant replacement.
The callback can be added to a corpus or other data iterator during training. The callback can be added to a corpus or other data iterator during training.

View File

@ -21,8 +21,8 @@ def train(
output_path: Optional[Path] = None, output_path: Optional[Path] = None,
*, *,
use_gpu: int = -1, use_gpu: int = -1,
stdout: IO=sys.stdout, stdout: IO = sys.stdout,
stderr: IO=sys.stderr stderr: IO = sys.stderr,
) -> None: ) -> None:
"""Train a pipeline. """Train a pipeline.
@ -34,7 +34,7 @@ def train(
printing, set to io.StringIO. printing, set to io.StringIO.
stderr (file): A second file-like object to write output messages. To disable stderr (file): A second file-like object to write output messages. To disable
printing, set to io.StringIO. printing, set to io.StringIO.
RETURNS (Path / None): The path to the final exported model. RETURNS (Path / None): The path to the final exported model.
""" """
# We use no_print here so we can respect the stdout/stderr options. # We use no_print here so we can respect the stdout/stderr options.

View File

@ -16,7 +16,7 @@ from .errors import Errors
from .attrs import intify_attrs, NORM, IS_STOP from .attrs import intify_attrs, NORM, IS_STOP
from .vectors import Vectors from .vectors import Vectors
from .util import registry from .util import registry
from .lookups import Lookups, load_lookups from .lookups import Lookups
from . import util from . import util
from .lang.norm_exceptions import BASE_NORMS from .lang.norm_exceptions import BASE_NORMS
from .lang.lex_attrs import LEX_ATTRS, is_stop, get_lang from .lang.lex_attrs import LEX_ATTRS, is_stop, get_lang