Refactor util.to_ternary_int (#7944)

* Refactor to avoid literal comparison with `is`
* Extend tests
This commit is contained in:
Adriane Boyd 2021-04-29 16:58:54 +02:00 committed by GitHub
parent 49aed683cc
commit 7cf5bd072f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 24 additions and 4 deletions

View File

@ -8,6 +8,7 @@ from spacy import prefer_gpu, require_gpu, require_cpu
from spacy.ml._precomputable_affine import PrecomputableAffine from spacy.ml._precomputable_affine import PrecomputableAffine
from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding from spacy.ml._precomputable_affine import _backprop_precomputable_affine_padding
from spacy.util import dot_to_object, SimpleFrozenList, import_file from spacy.util import dot_to_object, SimpleFrozenList, import_file
from spacy.util import to_ternary_int
from thinc.api import Config, Optimizer, ConfigValidationError, get_current_ops from thinc.api import Config, Optimizer, ConfigValidationError, get_current_ops
from thinc.api import set_current_ops from thinc.api import set_current_ops
from spacy.training.batchers import minibatch_by_words from spacy.training.batchers import minibatch_by_words
@ -386,3 +387,18 @@ def make_dummy_component(
nlp = English.from_config(config) nlp = English.from_config(config)
nlp.add_pipe("dummy_component") nlp.add_pipe("dummy_component")
nlp.initialize() nlp.initialize()
def test_to_ternary_int():
assert to_ternary_int(True) == 1
assert to_ternary_int(None) == 0
assert to_ternary_int(False) == -1
assert to_ternary_int(1) == 1
assert to_ternary_int(1.0) == 1
assert to_ternary_int(0) == 0
assert to_ternary_int(0.0) == 0
assert to_ternary_int(-1) == -1
assert to_ternary_int(5) == -1
assert to_ternary_int(-10) == -1
assert to_ternary_int("string") == -1
assert to_ternary_int([0, "string"]) == -1

View File

@ -1533,11 +1533,15 @@ def to_ternary_int(val) -> int:
attributes such as SENT_START: True/1/1.0 is 1 (True), None/0/0.0 is 0 attributes such as SENT_START: True/1/1.0 is 1 (True), None/0/0.0 is 0
(None), any other values are -1 (False). (None), any other values are -1 (False).
""" """
if isinstance(val, float): if val is True:
val = int(val)
if val is True or val is 1:
return 1 return 1
elif val is None or val is 0: elif val is None:
return 0
elif val is False:
return -1
elif val == 1:
return 1
elif val == 0:
return 0 return 0
else: else:
return -1 return -1