diff --git a/main.py b/main.py index c8283fb..9109db8 100644 --- a/main.py +++ b/main.py @@ -110,7 +110,14 @@ def parse_args(): parser.add_argument("--top-p", type=float, default=0.95) parser.add_argument("--temperature", type=float, default=0.6) parser.add_argument("--global-topk", type=int) - parser.add_argument("--do-sample", type=bool, default=True) + parser.add_argument( + "--do-sample", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable multinomial sampling on the on-device sampler. " + "Pass --no-do-sample to force greedy/argmax (required for " + "deterministic logit_validation accuracy checks).", + ) parser.add_argument("--dynamic", action="store_true") parser.add_argument("--pad-token-id", type=int, default=2) parser.add_argument("--top-k-kernel-enabled", action="store_true", default=False) @@ -154,7 +161,12 @@ def parse_args(): parser.add_argument("--on-device-sampling", action="store_true") # Bucketing - parser.add_argument("--enable-bucketing", type=bool, default=True) + parser.add_argument( + "--enable-bucketing", + action=argparse.BooleanOptionalAction, + default=True, + help="Enable bucketed compilation. Pass --no-enable-bucketing to disable.", + ) parser.add_argument("--bucket-n-active-tokens", action="store_true") parser.add_argument("--context-encoding-buckets", nargs="+", type=int) parser.add_argument("--prefix-buckets", nargs="+", type=int)