Merge pull request #6729 from explosion/chore/tidy-up

This commit is contained in:
Ines Montani 2021-01-15 13:27:59 +11:00 committed by GitHub
commit 330f9818c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
23 changed files with 141 additions and 144 deletions

View File

@ -384,7 +384,10 @@ def debug_data(
# rare labels in projectivized train # rare labels in projectivized train
rare_projectivized_labels = [] rare_projectivized_labels = []
for label in gold_train_data["deps"]: for label in gold_train_data["deps"]:
if gold_train_data["deps"][label] <= DEP_LABEL_THRESHOLD and DELIMITER in label: if (
gold_train_data["deps"][label] <= DEP_LABEL_THRESHOLD
and DELIMITER in label
):
rare_projectivized_labels.append( rare_projectivized_labels.append(
f"{label}: {gold_train_data['deps'][label]}" f"{label}: {gold_train_data['deps'][label]}"
) )

View File

@ -30,7 +30,11 @@ def info_cli(
def info( def info(
model: Optional[str] = None, *, markdown: bool = False, silent: bool = True, exclude: List[str] model: Optional[str] = None,
*,
markdown: bool = False,
silent: bool = True,
exclude: List[str],
) -> Union[str, dict]: ) -> Union[str, dict]:
msg = Printer(no_print=silent, pretty=not silent) msg = Printer(no_print=silent, pretty=not silent)
if model: if model:
@ -98,7 +102,9 @@ def info_model(model: str, *, silent: bool = True) -> Dict[str, Any]:
} }
def get_markdown(data: Dict[str, Any], title: Optional[str] = None, exclude: List[str] = None) -> str: def get_markdown(
data: Dict[str, Any], title: Optional[str] = None, exclude: List[str] = None
) -> str:
"""Get data in GitHub-flavoured Markdown format for issues etc. """Get data in GitHub-flavoured Markdown format for issues etc.
data (dict or list of tuples): Label/value pairs. data (dict or list of tuples): Label/value pairs.
@ -115,7 +121,7 @@ def get_markdown(data: Dict[str, Any], title: Optional[str] = None, exclude: Lis
if isinstance(value, str): if isinstance(value, str):
try: try:
existing_path = Path(value).exists() existing_path = Path(value).exists()
except: except Exception:
# invalid Path, like a URL string # invalid Path, like a URL string
existing_path = False existing_path = False
if existing_path: if existing_path:

View File

@ -36,43 +36,44 @@ _num_words = [
"ትሪሊዮን", "ትሪሊዮን",
"ኳድሪሊዮን", "ኳድሪሊዮን",
"ገጅሊዮን", "ገጅሊዮን",
"ባዝሊዮን" "ባዝሊዮን",
] ]
_ordinal_words = [ _ordinal_words = [
"አንደኛ", "አንደኛ",
"ሁለተኛ", "ሁለተኛ",
"ሶስተኛ", "ሶስተኛ",
"አራተኛ", "አራተኛ",
"አምስተኛ", "አምስተኛ",
"ስድስተኛ", "ስድስተኛ",
"ሰባተኛ", "ሰባተኛ",
"ስምንተኛ", "ስምንተኛ",
"ዘጠነኛ", "ዘጠነኛ",
"አስረኛ", "አስረኛ",
"አስራ አንደኛ", "አስራ አንደኛ",
"አስራ ሁለተኛ", "አስራ ሁለተኛ",
"አስራ ሶስተኛ", "አስራ ሶስተኛ",
"አስራ አራተኛ", "አስራ አራተኛ",
"አስራ አምስተኛ", "አስራ አምስተኛ",
"አስራ ስድስተኛ", "አስራ ስድስተኛ",
"አስራ ሰባተኛ", "አስራ ሰባተኛ",
"አስራ ስምንተኛ", "አስራ ስምንተኛ",
"አስራ ዘጠነኛ", "አስራ ዘጠነኛ",
"ሃያኛ", "ሃያኛ",
"ሰላሳኛ" "ሰላሳኛ" "አርባኛ",
"አርባኛ", "አምሳኛ",
"አምሳኛ", "ስድሳኛ",
"ስድሳኛ", "ሰባኛ",
"ሰባኛ", "ሰማንያኛ",
"ሰማንያኛ", "ዘጠናኛ",
"ዘጠናኛ", "መቶኛ",
"መቶኛ", "ሺኛ",
"ሺኛ", "ሚሊዮንኛ",
"ሚሊዮንኛ", "ቢሊዮንኛ",
"ቢሊዮንኛ", "ትሪሊዮንኛ",
"ትሪሊዮንኛ"
] ]
def like_num(text): def like_num(text):
if text.startswith(("+", "-", "±", "~")): if text.startswith(("+", "-", "±", "~")):
text = text[1:] text = text[1:]

