---
tests/test_integration.py | 18 ++++++++++++++++++
typed_flags/flags.py | 16 ++++++++++++----
typed_flags/special_types.py | 17 ++++++++++++++++-
3 files changed, 46 insertions(+), 5 deletions(-)
diff --git a/tests/test_integration.py b/tests/test_integration.py
index ae58b8c..83e6bf3 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -340,6 +340,24 @@ class DefaultClassVariableTests(TestCase):
self.assertEqual(args.arg_dict_str_int, arg_dict_str_int)
self.assertEqual(args.arg_positive_int, arg_positive_int)
+ def test_set_null(self) -> None:
+ raw_args: List[str] = [
+ "--arg-optional-str",
+ "null",
+ "--arg-optional-int",
+ "null",
+ "--arg-optional-float",
+ "null",
+ "--arg-optional-str-literal",
+ "null",
+ ]
+ args = IntegrationDefaultFlags().parse_args(raw_args)
+
+ self.assertTrue(args.arg_optional_str is None)
+ self.assertTrue(args.arg_optional_int is None)
+ self.assertTrue(args.arg_optional_float is None)
+ self.assertTrue(args.arg_optional_str_literal is None)
+
class AddArgumentTests(TestCase):
def test_positional(self) -> None:
diff --git a/typed_flags/flags.py b/typed_flags/flags.py
index f6f91b0..a113316 100644
--- a/typed_flags/flags.py
+++ b/typed_flags/flags.py
@@ -27,7 +27,7 @@ from .utils import (
is_optional_type,
type_to_str,
)
-from .special_types import StoreDictKeyPair
+from .special_types import ParseOptional, StoreDictKeyPair
__all__ = ["TypedFlags"]
@@ -120,9 +120,6 @@ class TypedFlags(ArgumentParser):
# If type is not explicitly provided, set it if it's one of our supported default types
if "type" not in kwargs:
- if (optional_type := is_optional_type(var_type)) is not None:
- var_type = optional_type
-
origin: Optional[type] = get_origin(var_type)
# First check whether it is a literal type or a boxed literal type
@@ -150,6 +147,17 @@ class TypedFlags(ArgumentParser):
kwargs["key_type"] = key_type
kwargs["value_type"] = value_type
+ elif (optional_type := is_optional_type(var_type)) is not None:
+ inner_origin = get_origin(optional_type)
+ if inner_origin is Literal:
+ kwargs["choices"] = get_string_literals(optional_type, variable) + ["null"]
+ optional_type = str
+ elif inner_origin is not None:
+ raise ValueError(f"{origin} cannot be nested with Optional right now.")
+ assert optional_type != bool, "no optional bool" # type: ignore[comparison-overlap]
+ kwargs["action"] = ParseOptional
+ kwargs["type"] = str
+ kwargs["type_"] = optional_type
elif origin is None:
# If bool then set action, otherwise set type
if var_type == bool:
diff --git a/typed_flags/special_types.py b/typed_flags/special_types.py
index 76739cc..d2c3544 100644
--- a/typed_flags/special_types.py
+++ b/typed_flags/special_types.py
@@ -2,7 +2,7 @@
from argparse import Action
from typing import Any
-__all__ = ["StoreDictKeyPair"]
+__all__ = ["ParseOptional", "StoreDictKeyPair"]
class StoreDictKeyPair(Action):
@@ -21,3 +21,18 @@ class StoreDictKeyPair(Action):
key, value = key_value.split("=")
my_dict[self._key_type(key.strip())] = self._value_type(value.strip())
setattr(namespace, self.dest, my_dict)
+
+
+class ParseOptional(Action):
+ """Action for parsing optionals on the commandline."""
+
+ def __init__(self, option_strings: Any, type_: type, *args: Any, **kwargs: Any):
+ self._type_ = type_
+ super().__init__(option_strings, *args, **kwargs)
+
+ def __call__(self, parser: Any, namespace: Any, values: Any, option_string: Any = None) -> None:
+ if values == "null":
+ value = None
+ else:
+ value = self._type_(values.strip())
+ setattr(namespace, self.dest, value)
--
2.24.3 (Apple Git-128)
---
tests/test_integration.py | 136 +++++++++++++++++---------------------
1 file changed, 61 insertions(+), 75 deletions(-)
diff --git a/tests/test_integration.py b/tests/test_integration.py
index 83e6bf3..1af5702 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -234,88 +234,74 @@ class DefaultClassVariableTests(TestCase):
def test_set_default_args(self) -> None:
arg_untyped = "yes"
arg_str = "goodbye"
- arg_int = "2"
- arg_float = "1e-2"
+ arg_int = 2
+ arg_float = 1e-2
arg_bool_true = False
arg_bool_false = True
arg_str_literal = "venus"
arg_optional_str = "hello"
- arg_optional_int = "77"
- arg_optional_float = "7.7"
+ arg_optional_int = 3
+ arg_optional_float = 7.7
arg_optional_str_literal = "spanish"
arg_list_str = ["hi", "there", "how", "are", "you"]
- arg_list_int = ["1", "2", "3", "10", "-11"]
- arg_list_float = ["2.2", "-3.3", "2e20"]
+ arg_list_int = [1, 2, 3, 10, -11]
+ arg_list_float = [2.2, -3.3, 2e20]
arg_list_str_empty = []
arg_list_str_literal = ["Li", "Be"]
arg_set_str = ["hi", "hi", "hi", "how"]
- arg_set_int = ["1", "2", "2", "2", "3"]
- arg_set_float = ["1.23", "4.4", "1.23"]
+ arg_set_int = [1, 2, 2, 2, 3]
+ arg_set_float = [1.23, 4.4, 1.23]
arg_set_str_literal = ["C", "He", "C"]
- arg_dict_str_int = ["me=2", "you=3", "they=-7"]
- arg_positive_int = "7"
-
- args = IntegrationDefaultFlags().parse_args(
- [
- "--arg-untyped",
- arg_untyped,
- "--arg-str",
- arg_str,
- "--arg-int",
- arg_int,
- "--arg-float",
- arg_float,
- "--arg-bool-true",
- str(arg_bool_true),
- "--arg-bool-false",
- str(arg_bool_false),
- "--arg-str-literal",
- arg_str_literal,
- "--arg-optional-str",
- arg_optional_str,
- "--arg-optional-int",
- arg_optional_int,
- "--arg-optional-float",
- arg_optional_float,
- "--arg-optional-str-literal",
- arg_optional_str_literal,
- "--arg-list-str",
- *arg_list_str,
- "--arg-list-int",
- *arg_list_int,
- "--arg-list-float",
- *arg_list_float,
- "--arg-list-str-empty",
- *arg_list_str_empty,
- "--arg-list-str-literal",
- *arg_list_str_literal,
- "--arg-set-str",
- *arg_set_str,
- "--arg-set-int",
- *arg_set_int,
- "--arg-set-float",
- *arg_set_float,
- "--arg-set-str-literal",
- *arg_set_str_literal,
- "--arg-dict-str-int",
- *arg_dict_str_int,
- "--arg-positive-int",
- *arg_positive_int,
- ]
- )
-
- arg_int = int(arg_int)
- arg_float = float(arg_float)
- arg_optional_int = float(arg_optional_int)
- arg_optional_float = float(arg_optional_float)
- arg_list_int = [int(arg) for arg in arg_list_int]
- arg_list_float = [float(arg) for arg in arg_list_float]
- arg_set_str = set(arg_set_str)
- arg_set_int = {int(arg) for arg in arg_set_int}
- arg_set_float = {float(arg) for arg in arg_set_float}
- arg_set_str_literal = set(arg_set_str_literal)
arg_dict_str_int = {"me": 2, "you": 3, "they": -7}
- arg_positive_int = int(arg_positive_int)
+ arg_positive_int = 7
+
+ raw_args: List[str] = [
+ "--arg-untyped",
+ arg_untyped,
+ "--arg-str",
+ arg_str,
+ "--arg-int",
+ str(arg_int),
+ "--arg-float",
+ str(arg_float),
+ "--arg-bool-true",
+ str(arg_bool_true),
+ "--arg-bool-false",
+ str(arg_bool_false),
+ "--arg-str-literal",
+ arg_str_literal,
+ "--arg-optional-str",
+ arg_optional_str,
+ "--arg-optional-int",
+ str(arg_optional_int),
+ "--arg-optional-float",
+ str(arg_optional_float),
+ "--arg-optional-str-literal",
+ arg_optional_str_literal,
+ "--arg-list-str",
+ *arg_list_str,
+ "--arg-list-int",
+ *map(str, arg_list_int),
+ "--arg-list-float",
+ *map(str, arg_list_float),
+ "--arg-list-str-empty",
+ *arg_list_str_empty,
+ "--arg-list-str-literal",
+ *arg_list_str_literal,
+ "--arg-set-str",
+ *arg_set_str,
+ "--arg-set-int",
+ *map(str, arg_set_int),
+ "--arg-set-float",
+ *map(str, arg_set_float),
+ "--arg-set-str-literal",
+ *arg_set_str_literal,
+ "--arg-dict-str-int",
+ *[f"{k}={v}" for k, v in arg_dict_str_int.items()],
+ "--arg-positive-int",
+ str(arg_positive_int),
+ ]
+ args = IntegrationDefaultFlags().parse_args(raw_args)
self.assertEqual(args.arg_untyped, arg_untyped)
self.assertEqual(args.arg_str, arg_str)
@@ -333,10 +319,10 @@ class DefaultClassVariableTests(TestCase):
self.assertEqual(args.arg_list_float, arg_list_float)
self.assertEqual(args.arg_list_str_empty, arg_list_str_empty)
self.assertEqual(args.arg_list_str_literal, arg_list_str_literal)
- self.assertEqual(args.arg_set_str, arg_set_str)
- self.assertEqual(args.arg_set_int, arg_set_int)
- self.assertEqual(args.arg_set_float, arg_set_float)
- self.assertEqual(args.arg_set_str_literal, arg_set_str_literal)
+ self.assertEqual(args.arg_set_str, set(arg_set_str))
+ self.assertEqual(args.arg_set_int, set(arg_set_int))
+ self.assertEqual(args.arg_set_float, set(arg_set_float))
+ self.assertEqual(args.arg_set_str_literal, set(arg_set_str_literal))
self.assertEqual(args.arg_dict_str_int, arg_dict_str_int)
self.assertEqual(args.arg_positive_int, arg_positive_int)
--
2.24.3 (Apple Git-128)
---
tests/test_integration.py | 10 +++-------
typed_flags/flags.py | 6 +++---
2 files changed, 6 insertions(+), 10 deletions(-)
diff --git a/tests/test_integration.py b/tests/test_integration.py
index 1af5702..95fd6d6 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -85,19 +85,15 @@ class RequiredClassVariableTests(TestCase):
def test_arg_str_required(self):
with self.assertRaises(SystemExit):
- self.TypedFlags.parse_args(
- ["--arg-str-required", "tappy",]
- )
+ self.TypedFlags.parse_args(["--arg-str-required", "tappy"])
def test_arg_list_str_required(self):
with self.assertRaises(SystemExit):
- self.TypedFlags.parse_args(
- ["--arg-list-str-required", "hi", "there",]
- )
+ self.TypedFlags.parse_args(["--arg-list-str-required", "hi", "there"])
def test_both_assigned_okay(self):
args = self.TypedFlags.parse_args(
- ["--arg-str-required", "tappy", "--arg-list-str-required", "hi", "there",]
+ ["--arg-str-required", "tappy", "--arg-list-str-required", "hi", "there"]
)
self.assertEqual(args.arg_str_required, "tappy")
self.assertEqual(args.arg_list_str_required, ["hi", "there"])
diff --git a/typed_flags/flags.py b/typed_flags/flags.py
index a113316..056406d 100644
--- a/typed_flags/flags.py
+++ b/typed_flags/flags.py
@@ -285,7 +285,7 @@ class TypedFlags(ArgumentParser):
def _get_class_dict(self) -> Dict[str, Any]:
"""Return a dictionary mapping class variable names to values from the class dict."""
- class_dict = self._get_from_self_and_super(extract_func=vars, dict_type=dict,)
+ class_dict = self._get_from_self_and_super(extract_func=vars, dict_type=dict)
return {
var: val
for var, val in class_dict.items()
@@ -304,8 +304,8 @@ class TypedFlags(ArgumentParser):
def as_dict(self) -> Dict[str, Any]:
"""Returns the member variables corresponding to the class variable arguments.
- :return: A dictionary mapping each argument's name to its value.
- """
+ :return: A dictionary mapping each argument's name to its value.
+ """
if not self._parsed:
raise ValueError("You should call `parse_args` before retrieving arguments.")
--
2.24.3 (Apple Git-128)