~tmk/typed-flags

typed-flags: Add flags with dictionary type v2 APPLIED

Thomas M: 1
 Add flags with dictionary type

 4 files changed, 61 insertions(+), 7 deletions(-)
Export patchset (mbox)
How do I use this?

Copy & paste the following snippet into your terminal to import this patchset into git:

curl -s https://lists.sr.ht/~tmk/typed-flags/patches/10763/mbox | git am -3
Learn more about email & git

[PATCH typed-flags v2] Add flags with dictionary type Export this patch

Forgot to add the new yaml file!

---
 tests/flags2.yaml         |  6 ++++++
 tests/test_integration.py | 20 +++++++++++++++++++-
 typed_flags/flags.py      | 24 +++++++++++++++++++-----
 typed_flags/utils.py      | 18 +++++++++++++++++-
 4 files changed, 61 insertions(+), 7 deletions(-)
 create mode 100644 tests/flags2.yaml

diff --git a/tests/flags2.yaml b/tests/flags2.yaml
new file mode 100644
index 0000000..1f284b0
--- /dev/null
+++ b/tests/flags2.yaml
@@ -0,0 +1,6 @@
numbers:
my_dict:
  one: 1.0
  quarter: 0.25
  half: 0.5
  minus_two: -2
diff --git a/tests/test_integration.py b/tests/test_integration.py
index aa098cc..11481dd 100644
--- a/tests/test_integration.py
+++ b/tests/test_integration.py
@@ -1,6 +1,6 @@
import sys
import unittest
from typing import List, Literal, Optional, Set
from typing import List, Literal, Optional, Set, Dict
from unittest import TestCase

from typed_flags import TypedFlags
@@ -138,6 +138,7 @@ class IntegrationDefaultFlags(TypedFlags):
    arg_set_int: Set[int] = {10, -11}
    arg_set_float: Set[float] = {3.14, 6.28}
    arg_set_str_literal: Set[Literal["H", "He", "Li", "Be", "B", "C"]] = {"H", "He"}
    arg_dict_str_int: Dict[str, int] = {}


class SubclassTests(TestCase):
@@ -180,6 +181,17 @@ class YamlConfigTests(TestCase):
        self.assertEqual(args.arg_str_literal, "venus")
        self.assertEqual(args.arg_list_int, [10, 11, 12])

    def test_yaml_edge_cases(self) -> None:
        class _DifficultFlags(TypedFlags):
            numbers: List[int] = [2, 3]
            my_dict: Dict[str, float] = {}

        flags = _DifficultFlags(fromfile_prefix_chars="@").parse_args(["@tests/flags2.yaml"])
        self.assertListEqual(flags.numbers, [])
        self.assertDictEqual(
            flags.my_dict, {"one": 1.0, "quarter": 0.25, "half": 0.5, "minus_two": -2.0}
        )


class DefaultClassVariableTests(TestCase):
    def test_get_default_args(self) -> None:
@@ -205,6 +217,7 @@ class DefaultClassVariableTests(TestCase):
        self.assertEqual(args.arg_set_int, {10, -11})
        self.assertEqual(args.arg_set_float, {3.14, 6.28})
        self.assertEqual(args.arg_set_str_literal, {"H", "He"})
        self.assertEqual(args.arg_dict_str_int, {})

    def test_set_default_args(self) -> None:
        arg_untyped = "yes"
@@ -227,6 +240,7 @@ class DefaultClassVariableTests(TestCase):
        arg_set_int = ["1", "2", "2", "2", "3"]
        arg_set_float = ["1.23", "4.4", "1.23"]
        arg_set_str_literal = ["C", "He", "C"]
        arg_dict_str_int = ["me=2", "you=3", "they=-7"]

        args = IntegrationDefaultFlags().parse_args(
            [
@@ -270,6 +284,8 @@ class DefaultClassVariableTests(TestCase):
                *arg_set_float,
                "--arg-set-str-literal",
                *arg_set_str_literal,
                "--arg-dict-str-int",
                *arg_dict_str_int,
            ]
        )

@@ -283,6 +299,7 @@ class DefaultClassVariableTests(TestCase):
        arg_set_int = {int(arg) for arg in arg_set_int}
        arg_set_float = {float(arg) for arg in arg_set_float}
        arg_set_str_literal = set(arg_set_str_literal)
        arg_dict_str_int = {"me": 2, "you": 3, "they": -7}

        self.assertEqual(args.arg_untyped, arg_untyped)
        self.assertEqual(args.arg_str, arg_str)
