Tidy up and auto-format

This commit is contained in:
Ines Montani 2021-01-15 11:57:36 +11:00
parent e8a97a2bd6
commit b0b743597c
23 changed files with 140 additions and 144 deletions

View File

@ -384,7 +384,10 @@ def debug_data(
# rare labels in projectivized train # rare labels in projectivized train
rare_projectivized_labels = [] rare_projectivized_labels = []
for label in gold_train_data["deps"]: for label in gold_train_data["deps"]:
if gold_train_data["deps"][label] <= DEP_LABEL_THRESHOLD and DELIMITER in label: if (
gold_train_data["deps"][label] <= DEP_LABEL_THRESHOLD
and DELIMITER in label
):
rare_projectivized_labels.append( rare_projectivized_labels.append(
f"{label}: {gold_train_data['deps'][label]}" f"{label}: {gold_train_data['deps'][label]}"
) )

View File

@ -30,7 +30,11 @@ def info_cli(
def info( def info(
model: Optional[str] = None, *, markdown: bool = False, silent: bool = True, exclude: List[str] model: Optional[str] = None,
*,
markdown: bool = False,
silent: bool = True,
exclude: List[str],
) -> Union[str, dict]: ) -> Union[str, dict]:
msg = Printer(no_print=silent, pretty=not silent) msg = Printer(no_print=silent, pretty=not silent)
if model: if model:
@ -98,7 +102,9 @@ def info_model(model: str, *, silent: bool = True) -> Dict[str, Any]:
} }
def get_markdown(data: Dict[str, Any], title: Optional[str] = None, exclude: List[str] = None) -> str: def get_markdown(
data: Dict[str, Any], title: Optional[str] = None, exclude: List[str] = None
) -> str:
"""Get data in GitHub-flavoured Markdown format for issues etc. """Get data in GitHub-flavoured Markdown format for issues etc.
data (dict or list of tuples): Label/value pairs. data (dict or list of tuples): Label/value pairs.
@ -115,7 +121,7 @@ def get_markdown(data: Dict[str, Any], title: Optional[str] = None, exclude: Lis
if isinstance(value, str): if isinstance(value, str):
try: try:
existing_path = Path(value).exists() existing_path = Path(value).exists()
except: except Exception:
# invalid Path, like a URL string # invalid Path, like a URL string
existing_path = False existing_path = False
if existing_path: if existing_path:

View File

@ -36,43 +36,44 @@ _num_words = [
"ትሪሊዮን", "ትሪሊዮን",
"ኳድሪሊዮን", "ኳድሪሊዮን",
"ገጅሊዮን", "ገጅሊዮን",
"ባዝሊዮን" "ባዝሊዮን",
] ]
_ordinal_words = [ _ordinal_words = [
"አንደኛ", "አንደኛ",
"ሁለተኛ", "ሁለተኛ",
"ሶስተኛ", "ሶስተኛ",
"አራተኛ", "አራተኛ",
"አምስተኛ", "አምስተኛ",
"ስድስተኛ", "ስድስተኛ",
"ሰባተኛ", "ሰባተኛ",
"ስምንተኛ", "ስምንተኛ",
"ዘጠነኛ", "ዘጠነኛ",
"አስረኛ", "አስረኛ",
"አስራ አንደኛ", "አስራ አንደኛ",
"አስራ ሁለተኛ", "አስራ ሁለተኛ",
"አስራ ሶስተኛ", "አስራ ሶስተኛ",
"አስራ አራተኛ", "አስራ አራተኛ",
"አስራ አምስተኛ", "አስራ አምስተኛ",
"አስራ ስድስተኛ", "አስራ ስድስተኛ",
"አስራ ሰባተኛ", "አስራ ሰባተኛ",
"አስራ ስምንተኛ", "አስራ ስምንተኛ",
"አስራ ዘጠነኛ", "አስራ ዘጠነኛ",
"ሃያኛ", "ሃያኛ",
"ሰላሳኛ" "ሰላሳኛ" "አርባኛ",
"አርባኛ", "አምሳኛ",
"አምሳኛ", "ስድሳኛ",
"ስድሳኛ", "ሰባኛ",
"ሰባኛ", "ሰማንያኛ",
"ሰማንያኛ", "ዘጠናኛ",
"ዘጠናኛ", "መቶኛ",
"መቶኛ", "ሺኛ",
"ሺኛ", "ሚሊዮንኛ",
"ሚሊዮንኛ", "ቢሊዮንኛ",
"ቢሊዮንኛ", "ትሪሊዮንኛ",
"ትሪሊዮንኛ"
] ]
def like_num(text): def like_num(text):
if text.startswith(("+", "-", "±", "~")): if text.startswith(("+", "-", "±", "~")):
text = text[1:] text = text[1:]
@ -93,7 +94,7 @@ def like_num(text):
return True return True
if text_lower.endswith(""): if text_lower.endswith(""):
if text_lower[:-2].isdigit(): if text_lower[:-2].isdigit():
return True return True
return False return False

