Before, `Optional` could only be used to have None when no flag is specified. But sometimes you want to set something explicitly to `None`. Now you can do that by using "null". Thomas M (4): Handle tabs in yaml files correctly Change how Optional types work Change how the main test works Format with black tests/test_integration.py | 164 +++++++++++++++++------------------ typed_flags/flags.py | 24 +++-- typed_flags/special_types.py | 17 +++- 3 files changed, 114 insertions(+), 91 deletions(-) -- 2.24.3 (Apple Git-128)
Copy & paste the following snippet into your terminal to import this patchset into git:
curl -s https://lists.sr.ht/~tmk/typed-flags/patches/14036/mbox | git am -3Learn more about email & git
--- typed_flags/flags.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/typed_flags/flags.py b/typed_flags/flags.py index a0069cb..f6f91b0 100644 --- a/typed_flags/flags.py +++ b/typed_flags/flags.py @@ -324,7 +324,7 @@ class TypedFlags(ArgumentParser): key, value = arg_line.split(sep=":", maxsplit=1) key = key.rstrip() value = value.strip() - if key[0] == " ": + if key[0] in (" ", "\t"): # this line is indented key = key.strip() return [f"{key}={value}"] if not value: # no associated value -- 2.24.3 (Apple Git-128)
--- tests/test_integration.py | 18 ++++++++++++++++++ typed_flags/flags.py | 16 ++++++++++++---- typed_flags/special_types.py | 17 ++++++++++++++++- 3 files changed, 46 insertions(+), 5 deletions(-) diff --git a/tests/test_integration.py b/tests/test_integration.py index ae58b8c..83e6bf3 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -340,6 +340,24 @@ class DefaultClassVariableTests(TestCase): self.assertEqual(args.arg_dict_str_int, arg_dict_str_int) self.assertEqual(args.arg_positive_int, arg_positive_int) + def test_set_null(self) -> None: + raw_args: List[str] = [ + "--arg-optional-str", + "null", + "--arg-optional-int", + "null", + "--arg-optional-float", + "null", + "--arg-optional-str-literal", + "null", + ] + args = IntegrationDefaultFlags().parse_args(raw_args) + + self.assertTrue(args.arg_optional_str is None) + self.assertTrue(args.arg_optional_int is None) + self.assertTrue(args.arg_optional_float is None) + self.assertTrue(args.arg_optional_str_literal is None) + class AddArgumentTests(TestCase): def test_positional(self) -> None: diff --git a/typed_flags/flags.py b/typed_flags/flags.py index f6f91b0..a113316 100644 --- a/typed_flags/flags.py +++ b/typed_flags/flags.py @@ -27,7 +27,7 @@ from .utils import ( is_optional_type, type_to_str, ) -from .special_types import StoreDictKeyPair +from .special_types import ParseOptional, StoreDictKeyPair __all__ = ["TypedFlags"] @@ -120,9 +120,6 @@ class TypedFlags(ArgumentParser): # If type is not explicitly provided, set it if it's one of our supported default types if "type" not in kwargs: - if (optional_type := is_optional_type(var_type)) is not None: - var_type = optional_type - origin: Optional[type] = get_origin(var_type) # First check whether it is a literal type or a boxed literal type @@ -150,6 +147,17 @@ class TypedFlags(ArgumentParser): kwargs["key_type"] = key_type kwargs["value_type"] = value_type + elif (optional_type := is_optional_type(var_type)) is not None: + inner_origin = get_origin(optional_type) + if inner_origin is Literal: + kwargs["choices"] = get_string_literals(optional_type, variable) + ["null"] + optional_type = str + elif inner_origin is not None: + raise ValueError(f"{origin} cannot be nested with Optional right now.") + assert optional_type != bool, "no optional bool" # type: ignore[comparison-overlap] + kwargs["action"] = ParseOptional + kwargs["type"] = str + kwargs["type_"] = optional_type elif origin is None: # If bool then set action, otherwise set type if var_type == bool: diff --git a/typed_flags/special_types.py b/typed_flags/special_types.py index 76739cc..d2c3544 100644 --- a/typed_flags/special_types.py +++ b/typed_flags/special_types.py @@ -2,7 +2,7 @@ from argparse import Action from typing import Any -__all__ = ["StoreDictKeyPair"] +__all__ = ["ParseOptional", "StoreDictKeyPair"] class StoreDictKeyPair(Action): @@ -21,3 +21,18 @@ class StoreDictKeyPair(Action): key, value = key_value.split("=") my_dict[self._key_type(key.strip())] = self._value_type(value.strip()) setattr(namespace, self.dest, my_dict) + + +class ParseOptional(Action): + """Action for parsing optionals on the commandline.""" + + def __init__(self, option_strings: Any, type_: type, *args: Any, **kwargs: Any): + self._type_ = type_ + super().__init__(option_strings, *args, **kwargs) + + def __call__(self, parser: Any, namespace: Any, values: Any, option_string: Any = None) -> None: + if values == "null": + value = None + else: + value = self._type_(values.strip()) + setattr(namespace, self.dest, value) -- 2.24.3 (Apple Git-128)
--- tests/test_integration.py | 136 +++++++++++++++++--------------------- 1 file changed, 61 insertions(+), 75 deletions(-) diff --git a/tests/test_integration.py b/tests/test_integration.py index 83e6bf3..1af5702 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -234,88 +234,74 @@ class DefaultClassVariableTests(TestCase): def test_set_default_args(self) -> None: arg_untyped = "yes" arg_str = "goodbye" - arg_int = "2" - arg_float = "1e-2" + arg_int = 2 + arg_float = 1e-2 arg_bool_true = False arg_bool_false = True arg_str_literal = "venus" arg_optional_str = "hello" - arg_optional_int = "77" - arg_optional_float = "7.7" + arg_optional_int = 3 + arg_optional_float = 7.7 arg_optional_str_literal = "spanish" arg_list_str = ["hi", "there", "how", "are", "you"] - arg_list_int = ["1", "2", "3", "10", "-11"] - arg_list_float = ["2.2", "-3.3", "2e20"] + arg_list_int = [1, 2, 3, 10, -11] + arg_list_float = [2.2, -3.3, 2e20] arg_list_str_empty = [] arg_list_str_literal = ["Li", "Be"] arg_set_str = ["hi", "hi", "hi", "how"] - arg_set_int = ["1", "2", "2", "2", "3"] - arg_set_float = ["1.23", "4.4", "1.23"] + arg_set_int = [1, 2, 2, 2, 3] + arg_set_float = [1.23, 4.4, 1.23] arg_set_str_literal = ["C", "He", "C"] - arg_dict_str_int = ["me=2", "you=3", "they=-7"] - arg_positive_int = "7" - - args = IntegrationDefaultFlags().parse_args( - [ - "--arg-untyped", - arg_untyped, - "--arg-str", - arg_str, - "--arg-int", - arg_int, - "--arg-float", - arg_float, - "--arg-bool-true", - str(arg_bool_true), - "--arg-bool-false", - str(arg_bool_false), - "--arg-str-literal", - arg_str_literal, - "--arg-optional-str", - arg_optional_str, - "--arg-optional-int", - arg_optional_int, - "--arg-optional-float", - arg_optional_float, - "--arg-optional-str-literal", - arg_optional_str_literal, - "--arg-list-str", - *arg_list_str, - "--arg-list-int", - *arg_list_int, - "--arg-list-float", - *arg_list_float, - "--arg-list-str-empty", - *arg_list_str_empty, - "--arg-list-str-literal", - *arg_list_str_literal, - "--arg-set-str", - *arg_set_str, - "--arg-set-int", - *arg_set_int, - "--arg-set-float", - *arg_set_float, - "--arg-set-str-literal", - *arg_set_str_literal, - "--arg-dict-str-int", - *arg_dict_str_int, - "--arg-positive-int", - *arg_positive_int, - ] - ) - - arg_int = int(arg_int) - arg_float = float(arg_float) - arg_optional_int = float(arg_optional_int) - arg_optional_float = float(arg_optional_float) - arg_list_int = [int(arg) for arg in arg_list_int] - arg_list_float = [float(arg) for arg in arg_list_float] - arg_set_str = set(arg_set_str) - arg_set_int = {int(arg) for arg in arg_set_int} - arg_set_float = {float(arg) for arg in arg_set_float} - arg_set_str_literal = set(arg_set_str_literal) arg_dict_str_int = {"me": 2, "you": 3, "they": -7} - arg_positive_int = int(arg_positive_int) + arg_positive_int = 7 + + raw_args: List[str] = [ + "--arg-untyped", + arg_untyped, + "--arg-str", + arg_str, + "--arg-int", + str(arg_int), + "--arg-float", + str(arg_float), + "--arg-bool-true", + str(arg_bool_true), + "--arg-bool-false", + str(arg_bool_false), + "--arg-str-literal", + arg_str_literal, + "--arg-optional-str", + arg_optional_str, + "--arg-optional-int", + str(arg_optional_int), + "--arg-optional-float", + str(arg_optional_float), + "--arg-optional-str-literal", + arg_optional_str_literal, + "--arg-list-str", + *arg_list_str, + "--arg-list-int", + *map(str, arg_list_int), + "--arg-list-float", + *map(str, arg_list_float), + "--arg-list-str-empty", + *arg_list_str_empty, + "--arg-list-str-literal", + *arg_list_str_literal, + "--arg-set-str", + *arg_set_str, + "--arg-set-int", + *map(str, arg_set_int), + "--arg-set-float", + *map(str, arg_set_float), + "--arg-set-str-literal", + *arg_set_str_literal, + "--arg-dict-str-int", + *[f"{k}={v}" for k, v in arg_dict_str_int.items()], + "--arg-positive-int", + str(arg_positive_int), + ] + args = IntegrationDefaultFlags().parse_args(raw_args) self.assertEqual(args.arg_untyped, arg_untyped) self.assertEqual(args.arg_str, arg_str) @@ -333,10 +319,10 @@ class DefaultClassVariableTests(TestCase): self.assertEqual(args.arg_list_float, arg_list_float) self.assertEqual(args.arg_list_str_empty, arg_list_str_empty) self.assertEqual(args.arg_list_str_literal, arg_list_str_literal) - self.assertEqual(args.arg_set_str, arg_set_str) - self.assertEqual(args.arg_set_int, arg_set_int) - self.assertEqual(args.arg_set_float, arg_set_float) - self.assertEqual(args.arg_set_str_literal, arg_set_str_literal) + self.assertEqual(args.arg_set_str, set(arg_set_str)) + self.assertEqual(args.arg_set_int, set(arg_set_int)) + self.assertEqual(args.arg_set_float, set(arg_set_float)) + self.assertEqual(args.arg_set_str_literal, set(arg_set_str_literal)) self.assertEqual(args.arg_dict_str_int, arg_dict_str_int) self.assertEqual(args.arg_positive_int, arg_positive_int) -- 2.24.3 (Apple Git-128)
--- tests/test_integration.py | 10 +++------- typed_flags/flags.py | 6 +++--- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/tests/test_integration.py b/tests/test_integration.py index 1af5702..95fd6d6 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -85,19 +85,15 @@ class RequiredClassVariableTests(TestCase): def test_arg_str_required(self): with self.assertRaises(SystemExit): - self.TypedFlags.parse_args( - ["--arg-str-required", "tappy",] - ) + self.TypedFlags.parse_args(["--arg-str-required", "tappy"]) def test_arg_list_str_required(self): with self.assertRaises(SystemExit): - self.TypedFlags.parse_args( - ["--arg-list-str-required", "hi", "there",] - ) + self.TypedFlags.parse_args(["--arg-list-str-required", "hi", "there"]) def test_both_assigned_okay(self): args = self.TypedFlags.parse_args( - ["--arg-str-required", "tappy", "--arg-list-str-required", "hi", "there",] + ["--arg-str-required", "tappy", "--arg-list-str-required", "hi", "there"] ) self.assertEqual(args.arg_str_required, "tappy") self.assertEqual(args.arg_list_str_required, ["hi", "there"]) diff --git a/typed_flags/flags.py b/typed_flags/flags.py index a113316..056406d 100644 --- a/typed_flags/flags.py +++ b/typed_flags/flags.py @@ -285,7 +285,7 @@ class TypedFlags(ArgumentParser): def _get_class_dict(self) -> Dict[str, Any]: """Return a dictionary mapping class variable names to values from the class dict.""" - class_dict = self._get_from_self_and_super(extract_func=vars, dict_type=dict,) + class_dict = self._get_from_self_and_super(extract_func=vars, dict_type=dict) return { var: val for var, val in class_dict.items() @@ -304,8 +304,8 @@ class TypedFlags(ArgumentParser): def as_dict(self) -> Dict[str, Any]: """Returns the member variables corresponding to the class variable arguments. - :return: A dictionary mapping each argument's name to its value. - """ + :return: A dictionary mapping each argument's name to its value. + """ if not self._parsed: raise ValueError("You should call `parse_args` before retrieving arguments.") -- 2.24.3 (Apple Git-128)