@@ -305,6 +322,7 @@ class DefaultClassVariableTests(TestCase):
        self.assertEqual(args.arg_set_int, arg_set_int)
        self.assertEqual(args.arg_set_float, arg_set_float)
        self.assertEqual(args.arg_set_str_literal, arg_set_str_literal)
        self.assertEqual(args.arg_dict_str_int, arg_dict_str_int)


class AddArgumentTests(TestCase):
diff --git a/typed_flags/flags.py b/typed_flags/flags.py
index 698a256..2f44ede 100644
--- a/typed_flags/flags.py
+++ b/typed_flags/flags.py
@@ -20,6 +20,7 @@ from typing import (
)

from .utils import (
    StoreDictKeyPair,
    get_dest,
    get_string_literals,
    is_literal_type,
@@ -127,22 +128,30 @@ class TypedFlags(ArgumentParser):
                if origin is Literal:  # type: ignore[comparison-overlap]
                    kwargs["choices"] = get_string_literals(var_type, variable)
                    var_type = str

                elif origin in (list, set):
                    arg = get_args(var_type)[0]
                    if is_literal_type(arg):
                        # unpack the outer type and then the literal
                        kwargs["choices"] = get_string_literals(get_args(var_type)[0], variable)
                        var_type = str
                        kwargs["type"] = str
                        kwargs["nargs"] = kwargs.get("nargs", "*")
                    else:
                        # If List type, extract type of elements in list and set nargs
                        kwargs["type"] = arg
                        kwargs["nargs"] = kwargs.get("nargs", "*")

                elif origin is dict:
                    key_type, value_type = get_args(var_type)
                    kwargs["action"] = StoreDictKeyPair
                    kwargs["nargs"] = kwargs.get("nargs", "*")
                    kwargs["type"] = str
                    kwargs["key_type"] = key_type
                    kwargs["value_type"] = value_type

                elif origin is None:
                    # If bool then set action, otherwise set type
                    if var_type == bool:
                        # kwargs['action'] = kwargs.get('action', f'store_{"true" if kwargs["required"] or not kwargs["default"] else "false"}')
                        kwargs["type"] = eval
                        kwargs["choices"] = [True, False]
                    else:
@@ -298,7 +307,7 @@ class TypedFlags(ArgumentParser):
        formatted = []
        for k in sorted(args_dict):
            v = args_dict[k]
            formatted.append(f'{k}="{v}"' if isinstance(args_dict[k], str) else f"{k}={v}")
            formatted.append(f'{k}="{v}"' if isinstance(v, str) else f"{k}={v}")
        return "Namespace(" + ", ".join(formatted) + ")"

    def convert_arg_line_to_args(self, arg_line: str) -> List[str]:
@@ -307,9 +316,14 @@ class TypedFlags(ArgumentParser):
        if not arg_line.strip():  # empty line
            return []
        key, value = arg_line.split(sep=":", maxsplit=1)
        key = key.strip()
        key = key.rstrip()
        value = value.strip()
        if value[0] == '"' and value[-1] == '"':  # if wrapped in quotes, don't split further
        if key[0] == " ":
            key = key.strip()
            return [f"{key}={value}"]
        if not value:  # no associated value
            values = []
        elif value[0] == '"' and value[-1] == '"':  # if wrapped in quotes, don't split further
            values = [value[1:-1]]
        else:
            values = value.split()
diff --git a/typed_flags/utils.py b/typed_flags/utils.py
index 402b3cd..f403f01 100644
--- a/typed_flags/utils.py
+++ b/typed_flags/utils.py
@@ -1,4 +1,4 @@
from argparse import ArgumentParser
from argparse import ArgumentParser, Action
from typing import Any, List, Literal, Optional, Union, get_args, get_origin


@@ -72,3 +72,19 @@ def is_optional_type(tp: type) -> Optional[type]:
        elif args[1] is type(None):  # noqa
            return args[0]
    return None


class StoreDictKeyPair(Action):
    def __init__(
        self, option_strings: Any, key_type: type, value_type: type, *args: Any, **kwargs: Any
    ):
        self._key_type = key_type
        self._value_type = value_type
        super().__init__(option_strings, *args, **kwargs)

    def __call__(self, parser: Any, namespace: Any, values: Any, option_string: Any = None) -> None:
        my_dict = {}
        for kv in values:
            k, v = kv.split("=")
            my_dict[self._key_type(k.strip())] = self._value_type(v.strip())
        setattr(namespace, self.dest, my_dict)
-- 
2.24.2 (Apple Git-127)
Thanks!
Here we could also test the dictionary argument.