From b0b743597c078af3c4f67eebae26c5f40869e742 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 15 Jan 2021 11:57:36 +1100 Subject: [PATCH 1/2] Tidy up and auto-format --- spacy/cli/debug_data.py | 5 +- spacy/cli/info.py | 12 +++- spacy/lang/am/lex_attrs.py | 67 +++++++++--------- spacy/lang/am/tokenizer_exceptions.py | 3 +- spacy/lang/es/syntax_iterators.py | 2 +- spacy/lang/pt/stop_words.py | 4 +- spacy/lang/ti/lex_attrs.py | 69 ++++++++++--------- spacy/lang/ti/tokenizer_exceptions.py | 3 +- spacy/ml/models/textcat.py | 18 ++--- spacy/ml/models/tok2vec.py | 9 ++- spacy/pipeline/textcat.py | 1 - spacy/pipeline/textcat_multilabel.py | 2 +- spacy/tests/conftest.py | 4 ++ spacy/tests/doc/test_graph.py | 15 +--- spacy/tests/doc/test_token_api.py | 4 +- spacy/tests/lang/am/test_text.py | 7 +- spacy/tests/lang/en/test_exceptions.py | 4 +- spacy/tests/lang/ti/test_text.py | 7 +- spacy/tests/parser/test_ner.py | 2 +- spacy/tests/pipeline/test_textcat.py | 14 ++-- spacy/tests/regression/test_issue4001-4500.py | 7 +- spacy/tests/test_language.py | 4 +- spacy/tests/training/test_new_example.py | 21 ++---- 23 files changed, 140 insertions(+), 144 deletions(-) diff --git a/spacy/cli/debug_data.py b/spacy/cli/debug_data.py index 8eabf1f8f..c04647fde 100644 --- a/spacy/cli/debug_data.py +++ b/spacy/cli/debug_data.py @@ -384,7 +384,10 @@ def debug_data( # rare labels in projectivized train rare_projectivized_labels = [] for label in gold_train_data["deps"]: - if gold_train_data["deps"][label] <= DEP_LABEL_THRESHOLD and DELIMITER in label: + if ( + gold_train_data["deps"][label] <= DEP_LABEL_THRESHOLD + and DELIMITER in label + ): rare_projectivized_labels.append( f"{label}: {gold_train_data['deps'][label]}" ) diff --git a/spacy/cli/info.py b/spacy/cli/info.py index 19a380eb9..350e673ac 100644 --- a/spacy/cli/info.py +++ b/spacy/cli/info.py @@ -30,7 +30,11 @@ def info_cli( def info( - model: Optional[str] = None, *, markdown: bool = False, silent: bool = True, exclude: List[str] + model: Optional[str] = None, + *, + markdown: bool = False, + silent: bool = True, + exclude: List[str], ) -> Union[str, dict]: msg = Printer(no_print=silent, pretty=not silent) if model: @@ -98,7 +102,9 @@ def info_model(model: str, *, silent: bool = True) -> Dict[str, Any]: } -def get_markdown(data: Dict[str, Any], title: Optional[str] = None, exclude: List[str] = None) -> str: +def get_markdown( + data: Dict[str, Any], title: Optional[str] = None, exclude: List[str] = None +) -> str: """Get data in GitHub-flavoured Markdown format for issues etc. data (dict or list of tuples): Label/value pairs. @@ -115,7 +121,7 @@ def get_markdown(data: Dict[str, Any], title: Optional[str] = None, exclude: Lis if isinstance(value, str): try: existing_path = Path(value).exists() - except: + except Exception: # invalid Path, like a URL string existing_path = False if existing_path: diff --git a/spacy/lang/am/lex_attrs.py b/spacy/lang/am/lex_attrs.py index d53b227eb..9e111b8d5 100644 --- a/spacy/lang/am/lex_attrs.py +++ b/spacy/lang/am/lex_attrs.py @@ -36,43 +36,44 @@ _num_words = [ "ትሪሊዮን", "ኳድሪሊዮን", "ገጅሊዮን", - "ባዝሊዮን" + "ባዝሊዮን", ] _ordinal_words = [ "አንደኛ", - "ሁለተኛ", - "ሶስተኛ", - "አራተኛ", - "አምስተኛ", - "ስድስተኛ", - "ሰባተኛ", - "ስምንተኛ", - "ዘጠነኛ", - "አስረኛ", - "አስራ አንደኛ", - "አስራ ሁለተኛ", - "አስራ ሶስተኛ", - "አስራ አራተኛ", - "አስራ አምስተኛ", - "አስራ ስድስተኛ", - "አስራ ሰባተኛ", - "አስራ ስምንተኛ", - "አስራ ዘጠነኛ", - "ሃያኛ", - "ሰላሳኛ" - "አርባኛ", - "አምሳኛ", - "ስድሳኛ", - "ሰባኛ", - "ሰማንያኛ", - "ዘጠናኛ", - "መቶኛ", - "ሺኛ", - "ሚሊዮንኛ", - "ቢሊዮንኛ", - "ትሪሊዮንኛ" + "ሁለተኛ", + "ሶስተኛ", + "አራተኛ", + "አምስተኛ", + "ስድስተኛ", + "ሰባተኛ", + "ስምንተኛ", + "ዘጠነኛ", + "አስረኛ", + "አስራ አንደኛ", + "አስራ ሁለተኛ", + "አስራ ሶስተኛ", + "አስራ አራተኛ", + "አስራ አምስተኛ", + "አስራ ስድስተኛ", + "አስራ ሰባተኛ", + "አስራ ስምንተኛ", + "አስራ ዘጠነኛ", + "ሃያኛ", + "ሰላሳኛ" "አርባኛ", + "አምሳኛ", + "ስድሳኛ", + "ሰባኛ", + "ሰማንያኛ", + "ዘጠናኛ", + "መቶኛ", + "ሺኛ", + "ሚሊዮንኛ", + "ቢሊዮንኛ", + "ትሪሊዮንኛ", ] + + def like_num(text): if text.startswith(("+", "-", "±", "~")): text = text[1:] @@ -93,7 +94,7 @@ def like_num(text): return True if text_lower.endswith("ኛ"): if text_lower[:-2].isdigit(): - return True + return True return False diff --git a/spacy/lang/am/tokenizer_exceptions.py b/spacy/lang/am/tokenizer_exceptions.py index c5624ea39..9472fe918 100644 --- a/spacy/lang/am/tokenizer_exceptions.py +++ b/spacy/lang/am/tokenizer_exceptions.py @@ -5,9 +5,8 @@ _exc = {} for exc_data in [ - {ORTH: "ት/ቤት"}, + {ORTH: "ት/ቤት"}, {ORTH: "ወ/ሮ", NORM: "ወይዘሮ"}, - ]: _exc[exc_data[ORTH]] = [exc_data] diff --git a/spacy/lang/es/syntax_iterators.py b/spacy/lang/es/syntax_iterators.py index 4c49dcd44..e753a3f98 100644 --- a/spacy/lang/es/syntax_iterators.py +++ b/spacy/lang/es/syntax_iterators.py @@ -1,4 +1,4 @@ -from typing import Union, Iterator, Optional, List, Tuple +from typing import Union, Iterator from ...symbols import NOUN, PROPN, PRON, VERB, AUX from ...errors import Errors diff --git a/spacy/lang/pt/stop_words.py b/spacy/lang/pt/stop_words.py index 909d8bf97..ce3c86ff5 100644 --- a/spacy/lang/pt/stop_words.py +++ b/spacy/lang/pt/stop_words.py @@ -59,8 +59,8 @@ tudo tão têm um uma umas uns usa usar último vai vais valor veja vem vens ver vez vezes vinda vindo vinte você vocês vos vossa -vossas vosso vossos vários vão vêm vós +vossas vosso vossos vários vão vêm vós zero """.split() -) +) diff --git a/spacy/lang/ti/lex_attrs.py b/spacy/lang/ti/lex_attrs.py index 61a0b8516..ed094de3b 100644 --- a/spacy/lang/ti/lex_attrs.py +++ b/spacy/lang/ti/lex_attrs.py @@ -36,43 +36,44 @@ _num_words = [ "ትሪልዮን", "ኳድሪልዮን", "ገጅልዮን", - "ባዝልዮን" + "ባዝልዮን", ] _ordinal_words = [ - "ቀዳማይ", - "ካልኣይ", - "ሳልሳይ", - "ራብኣይ", - "ሓምሻይ", - "ሻድሻይ", - "ሻውዓይ", - "ሻምናይ", - "ዘጠነኛ", - "አስረኛ", - "ኣሰርተ አንደኛ", - "ኣሰርተ ሁለተኛ", - "ኣሰርተ ሶስተኛ", - "ኣሰርተ አራተኛ", - "ኣሰርተ አምስተኛ", - "ኣሰርተ ስድስተኛ", - "ኣሰርተ ሰባተኛ", - "ኣሰርተ ስምንተኛ", - "ኣሰርተ ዘጠነኛ", - "ሃያኛ", - "ሰላሳኛ" - "አርባኛ", - "አምሳኛ", - "ስድሳኛ", - "ሰባኛ", - "ሰማንያኛ", - "ዘጠናኛ", - "መቶኛ", - "ሺኛ", - "ሚሊዮንኛ", - "ቢሊዮንኛ", - "ትሪሊዮንኛ" + "ቀዳማይ", + "ካልኣይ", + "ሳልሳይ", + "ራብኣይ", + "ሓምሻይ", + "ሻድሻይ", + "ሻውዓይ", + "ሻምናይ", + "ዘጠነኛ", + "አስረኛ", + "ኣሰርተ አንደኛ", + "ኣሰርተ ሁለተኛ", + "ኣሰርተ ሶስተኛ", + "ኣሰርተ አራተኛ", + "ኣሰርተ አምስተኛ", + "ኣሰርተ ስድስተኛ", + "ኣሰርተ ሰባተኛ", + "ኣሰርተ ስምንተኛ", + "ኣሰርተ ዘጠነኛ", + "ሃያኛ", + "ሰላሳኛ" "አርባኛ", + "አምሳኛ", + "ስድሳኛ", + "ሰባኛ", + "ሰማንያኛ", + "ዘጠናኛ", + "መቶኛ", + "ሺኛ", + "ሚሊዮንኛ", + "ቢሊዮንኛ", + "ትሪሊዮንኛ", ] + + def like_num(text): if text.startswith(("+", "-", "±", "~")): text = text[1:] @@ -93,7 +94,7 @@ def like_num(text): return True if text_lower.endswith("ኛ"): if text_lower[:-2].isdigit(): - return True + return True return False diff --git a/spacy/lang/ti/tokenizer_exceptions.py b/spacy/lang/ti/tokenizer_exceptions.py index 57006f605..3d79cd84b 100644 --- a/spacy/lang/ti/tokenizer_exceptions.py +++ b/spacy/lang/ti/tokenizer_exceptions.py @@ -5,10 +5,9 @@ _exc = {} for exc_data in [ - {ORTH: "ት/ቤት"}, + {ORTH: "ት/ቤት"}, {ORTH: "ወ/ሮ", NORM: "ወይዘሮ"}, {ORTH: "ወ/ሪ", NORM: "ወይዘሪት"}, - ]: _exc[exc_data[ORTH]] = [exc_data] diff --git a/spacy/ml/models/textcat.py b/spacy/ml/models/textcat.py index 8c7316f62..000ca5066 100644 --- a/spacy/ml/models/textcat.py +++ b/spacy/ml/models/textcat.py @@ -71,17 +71,19 @@ def build_text_classifier_v2( exclusive_classes = not linear_model.attrs["multi_label"] with Model.define_operators({">>": chain, "|": concatenate}): width = tok2vec.maybe_get_dim("nO") - attention_layer = ParametricAttention(width) # TODO: benchmark performance difference of this layer + attention_layer = ParametricAttention( + width + ) # TODO: benchmark performance difference of this layer maxout_layer = Maxout(nO=width, nI=width) linear_layer = Linear(nO=nO, nI=width) cnn_model = ( - tok2vec - >> list2ragged() - >> attention_layer - >> reduce_sum() - >> residual(maxout_layer) - >> linear_layer - >> Dropout(0.0) + tok2vec + >> list2ragged() + >> attention_layer + >> reduce_sum() + >> residual(maxout_layer) + >> linear_layer + >> Dropout(0.0) ) nO_double = nO * 2 if nO else None diff --git a/spacy/ml/models/tok2vec.py b/spacy/ml/models/tok2vec.py index 2bb420004..dd4b6deee 100644 --- a/spacy/ml/models/tok2vec.py +++ b/spacy/ml/models/tok2vec.py @@ -89,7 +89,7 @@ def build_hash_embed_cnn_tok2vec( # TODO: archive @registry.architectures.register("spacy.Tok2Vec.v1") -def build_Tok2Vec_model( +def _build_Tok2Vec_model( embed: Model[List[Doc], List[Floats2d]], encode: Model[List[Floats2d], List[Floats2d]], ) -> Model[List[Doc], List[Floats2d]]: @@ -109,7 +109,6 @@ def build_Tok2Vec_model( return tok2vec - @registry.architectures.register("spacy.Tok2Vec.v2") def build_Tok2Vec_model( embed: Model[List[Doc], List[Floats2d]], @@ -130,7 +129,6 @@ def build_Tok2Vec_model( return tok2vec - @registry.architectures.register("spacy.MultiHashEmbed.v1") def MultiHashEmbed( width: int, @@ -280,7 +278,7 @@ def CharacterEmbed( # TODO: archive @registry.architectures.register("spacy.MaxoutWindowEncoder.v1") -def MaxoutWindowEncoder( +def _MaxoutWindowEncoder( width: int, window_size: int, maxout_pieces: int, depth: int ) -> Model[List[Floats2d], List[Floats2d]]: """Encode context using convolutions with maxout activation, layer @@ -310,6 +308,7 @@ def MaxoutWindowEncoder( model.attrs["receptive_field"] = window_size * depth return model + @registry.architectures.register("spacy.MaxoutWindowEncoder.v2") def MaxoutWindowEncoder( width: int, window_size: int, maxout_pieces: int, depth: int @@ -344,7 +343,7 @@ def MaxoutWindowEncoder( # TODO: archive @registry.architectures.register("spacy.MishWindowEncoder.v1") -def MishWindowEncoder( +def _MishWindowEncoder( width: int, window_size: int, depth: int ) -> Model[List[Floats2d], List[Floats2d]]: """Encode context using convolutions with mish activation, layer diff --git a/spacy/pipeline/textcat.py b/spacy/pipeline/textcat.py index 456a8bb38..c09533319 100644 --- a/spacy/pipeline/textcat.py +++ b/spacy/pipeline/textcat.py @@ -388,7 +388,6 @@ class TextCategorizer(TrainablePipe): **kwargs, ) - def _validate_categories(self, examples: List[Example]): """Check whether the provided examples all have single-label cats annotations.""" for ex in examples: diff --git a/spacy/pipeline/textcat_multilabel.py b/spacy/pipeline/textcat_multilabel.py index fc8920414..41c5a1335 100644 --- a/spacy/pipeline/textcat_multilabel.py +++ b/spacy/pipeline/textcat_multilabel.py @@ -187,5 +187,5 @@ class MultiLabel_TextCategorizer(TextCategorizer): def _validate_categories(self, examples: List[Example]): """This component allows any type of single- or multi-label annotations. - This method overwrites the more strict one from 'textcat'. """ + This method overwrites the more strict one from 'textcat'.""" pass diff --git a/spacy/tests/conftest.py b/spacy/tests/conftest.py index 0d32c22e2..bd323e2d5 100644 --- a/spacy/tests/conftest.py +++ b/spacy/tests/conftest.py @@ -28,10 +28,12 @@ def pytest_runtest_setup(item): def tokenizer(): return get_lang_class("xx")().tokenizer + @pytest.fixture(scope="session") def am_tokenizer(): return get_lang_class("am")().tokenizer + @pytest.fixture(scope="session") def ar_tokenizer(): return get_lang_class("ar")().tokenizer @@ -247,10 +249,12 @@ def th_tokenizer(): pytest.importorskip("pythainlp") return get_lang_class("th")().tokenizer + @pytest.fixture(scope="session") def ti_tokenizer(): return get_lang_class("ti")().tokenizer + @pytest.fixture(scope="session") def tr_tokenizer(): return get_lang_class("tr")().tokenizer diff --git a/spacy/tests/doc/test_graph.py b/spacy/tests/doc/test_graph.py index d5e2c05d1..e464b0058 100644 --- a/spacy/tests/doc/test_graph.py +++ b/spacy/tests/doc/test_graph.py @@ -17,17 +17,8 @@ def test_graph_edges_and_nodes(): assert graph.get_node((0,)) == node1 node2 = graph.add_node((1, 3)) assert list(node2) == [1, 3] - graph.add_edge( - node1, - node2, - label="one", - weight=-10.5 - ) - assert graph.has_edge( - node1, - node2, - label="one" - ) + graph.add_edge(node1, node2, label="one", weight=-10.5) + assert graph.has_edge(node1, node2, label="one") assert node1.heads() == [] assert [tuple(h) for h in node2.heads()] == [(0,)] assert [tuple(t) for t in node1.tails()] == [(1, 3)] @@ -42,7 +33,7 @@ def test_graph_walk(): nodes=[(0,), (1,), (2,), (3,)], edges=[(0, 1), (0, 2), (0, 3), (3, 0)], labels=None, - weights=None + weights=None, ) node0, node1, node2, node3 = list(graph.nodes) assert [tuple(h) for h in node0.heads()] == [(3,)] diff --git a/spacy/tests/doc/test_token_api.py b/spacy/tests/doc/test_token_api.py index dda28809d..1e13882c5 100644 --- a/spacy/tests/doc/test_token_api.py +++ b/spacy/tests/doc/test_token_api.py @@ -255,8 +255,8 @@ def test_token_api_non_conjuncts(en_vocab): def test_missing_head_dep(en_vocab): """ Check that the Doc constructor and Example.from_dict parse missing information the same""" - heads = [1, 1, 1, 1, 2, None] # element 5 is missing - deps = ["", "ROOT", "dobj", "cc", "conj", None] # element 0 and 5 are missing + heads = [1, 1, 1, 1, 2, None] # element 5 is missing + deps = ["", "ROOT", "dobj", "cc", "conj", None] # element 0 and 5 are missing words = ["I", "like", "London", "and", "Berlin", "."] doc = Doc(en_vocab, words=words, heads=heads, deps=deps) pred_has_heads = [t.has_head() for t in doc] diff --git a/spacy/tests/lang/am/test_text.py b/spacy/tests/lang/am/test_text.py index e7529363e..407f19e46 100644 --- a/spacy/tests/lang/am/test_text.py +++ b/spacy/tests/lang/am/test_text.py @@ -1,17 +1,16 @@ import pytest -from spacy.lang.am.lex_attrs import like_num def test_am_tokenizer_handles_long_text(am_tokenizer): text = """ሆሴ ሙጂካ በበጋ ወቅት በኦክስፎርድ ንግግር አንድያቀርቡ ሲጋበዙ ጭንቅላታቸው "ፈነዳ"። -“እጅግ ጥንታዊ” የእንግሊዝኛ ተናጋሪ ዩኒቨርስቲ፣ በአስር ሺዎች የሚቆጠሩ ዩሮዎችን ለተማሪዎች በማስተማር የሚያስከፍለው +“እጅግ ጥንታዊ” የእንግሊዝኛ ተናጋሪ ዩኒቨርስቲ፣ በአስር ሺዎች የሚቆጠሩ ዩሮዎችን ለተማሪዎች በማስተማር የሚያስከፍለው -እና ከማርጋሬት ታቸር እስከ ስቲቨን ሆኪንግ በአዳራሾቻቸው ውስጥ ንግግር ያደረጉበት የትምህርት ማዕከል፣ በሞንቴቪዴኦ +እና ከማርጋሬት ታቸር እስከ ስቲቨን ሆኪንግ በአዳራሾቻቸው ውስጥ ንግግር ያደረጉበት የትምህርት ማዕከል፣ በሞንቴቪዴኦ በሚገኘው የመንግስት ትምህርት ቤት የሰለጠኑትን የ81 ዓመቱ አዛውንት አገልግሎት ጠየቁ።""" tokens = am_tokenizer(text) - + assert len(tokens) == 56 diff --git a/spacy/tests/lang/en/test_exceptions.py b/spacy/tests/lang/en/test_exceptions.py index f5345cbe2..02ecaed6e 100644 --- a/spacy/tests/lang/en/test_exceptions.py +++ b/spacy/tests/lang/en/test_exceptions.py @@ -121,9 +121,7 @@ def test_en_tokenizer_norm_exceptions(en_tokenizer, text, norms): assert [token.norm_ for token in tokens] == norms -@pytest.mark.parametrize( - "text,norm", [("Jan.", "January"), ("'cuz", "because")] -) +@pytest.mark.parametrize("text,norm", [("Jan.", "January"), ("'cuz", "because")]) def test_en_lex_attrs_norm_exceptions(en_tokenizer, text, norm): tokens = en_tokenizer(text) assert tokens[0].norm_ == norm diff --git a/spacy/tests/lang/ti/test_text.py b/spacy/tests/lang/ti/test_text.py index eb1b515eb..177a9e4b2 100644 --- a/spacy/tests/lang/ti/test_text.py +++ b/spacy/tests/lang/ti/test_text.py @@ -1,17 +1,16 @@ import pytest -from spacy.lang.ti.lex_attrs import like_num def test_ti_tokenizer_handles_long_text(ti_tokenizer): text = """ቻንስለር ጀርመን ኣንገላ መርከል ኣብታ ሃገር ቁጽሪ መትሓዝቲ ኮቪድ መዓልታዊ ክብረ መዝገብ ድሕሪ ምህራሙ- ጽኑዕ እገዳ ክግበር ጸዊዓ። -መርከል ሎሚ ንታሕታዋይ ባይቶ ሃገራ ክትገልጽ ከላ፡ ኣብ ወሳኒ ምዕራፍ ቃልሲ ኢና ዘለና-ዳሕራዋይ ማዕበል ካብቲ ቀዳማይ ክገድድ ይኽእል`ዩ ኢላ። +መርከል ሎሚ ንታሕታዋይ ባይቶ ሃገራ ክትገልጽ ከላ፡ ኣብ ወሳኒ ምዕራፍ ቃልሲ ኢና ዘለና-ዳሕራዋይ ማዕበል ካብቲ ቀዳማይ ክገድድ ይኽእል`ዩ ኢላ። -ትካል ምክልኻል ተላገብቲ ሕማማት ጀርመን፡ ኣብ ዝሓለፈ 24 ሰዓታት ኣብ ምልእቲ ጀርመር 590 ሰባት ብኮቪድ19 ምሟቶም ኣፍሊጡ`ሎ። +ትካል ምክልኻል ተላገብቲ ሕማማት ጀርመን፡ ኣብ ዝሓለፈ 24 ሰዓታት ኣብ ምልእቲ ጀርመር 590 ሰባት ብኮቪድ19 ምሟቶም ኣፍሊጡ`ሎ። ቻንስለር ኣንጀላ መርከል ኣብ እዋን በዓላት ልደት ስድራቤታት ክተኣኻኸባ ዝፍቀደለን`ኳ እንተኾነ ድሕሪኡ ኣብ ዘሎ ግዜ ግን እቲ እገዳታት ክትግበር ትደሊ።""" tokens = ti_tokenizer(text) - + assert len(tokens) == 85 diff --git a/spacy/tests/parser/test_ner.py b/spacy/tests/parser/test_ner.py index d110eb11c..dffdff1ec 100644 --- a/spacy/tests/parser/test_ner.py +++ b/spacy/tests/parser/test_ner.py @@ -389,7 +389,7 @@ def test_beam_ner_scores(): for j in range(len(doc)): for label in ner.labels: - score = entity_scores[(j, j+1, label)] + score = entity_scores[(j, j + 1, label)] eps = 0.00001 assert 0 - eps <= score <= 1 + eps diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index 04a3eb27d..fd1be53ee 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -146,12 +146,12 @@ def test_no_resize(name): def test_error_with_multi_labels(): nlp = Language() - textcat = nlp.add_pipe("textcat") + nlp.add_pipe("textcat") train_examples = [] for text, annotations in TRAIN_DATA_MULTI_LABEL: train_examples.append(Example.from_dict(nlp.make_doc(text), annotations)) with pytest.raises(ValueError): - optimizer = nlp.initialize(get_examples=lambda: train_examples) + nlp.initialize(get_examples=lambda: train_examples) @pytest.mark.parametrize( @@ -273,7 +273,7 @@ def test_overfitting_IO_multi(): assert_equal(batch_cats_1, no_batch_cats) -def test_overfitting_IO_multi(): +def test_overfitting_IO_multi_2(): # Simple test to try and quickly overfit the multi-label textcat component - ensuring the ML models work correctly fix_random_seed(0) nlp = English() @@ -362,7 +362,9 @@ def test_positive_class(): textcat_multilabel = nlp.add_pipe("textcat_multilabel") get_examples = make_get_examples_multi_label(nlp) with pytest.raises(TypeError): - textcat_multilabel.initialize(get_examples, labels=["POS", "NEG"], positive_label="POS") + textcat_multilabel.initialize( + get_examples, labels=["POS", "NEG"], positive_label="POS" + ) textcat_multilabel.initialize(get_examples, labels=["FICTION", "DRAMA"]) assert textcat_multilabel.labels == ("FICTION", "DRAMA") assert "positive_label" not in textcat_multilabel.cfg @@ -381,7 +383,9 @@ def test_positive_class_not_binary(): textcat = nlp.add_pipe("textcat") get_examples = make_get_examples_multi_label(nlp) with pytest.raises(ValueError): - textcat.initialize(get_examples, labels=["SOME", "THING", "POS"], positive_label="POS") + textcat.initialize( + get_examples, labels=["SOME", "THING", "POS"], positive_label="POS" + ) def test_textcat_evaluation(): diff --git a/spacy/tests/regression/test_issue4001-4500.py b/spacy/tests/regression/test_issue4001-4500.py index 521fa0d73..25982623f 100644 --- a/spacy/tests/regression/test_issue4001-4500.py +++ b/spacy/tests/regression/test_issue4001-4500.py @@ -13,7 +13,6 @@ from spacy.lang.el import Greek from spacy.language import Language import spacy from thinc.api import compounding -from collections import defaultdict from ..util import make_tempdir @@ -304,16 +303,14 @@ def test_issue4313(): doc = nlp("What do you think about Apple ?") assert len(ner.labels) == 1 assert "SOME_LABEL" in ner.labels - ner.add_label("MY_ORG") # TODO: not sure if we want this to be necessary... + ner.add_label("MY_ORG") # TODO: not sure if we want this to be necessary... apple_ent = Span(doc, 5, 6, label="MY_ORG") doc.ents = list(doc.ents) + [apple_ent] # ensure the beam_parse still works with the new label docs = [doc] ner = nlp.get_pipe("beam_ner") - beams = ner.beam_parse( - docs, drop=0.0, beam_width=beam_width, beam_density=beam_density - ) + ner.beam_parse(docs, drop=0.0, beam_width=beam_width, beam_density=beam_density) def test_issue4348(): diff --git a/spacy/tests/test_language.py b/spacy/tests/test_language.py index 6ffeeadce..d6efce32f 100644 --- a/spacy/tests/test_language.py +++ b/spacy/tests/test_language.py @@ -251,7 +251,9 @@ def test_language_from_config_before_after_init(): nlp.initialize() assert nlp.meta["before_init"] == "before" assert nlp.meta["after_init"] == "after" - assert all([ran_before, ran_after, ran_after_pipeline, ran_before_init, ran_after_init]) + assert all( + [ran_before, ran_after, ran_after_pipeline, ran_before_init, ran_after_init] + ) def test_language_from_config_before_after_init_invalid(): diff --git a/spacy/tests/training/test_new_example.py b/spacy/tests/training/test_new_example.py index 0a3184071..be3419b82 100644 --- a/spacy/tests/training/test_new_example.py +++ b/spacy/tests/training/test_new_example.py @@ -166,19 +166,11 @@ def test_Example_from_dict_with_entities(annots): vocab = Vocab() predicted = Doc(vocab, words=annots["words"]) example = Example.from_dict(predicted, annots) - assert len(list(example.reference.ents)) == 2 - assert [example.reference[i].ent_iob_ for i in range(7)] == [ - "O", - "O", - "B", - "I", - "O", - "B", - "O", - ] + # fmt: off + assert [example.reference[i].ent_iob_ for i in range(7)] == ["O", "O", "B", "I", "O", "B", "O"] assert example.get_aligned("ENT_IOB") == [2, 2, 3, 1, 2, 3, 2] - + # fmt: on assert example.reference[2].ent_type_ == "LOC" assert example.reference[3].ent_type_ == "LOC" assert example.reference[5].ent_type_ == "LOC" @@ -299,7 +291,8 @@ def test_Example_missing_heads(): assert parsed_heads[2] == heads[2] assert parsed_heads[4] == heads[4] assert parsed_heads[5] == heads[5] - assert [t.has_head() for t in example.reference] == [True, True, True, False, True, True] - + expected = [True, True, True, False, True, True] + assert [t.has_head() for t in example.reference] == expected # Ensure that the missing head doesn't create an artificial new sentence start - assert example.get_aligned_sent_starts() == [True, False, False, False, False, False] + expected = [True, False, False, False, False, False] + assert example.get_aligned_sent_starts() == expected From f9e4ac12836dbbe7a702d93db0682dd4cf193d42 Mon Sep 17 00:00:00 2001 From: Ines Montani Date: Fri, 15 Jan 2021 12:51:02 +1100 Subject: [PATCH 2/2] Fix test --- spacy/tests/pipeline/test_textcat.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/spacy/tests/pipeline/test_textcat.py b/spacy/tests/pipeline/test_textcat.py index fd1be53ee..f41ee4bd2 100644 --- a/spacy/tests/pipeline/test_textcat.py +++ b/spacy/tests/pipeline/test_textcat.py @@ -226,7 +226,8 @@ def test_overfitting_IO(): assert_equal(batch_cats_1, no_batch_cats) -def test_overfitting_IO_multi(): +@pytest.mark.skip(reason="TODO: Can this be removed?") +def test_overfitting_IO_multi_old(): # Simple test to try and quickly overfit the multi-label textcat component - ensuring the ML models work correctly fix_random_seed(0) nlp = English() @@ -273,7 +274,7 @@ def test_overfitting_IO_multi(): assert_equal(batch_cats_1, no_batch_cats) -def test_overfitting_IO_multi_2(): +def test_overfitting_IO_multi(): # Simple test to try and quickly overfit the multi-label textcat component - ensuring the ML models work correctly fix_random_seed(0) nlp = English()