View File

@ -7,7 +7,6 @@ _exc = {}
for exc_data in [ for exc_data in [
{ORTH: "ት/ቤት"}, {ORTH: "ት/ቤት"},
{ORTH: "ወ/ሮ", NORM: "ወይዘሮ"}, {ORTH: "ወ/ሮ", NORM: "ወይዘሮ"},
]: ]:
_exc[exc_data[ORTH]] = [exc_data] _exc[exc_data[ORTH]] = [exc_data]

View File

@ -1,4 +1,4 @@
from typing import Union, Iterator, Optional, List, Tuple from typing import Union, Iterator
from ...symbols import NOUN, PROPN, PRON, VERB, AUX from ...symbols import NOUN, PROPN, PRON, VERB, AUX
from ...errors import Errors from ...errors import Errors

View File

@ -36,43 +36,44 @@ _num_words = [
"ትሪልዮን", "ትሪልዮን",
"ኳድሪልዮን", "ኳድሪልዮን",
"ገጅልዮን", "ገጅልዮን",
"ባዝልዮን" "ባዝልዮን",
] ]
_ordinal_words = [ _ordinal_words = [
"ቀዳማይ", "ቀዳማይ",
"ካልኣይ", "ካልኣይ",
"ሳልሳይ", "ሳልሳይ",
"ራብኣይ", "ራብኣይ",
"ሓምሻይ", "ሓምሻይ",
"ሻድሻይ", "ሻድሻይ",
"ሻውዓይ", "ሻውዓይ",
"ሻምናይ", "ሻምናይ",
"ዘጠነኛ", "ዘጠነኛ",
"አስረኛ", "አስረኛ",
"ኣሰርተ አንደኛ", "ኣሰርተ አንደኛ",
"ኣሰርተ ሁለተኛ", "ኣሰርተ ሁለተኛ",
"ኣሰርተ ሶስተኛ", "ኣሰርተ ሶስተኛ",
"ኣሰርተ አራተኛ", "ኣሰርተ አራተኛ",
"ኣሰርተ አምስተኛ", "ኣሰርተ አምስተኛ",
"ኣሰርተ ስድስተኛ", "ኣሰርተ ስድስተኛ",
"ኣሰርተ ሰባተኛ", "ኣሰርተ ሰባተኛ",
"ኣሰርተ ስምንተኛ", "ኣሰርተ ስምንተኛ",
"ኣሰርተ ዘጠነኛ", "ኣሰርተ ዘጠነኛ",
"ሃያኛ", "ሃያኛ",
"ሰላሳኛ" "ሰላሳኛ" "አርባኛ",
"አርባኛ", "አምሳኛ",
"አምሳኛ", "ስድሳኛ",
"ስድሳኛ", "ሰባኛ",
"ሰባኛ", "ሰማንያኛ",
"ሰማንያኛ", "ዘጠናኛ",
"ዘጠናኛ", "መቶኛ",
"መቶኛ", "ሺኛ",
"ሺኛ", "ሚሊዮንኛ",
"ሚሊዮንኛ", "ቢሊዮንኛ",
"ቢሊዮንኛ", "ትሪሊዮንኛ",
"ትሪሊዮንኛ"
] ]
def like_num(text): def like_num(text):
if text.startswith(("+", "-", "±", "~")): if text.startswith(("+", "-", "±", "~")):
text = text[1:] text = text[1:]

View File

@ -8,7 +8,6 @@ for exc_data in [
{ORTH: "ት/ቤት"}, {ORTH: "ት/ቤት"},
{ORTH: "ወ/ሮ", NORM: "ወይዘሮ"}, {ORTH: "ወ/ሮ", NORM: "ወይዘሮ"},
{ORTH: "ወ/ሪ", NORM: "ወይዘሪት"}, {ORTH: "ወ/ሪ", NORM: "ወይዘሪት"},
]: ]:
_exc[exc_data[ORTH]] = [exc_data] _exc[exc_data[ORTH]] = [exc_data]

View File