View File

@ -5,9 +5,8 @@ _exc = {}
for exc_data in [ for exc_data in [
{ORTH: "ት/ቤት"}, {ORTH: "ት/ቤት"},
{ORTH: "ወ/ሮ", NORM: "ወይዘሮ"}, {ORTH: "ወ/ሮ", NORM: "ወይዘሮ"},
]: ]:
_exc[exc_data[ORTH]] = [exc_data] _exc[exc_data[ORTH]] = [exc_data]

View File

@ -1,4 +1,4 @@
from typing import Union, Iterator, Optional, List, Tuple from typing import Union, Iterator
from ...symbols import NOUN, PROPN, PRON, VERB, AUX from ...symbols import NOUN, PROPN, PRON, VERB, AUX
from ...errors import Errors from ...errors import Errors

View File

@ -59,8 +59,8 @@ tudo tão têm
um uma umas uns usa usar último um uma umas uns usa usar último
vai vais valor veja vem vens ver vez vezes vinda vindo vinte você vocês vos vossa vai vais valor veja vem vens ver vez vezes vinda vindo vinte você vocês vos vossa
vossas vosso vossos vários vão vêm vós vossas vosso vossos vários vão vêm vós
zero zero
""".split() """.split()
) )

View File

@ -36,43 +36,44 @@ _num_words = [
"ትሪልዮን", "ትሪልዮን",
"ኳድሪልዮን", "ኳድሪልዮን",
"ገጅልዮን", "ገጅልዮን",
"ባዝልዮን" "ባዝልዮን",
] ]
_ordinal_words = [ _ordinal_words = [
"ቀዳማይ", "ቀዳማይ",
"ካልኣይ", "ካልኣይ",
"ሳልሳይ", "ሳልሳይ",
"ራብኣይ", "ራብኣይ",
"ሓምሻይ", "ሓምሻይ",
"ሻድሻይ", "ሻድሻይ",
"ሻውዓይ", "ሻውዓይ",
"ሻምናይ", "ሻምናይ",
"ዘጠነኛ", "ዘጠነኛ",
"አስረኛ", "አስረኛ",
"ኣሰርተ አንደኛ", "ኣሰርተ አንደኛ",
"ኣሰርተ ሁለተኛ", "ኣሰርተ ሁለተኛ",
"ኣሰርተ ሶስተኛ", "ኣሰርተ ሶስተኛ",
"ኣሰርተ አራተኛ", "ኣሰርተ አራተኛ",
"ኣሰርተ አምስተኛ", "ኣሰርተ አምስተኛ",
"ኣሰርተ ስድስተኛ", "ኣሰርተ ስድስተኛ",
"ኣሰርተ ሰባተኛ", "ኣሰርተ ሰባተኛ",
"ኣሰርተ ስምንተኛ", "ኣሰርተ ስምንተኛ",
"ኣሰርተ ዘጠነኛ", "ኣሰርተ ዘጠነኛ",
"ሃያኛ", "ሃያኛ",
"ሰላሳኛ" "ሰላሳኛ" "አርባኛ",
"አርባኛ", "አምሳኛ",
"አምሳኛ", "ስድሳኛ",
"ስድሳኛ", "ሰባኛ",
"ሰባኛ", "ሰማንያኛ",
"ሰማንያኛ", "ዘጠናኛ",
"ዘጠናኛ", "መቶኛ",
"መቶኛ", "ሺኛ",
"ሺኛ", "ሚሊዮንኛ",
"ሚሊዮንኛ", "ቢሊዮንኛ",
"ቢሊዮንኛ", "ትሪሊዮንኛ",
"ትሪሊዮንኛ"
] ]
def like_num(text): def like_num(text):
if text.startswith(("+", "-", "±", "~")): if text.startswith(("+", "-", "±", "~")):
text = text[1:] text = text[1:]
@ -93,7 +94,7 @@ def like_num(text):
return True return True
if text_lower.endswith(""): if text_lower.endswith(""):
if text_lower[:-2].isdigit(): if text_lower[:-2].isdigit():
return True return True
return False return False

