diff --git a/python-stdlib/argparse/argparse.py b/python-stdlib/argparse/argparse.py index 5c92887f9..3df696edc 100644 --- a/python-stdlib/argparse/argparse.py +++ b/python-stdlib/argparse/argparse.py @@ -10,8 +10,12 @@ class _ArgError(BaseException): pass +class ArgumentTypeError(BaseException): + pass + + class _Arg: - def __init__(self, names, dest, action, nargs, const, default, help): + def __init__(self, names, dest, action, nargs, const, default, help, type): self.names = names self.dest = dest self.action = action @@ -19,20 +23,28 @@ def __init__(self, names, dest, action, nargs, const, default, help): self.const = const self.default = default self.help = help + self.type = type + + def _apply(self, optname, arg): + if self.type: + try: + return self.type(arg) + except Exception as e: + if isinstance(e, (ArgumentTypeError, TypeError, ValueError)): + raise _ArgError("invalid value for %s: %s (%s)" % (optname, arg, str(e))) + raise + return arg def parse(self, optname, args): # parse args for this arg if self.action == "store": if self.nargs is None: if args: - return args.pop(0) + return self._apply(optname, args.pop(0)) else: raise _ArgError("expecting value for %s" % optname) elif self.nargs == "?": - if args: - return args.pop(0) - else: - return self.default + return self._apply(optname, args.pop(0) if args else self.default) else: if self.nargs == "*": n = -1 @@ -52,7 +64,7 @@ def parse(self, optname, args): else: break else: - ret.append(args.pop(0)) + ret.append(self._apply(optname, args.pop(0))) n -= 1 if n > 0: raise _ArgError("expecting value for %s" % optname) @@ -103,6 +115,10 @@ def add_argument(self, *args, **kwargs): dest = args[0] if not args: args = [dest] + arg_type = kwargs.get("type", None) + if arg_type is not None: + if not callable(arg_type): + raise ValueError("type is not callable") list.append( _Arg( args, @@ -112,6 +128,7 @@ def add_argument(self, *args, **kwargs): const, default, kwargs.get("help", ""), + arg_type, ) ) @@ -176,7 +193,9 @@ def _parse_args(self, args, return_unknown): arg_vals = [] for opt in self.opt: arg_dest.append(opt.dest) - arg_vals.append(opt.default) + arg_vals.append( + opt._apply(opt.dest, opt.default) if isinstance(opt.default, str) else opt.default + ) # deal with unknown arguments, if needed unknown = [] diff --git a/python-stdlib/argparse/manifest.py b/python-stdlib/argparse/manifest.py index 02bf1a22c..a6952c917 100644 --- a/python-stdlib/argparse/manifest.py +++ b/python-stdlib/argparse/manifest.py @@ -1,4 +1,4 @@ -metadata(version="0.4.0") +metadata(version="0.4.1") # Originally written by Damien George. diff --git a/python-stdlib/argparse/test_argparse.py b/python-stdlib/argparse/test_argparse.py index d86e53211..7548d4a92 100644 --- a/python-stdlib/argparse/test_argparse.py +++ b/python-stdlib/argparse/test_argparse.py @@ -66,3 +66,47 @@ args, rest = parser.parse_known_args(["a", "b", "c", "-b", "2", "--x", "5", "1"]) assert args.a == ["a", "b"] and args.b == "2" assert rest == ["c", "--x", "5", "1"] + + +class CustomArgType: + def __init__(self, add): + self.add = add + + def __call__(self, value): + return int(value) + self.add + + +parser = argparse.ArgumentParser() +parser.add_argument("-a", type=int) +args = parser.parse_args(["-a", "123"]) +assert args.a == 123 +parser.add_argument("-b", type=str) +args = parser.parse_args(["-b", "string"]) +assert args.b == "string" +parser.add_argument("-c", type=CustomArgType(1)) +args = parser.parse_args(["-c", "123"]) +assert args.c == 124 +try: + parser.add_argument("-d", type=()) + assert False +except ValueError as e: + assert "not callable" in str(e) +parser.add_argument("-d", type=int, nargs="+") +args = parser.parse_args(["-d", "123", "124", "125"]) +assert args.d == [123, 124, 125] +parser.add_argument("-e", type=CustomArgType(1), nargs="+") +args = parser.parse_args(["-e", "123", "124", "125"]) +assert args.e == [124, 125, 126] +parser.add_argument("-f", type=CustomArgType(1), nargs="?") +args = parser.parse_args(["-f", "123"]) +assert args.f == 124 +parser.add_argument("-g", type=CustomArgType(1), default=1) +parser.add_argument("-i", type=CustomArgType(1), default="1") +args = parser.parse_args([]) +assert args.g == 1 +assert args.i == 2 +parser.add_argument("-j", type=CustomArgType(1), default=1) +args = parser.parse_args(["-j", "3"]) +assert args.g == 1 +assert args.i == 2 +assert args.j == 4