Thomas M: 1 Add flags with dictionary type 3 files changed, 55 insertions(+), 7 deletions(-)
Copy & paste the following snippet into your terminal to import this patchset into git:
curl -s https://lists.sr.ht/~tmk/typed-flags/patches/10762/mbox | git am -3Learn more about email & git
Also some other changes. --- tests/test_integration.py | 20 +++++++++++++++++++- typed_flags/flags.py | 24 +++++++++++++++++++----- typed_flags/utils.py | 18 +++++++++++++++++- 3 files changed, 55 insertions(+), 7 deletions(-) diff --git a/tests/test_integration.py b/tests/test_integration.py index aa098cc..11481dd 100644 --- a/tests/test_integration.py +++ b/tests/test_integration.py @@ -1,6 +1,6 @@ import sys import unittest -from typing import List, Literal, Optional, Set +from typing import List, Literal, Optional, Set, Dict from unittest import TestCase from typed_flags import TypedFlags @@ -138,6 +138,7 @@ class IntegrationDefaultFlags(TypedFlags): arg_set_int: Set[int] = {10, -11} arg_set_float: Set[float] = {3.14, 6.28} arg_set_str_literal: Set[Literal["H", "He", "Li", "Be", "B", "C"]] = {"H", "He"} + arg_dict_str_int: Dict[str, int] = {} class SubclassTests(TestCase): @@ -180,6 +181,17 @@ class YamlConfigTests(TestCase): self.assertEqual(args.arg_str_literal, "venus") self.assertEqual(args.arg_list_int, [10, 11, 12]) + def test_yaml_edge_cases(self) -> None: + class _DifficultFlags(TypedFlags): + numbers: List[int] = [2, 3] + my_dict: Dict[str, float] = {} + + flags = _DifficultFlags(fromfile_prefix_chars="@").parse_args(["@tests/flags2.yaml"]) + self.assertListEqual(flags.numbers, []) + self.assertDictEqual( + flags.my_dict, {"one": 1.0, "quarter": 0.25, "half": 0.5, "minus_two": -2.0} + ) + class DefaultClassVariableTests(TestCase): def test_get_default_args(self) -> None: @@ -205,6 +217,7 @@ class DefaultClassVariableTests(TestCase): self.assertEqual(args.arg_set_int, {10, -11}) self.assertEqual(args.arg_set_float, {3.14, 6.28}) self.assertEqual(args.arg_set_str_literal, {"H", "He"}) + self.assertEqual(args.arg_dict_str_int, {}) def test_set_default_args(self) -> None: arg_untyped = "yes" @@ -227,6 +240,7 @@ class DefaultClassVariableTests(TestCase): arg_set_int = ["1", "2", "2", "2", "3"] arg_set_float = ["1.23", "4.4", "1.23"] arg_set_str_literal = ["C", "He", "C"] + arg_dict_str_int = ["me=2", "you=3", "they=-7"] args = IntegrationDefaultFlags().parse_args( [ @@ -270,6 +284,8 @@ class DefaultClassVariableTests(TestCase): *arg_set_float, "--arg-set-str-literal", *arg_set_str_literal, + "--arg-dict-str-int", + *arg_dict_str_int, ] ) @@ -283,6 +299,7 @@ class DefaultClassVariableTests(TestCase): arg_set_int = {int(arg) for arg in arg_set_int} arg_set_float = {float(arg) for arg in arg_set_float} arg_set_str_literal = set(arg_set_str_literal) + arg_dict_str_int = {"me": 2, "you": 3, "they": -7} self.assertEqual(args.arg_untyped, arg_untyped) self.assertEqual(args.arg_str, arg_str) @@ -305,6 +322,7 @@ class DefaultClassVariableTests(TestCase): self.assertEqual(args.arg_set_int, arg_set_int) self.assertEqual(args.arg_set_float, arg_set_float) self.assertEqual(args.arg_set_str_literal, arg_set_str_literal) + self.assertEqual(args.arg_dict_str_int, arg_dict_str_int) class AddArgumentTests(TestCase): diff --git a/typed_flags/flags.py b/typed_flags/flags.py index 698a256..2f44ede 100644 --- a/typed_flags/flags.py +++ b/typed_flags/flags.py @@ -20,6 +20,7 @@ from typing import ( ) from .utils import ( + StoreDictKeyPair, get_dest, get_string_literals, is_literal_type, @@ -127,22 +128,30 @@ class TypedFlags(ArgumentParser): if origin is Literal: # type: ignore[comparison-overlap] kwargs["choices"] = get_string_literals(var_type, variable) var_type = str + elif origin in (list, set): arg = get_args(var_type)[0] if is_literal_type(arg): # unpack the outer type and then the literal kwargs["choices"] = get_string_literals(get_args(var_type)[0], variable) - var_type = str + kwargs["type"] = str kwargs["nargs"] = kwargs.get("nargs", "*") else: # If List type, extract type of elements in list and set nargs kwargs["type"] = arg kwargs["nargs"] = kwargs.get("nargs", "*") + elif origin is dict: + key_type, value_type = get_args(var_type) + kwargs["action"] = StoreDictKeyPair + kwargs["nargs"] = kwargs.get("nargs", "*") + kwargs["type"] = str + kwargs["key_type"] = key_type + kwargs["value_type"] = value_type + elif origin is None: # If bool then set action, otherwise set type if var_type == bool: - # kwargs['action'] = kwargs.get('action', f'store_{"true" if kwargs["required"] or not kwargs["default"] else "false"}') kwargs["type"] = eval kwargs["choices"] = [True, False] else: @@ -298,7 +307,7 @@ class TypedFlags(ArgumentParser): formatted = [] for k in sorted(args_dict): v = args_dict[k] - formatted.append(f'{k}="{v}"' if isinstance(args_dict[k], str) else f"{k}={v}") + formatted.append(f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}") return "Namespace(" + ", ".join(formatted) + ")" def convert_arg_line_to_args(self, arg_line: str) -> List[str]: @@ -307,9 +316,14 @@ class TypedFlags(ArgumentParser): if not arg_line.strip(): # empty line return [] key, value = arg_line.split(sep=":", maxsplit=1) - key = key.strip() + key = key.rstrip() value = value.strip() - if value[0] == '"' and value[-1] == '"': # if wrapped in quotes, don't split further + if key[0] == " ": + key = key.strip() + return [f"{key}={value}"] + if not value: # no associated value + values = [] + elif value[0] == '"' and value[-1] == '"': # if wrapped in quotes, don't split further values = [value[1:-1]] else: values = value.split() diff --git a/typed_flags/utils.py b/typed_flags/utils.py index 402b3cd..f403f01 100644 --- a/typed_flags/utils.py +++ b/typed_flags/utils.py @@ -1,4 +1,4 @@ -from argparse import ArgumentParser +from argparse import ArgumentParser, Action from typing import Any, List, Literal, Optional, Union, get_args, get_origin @@ -72,3 +72,19 @@ def is_optional_type(tp: type) -> Optional[type]: elif args[1] is type(None): # noqa return args[0] return None + + +class StoreDictKeyPair(Action): + def __init__( + self, option_strings: Any, key_type: type, value_type: type, *args: Any, **kwargs: Any + ): + self._key_type = key_type + self._value_type = value_type + super().__init__(option_strings, *args, **kwargs) + + def __call__(self, parser: Any, namespace: Any, values: Any, option_string: Any = None) -> None: + my_dict = {} + for kv in values: + k, v = kv.split("=") + my_dict[self._key_type(k.strip())] = self._value_type(v.strip()) + setattr(namespace, self.dest, my_dict) -- 2.24.2 (Apple Git-127)