View File

@ -5,10 +5,9 @@ _exc = {}
for exc_data in [ for exc_data in [
{ORTH: "ት/ቤት"}, {ORTH: "ት/ቤት"},
{ORTH: "ወ/ሮ", NORM: "ወይዘሮ"}, {ORTH: "ወ/ሮ", NORM: "ወይዘሮ"},
{ORTH: "ወ/ሪ", NORM: "ወይዘሪት"}, {ORTH: "ወ/ሪ", NORM: "ወይዘሪት"},
]: ]:
_exc[exc_data[ORTH]] = [exc_data] _exc[exc_data[ORTH]] = [exc_data]

View File

@ -71,17 +71,19 @@ def build_text_classifier_v2(
exclusive_classes = not linear_model.attrs["multi_label"] exclusive_classes = not linear_model.attrs["multi_label"]
with Model.define_operators({">>": chain, "|": concatenate}): with Model.define_operators({">>": chain, "|": concatenate}):
width = tok2vec.maybe_get_dim("nO") width = tok2vec.maybe_get_dim("nO")
attention_layer = ParametricAttention(width) # TODO: benchmark performance difference of this layer attention_layer = ParametricAttention(
width
) # TODO: benchmark performance difference of this layer
maxout_layer = Maxout(nO=width, nI=width) maxout_layer = Maxout(nO=width, nI=width)
linear_layer = Linear(nO=nO, nI=width) linear_layer = Linear(nO=nO, nI=width)
cnn_model = ( cnn_model = (
tok2vec tok2vec
>> list2ragged() >> list2ragged()
>> attention_layer >> attention_layer
>> reduce_sum() >> reduce_sum()
>> residual(maxout_layer) >> residual(maxout_layer)
>> linear_layer >> linear_layer
>> Dropout(0.0) >> Dropout(0.0)
) )
nO_double = nO * 2 if nO else None nO_double = nO * 2 if nO else None

View File