@ -71,17 +71,19 @@ def build_text_classifier_v2(
exclusive_classes = not linear_model.attrs["multi_label"] exclusive_classes = not linear_model.attrs["multi_label"]
with Model.define_operators({">>": chain, "|": concatenate}): with Model.define_operators({">>": chain, "|": concatenate}):
width = tok2vec.maybe_get_dim("nO") width = tok2vec.maybe_get_dim("nO")
attention_layer = ParametricAttention(width) # TODO: benchmark performance difference of this layer attention_layer = ParametricAttention(
width
) # TODO: benchmark performance difference of this layer
maxout_layer = Maxout(nO=width, nI=width) maxout_layer = Maxout(nO=width, nI=width)
linear_layer = Linear(nO=nO, nI=width) linear_layer = Linear(nO=nO, nI=width)
cnn_model = ( cnn_model = (
tok2vec tok2vec
>> list2ragged() >> list2ragged()
>> attention_layer >> attention_layer
>> reduce_sum() >> reduce_sum()
>> residual(maxout_layer) >> residual(maxout_layer)
>> linear_layer >> linear_layer
>> Dropout(0.0) >> Dropout(0.0)
) )
nO_double = nO * 2 if nO else None nO_double = nO * 2 if nO else None

View File

@ -89,7 +89,7 @@ def build_hash_embed_cnn_tok2vec(
# TODO: archive # TODO: archive
@registry.architectures.register("spacy.Tok2Vec.v1") @registry.architectures.register("spacy.Tok2Vec.v1")
def build_Tok2Vec_model( def _build_Tok2Vec_model(
embed: Model[List[Doc], List[Floats2d]], embed: Model[List[Doc], List[Floats2d]],
encode: Model[List[Floats2d], List[Floats2d]], encode: Model[List[Floats2d], List[Floats2d]],
) -> Model[List[Doc], List[Floats2d]]: ) -> Model[List[Doc], List[Floats2d]]:
@ -109,7 +109,6 @@ def build_Tok2Vec_model(
return tok2vec return tok2vec
@registry.architectures.register("spacy.Tok2Vec.v2") @registry.architectures.register("spacy.Tok2Vec.v2")
def build_Tok2Vec_model( def build_Tok2Vec_model(
embed: Model[List[Doc], List[Floats2d]], embed: Model[List[Doc], List[Floats2d]],
@ -130,7 +129,6 @@ def build_Tok2Vec_model(
return tok2vec return tok2vec
@registry.architectures.register("spacy.MultiHashEmbed.v1") @registry.architectures.register("spacy.MultiHashEmbed.v1")
def MultiHashEmbed( def MultiHashEmbed(
width: int, width: int,
@ -280,7 +278,7 @@ def CharacterEmbed(
# TODO: archive # TODO: archive
@registry.architectures.register("spacy.MaxoutWindowEncoder.v1") @registry.architectures.register("spacy.MaxoutWindowEncoder.v1")
def MaxoutWindowEncoder( def _MaxoutWindowEncoder(
width: int, window_size: int, maxout_pieces: int, depth: int width: int, window_size: int, maxout_pieces: int, depth: int
) -> Model[List[Floats2d], List[Floats2d]]: ) -> Model[List[Floats2d], List[Floats2d]]:
"""Encode context using convolutions with maxout activation, layer """Encode context using convolutions with maxout activation, layer
@ -310,6 +308,7 @@ def MaxoutWindowEncoder(
model.attrs["receptive_field"] = window_size * depth model.attrs["receptive_field"] = window_size * depth
return model return model
@registry.architectures.register("spacy.MaxoutWindowEncoder.v2") @registry.architectures.register("spacy.MaxoutWindowEncoder.v2")
def MaxoutWindowEncoder( def MaxoutWindowEncoder(
width: int, window_size: int, maxout_pieces: int, depth: int width: int, window_size: int, maxout_pieces: int, depth: int
@ -344,7 +343,7 @@ def MaxoutWindowEncoder(
# TODO: archive # TODO: archive
@registry.architectures.register("spacy.MishWindowEncoder.v1") @registry.architectures.register("spacy.MishWindowEncoder.v1")
def MishWindowEncoder( def _MishWindowEncoder(
width: int, window_size: int, depth: int width: int, window_size: int, depth: int
) -> Model[List[Floats2d], List[Floats2d]]: ) -> Model[List[Floats2d], List[Floats2d]]:
"""Encode context using convolutions with mish activation, layer """Encode context using convolutions with mish activation, layer

View File

@ -388,7 +388,6 @@ class TextCategorizer(TrainablePipe):
**kwargs, **kwargs,
) )
def _validate_categories(self, examples: List[Example]): def _validate_categories(self, examples: List[Example]):
"""Check whether the provided examples all have single-label cats annotations.""" """Check whether the provided examples all have single-label cats annotations."""
for ex in examples: for ex in examples:

View File

@ -187,5 +187,5 @@ class MultiLabel_TextCategorizer(TextCategorizer):
def _validate_categories(self, examples: List[Example]): def _validate_categories(self, examples: List[Example]):
"""This component allows any type of single- or multi-label annotations. """This component allows any type of single- or multi-label annotations.
This method overwrites the more strict one from 'textcat'. """ This method overwrites the more strict one from 'textcat'."""
pass pass

View File

@ -28,10 +28,12 @@ def pytest_runtest_setup(item):
def tokenizer(): def tokenizer():
return get_lang_class("xx")().tokenizer return get_lang_class("xx")().tokenizer
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def am_tokenizer(): def am_tokenizer():
return get_lang_class("am")().tokenizer return get_lang_class("am")().tokenizer
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def ar_tokenizer(): def ar_tokenizer():
return get_lang_class("ar")().tokenizer return get_lang_class("ar")().tokenizer
@ -247,10 +249,12 @@ def th_tokenizer():
pytest.importorskip("pythainlp") pytest.importorskip("pythainlp")
return get_lang_class("th")().tokenizer return get_lang_class("th")().tokenizer
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def ti_tokenizer(): def ti_tokenizer():
return get_lang_class("ti")().tokenizer return get_lang_class("ti")().tokenizer
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def tr_tokenizer(): def tr_tokenizer():
return get_lang_class("tr")().tokenizer return get_lang_class("tr")().tokenizer

View File

@ -17,17 +17,8 @@ def test_graph_edges_and_nodes():
assert graph.get_node((0,)) == node1 assert graph.get_node((0,)) == node1
node2 = graph.add_node((1, 3)) node2 = graph.add_node((1, 3))
assert list(node2) == [1, 3] assert list(node2) == [1, 3]
graph.add_edge( graph.add_edge(node1, node2, label="one", weight=-10.5)
node1, assert graph.has_edge(node1, node2, label="one")
node2,
label="one",
weight=-10.5
)
assert graph.has_edge(
node1,
node2,
label="one"
)
assert node1.heads() == [] assert node1.heads() == []
assert [tuple(h) for h in node2.heads()] == [(0,)] assert [tuple(h) for h in node2.heads()] == [(0,)]
assert [tuple(t) for t in node1.tails()] == [(1, 3)] assert [tuple(t) for t in node1.tails()] == [(1, 3)]
@ -42,7 +33,7 @@ def test_graph_walk():
nodes=[(0,), (1,), (2,), (3,)], nodes=[(0,), (1,), (2,), (3,)],
edges=[(0, 1), (0, 2), (0, 3), (3, 0)], edges=[(0, 1), (0, 2), (0, 3), (3, 0)],
labels=None, labels=None,
weights=None weights=None,
) )
node0, node1, node2, node3 = list(graph.nodes) node0, node1, node2, node3 = list(graph.nodes)
assert [tuple(h) for h in node0.heads()] == [(3,)] assert [tuple(h) for h in node0.heads()] == [(3,)]

View File

@ -255,8 +255,8 @@ def test_token_api_non_conjuncts(en_vocab):
def test_missing_head_dep(en_vocab): def test_missing_head_dep(en_vocab):
""" Check that the Doc constructor and Example.from_dict parse missing information the same""" """ Check that the Doc constructor and Example.from_dict parse missing information the same"""
heads = [1, 1, 1, 1, 2, None] # element 5 is missing heads = [1, 1, 1, 1, 2, None] # element 5 is missing
deps = ["", "ROOT", "dobj", "cc", "conj", None] # element 0 and 5 are missing deps = ["", "ROOT", "dobj", "cc", "conj", None] # element 0 and 5 are missing
words = ["I", "like", "London", "and", "Berlin", "."] words = ["I", "like", "London", "and", "Berlin", "."]
doc = Doc(en_vocab, words=words, heads=heads, deps=deps) doc = Doc(en_vocab, words=words, heads=heads, deps=deps)
pred_has_heads = [t.has_head() for t in doc] pred_has_heads = [t.has_head() for t in doc]

View File

@ -1,5 +1,4 @@
import pytest import pytest
from spacy.lang.am.lex_attrs import like_num
def test_am_tokenizer_handles_long_text(am_tokenizer): def test_am_tokenizer_handles_long_text(am_tokenizer):

View File

@ -121,9 +121,7 @@ def test_en_tokenizer_norm_exceptions(en_tokenizer, text, norms):
assert [token.norm_ for token in tokens] == norms assert [token.norm_ for token in tokens] == norms
@pytest.mark.parametrize( @pytest.mark.parametrize("text,norm", [("Jan.", "January"), ("'cuz", "because")])
"text,norm", [("Jan.", "January"), ("'cuz", "because")]
)
def test_en_lex_attrs_norm_exceptions(en_tokenizer, text, norm): def test_en_lex_attrs_norm_exceptions(en_tokenizer, text, norm):
tokens = en_tokenizer(text) tokens = en_tokenizer(text)
assert tokens[0].norm_ == norm assert tokens[0].norm_ == norm

View File

@ -1,5 +1,4 @@
import pytest import pytest
from spacy.lang.ti.lex_attrs import like_num
def test_ti_tokenizer_handles_long_text(ti_tokenizer): def test_ti_tokenizer_handles_long_text(ti_tokenizer):

View File

@ -389,7 +389,7 @@ def test_beam_ner_scores():
for j in range(len(doc)): for j in range(len(doc)):
for label in ner.labels: for label in ner.labels:
score = entity_scores[(j, j+1, label)] score = entity_scores[(j, j + 1, label)]
eps = 0.00001 eps = 0.00001
assert 0 - eps <= score <= 1 + eps assert 0 - eps <= score <= 1 + eps

View File

@ -146,12 +146,12 @@ def test_no_resize(name):
def test_error_with_multi_labels(): def test_error_with_multi_labels():
nlp = Language() nlp = Language()
textcat = nlp.add_pipe("textcat") nlp.add_pipe("textcat")
train_examples = [] train_examples = []
for text, annotations in TRAIN_DATA_MULTI_LABEL: for text, annotations in TRAIN_DATA_MULTI_LABEL:
train_examples.append(Example.from_dict(nlp.make_doc(text), annotations)) train_examples.append(Example.from_dict(nlp.make_doc(text), annotations))
with pytest.raises(ValueError): with pytest.raises(ValueError):
optimizer = nlp.initialize(get_examples=lambda: train_examples) nlp.initialize(get_examples=lambda: train_examples)
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -226,7 +226,8 @@ def test_overfitting_IO():
assert_equal(batch_cats_1, no_batch_cats) assert_equal(batch_cats_1, no_batch_cats)
def test_overfitting_IO_multi(): @pytest.mark.skip(reason="TODO: Can this be removed?")
def test_overfitting_IO_multi_old():
# Simple test to try and quickly overfit the multi-label textcat component - ensuring the ML models work correctly # Simple test to try and quickly overfit the multi-label textcat component - ensuring the ML models work correctly
fix_random_seed(0) fix_random_seed(0)
nlp = English() nlp = English()
@ -362,7 +363,9 @@ def test_positive_class():
textcat_multilabel = nlp.add_pipe("textcat_multilabel") textcat_multilabel = nlp.add_pipe("textcat_multilabel")
get_examples = make_get_examples_multi_label(nlp) get_examples = make_get_examples_multi_label(nlp)
with pytest.raises(TypeError): with pytest.raises(TypeError):
textcat_multilabel.initialize(get_examples, labels=["POS", "NEG"], positive_label="POS") textcat_multilabel.initialize(
get_examples, labels=["POS", "NEG"], positive_label="POS"
)
textcat_multilabel.initialize(get_examples, labels=["FICTION", "DRAMA"]) textcat_multilabel.initialize(get_examples, labels=["FICTION", "DRAMA"])
assert textcat_multilabel.labels == ("FICTION", "DRAMA") assert textcat_multilabel.labels == ("FICTION", "DRAMA")
assert "positive_label" not in textcat_multilabel.cfg assert "positive_label" not in textcat_multilabel.cfg
@ -381,7 +384,9 @@ def test_positive_class_not_binary():
textcat = nlp.add_pipe("textcat") textcat = nlp.add_pipe("textcat")
get_examples = make_get_examples_multi_label(nlp) get_examples = make_get_examples_multi_label(nlp)
with pytest.raises(ValueError): with pytest.raises(ValueError):
textcat.initialize(get_examples, labels=["SOME", "THING", "POS"], positive_label="POS") textcat.initialize(
get_examples, labels=["SOME", "THING", "POS"], positive_label="POS"
)
def test_textcat_evaluation(): def test_textcat_evaluation():

View File

@ -13,7 +13,6 @@ from spacy.lang.el import Greek
from spacy.language import Language from spacy.language import Language
import spacy import spacy
from thinc.api import compounding from thinc.api import compounding
from collections import defaultdict
from ..util import make_tempdir from ..util import make_tempdir
@ -304,16 +303,14 @@ def test_issue4313():
doc = nlp("What do you think about Apple ?") doc = nlp("What do you think about Apple ?")
assert len(ner.labels) == 1 assert len(ner.labels) == 1
assert "SOME_LABEL" in ner.labels assert "SOME_LABEL" in ner.labels
ner.add_label("MY_ORG") # TODO: not sure if we want this to be necessary... ner.add_label("MY_ORG") # TODO: not sure if we want this to be necessary...
apple_ent = Span(doc, 5, 6, label="MY_ORG") apple_ent = Span(doc, 5, 6, label="MY_ORG")
doc.ents = list(doc.ents) + [apple_ent] doc.ents = list(doc.ents) + [apple_ent]
# ensure the beam_parse still works with the new label # ensure the beam_parse still works with the new label
docs = [doc] docs = [doc]
ner = nlp.get_pipe("beam_ner") ner = nlp.get_pipe("beam_ner")
beams = ner.beam_parse( ner.beam_parse(docs, drop=0.0, beam_width=beam_width, beam_density=beam_density)
docs, drop=0.0, beam_width=beam_width, beam_density=beam_density
)
def test_issue4348(): def test_issue4348():

View File

@ -251,7 +251,9 @@ def test_language_from_config_before_after_init():
nlp.initialize() nlp.initialize()
assert nlp.meta["before_init"] == "before" assert nlp.meta["before_init"] == "before"
assert nlp.meta["after_init"] == "after" assert nlp.meta["after_init"] == "after"
assert all([ran_before, ran_after, ran_after_pipeline, ran_before_init, ran_after_init]) assert all(
[ran_before, ran_after, ran_after_pipeline, ran_before_init, ran_after_init]
)
def test_language_from_config_before_after_init_invalid(): def test_language_from_config_before_after_init_invalid():

View File

@ -166,19 +166,11 @@ def test_Example_from_dict_with_entities(annots):
vocab = Vocab() vocab = Vocab()
predicted = Doc(vocab, words=annots["words"]) predicted = Doc(vocab, words=annots["words"])
example = Example.from_dict(predicted, annots) example = Example.from_dict(predicted, annots)
assert len(list(example.reference.ents)) == 2 assert len(list(example.reference.ents)) == 2
assert [example.reference[i].ent_iob_ for i in range(7)] == [ # fmt: off
"O", assert [example.reference[i].ent_iob_ for i in range(7)] == ["O", "O", "B", "I", "O", "B", "O"]
"O",
"B",
"I",
"O",
"B",
"O",
]
assert example.get_aligned("ENT_IOB") == [2, 2, 3, 1, 2, 3, 2] assert example.get_aligned("ENT_IOB") == [2, 2, 3, 1, 2, 3, 2]
# fmt: on
assert example.reference[2].ent_type_ == "LOC" assert example.reference[2].ent_type_ == "LOC"
assert example.reference[3].ent_type_ == "LOC" assert example.reference[3].ent_type_ == "LOC"
assert example.reference[5].ent_type_ == "LOC" assert example.reference[5].ent_type_ == "LOC"
@ -299,7 +291,8 @@ def test_Example_missing_heads():
assert parsed_heads[2] == heads[2] assert parsed_heads[2] == heads[2]
assert parsed_heads[4] == heads[4] assert parsed_heads[4] == heads[4]
assert parsed_heads[5] == heads[5] assert parsed_heads[5] == heads[5]
assert [t.has_head() for t in example.reference] == [True, True, True, False, True, True] expected = [True, True, True, False, True, True]
assert [t.has_head() for t in example.reference] == expected
# Ensure that the missing head doesn't create an artificial new sentence start # Ensure that the missing head doesn't create an artificial new sentence start
assert example.get_aligned_sent_starts() == [True, False, False, False, False, False] expected = [True, False, False, False, False, False]
assert example.get_aligned_sent_starts() == expected