More can be added easily. Thomas M (2): Add a new type `PositiveInt` for use with TypedFlags Add meta variables for float and int tests/test_integration.py | 18 +++++++++++++++-- typed_flags/__init__.py | 5 ++--- typed_flags/flags.py | 16 ++++++++++----- typed_flags/types.py | 42 +++++++++++++++++++++++++++++++++++++++ typed_flags/utils.py | 18 +---------------- 5 files changed, 72 insertions(+), 27 deletions(-) create mode 100644 typed_flags/types.py -- 2.24.2 (Apple Git-127)
Copy & paste the following snippet into your terminal to import this patchset into git:
curl -s https://lists.sr.ht/~tmk/typed-flags/patches/11208/mbox | git am -3Learn more about email & git
The new type asserts that the given integer is positive. A new file was created for this: types.py. --- tests/test_integration.py | 18 +++++++++++++++-- typed_flags/__init__.py | 5 ++--- typed_flags/flags.py | 12 ++++++----- typed_flags/types.py | 42 +++++++++++++++++++++++++++++++++++++++ typed_flags/utils.py | 18 +---------------- 5 files changed, 68 insertions(+), 27 deletions(-) create mode 100644 typed_flags/types.py diff --git a/tests/test_integration.py b/tests/test_integration.py index 11481dd..ade6fc2 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -3,7 +3,7 @@ import unittest from typing import List, Literal, Optional, Set, Dict from unittest import TestCase -from typed_flags import TypedFlags +from typed_flags import TypedFlags, PositiveInt class EdgeCaseTests(TestCase): @@ -100,6 +100,14 @@ class RequiredClassVariableTests(TestCase): self.assertEqual(args.arg_str_required, "tappy") self.assertEqual(args.arg_list_str_required, ["hi", "there"]) + def test_constraint_violation(self) -> None: + class ConstraintViolation(TypedFlags): + pos_num: PositiveInt + + args = ConstraintViolation() + with self.assertRaises(SystemExit): + args.parse_args(["--pos-num", "0"]) + def tearDown(self) -> None: sys.stderr = self.stderr @@ -139,6 +147,7 @@ class IntegrationDefaultFlags(TypedFlags): arg_set_float: Set[float] = {3.14, 6.28} arg_set_str_literal: Set[Literal["H", "He", "Li", "Be", "B", "C"]] = {"H", "He"} arg_dict_str_int: Dict[str, int] = {} + arg_positive_int: PositiveInt = 3 class SubclassTests(TestCase): @@ -218,6 +227,7 @@ class DefaultClassVariableTests(TestCase): self.assertEqual(args.arg_set_float, {3.14, 6.28}) self.assertEqual(args.arg_set_str_literal, {"H", "He"}) self.assertEqual(args.arg_dict_str_int, {}) + self.assertEqual(args.arg_positive_int, 3) def test_set_default_args(self) -> None: arg_untyped = "yes" @@ -241,6 +251,7 @@ class DefaultClassVariableTests(TestCase): arg_set_float = ["1.23", "4.4", "1.23"] arg_set_str_literal = ["C", "He", "C"] arg_dict_str_int = ["me=2", "you=3", "they=-7"] + arg_positive_int = "7" args = IntegrationDefaultFlags().parse_args( [ @@ -286,6 +297,8 @@ class DefaultClassVariableTests(TestCase): *arg_set_str_literal, "--arg-dict-str-int", *arg_dict_str_int, + "--arg-positive-int", + *arg_positive_int, ] ) @@ -300,12 +313,12 @@ class DefaultClassVariableTests(TestCase): arg_set_float = {float(arg) for arg in arg_set_float} arg_set_str_literal = set(arg_set_str_literal) arg_dict_str_int = {"me": 2, "you": 3, "they": -7} + arg_positive_int = int(arg_positive_int) self.assertEqual(args.arg_untyped, arg_untyped) self.assertEqual(args.arg_str, arg_str) self.assertEqual(args.arg_int, arg_int) self.assertEqual(args.arg_float, arg_float) - # Note: setting the bools as flags results in the opposite of their default self.assertEqual(args.arg_bool_true, arg_bool_true) self.assertEqual(args.arg_bool_false, arg_bool_false) self.assertEqual(args.arg_str_literal, arg_str_literal) @@ -323,6 +336,7 @@ class DefaultClassVariableTests(TestCase): self.assertEqual(args.arg_set_float, arg_set_float) self.assertEqual(args.arg_set_str_literal, arg_set_str_literal) self.assertEqual(args.arg_dict_str_int, arg_dict_str_int) + self.assertEqual(args.arg_positive_int, arg_positive_int) class AddArgumentTests(TestCase): diff --git a/typed_flags/__init__.py b/typed_flags/__init__.py index ee6bbf0..1e7a3e9 100644 --- a/typed_flags/__init__.py +++ b/typed_flags/__init__.py @@ -1,3 +1,2 @@ -from .flags import TypedFlags - -__all__ = ["TypedFlags"] +from .flags import * +from .types import * diff --git a/typed_flags/flags.py b/typed_flags/flags.py index 2f44ede..73264e1 100644 --- a/typed_flags/flags.py +++ b/typed_flags/flags.py @@ -20,7 +20,6 @@ from typing import ( ) from .utils import ( - StoreDictKeyPair, get_dest, get_string_literals, is_literal_type, @@ -29,6 +28,9 @@ from .utils import ( is_union_type, type_to_str, ) +from .types import StoreDictKeyPair, maybe_convert + +__all__ = ["TypedFlags"] TFType = TypeVar("TFType", bound="TypedFlags") _DictType = TypeVar("_DictType", Dict[str, Any], "OrderedDict[str, Any]") @@ -138,7 +140,7 @@ class TypedFlags(ArgumentParser): kwargs["nargs"] = kwargs.get("nargs", "*") else: # If List type, extract type of elements in list and set nargs - kwargs["type"] = arg + kwargs["type"] = maybe_convert(arg) kwargs["nargs"] = kwargs.get("nargs", "*") elif origin is dict: @@ -146,8 +148,8 @@ class TypedFlags(ArgumentParser): kwargs["action"] = StoreDictKeyPair kwargs["nargs"] = kwargs.get("nargs", "*") kwargs["type"] = str - kwargs["key_type"] = key_type - kwargs["value_type"] = value_type + kwargs["key_type"] = maybe_convert(key_type) + kwargs["value_type"] = maybe_convert(value_type) elif origin is None: # If bool then set action, otherwise set type @@ -155,7 +157,7 @@ class TypedFlags(ArgumentParser): kwargs["type"] = eval kwargs["choices"] = [True, False] else: - kwargs["type"] = var_type + kwargs["type"] = maybe_convert(var_type) else: raise ValueError( f'Variable "{variable}" has type "{var_type}" which is not supported by default.\n' diff --git a/typed_flags/types.py b/typed_flags/types.py new file mode 100644 index 0000000..cc9b4a9 --- /dev/null +++ b/typed_flags/types.py @@ -0,0 +1,42 @@ +"""Custom types for use with TypedFlags.""" +from argparse import Action, ArgumentTypeError +from typing import Any, Callable, NewType, Mapping, Union + +__all__ = ["StoreDictKeyPair", "PositiveInt", "maybe_convert"] + + +class StoreDictKeyPair(Action): + """Action for parsing dictionaries on the commandline.""" + + def __init__( + self, option_strings: Any, key_type: type, value_type: type, *args: Any, **kwargs: Any + ): + self._key_type = key_type + self._value_type = value_type + super().__init__(option_strings, *args, **kwargs) + + def __call__(self, parser: Any, namespace: Any, values: Any, option_string: Any = None) -> None: + my_dict = {} + for kv in values: + k, v = kv.split("=") + my_dict[self._key_type(k.strip())] = self._value_type(v.strip()) + setattr(namespace, self.dest, my_dict) + + +PositiveInt = NewType("PositiveInt", int) + + +def _to_positive_int(arg: str) -> PositiveInt: + num = int(arg) + if num > 0: + return PositiveInt(num) + raise ArgumentTypeError(f"{num} is not positive") + + +TYPE_TO_CONSTRUCTOR: Mapping[type, Callable[[str], Any]] = { + PositiveInt: _to_positive_int, +} + + +def maybe_convert(arg_type: type) -> Union[type, Callable]: + return TYPE_TO_CONSTRUCTOR.get(arg_type, arg_type) diff --git a/typed_flags/utils.py b/typed_flags/utils.py index f403f01..402b3cd 100644 --- a/typed_flags/utils.py +++ b/typed_flags/utils.py @@ -1,4 +1,4 @@ -from argparse import ArgumentParser, Action +from argparse import ArgumentParser from typing import Any, List, Literal, Optional, Union, get_args, get_origin @@ -72,19 +72,3 @@ def is_optional_type(tp: type) -> Optional[type]: elif args[1] is type(None): # noqa return args[0] return None - - -class StoreDictKeyPair(Action): - def __init__( - self, option_strings: Any, key_type: type, value_type: type, *args: Any, **kwargs: Any - ): - self._key_type = key_type - self._value_type = value_type - super().__init__(option_strings, *args, **kwargs) - - def __call__(self, parser: Any, namespace: Any, values: Any, option_string: Any = None) -> None: - my_dict = {} - for kv in values: - k, v = kv.split("=") - my_dict[self._key_type(k.strip())] = self._value_type(v.strip()) - setattr(namespace, self.dest, my_dict) -- 2.24.2 (Apple Git-127)
--- typed_flags/flags.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/typed_flags/flags.py b/typed_flags/flags.py index 73264e1..c74eff1 100644 --- a/typed_flags/flags.py +++ b/typed_flags/flags.py @@ -158,6 +158,10 @@ class TypedFlags(ArgumentParser): kwargs["choices"] = [True, False] else: kwargs["type"] = maybe_convert(var_type) + if var_type == float: + kwargs["metavar"] = "FLOAT" + elif var_type == int: + kwargs["metavar"] = "INT" else: raise ValueError( f'Variable "{variable}" has type "{var_type}" which is not supported by default.\n' -- 2.24.2 (Apple Git-127)