@ -89,7 +89,7 @@ def build_hash_embed_cnn_tok2vec(
# TODO: archive # TODO: archive
@registry.architectures.register("spacy.Tok2Vec.v1") @registry.architectures.register("spacy.Tok2Vec.v1")
def build_Tok2Vec_model( def _build_Tok2Vec_model(
embed: Model[List[Doc], List[Floats2d]], embed: Model[List[Doc], List[Floats2d]],
encode: Model[List[Floats2d], List[Floats2d]], encode: Model[List[Floats2d], List[Floats2d]],
) -> Model[List[Doc], List[Floats2d]]: ) -> Model[List[Doc], List[Floats2d]]:
@ -109,7 +109,6 @@ def build_Tok2Vec_model(
return tok2vec return tok2vec
@registry.architectures.register("spacy.Tok2Vec.v2") @registry.architectures.register("spacy.Tok2Vec.v2")
def build_Tok2Vec_model( def build_Tok2Vec_model(
embed: Model[List[Doc], List[Floats2d]], embed: Model[List[Doc], List[Floats2d]],
@ -130,7 +129,6 @@ def build_Tok2Vec_model(
return tok2vec return tok2vec
@registry.architectures.register("spacy.MultiHashEmbed.v1") @registry.architectures.register("spacy.MultiHashEmbed.v1")
def MultiHashEmbed( def MultiHashEmbed(
width: int, width: int,
@ -280,7 +278,7 @@ def CharacterEmbed(
# TODO: archive # TODO: archive
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1") @registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
def MaxoutWindowEncoder( def _MaxoutWindowEncoder(
width: int, window_size: int, maxout_pieces: int, depth: int width: int, window_size: int, maxout_pieces: int, depth: int
) -> Model[List[Floats2d], List[Floats2d]]: ) -> Model[List[Floats2d], List[Floats2d]]:
"""Encode context using convolutions with maxout activation, layer """Encode context using convolutions with maxout activation, layer
@ -310,6 +308,7 @@ def MaxoutWindowEncoder(
model.attrs["receptive_field"] = window_size * depth model.attrs["receptive_field"] = window_size * depth
return model return model
@registry.architectures.register("spacy.MaxoutWindowEncoder.v2") @registry.architectures.register("spacy.MaxoutWindowEncoder.v2")
def MaxoutWindowEncoder( def MaxoutWindowEncoder(
width: int, window_size: int, maxout_pieces: int, depth: int width: int, window_size: int, maxout_pieces: int, depth: int
@ -344,7 +343,7 @@ def MaxoutWindowEncoder(
# TODO: archive # TODO: archive
@registry.architectures.register("spacy.MishWindowEncoder.v1") @registry.architectures.register("spacy.MishWindowEncoder.v1")
def MishWindowEncoder( def _MishWindowEncoder(
width: int, window_size: int, depth: int width: int, window_size: int, depth: int
) -> Model[List[Floats2d], List[Floats2d]]: ) -> Model[List[Floats2d], List[Floats2d]]:
"""Encode context using convolutions with mish activation, layer """Encode context using convolutions with mish activation, layer

View File

@ -388,7 +388,6 @@ class TextCategorizer(TrainablePipe):
**kwargs, **kwargs,
) )
def _validate_categories(self, examples: List[Example]): def _validate_categories(self, examples: List[Example]):
"""Check whether the provided examples all have single-label cats annotations.""" """Check whether the provided examples all have single-label cats annotations."""
for ex in examples: for ex in examples:

View File

@ -187,5 +187,5 @@ class MultiLabel_TextCategorizer(TextCategorizer):
def _validate_categories(self, examples: List[Example]): def _validate_categories(self, examples: List[Example]):
"""This component allows any type of single- or multi-label annotations. """This component allows any type of single- or multi-label annotations.
This method overwrites the more strict one from 'textcat'. """ This method overwrites the more strict one from 'textcat'."""
pass pass

View File

@ -28,10 +28,12 @@ def pytest_runtest_setup(item):
def tokenizer(): def tokenizer():
return get_lang_class("xx")().tokenizer return get_lang_class("xx")().tokenizer
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def am_tokenizer(): def am_tokenizer():
return get_lang_class("am")().tokenizer return get_lang_class("am")().tokenizer
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def ar_tokenizer(): def ar_tokenizer():
return get_lang_class("ar")().tokenizer return get_lang_class("ar")().tokenizer
@ -247,10 +249,12 @@ def th_tokenizer():
pytest.importorskip("pythainlp") pytest.importorskip("pythainlp")
return get_lang_class("th")().tokenizer return get_lang_class("th")().tokenizer
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def ti_tokenizer(): def ti_tokenizer():
return get_lang_class("ti")().tokenizer return get_lang_class("ti")().tokenizer
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def tr_tokenizer(): def tr_tokenizer():
return get_lang_class("tr")().tokenizer return get_lang_class("tr")().tokenizer

View File

@ -17,17 +17,8 @@ def test_graph_edges_and_nodes():
assert graph.get_node((0,)) == node1 assert graph.get_node((0,)) == node1
node2 = graph.add_node((1, 3)) node2 = graph.add_node((1, 3))
assert list(node2) == [1, 3] assert list(node2) == [1, 3]
graph.add_edge( graph.add_edge(node1, node2, label="one", weight=-10.5)
node1, assert graph.has_edge(node1, node2, label="one")
node2,
label="one",
weight=-10.5
)
assert graph.has_edge(
node1,
node2,
label="one"
)
assert node1.heads() == [] assert node1.heads() == []
assert [tuple(h) for h in node2.heads()] == [(0,)] assert [tuple(h) for h in node2.heads()] == [(0,)]
assert [tuple(t) for t in node1.tails()] == [(1, 3)] assert [tuple(t) for t in node1.tails()] == [(1, 3)]
@ -42,7 +33,7 @@ def test_graph_walk():
nodes=[(0,), (1,), (2,), (3,)], nodes=[(0,), (1,), (2,), (3,)],
edges=[(0, 1), (0, 2), (0, 3), (3, 0)], edges=[(0, 1), (0, 2), (0, 3), (3, 0)],
labels=None, labels=None,
weights=None weights=None,
) )
node0, node1, node2, node3 = list(graph.nodes) node0, node1, node2, node3 = list(graph.nodes)
assert [tuple(h) for h in node0.heads()] == [(3,)] assert [tuple(h) for h in node0.heads()] == [(3,)]

View File

@ -255,8 +255,8 @@ def test_token_api_non_conjuncts(en_vocab):
def test_missing_head_dep(en_vocab): def test_missing_head_dep(en_vocab):
""" Check that the Doc constructor and Example.from_dict parse missing information the same""" """ Check that the Doc constructor and Example.from_dict parse missing information the same"""
heads = [1, 1, 1, 1, 2, None] # element 5 is missing heads = [1, 1, 1, 1, 2, None] # element 5 is missing
deps = ["", "ROOT", "dobj", "cc", "conj", None] # element 0 and 5 are missing deps = ["", "ROOT", "dobj", "cc", "conj", None] # element 0 and 5 are missing
words = ["I", "like", "London", "and", "Berlin", "."] words = ["I", "like", "London", "and", "Berlin", "."]
doc = Doc(en_vocab, words=words, heads=heads, deps=deps) doc = Doc(en_vocab, words=words, heads=heads, deps=deps)
pred_has_heads = [t.has_head() for t in doc] pred_has_heads = [t.has_head() for t in doc]

View File

@ -1,17 +1,16 @@
import pytest import pytest
from spacy.lang.am.lex_attrs import like_num
def test_am_tokenizer_handles_long_text(am_tokenizer): def test_am_tokenizer_handles_long_text(am_tokenizer):
text = """ሆሴ ሙጂካ በበጋ ወቅት በኦክስፎርድ ንግግር አንድያቀርቡ ሲጋበዙ ጭንቅላታቸው "ፈነዳ" text = """ሆሴ ሙጂካ በበጋ ወቅት በኦክስፎርድ ንግግር አንድያቀርቡ ሲጋበዙ ጭንቅላታቸው "ፈነዳ"
እጅግ ጥንታዊ የእንግሊዝኛ ተናጋሪ ዩኒቨርስቲ በአስር ሺዎች የሚቆጠሩ ዩሮዎችን ለተማሪዎች በማስተማር የሚያስከፍለው እጅግ ጥንታዊ የእንግሊዝኛ ተናጋሪ ዩኒቨርስቲ በአስር ሺዎች የሚቆጠሩ ዩሮዎችን ለተማሪዎች በማስተማር የሚያስከፍለው
እና ከማርጋሬት ታቸር እስከ ስቲቨን ሆኪንግ በአዳራሾቻቸው ውስጥ ንግግር ያደረጉበት የትምህርት ማዕከል በሞንቴቪዴኦ እና ከማርጋሬት ታቸር እስከ ስቲቨን ሆኪንግ በአዳራሾቻቸው ውስጥ ንግግር ያደረጉበት የትምህርት ማዕከል በሞንቴቪዴኦ
በሚገኘው የመንግስት ትምህርት ቤት የሰለጠኑትን የ81 ዓመቱ አዛውንት አገልግሎት ጠየቁ""" በሚገኘው የመንግስት ትምህርት ቤት የሰለጠኑትን የ81 ዓመቱ አዛውንት አገልግሎት ጠየቁ"""
tokens = am_tokenizer(text) tokens = am_tokenizer(text)
assert len(tokens) == 56 assert len(tokens) == 56

View File

@ -121,9 +121,7 @@ def test_en_tokenizer_norm_exceptions(en_tokenizer, text, norms):
assert [token.norm_ for token in tokens] == norms assert [token.norm_ for token in tokens] == norms
@pytest.mark.parametrize( @pytest.mark.parametrize("text,norm", [("Jan.", "January"), ("'cuz", "because")])
"text,norm", [("Jan.", "January"), ("'cuz", "because")]
)
def test_en_lex_attrs_norm_exceptions(en_tokenizer, text, norm): def test_en_lex_attrs_norm_exceptions(en_tokenizer, text, norm):
tokens = en_tokenizer(text) tokens = en_tokenizer(text)
assert tokens[0].norm_ == norm assert tokens[0].norm_ == norm

View File

@ -1,17 +1,16 @@
import pytest import pytest
from spacy.lang.ti.lex_attrs import like_num
def test_ti_tokenizer_handles_long_text(ti_tokenizer): def test_ti_tokenizer_handles_long_text(ti_tokenizer):
text = """ቻንስለር ጀርመን ኣንገላ መርከል ኣብታ ሃገር ቁጽሪ መትሓዝቲ ኮቪድ መዓልታዊ ክብረ መዝገብ ድሕሪ ምህራሙ- ጽኑዕ እገዳ ክግበር ጸዊዓ። text = """ቻንስለር ጀርመን ኣንገላ መርከል ኣብታ ሃገር ቁጽሪ መትሓዝቲ ኮቪድ መዓልታዊ ክብረ መዝገብ ድሕሪ ምህራሙ- ጽኑዕ እገዳ ክግበር ጸዊዓ።
መርከል ሎሚ ንታሕታዋይ ባይቶ ሃገራ ክትገልጽ ከላ ኣብ ወሳኒ ምዕራፍ ቃልሲ ኢና ዘለና-ዳሕራዋይ ማዕበል ካብቲ ቀዳማይ ክገድድ ይኽእል` ኢላ መርከል ሎሚ ንታሕታዋይ ባይቶ ሃገራ ክትገልጽ ከላ ኣብ ወሳኒ ምዕራፍ ቃልሲ ኢና ዘለና-ዳሕራዋይ ማዕበል ካብቲ ቀዳማይ ክገድድ ይኽእል` ኢላ
ትካል ምክልኻል ተላገብቲ ሕማማት ጀርመን ኣብ ዝሓለፈ 24 ሰዓታት ኣብ ምልእቲ ጀርመር 590 ሰባት ብኮቪድ19 ምሟቶም ኣፍሊጡ` ትካል ምክልኻል ተላገብቲ ሕማማት ጀርመን ኣብ ዝሓለፈ 24 ሰዓታት ኣብ ምልእቲ ጀርመር 590 ሰባት ብኮቪድ19 ምሟቶም ኣፍሊጡ`
ቻንስለር ኣንጀላ መርከል ኣብ እዋን በዓላት ልደት ስድራቤታት ክተኣኻኸባ ዝፍቀደለን` እንተኾነ ድሕሪኡ ኣብ ዘሎ ግዜ ግን እቲ እገዳታት ክትግበር ትደሊ""" ቻንስለር ኣንጀላ መርከል ኣብ እዋን በዓላት ልደት ስድራቤታት ክተኣኻኸባ ዝፍቀደለን` እንተኾነ ድሕሪኡ ኣብ ዘሎ ግዜ ግን እቲ እገዳታት ክትግበር ትደሊ"""
tokens = ti_tokenizer(text) tokens = ti_tokenizer(text)
assert len(tokens) == 85 assert len(tokens) == 85

View File

@ -389,7 +389,7 @@ def test_beam_ner_scores():
for j in range(len(doc)): for j in range(len(doc)):
for label in ner.labels: for label in ner.labels:
score = entity_scores[(j, j+1, label)] score = entity_scores[(j, j + 1, label)]
eps = 0.00001 eps = 0.00001
assert 0 - eps <= score <= 1 + eps assert 0 - eps <= score <= 1 + eps

View File

@ -146,12 +146,12 @@ def test_no_resize(name):
def test_error_with_multi_labels(): def test_error_with_multi_labels():
nlp = Language() nlp = Language()
textcat = nlp.add_pipe("textcat") nlp.add_pipe("textcat")
train_examples = [] train_examples = []
for text, annotations in TRAIN_DATA_MULTI_LABEL: for text, annotations in TRAIN_DATA_MULTI_LABEL:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations)) train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
with pytest.raises(ValueError): with pytest.raises(ValueError):
optimizer = nlp.initialize(get_examples=lambda: train_examples) nlp.initialize(get_examples=lambda: train_examples)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -273,7 +273,7 @@ def test_overfitting_IO_multi():
assert_equal(batch_cats_1, no_batch_cats) assert_equal(batch_cats_1, no_batch_cats)
def test_overfitting_IO_multi(): def test_overfitting_IO_multi_2():
# Simple test to try and quickly overfit the multi-label textcat component - ensuring the ML models work correctly # Simple test to try and quickly overfit the multi-label textcat component - ensuring the ML models work correctly
fix_random_seed(0) fix_random_seed(0)
nlp = English() nlp = English()
@ -362,7 +362,9 @@ def test_positive_class():
textcat_multilabel = nlp.add_pipe("textcat_multilabel") textcat_multilabel = nlp.add_pipe("textcat_multilabel")
get_examples = make_get_examples_multi_label(nlp) get_examples = make_get_examples_multi_label(nlp)
with pytest.raises(TypeError): with pytest.raises(TypeError):
textcat_multilabel.initialize(get_examples, labels=["POS", "NEG"], positive_label="POS") textcat_multilabel.initialize(
get_examples, labels=["POS", "NEG"], positive_label="POS"
)
textcat_multilabel.initialize(get_examples, labels=["FICTION", "DRAMA"]) textcat_multilabel.initialize(get_examples, labels=["FICTION", "DRAMA"])
assert textcat_multilabel.labels == ("FICTION", "DRAMA") assert textcat_multilabel.labels == ("FICTION", "DRAMA")
assert "positive_label" not in textcat_multilabel.cfg assert "positive_label" not in textcat_multilabel.cfg
@ -381,7 +383,9 @@ def test_positive_class_not_binary():
textcat = nlp.add_pipe("textcat") textcat = nlp.add_pipe("textcat")
get_examples = make_get_examples_multi_label(nlp) get_examples = make_get_examples_multi_label(nlp)
with pytest.raises(ValueError): with pytest.raises(ValueError):
textcat.initialize(get_examples, labels=["SOME", "THING", "POS"], positive_label="POS") textcat.initialize(
get_examples, labels=["SOME", "THING", "POS"], positive_label="POS"
)
def test_textcat_evaluation(): def test_textcat_evaluation():

View File

@ -13,7 +13,6 @@ from spacy.lang.el import Greek
from spacy.language import Language from spacy.language import Language
import spacy import spacy
from thinc.api import compounding from thinc.api import compounding
from collections import defaultdict
from ..util import make_tempdir from ..util import make_tempdir
@ -304,16 +303,14 @@ def test_issue4313():
doc = nlp("What do you think about Apple ?") doc = nlp("What do you think about Apple ?")
assert len(ner.labels) == 1 assert len(ner.labels) == 1
assert "SOME_LABEL" in ner.labels assert "SOME_LABEL" in ner.labels
ner.add_label("MY_ORG") # TODO: not sure if we want this to be necessary... ner.add_label("MY_ORG") # TODO: not sure if we want this to be necessary...
apple_ent = Span(doc, 5, 6, label="MY_ORG") apple_ent = Span(doc, 5, 6, label="MY_ORG")
doc.ents = list(doc.ents) + [apple_ent] doc.ents = list(doc.ents) + [apple_ent]
# ensure the beam_parse still works with the new label # ensure the beam_parse still works with the new label
docs = [doc] docs = [doc]
ner = nlp.get_pipe("beam_ner") ner = nlp.get_pipe("beam_ner")
beams = ner.beam_parse( ner.beam_parse(docs, drop=0.0, beam_width=beam_width, beam_density=beam_density)
docs, drop=0.0, beam_width=beam_width, beam_density=beam_density
)
def test_issue4348(): def test_issue4348():

View File

@ -251,7 +251,9 @@ def test_language_from_config_before_after_init():
nlp.initialize() nlp.initialize()
assert nlp.meta["before_init"] == "before" assert nlp.meta["before_init"] == "before"
assert nlp.meta["after_init"] == "after" assert nlp.meta["after_init"] == "after"
assert all([ran_before, ran_after, ran_after_pipeline, ran_before_init, ran_after_init]) assert all(
[ran_before, ran_after, ran_after_pipeline, ran_before_init, ran_after_init]
)
def test_language_from_config_before_after_init_invalid(): def test_language_from_config_before_after_init_invalid():

View File

@ -166,19 +166,11 @@ def test_Example_from_dict_with_entities(annots):
vocab = Vocab() vocab = Vocab()
predicted = Doc(vocab, words=annots["words"]) predicted = Doc(vocab, words=annots["words"])
example = Example.from_dict(predicted, annots) example = Example.from_dict(predicted, annots)
assert len(list(example.reference.ents)) == 2 assert len(list(example.reference.ents)) == 2
assert [example.reference[i].ent_iob_ for i in range(7)] == [ # fmt: off
"O", assert [example.reference[i].ent_iob_ for i in range(7)] == ["O", "O", "B", "I", "O", "B", "O"]
"O",
"B",
"I",
"O",
"B",
"O",
]
assert example.get_aligned("ENT_IOB") == [2, 2, 3, 1, 2, 3, 2] assert example.get_aligned("ENT_IOB") == [2, 2, 3, 1, 2, 3, 2]
# fmt: on
assert example.reference[2].ent_type_ == "LOC" assert example.reference[2].ent_type_ == "LOC"
assert example.reference[3].ent_type_ == "LOC" assert example.reference[3].ent_type_ == "LOC"
assert example.reference[5].ent_type_ == "LOC" assert example.reference[5].ent_type_ == "LOC"
@ -299,7 +291,8 @@ def test_Example_missing_heads():
assert parsed_heads[2] == heads[2] assert parsed_heads[2] == heads[2]
assert parsed_heads[4] == heads[4] assert parsed_heads[4] == heads[4]
assert parsed_heads[5] == heads[5] assert parsed_heads[5] == heads[5]
assert [t.has_head() for t in example.reference] == [True, True, True, False, True, True] expected = [True, True, True, False, True, True]
assert [t.has_head() for t in example.reference] == expected
# Ensure that the missing head doesn't create an artificial new sentence start # Ensure that the missing head doesn't create an artificial new sentence start
assert example.get_aligned_sent_starts() == [True, False, False, False, False, False] expected = [True, False, False, False, False, False]
assert example.get_aligned_sent_starts() == expected