From e8e0ac60e227d194eff1e21f9505e3f252f175e0 Mon Sep 17 00:00:00 2001 From: Ryan Fowler Date: Wed, 11 Feb 2026 00:36:40 -0800 Subject: [PATCH] Redesign CLI parsing subsystem to reduce boilerplate and improve readability - Add core.CutTrimmed to deduplicate cut() between cli and config packages - Consolidate error types (MissingEnvVarError, fileIsDirError) into errors.go - Extract parseURL to deduplicate URL normalization between ArgFn and applyFromCurl - Replace dataSet/jsonSet/xmlSet booleans with bodySource enum - Create flags.go with declarative helpers (boolFlag, ptrBoolFlag, stringFlag, cfgFlag) - Extract 15 flag handler methods and buildAWSConfig to deduplicate AWS config construction - Rewrite CLI() using helpers, reducing it from ~1,000 to ~230 lines - Unify validation with declarative SchemeExclusiveFlags and FromCurlExclusiveFlags --- internal/cli/app.go | 1414 +++++++++++-------------------------- internal/cli/cli.go | 161 ++--- internal/cli/errors.go | 47 ++ internal/cli/flags.go | 100 +++ internal/config/config.go | 10 +- internal/core/core.go | 9 + 6 files changed, 631 insertions(+), 1110 deletions(-) create mode 100644 internal/cli/flags.go diff --git a/internal/cli/app.go b/internal/cli/app.go index c740731..0f890a4 100644 --- a/internal/cli/app.go +++ b/internal/cli/app.go @@ -51,7 +51,9 @@ type App struct { Update bool Version bool - dataSet, jsonSet, xmlSet bool + dataSet bool + jsonSet bool + xmlSet bool } func (a *App) PrintHelp(p *core.Printer) { @@ -80,36 +82,12 @@ func (a *App) CLI() *CLI { if a.URL != nil { return fmt.Errorf("unexpected argument: %q", s) } - if s == "" { - return fmt.Errorf("empty URL provided") - } - - // For URLs that have the scheme omitted, add two - // slashes so it can be parsed correctly. - if !strings.Contains(s, "://") && s[0] != '/' { - s = "//" + s - } - - u, err := url.Parse(s) + u, isWS, err := parseURL(s) if err != nil { - return fmt.Errorf("invalid url: %w", err) + return err } - - // Lowercase the scheme, and validate. - u.Scheme = strings.ToLower(u.Scheme) - switch u.Scheme { - case "", "http", "https": - case "ws": - u.Scheme = "http" - a.WS = true - case "wss": - u.Scheme = "https" - a.WS = true - default: - return fmt.Errorf("unsupported url scheme: %s", u.Scheme) - } - a.URL = u + a.WS = a.WS || isWS return nil }, ExclusiveFlags: [][]string{ @@ -128,993 +106,509 @@ func (a *App) CLI() *CLI { {Key: "proto-import", Val: []string{"proto-file"}}, {Key: "remote-header-name", Val: []string{"remote-name"}}, }, + SchemeExclusiveFlags: map[string][]string{ + "ws": {"discard", "grpc", "form", "multipart", "xml", "edit"}, + "wss": {"discard", "grpc", "form", "multipart", "xml", "edit"}, + }, + FromCurlExclusiveFlags: []string{ + "method", "header", "data", "json", "xml", + "form", "multipart", "basic", "bearer", "aws-sigv4", + "output", "remote-name", "remote-header-name", + "range", "unix", "timeout", "connect-timeout", + "redirects", "proxy", "insecure", "tls", "http", + "cert", "key", "ca-cert", "dns-server", + "retry", "retry-delay", "grpc", "query", + }, Flags: []Flag{ + // cfgFlag: delegates to config parser + cfgFlag("auto-update", "", "(ENABLED|INTERVAL)", "Enable/disable auto-updates", + func() bool { return a.Cfg.AutoUpdate != nil }, a.Cfg.ParseAutoUpdate). + WithHidden(true), + + // Custom: AWS signature V4 with env var lookups { - Short: "", - Long: "auto-update", - Args: "(ENABLED|INTERVAL)", - IsHidden: true, - Description: "Enable/disable auto-updates", - Default: "", - IsSet: func() bool { - return a.Cfg.AutoUpdate != nil - }, - Fn: func(value string) error { - return a.Cfg.ParseAutoUpdate(value) - }, - }, - { - Short: "", Long: "aws-sigv4", Args: "REGION/SERVICE", Description: "Sign the request using AWS signature V4", - Default: "", - IsSet: func() bool { - return a.AWSSigv4 != nil - }, - Fn: func(value string) error { - region, service, ok := cut(value, "/") - if !ok { - const usage = "format must be " - return core.NewValueError("aws-sigv4", value, usage, false) - } - - accessKey := os.Getenv("AWS_ACCESS_KEY_ID") - if accessKey == "" { - return missingEnvVarErr("AWS_ACCESS_KEY_ID", "aws-sigv4") - } - secretKey := os.Getenv("AWS_SECRET_ACCESS_KEY") - if secretKey == "" { - return missingEnvVarErr("AWS_SECRET_ACCESS_KEY", "aws-sigv4") - } - - a.AWSSigv4 = &aws.Config{ - Region: region, - Service: service, - AccessKey: accessKey, - SecretKey: secretKey, - } - return nil - }, + IsSet: func() bool { return a.AWSSigv4 != nil }, + Fn: a.parseAWSSigv4Flag, }, + + // Custom: basic auth parsing { - Short: "", Long: "basic", Args: "USER:PASS", Description: "Enable HTTP basic authentication", - Default: "", - IsSet: func() bool { - return a.Basic != nil - }, - Fn: func(value string) error { - user, pass, ok := cut(value, ":") - if !ok { - const usage = "format must be " - return core.NewValueError("basic", value, usage, false) - } - a.Basic = &core.KeyVal[string]{Key: user, Val: pass} - return nil - }, - }, - { - Short: "", - Long: "bearer", - Args: "TOKEN", - Description: "Enable HTTP bearer authentication", - Default: "", - IsSet: func() bool { - return a.Bearer != "" - }, - Fn: func(value string) error { - a.Bearer = value - return nil - }, - }, - { - Short: "", - Long: "buildinfo", - Args: "", - Description: "Print the build information", - Default: "", - IsSet: func() bool { - return a.BuildInfo - }, - Fn: func(value string) error { - a.BuildInfo = true - return nil - }, - }, - { - Short: "", - Long: "ca-cert", - Args: "PATH", - Description: "CA certificate file path", - Default: "", - IsSet: func() bool { - return len(a.Cfg.CACerts) > 0 - }, - Fn: func(value string) error { - return a.Cfg.ParseCACerts(value) - }, + IsSet: func() bool { return a.Basic != nil }, + Fn: a.parseBasicFlag, }, + + // stringFlag: simple string value + stringFlag(&a.Bearer, "bearer", "", "TOKEN", "Enable HTTP bearer authentication"), + boolFlag(&a.BuildInfo, "buildinfo", "", "Print the build information"), + + // cfgFlag: delegates to config parser + cfgFlag("ca-cert", "", "PATH", "CA certificate file path", + func() bool { return len(a.Cfg.CACerts) > 0 }, a.Cfg.ParseCACerts), + + // Custom: file check + config parse { - Short: "", Long: "cert", Args: "PATH", Description: "Client certificate for mTLS", - Default: "", - IsSet: func() bool { - return a.Cfg.CertPath != "" - }, - Fn: func(value string) error { - if err := checkFileExists(value); err != nil { - return err - } - return a.Cfg.ParseCert(value) - }, - }, - { - Short: "", - Long: "clobber", - Args: "", - Description: "Overwrite existing output file", - Default: "", - IsSet: func() bool { - return a.Clobber - }, - Fn: func(value string) error { - a.Clobber = true - return nil - }, - }, - { - Short: "", - Long: "color", - Args: "OPTION", - Description: "Enable/disable color", - Default: "", - Aliases: []string{"colour"}, - Values: []core.KeyVal[string]{ - { - Key: "auto", - Val: "Automatically determine color", - }, - { - Key: "off", - Val: "Disable color output", - }, - { - Key: "on", - Val: "Enable color output", - }, - }, - IsSet: func() bool { - return a.Cfg.Color != core.ColorUnknown - }, - Fn: func(value string) error { - return a.Cfg.ParseColor(value) - }, - }, - { - Short: "", - Long: "complete", - Args: "SHELL", - Description: "Output shell completion", - Default: "", - Values: []core.KeyVal[string]{ - {Key: "bash"}, - {Key: "fish"}, - {Key: "zsh"}, - }, - HideValues: true, - IsSet: func() bool { - return a.Complete != "" - }, - Fn: func(value string) error { - a.Complete = value - return nil - }, - }, - { - Short: "c", - Long: "config", - Args: "PATH", - Description: "Path to config file", - Default: "", - IsSet: func() bool { - return a.ConfigPath != "" - }, - Fn: func(value string) error { - a.ConfigPath = value - return nil - }, - }, - { - Short: "", - Long: "connect-timeout", - Args: "SECONDS", - Description: "Timeout for connection establishment", - Default: "", - IsSet: func() bool { - return a.Cfg.ConnectTimeout != nil - }, - Fn: func(value string) error { - return a.Cfg.ParseConnectTimeout(value) - }, - }, - { - Short: "", - Long: "copy", - Args: "", - Description: "Copy the response body to clipboard", - Default: "", - IsSet: func() bool { - return a.Cfg.Copy != nil - }, - Fn: func(value string) error { - v := true - a.Cfg.Copy = &v - return nil - }, + IsSet: func() bool { return a.Cfg.CertPath != "" }, + Fn: a.parseCertFlag, }, + + boolFlag(&a.Clobber, "clobber", "", "Overwrite existing output file"), + + cfgFlag("color", "", "OPTION", "Enable/disable color", + func() bool { return a.Cfg.Color != core.ColorUnknown }, a.Cfg.ParseColor). + WithAliases("colour"). + WithValues([]core.KeyVal[string]{ + {Key: "auto", Val: "Automatically determine color"}, + {Key: "off", Val: "Disable color output"}, + {Key: "on", Val: "Enable color output"}, + }), + + stringFlag(&a.Complete, "complete", "", "SHELL", "Output shell completion"). + WithValues([]core.KeyVal[string]{ + {Key: "bash"}, {Key: "fish"}, {Key: "zsh"}, + }). + WithHideValues(), + + stringFlag(&a.ConfigPath, "config", "c", "PATH", "Path to config file"), + + cfgFlag("connect-timeout", "", "SECONDS", "Timeout for connection establishment", + func() bool { return a.Cfg.ConnectTimeout != nil }, a.Cfg.ParseConnectTimeout), + + ptrBoolFlag(&a.Cfg.Copy, "copy", "", "Copy the response body to clipboard"), + + // Custom: data with content type detection { Short: "d", Long: "data", Args: "[@]VALUE", Description: "Send a request body", - Default: "", - IsSet: func() bool { - return a.dataSet - }, - Fn: func(value string) error { - r, path, err := RequestBody(value) - if err != nil { - return err - } - a.Data, a.ContentType, err = core.DetectContentType(r, path) - if err != nil { - return err - } - a.dataSet = true - return nil - }, - }, - { - Short: "", - Long: "discard", - Args: "", - Description: "Discard the response body", - Default: "", - IsSet: func() bool { - return a.Discard - }, - Fn: func(value string) error { - a.Discard = true - return nil - }, - }, - { - Short: "", - Long: "dns-server", - Args: "IP[:PORT]|URL", - Description: "DNS server IP or DoH URL", - Default: "", - IsSet: func() bool { - return a.Cfg.DNSServer != nil - }, - Fn: func(value string) error { - return a.Cfg.ParseDNSServer(value) - }, - }, - { - Short: "", - Long: "dry-run", - Args: "", - Description: "Print out the request info and exit", - Default: "", - IsSet: func() bool { - return a.DryRun - }, - Fn: func(value string) error { - a.DryRun = true - return nil - }, - }, - { - Short: "e", - Long: "edit", - Args: "", - Description: "Use an editor to modify the request body", - Default: "", - IsSet: func() bool { - return a.Edit - }, - Fn: func(value string) error { - a.Edit = true - return nil - }, + IsSet: func() bool { return a.dataSet }, + Fn: a.parseDataFlag, }, + + boolFlag(&a.Discard, "discard", "", "Discard the response body"), + + cfgFlag("dns-server", "", "IP[:PORT]|URL", "DNS server IP or DoH URL", + func() bool { return a.Cfg.DNSServer != nil }, a.Cfg.ParseDNSServer), + + boolFlag(&a.DryRun, "dry-run", "", "Print out the request info and exit"), + boolFlag(&a.Edit, "edit", "e", "Use an editor to modify the request body"), + + // Custom: form key=value parsing { Short: "f", Long: "form", Args: "KEY=VALUE", Description: "Send a urlencoded form body", - Default: "", - IsSet: func() bool { - return len(a.Form) > 0 - }, - Fn: func(value string) error { - key, val, _ := cut(value, "=") - a.Form = append(a.Form, core.KeyVal[string]{Key: key, Val: val}) - return nil - }, - }, - { - Short: "", - Long: "format", - Args: "OPTION", - Description: "Enable/disable formatting", - Default: "", - Values: []core.KeyVal[string]{ - { - Key: "auto", - Val: "Automatically determine whether to format", - }, - { - Key: "off", - Val: "Disable output formatting", - }, - { - Key: "on", - Val: "Enable output formatting", - }, - }, - IsSet: func() bool { - return a.Cfg.Format != core.FormatUnknown - }, - Fn: func(value string) error { - return a.Cfg.ParseFormat(value) - }, - }, - { - Short: "", - Long: "from-curl", - Args: "COMMAND", - Description: "Execute a curl command using fetch", - Default: "", - IsSet: func() bool { - return a.FromCurl != "" - }, - Fn: func(value string) error { - a.FromCurl = value - return nil - }, - }, - { - Short: "", - Long: "grpc", - Args: "", - Description: "Enable gRPC mode", - Default: "", - IsSet: func() bool { - return a.GRPC - }, - Fn: func(value string) error { - a.GRPC = true - return nil - }, - }, - { - Short: "H", - Long: "header", - Args: "NAME:VALUE", - Description: "Set headers for the request", - Default: "", - IsSet: func() bool { - return len(a.Cfg.Headers) > 0 - }, - Fn: func(value string) error { - return a.Cfg.ParseHeader(value) - }, - }, - { - Short: "h", - Long: "help", - Args: "", - Description: "Print help", - Default: "", - IsSet: func() bool { - return a.Help - }, - Fn: func(value string) error { - a.Help = true - return nil - }, - }, - { - Short: "", - Long: "http", - Args: "VERSION", - Description: "HTTP version to use", - Default: "", - Values: []core.KeyVal[string]{ - { - Key: "1", - Val: "HTTP/1.1", - }, - { - Key: "2", - Val: "HTTP/2.0", - }, - { - Key: "3", - Val: "HTTP/3.0", - }, - }, - IsSet: func() bool { - return a.Cfg.HTTP != core.HTTPDefault - }, - Fn: func(value string) error { - return a.Cfg.ParseHTTP(value) - }, - }, - { - Short: "", - Long: "ignore-status", - Args: "", - Description: "Exit code unaffected by HTTP status", - Default: "", - IsSet: func() bool { - return a.Cfg.IgnoreStatus != nil - }, - Fn: func(value string) error { - v := true - a.Cfg.IgnoreStatus = &v - return nil - }, - }, - { - Short: "", - Long: "image", - Args: "OPTION", - Description: "Image rendering", - Default: "", - Values: []core.KeyVal[string]{ - { - Key: "auto", - Val: "Automatically decide image display", - }, - { - Key: "native", - Val: "Only use builtin decoders", - }, - { - Key: "off", - Val: "Disable image display", - }, - }, - IsSet: func() bool { - return a.Cfg.Image != core.ImageUnknown - }, - Fn: func(value string) error { - return a.Cfg.ParseImageSetting(value) - }, - }, - { - Short: "", - Long: "insecure", - Args: "", - Description: "Accept invalid TLS certs (!)", - Default: "", - IsSet: func() bool { - return a.Cfg.Insecure != nil - }, - Fn: func(value string) error { - v := true - a.Cfg.Insecure = &v - return nil - }, - }, - { - Short: "", - Long: "inspect-tls", - Args: "", - Description: "Inspect the TLS certificate chain", - Default: "", - IsSet: func() bool { - return a.InspectTLS - }, - Fn: func(value string) error { - a.InspectTLS = true - return nil - }, + IsSet: func() bool { return len(a.Form) > 0 }, + Fn: a.parseFormFlag, }, + + cfgFlag("format", "", "OPTION", "Enable/disable formatting", + func() bool { return a.Cfg.Format != core.FormatUnknown }, a.Cfg.ParseFormat). + WithValues([]core.KeyVal[string]{ + {Key: "auto", Val: "Automatically determine whether to format"}, + {Key: "off", Val: "Disable output formatting"}, + {Key: "on", Val: "Enable output formatting"}, + }), + + stringFlag(&a.FromCurl, "from-curl", "", "COMMAND", "Execute a curl command using fetch"), + boolFlag(&a.GRPC, "grpc", "", "Enable gRPC mode"), + + cfgFlag("header", "H", "NAME:VALUE", "Set headers for the request", + func() bool { return len(a.Cfg.Headers) > 0 }, a.Cfg.ParseHeader), + + boolFlag(&a.Help, "help", "h", "Print help"), + + cfgFlag("http", "", "VERSION", "HTTP version to use", + func() bool { return a.Cfg.HTTP != core.HTTPDefault }, a.Cfg.ParseHTTP). + WithValues([]core.KeyVal[string]{ + {Key: "1", Val: "HTTP/1.1"}, + {Key: "2", Val: "HTTP/2.0"}, + {Key: "3", Val: "HTTP/3.0"}, + }), + + ptrBoolFlag(&a.Cfg.IgnoreStatus, "ignore-status", "", "Exit code unaffected by HTTP status"), + + cfgFlag("image", "", "OPTION", "Image rendering", + func() bool { return a.Cfg.Image != core.ImageUnknown }, a.Cfg.ParseImageSetting). + WithValues([]core.KeyVal[string]{ + {Key: "auto", Val: "Automatically decide image display"}, + {Key: "native", Val: "Only use builtin decoders"}, + {Key: "off", Val: "Disable image display"}, + }), + + ptrBoolFlag(&a.Cfg.Insecure, "insecure", "", "Accept invalid TLS certs (!)"), + boolFlag(&a.InspectTLS, "inspect-tls", "", "Inspect the TLS certificate chain"), + + // Custom: JSON body { Short: "j", Long: "json", Args: "[@]VALUE", Description: "Send a JSON request body", - Default: "", - IsSet: func() bool { - return a.jsonSet - }, - Fn: func(value string) error { - r, _, err := RequestBody(value) - if err != nil { - return err - } - a.Data = r - a.ContentType = "application/json" - a.jsonSet = true - return nil - }, + IsSet: func() bool { return a.jsonSet }, + Fn: a.parseJSONFlag, }, + + // Custom: file check + config parse { - Short: "", Long: "key", Args: "PATH", Description: "Client private key for mTLS", - Default: "", - IsSet: func() bool { - return a.Cfg.KeyPath != "" - }, - Fn: func(value string) error { - if err := checkFileExists(value); err != nil { - return err - } - return a.Cfg.ParseKey(value) - }, - }, - { - Short: "m", - Long: "method", - Aliases: []string{"X"}, - Args: "METHOD", - Description: "HTTP method to use", - Default: "GET", - IsSet: func() bool { - return a.Method != "" - }, - Fn: func(value string) error { - a.Method = value - return nil - }, + IsSet: func() bool { return a.Cfg.KeyPath != "" }, + Fn: a.parseKeyFlag, }, + + stringFlag(&a.Method, "method", "m", "METHOD", "HTTP method to use"). + WithAliases("X"). + WithDefault("GET"), + + // Custom: multipart with file validation { Short: "F", Long: "multipart", Args: "NAME=[@]VALUE", Description: "Send a multipart form body", - Default: "", - IsSet: func() bool { - return len(a.Multipart) > 0 - }, - Fn: func(value string) error { - key, val, _ := cut(value, "=") - if strings.HasPrefix(val, "@") { - path := val[1:] - - // Expand '~' to the home directory. - if len(path) >= 2 && path[0] == '~' && path[1] == os.PathSeparator { - home, err := os.UserHomeDir() - if err != nil { - return err - } - path = home + path[1:] - val = "@" + path - } - - // Ensure the file exists. - stats, err := os.Stat(path) - if err != nil { - if os.IsNotExist(err) { - return fmt.Errorf("file does not exist: '%s'", path) - } - return err - } - if stats.IsDir() { - return fmt.Errorf("file is a directory: '%s'", path) - } - } - a.Multipart = append(a.Multipart, core.KeyVal[string]{Key: key, Val: val}) - return nil - }, - }, - { - Short: "", - Long: "no-encode", - Args: "", - Description: "Avoid requesting gzip/zstd encoding", - Default: "", - IsSet: func() bool { - return a.Cfg.NoEncode != nil - }, - Fn: func(value string) error { - v := true - a.Cfg.NoEncode = &v - return nil - }, - }, - { - Short: "", - Long: "no-pager", - Args: "", - Description: "Avoid using a pager for the output", - Default: "", - IsSet: func() bool { - return a.Cfg.NoPager != nil - }, - Fn: func(value string) error { - v := true - a.Cfg.NoPager = &v - return nil - }, - }, - { - Short: "o", - Long: "output", - Args: "PATH", - Description: "Write the response body to a file", - Default: "", - IsSet: func() bool { - return a.Output != "" - }, - Fn: func(value string) error { - a.Output = value - return nil - }, + IsSet: func() bool { return len(a.Multipart) > 0 }, + Fn: a.parseMultipartFlag, }, + + ptrBoolFlag(&a.Cfg.NoEncode, "no-encode", "", "Avoid requesting gzip/zstd encoding"), + ptrBoolFlag(&a.Cfg.NoPager, "no-pager", "", "Avoid using a pager for the output"), + stringFlag(&a.Output, "output", "o", "PATH", "Write the response body to a file"), + + // Custom: proto flags with file validation { - Short: "", Long: "proto-desc", Args: "PATH", Description: "Pre-compiled descriptor set file", - Default: "", - IsSet: func() bool { - return a.ProtoDesc != "" - }, - Fn: func(value string) error { - a.ProtoDesc = value - return checkFileExists(value) - }, + IsSet: func() bool { return a.ProtoDesc != "" }, + Fn: a.parseProtoDescFlag, }, { - Short: "", Long: "proto-file", Args: "PATH", Description: "Compile .proto file(s) via protoc", - Default: "", - IsSet: func() bool { - return len(a.ProtoFiles) > 0 - }, - Fn: func(value string) error { - // Support comma-separated paths. - for p := range strings.SplitSeq(value, ",") { - p = strings.TrimSpace(p) - if p == "" { - continue - } - err := checkFileExists(p) - if err != nil { - return err - } - a.ProtoFiles = append(a.ProtoFiles, p) - } - return nil - }, + IsSet: func() bool { return len(a.ProtoFiles) > 0 }, + Fn: a.parseProtoFileFlag, }, { - Short: "", Long: "proto-import", Args: "PATH", Description: "Import path for proto compilation", - Default: "", - IsSet: func() bool { - return len(a.ProtoImports) > 0 - }, - Fn: func(value string) error { - a.ProtoImports = append(a.ProtoImports, value) - return checkFileExists(value) - }, - }, - { - Short: "", - Long: "proxy", - Args: "PROXY", - Description: "Configure a proxy", - Default: "", - IsSet: func() bool { - return a.Cfg.Proxy != nil - }, - Fn: func(value string) error { - return a.Cfg.ParseProxy(value) - }, - }, - { - Short: "q", - Long: "query", - Args: "KEY=VALUE", - Description: "Append query parameters to the url", - Default: "", - IsSet: func() bool { - return len(a.Cfg.QueryParams) > 0 - }, - Fn: func(value string) error { - return a.Cfg.ParseQuery(value) - }, + IsSet: func() bool { return len(a.ProtoImports) > 0 }, + Fn: a.parseProtoImportFlag, }, + + cfgFlag("proxy", "", "PROXY", "Configure a proxy", + func() bool { return a.Cfg.Proxy != nil }, a.Cfg.ParseProxy), + + cfgFlag("query", "q", "KEY=VALUE", "Append query parameters to the url", + func() bool { return len(a.Cfg.QueryParams) > 0 }, a.Cfg.ParseQuery), + + // Custom: range parsing with validation { Short: "r", Long: "range", Args: "RANGE", Description: "Request a specific byte range", - Default: "", - IsSet: func() bool { - return len(a.Range) > 0 - }, - Fn: func(value string) error { - value = strings.TrimSpace(value) - start, end, ok := strings.Cut(value, "-") - start = strings.TrimSpace(start) - end = strings.TrimSpace(end) - if !ok || (start == "" && end == "") { - const usage = "invalid byte range" - return core.NewValueError("range", value, usage, false) - } - if !isValidRangeValue(start) { - usage := fmt.Sprintf("invalid range start '%s'", start) - return core.NewValueError("range", value, usage, false) - } - if !isValidRangeValue(end) { - usage := fmt.Sprintf("invalid range end '%s'", end) - return core.NewValueError("range", value, usage, false) - } - - a.Range = append(a.Range, start+"-"+end) - return nil - }, - }, - { - Short: "", - Long: "redirects", - Args: "NUM", - Description: "Maximum number of redirects", - Default: "", - IsSet: func() bool { - return a.Cfg.Redirects != nil - }, - Fn: func(value string) error { - return a.Cfg.ParseRedirects(value) - }, - }, - { - Short: "J", - Long: "remote-header-name", - Args: "", - Description: "Use content-disposition header filename", - Default: "", - IsSet: func() bool { - return a.RemoteHeaderName - }, - Fn: func(value string) error { - a.RemoteHeaderName = true - return nil - }, - }, - { - Short: "O", - Long: "remote-name", - Aliases: []string{"output-current-dir"}, - Args: "", - Description: "Use URL path component as output filename", - Default: "", - IsSet: func() bool { - return a.RemoteName - }, - Fn: func(value string) error { - a.RemoteName = true - return nil - }, - }, - { - Short: "", - Long: "retry", - Args: "NUM", - Description: "Maximum number of retries", - Default: "0", - IsSet: func() bool { - return a.Cfg.Retry != nil - }, - Fn: func(value string) error { - return a.Cfg.ParseRetry(value) - }, - }, - { - Short: "", - Long: "retry-delay", - Args: "SECONDS", - Description: "Initial delay between retries", - Default: "1", - IsSet: func() bool { - return a.Cfg.RetryDelay != nil - }, - Fn: func(value string) error { - return a.Cfg.ParseRetryDelay(value) - }, - }, - { - Short: "S", - Long: "session", - Args: "NAME", - Description: "Use a named session for cookies", - Default: "", - IsSet: func() bool { - return a.Cfg.Session != nil - }, - Fn: func(value string) error { - return a.Cfg.ParseSession(value) - }, - }, - { - Short: "s", - Long: "silent", - Args: "", - Description: "Print only errors to stderr", - Default: "", - IsSet: func() bool { - return a.Cfg.Silent != nil - }, - Fn: func(value string) error { - v := true - a.Cfg.Silent = &v - return nil - }, - }, - { - Short: "t", - Long: "timeout", - Args: "SECONDS", - Description: "Timeout applied to the request", - Default: "", - IsSet: func() bool { - return a.Cfg.Timeout != nil - }, - Fn: func(value string) error { - return a.Cfg.ParseTimeout(value) - }, - }, - { - Short: "T", - Long: "timing", - Args: "", - Description: "Display a timing waterfall chart", - Default: "", - IsSet: func() bool { - return a.Cfg.Timing != nil - }, - Fn: func(value string) error { - v := true - a.Cfg.Timing = &v - return nil - }, - }, - { - Short: "", - Long: "tls", - Args: "VERSION", - Description: "Minimum TLS version", - Default: "", - Values: []core.KeyVal[string]{ - { - Key: "1.0", - Val: "TLS v1.0", - }, - { - Key: "1.1", - Val: "TLS v1.1", - }, - { - Key: "1.2", - Val: "TLS v1.2", - }, - { - Key: "1.3", - Val: "TLS v1.3", - }, - }, - IsSet: func() bool { - return a.Cfg.TLS != nil - }, - Fn: func(value string) error { - return a.Cfg.ParseTLS(value) - }, - }, - { - Short: "", - Long: "unix", - Args: "PATH", - Description: "Make the request over a unix socket", - Default: "", - OS: unixOS, - IsSet: func() bool { - return a.UnixSocket != "" - }, - Fn: func(value string) error { - a.UnixSocket = value - return nil - }, - }, - { - Short: "", - Long: "update", - Args: "", - IsHidden: core.NoSelfUpdate, - Description: "Update the fetch binary in place", - Default: "", - IsSet: func() bool { - return a.Update - }, - Fn: func(value string) error { - a.Update = true - return nil - }, + IsSet: func() bool { return len(a.Range) > 0 }, + Fn: a.parseRangeFlag, }, + + cfgFlag("redirects", "", "NUM", "Maximum number of redirects", + func() bool { return a.Cfg.Redirects != nil }, a.Cfg.ParseRedirects), + + boolFlag(&a.RemoteHeaderName, "remote-header-name", "J", "Use content-disposition header filename"), + boolFlag(&a.RemoteName, "remote-name", "O", "Use URL path component as output filename"). + WithAliases("output-current-dir"), + + cfgFlag("retry", "", "NUM", "Maximum number of retries", + func() bool { return a.Cfg.Retry != nil }, a.Cfg.ParseRetry). + WithDefault("0"), + cfgFlag("retry-delay", "", "SECONDS", "Initial delay between retries", + func() bool { return a.Cfg.RetryDelay != nil }, a.Cfg.ParseRetryDelay). + WithDefault("1"), + + cfgFlag("session", "S", "NAME", "Use a named session for cookies", + func() bool { return a.Cfg.Session != nil }, a.Cfg.ParseSession), + + ptrBoolFlag(&a.Cfg.Silent, "silent", "s", "Print only errors to stderr"), + + cfgFlag("timeout", "t", "SECONDS", "Timeout applied to the request", + func() bool { return a.Cfg.Timeout != nil }, a.Cfg.ParseTimeout), + + ptrBoolFlag(&a.Cfg.Timing, "timing", "T", "Display a timing waterfall chart"), + + cfgFlag("tls", "", "VERSION", "Minimum TLS version", + func() bool { return a.Cfg.TLS != nil }, a.Cfg.ParseTLS). + WithValues([]core.KeyVal[string]{ + {Key: "1.0", Val: "TLS v1.0"}, + {Key: "1.1", Val: "TLS v1.1"}, + {Key: "1.2", Val: "TLS v1.2"}, + {Key: "1.3", Val: "TLS v1.3"}, + }), + + stringFlag(&a.UnixSocket, "unix", "", "PATH", "Make the request over a unix socket"). + WithOS(unixOS), + + boolFlag(&a.Update, "update", "", "Update the fetch binary in place"). + WithHidden(core.NoSelfUpdate), + + // Custom: verbose increments verbosity { Short: "v", Long: "verbose", - Args: "", Description: "Verbosity of the output", - Default: "", - IsSet: func() bool { - return a.Cfg.Verbosity != nil - }, - Fn: func(value string) error { - if a.Cfg.Verbosity == nil { - a.Cfg.Verbosity = core.PointerTo(1) - } else { - (*a.Cfg.Verbosity)++ - } - return nil - }, - }, - { - Short: "V", - Long: "version", - Args: "", - Description: "Print version", - Default: "", - IsSet: func() bool { - return a.Version - }, - Fn: func(value string) error { - a.Version = true - return nil - }, + IsSet: func() bool { return a.Cfg.Verbosity != nil }, + Fn: a.parseVerboseFlag, }, + + boolFlag(&a.Version, "version", "V", "Print version"), + + // Custom: XML body { Short: "x", Long: "xml", Args: "[@]VALUE", Description: "Send an XML request body", - Default: "", - IsSet: func() bool { - return a.xmlSet - }, - Fn: func(value string) error { - r, _, err := RequestBody(value) - if err != nil { - return err - } - a.Data = r - a.ContentType = "application/xml" - a.xmlSet = true - return nil - }, + IsSet: func() bool { return a.xmlSet }, + Fn: a.parseXMLFlag, }, }, } } +func (a *App) parseAWSSigv4Flag(value string) error { + region, service, ok := core.CutTrimmed(value, "/") + if !ok { + const usage = "format must be " + return core.NewValueError("aws-sigv4", value, usage, false) + } + cfg, err := buildAWSConfig(region, service) + if err != nil { + return err + } + a.AWSSigv4 = cfg + return nil +} + +func (a *App) parseBasicFlag(value string) error { + user, pass, ok := core.CutTrimmed(value, ":") + if !ok { + const usage = "format must be " + return core.NewValueError("basic", value, usage, false) + } + a.Basic = &core.KeyVal[string]{Key: user, Val: pass} + return nil +} + +func (a *App) parseCertFlag(value string) error { + if err := checkFileExists(value); err != nil { + return err + } + return a.Cfg.ParseCert(value) +} + +func (a *App) parseDataFlag(value string) error { + r, path, err := RequestBody(value) + if err != nil { + return err + } + a.Data, a.ContentType, err = core.DetectContentType(r, path) + if err != nil { + return err + } + a.dataSet = true + return nil +} + +func (a *App) parseFormFlag(value string) error { + key, val, _ := core.CutTrimmed(value, "=") + a.Form = append(a.Form, core.KeyVal[string]{Key: key, Val: val}) + return nil +} + +func (a *App) parseJSONFlag(value string) error { + r, _, err := RequestBody(value) + if err != nil { + return err + } + a.Data = r + a.ContentType = "application/json" + a.jsonSet = true + return nil +} + +func (a *App) parseKeyFlag(value string) error { + if err := checkFileExists(value); err != nil { + return err + } + return a.Cfg.ParseKey(value) +} + +func (a *App) parseMultipartFlag(value string) error { + key, val, _ := core.CutTrimmed(value, "=") + if strings.HasPrefix(val, "@") { + path := val[1:] + + // Expand '~' to the home directory. + if len(path) >= 2 && path[0] == '~' && path[1] == os.PathSeparator { + home, err := os.UserHomeDir() + if err != nil { + return err + } + path = home + path[1:] + val = "@" + path + } + + // Ensure the file exists. + stats, err := os.Stat(path) + if err != nil { + if os.IsNotExist(err) { + return fmt.Errorf("file does not exist: '%s'", path) + } + return err + } + if stats.IsDir() { + return fmt.Errorf("file is a directory: '%s'", path) + } + } + a.Multipart = append(a.Multipart, core.KeyVal[string]{Key: key, Val: val}) + return nil +} + +func (a *App) parseProtoDescFlag(value string) error { + a.ProtoDesc = value + return checkFileExists(value) +} + +func (a *App) parseProtoFileFlag(value string) error { + // Support comma-separated paths. + for p := range strings.SplitSeq(value, ",") { + p = strings.TrimSpace(p) + if p == "" { + continue + } + err := checkFileExists(p) + if err != nil { + return err + } + a.ProtoFiles = append(a.ProtoFiles, p) + } + return nil +} + +func (a *App) parseProtoImportFlag(value string) error { + a.ProtoImports = append(a.ProtoImports, value) + return checkFileExists(value) +} + +func (a *App) parseRangeFlag(value string) error { + value = strings.TrimSpace(value) + start, end, ok := strings.Cut(value, "-") + start = strings.TrimSpace(start) + end = strings.TrimSpace(end) + if !ok || (start == "" && end == "") { + const usage = "invalid byte range" + return core.NewValueError("range", value, usage, false) + } + if !isValidRangeValue(start) { + usage := fmt.Sprintf("invalid range start '%s'", start) + return core.NewValueError("range", value, usage, false) + } + if !isValidRangeValue(end) { + usage := fmt.Sprintf("invalid range end '%s'", end) + return core.NewValueError("range", value, usage, false) + } + a.Range = append(a.Range, start+"-"+end) + return nil +} + +func (a *App) parseVerboseFlag(string) error { + if a.Cfg.Verbosity == nil { + a.Cfg.Verbosity = core.PointerTo(1) + } else { + (*a.Cfg.Verbosity)++ + } + return nil +} + +func (a *App) parseXMLFlag(value string) error { + r, _, err := RequestBody(value) + if err != nil { + return err + } + a.Data = r + a.ContentType = "application/xml" + a.xmlSet = true + return nil +} + +// buildAWSConfig creates an AWS configuration from region and service, +// reading credentials from environment variables. +func buildAWSConfig(region, service string) (*aws.Config, error) { + accessKey := os.Getenv("AWS_ACCESS_KEY_ID") + if accessKey == "" { + return nil, missingEnvVarErr("AWS_ACCESS_KEY_ID", "aws-sigv4") + } + secretKey := os.Getenv("AWS_SECRET_ACCESS_KEY") + if secretKey == "" { + return nil, missingEnvVarErr("AWS_SECRET_ACCESS_KEY", "aws-sigv4") + } + return &aws.Config{ + Region: region, + Service: service, + AccessKey: accessKey, + SecretKey: secretKey, + }, nil +} + +// parseURL normalizes a raw URL string: adds "//" when the scheme is +// omitted, rewrites ws/wss schemes to http/https, and validates the scheme. +// It returns the parsed URL, whether it was a WebSocket URL, and any error. +func parseURL(rawURL string) (*url.URL, bool, error) { + if rawURL == "" { + return nil, false, fmt.Errorf("empty URL provided") + } + + // For URLs that have the scheme omitted, add two + // slashes so it can be parsed correctly. + if !strings.Contains(rawURL, "://") && rawURL[0] != '/' { + rawURL = "//" + rawURL + } + + u, err := url.Parse(rawURL) + if err != nil { + return nil, false, fmt.Errorf("invalid url: %w", err) + } + + // Lowercase the scheme, and validate. + var isWS bool + u.Scheme = strings.ToLower(u.Scheme) + switch u.Scheme { + case "", "http", "https": + case "ws": + u.Scheme = "http" + isWS = true + case "wss": + u.Scheme = "https" + isWS = true + default: + return nil, false, fmt.Errorf("unsupported url scheme: %s", u.Scheme) + } + return u, isWS, nil +} + func RequestBody(value string) (io.Reader, string, error) { switch { case len(value) == 0 || value[0] != '@': @@ -1162,13 +656,6 @@ func checkFileExists(value string) error { return err } -func cut(s, sep string) (string, string, bool) { - key, val, ok := strings.Cut(s, sep) - key = strings.TrimSpace(key) - val = strings.TrimSpace(val) - return key, val, ok -} - func isValidRangeValue(value string) bool { if value == "" { return true @@ -1176,48 +663,3 @@ func isValidRangeValue(value string) bool { _, err := strconv.Atoi(value) return err == nil } - -type MissingEnvVarError struct { - EnvVar string - Flag string -} - -type fileIsDirError string - -func (err fileIsDirError) Error() string { - return fmt.Sprintf("file '%s' is a directory", string(err)) -} - -func (err fileIsDirError) PrintTo(p *core.Printer) { - p.WriteString("file '") - p.Set(core.Dim) - p.WriteString(string(err)) - p.Reset() - p.WriteString("' is a directory") -} - -func missingEnvVarErr(envVar, flag string) *MissingEnvVarError { - return &MissingEnvVarError{ - EnvVar: envVar, - Flag: flag, - } -} - -func (err *MissingEnvVarError) Error() string { - return fmt.Sprintf("missing environment variable '%s' required for option '--%s'", err.EnvVar, err.Flag) -} - -func (err *MissingEnvVarError) PrintTo(p *core.Printer) { - p.WriteString("missing environment variable '") - p.Set(core.Yellow) - p.WriteString(err.EnvVar) - p.Reset() - - p.WriteString("' required for option '") - p.Set(core.Bold) - p.WriteString("--") - p.WriteString(err.Flag) - p.Reset() - - p.WriteString("'") -} diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 762e8b7..11bc1e2 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -11,7 +11,6 @@ import ( "strings" "time" - "github.com/ryanfowler/fetch/internal/aws" "github.com/ryanfowler/fetch/internal/core" "github.com/ryanfowler/fetch/internal/curl" ) @@ -25,6 +24,14 @@ type CLI struct { Flags []Flag ExclusiveFlags [][]string RequiredFlags []core.KeyVal[[]string] + + // SchemeExclusiveFlags maps URL schemes (e.g. "ws", "wss") to flags + // that cannot be used with that scheme. + SchemeExclusiveFlags map[string][]string + + // FromCurlExclusiveFlags lists flags that cannot be used alongside + // --from-curl. + FromCurlExclusiveFlags []string } type Arguments struct { @@ -47,7 +54,9 @@ type Flag struct { Fn func(value string) error } -func parse(cli *CLI, args []string) error { +// parseWithFlags parses the CLI arguments and returns the long flag map for +// use in post-parse validation. +func parseWithFlags(cli *CLI, args []string) (map[string]Flag, error) { short := make(map[string]Flag) long := make(map[string]Flag) for _, flag := range cli.Flags { @@ -75,13 +84,6 @@ func parse(cli *CLI, args []string) error { } } - exclusives := make(map[string][][]string) - for _, fs := range cli.ExclusiveFlags { - for _, f := range fs { - exclusives[f] = append(exclusives[f], fs) - } - } - var err error for len(args) > 0 { arg := args[0] @@ -91,7 +93,7 @@ func parse(cli *CLI, args []string) error { if len(arg) <= 1 || arg[0] != '-' { err = cli.ArgFn(arg) if err != nil { - return err + return nil, err } continue } @@ -100,7 +102,7 @@ func parse(cli *CLI, args []string) error { if arg[1] != '-' { args, err = parseShortFlag(arg, args, short) if err != nil { - return err + return nil, err } continue } @@ -109,7 +111,7 @@ func parse(cli *CLI, args []string) error { if len(arg) > 2 { args, err = parseLongFlag(arg, args, long) if err != nil { - return err + return nil, err } continue } @@ -117,12 +119,12 @@ func parse(cli *CLI, args []string) error { // "--" means consider everything else arguments. err = cli.ArgFn("--") if err != nil { - return err + return nil, err } for _, arg := range args { err = cli.ArgFn(arg) if err != nil { - return err + return nil, err } } break @@ -132,7 +134,7 @@ func parse(cli *CLI, args []string) error { for _, exc := range cli.ExclusiveFlags { err = validateExclusives(exc, long) if err != nil { - return err + return nil, err } } @@ -140,11 +142,11 @@ func parse(cli *CLI, args []string) error { for _, req := range cli.RequiredFlags { err = validateRequired(req, long) if err != nil { - return err + return nil, err } } - return nil + return long, nil } func parseShortFlag(arg string, args []string, short map[string]Flag) ([]string, error) { @@ -261,13 +263,13 @@ func Parse(args []string) (*App, error) { var app App cli := app.CLI() - err := parse(cli, args) + long, err := parseWithFlags(cli, args) if err != nil { return &app, err } if app.FromCurl != "" { - if err := app.validateFromCurlExclusives(); err != nil { + if err := validateFromCurlExclusives(&app, cli, long); err != nil { return &app, err } result, err := curl.Parse(app.FromCurl) @@ -279,43 +281,31 @@ func Parse(args []string) (*App, error) { } } - if err := app.validateWSExclusives(); err != nil { + if err := validateSchemeExclusives(&app, cli, long); err != nil { return &app, err } return &app, nil } -// validateWSExclusives checks that ws:// / wss:// scheme is not combined -// with incompatible flags. -func (a *App) validateWSExclusives() error { - if !a.WS { +// validateSchemeExclusives checks that scheme-specific exclusive flags +// (e.g. ws:// / wss:// flags) are not combined with incompatible flags. +func validateSchemeExclusives(app *App, cli *CLI, long map[string]Flag) error { + if !app.WS { return nil } - type flagCheck struct { - name string - isSet bool - } - conflicts := []flagCheck{ - {"discard", a.Discard}, - {"grpc", a.GRPC}, - {"form", len(a.Form) > 0}, - {"multipart", len(a.Multipart) > 0}, - {"xml", a.xmlSet}, - {"edit", a.Edit}, - } - // The URL scheme was rewritten from ws->http / wss->https during // parsing, so reverse the mapping for the error message. scheme := "ws" - if a.URL != nil && a.URL.Scheme == "https" { + if app.URL != nil && app.URL.Scheme == "https" { scheme = "wss" } - for _, c := range conflicts { - if c.isSet { - return schemeExclusiveError{scheme: scheme, flag: c.name} + exclusives := cli.SchemeExclusiveFlags[scheme] + for _, name := range exclusives { + if flag, ok := long[name]; ok && flag.IsSet() { + return schemeExclusiveError{scheme: scheme, flag: name} } } return nil @@ -462,51 +452,14 @@ func assertFlagNotExists(m map[string]Flag, value string) { // validateFromCurlExclusives checks that no request-specifying flags are used // alongside --from-curl. -func (a *App) validateFromCurlExclusives() error { - type flagCheck struct { - name string - isSet bool - } - conflicts := []flagCheck{ - {"method", a.Method != ""}, - {"header", len(a.Cfg.Headers) > 0}, - {"data", a.dataSet}, - {"json", a.jsonSet}, - {"xml", a.xmlSet}, - {"form", len(a.Form) > 0}, - {"multipart", len(a.Multipart) > 0}, - {"basic", a.Basic != nil}, - {"bearer", a.Bearer != ""}, - {"aws-sigv4", a.AWSSigv4 != nil}, - {"output", a.Output != ""}, - {"remote-name", a.RemoteName}, - {"remote-header-name", a.RemoteHeaderName}, - {"range", len(a.Range) > 0}, - {"unix", a.UnixSocket != ""}, - {"timeout", a.Cfg.Timeout != nil}, - {"connect-timeout", a.Cfg.ConnectTimeout != nil}, - {"redirects", a.Cfg.Redirects != nil}, - {"proxy", a.Cfg.Proxy != nil}, - {"insecure", a.Cfg.Insecure != nil}, - {"tls", a.Cfg.TLS != nil}, - {"http", a.Cfg.HTTP != core.HTTPDefault}, - {"cert", a.Cfg.CertPath != ""}, - {"key", a.Cfg.KeyPath != ""}, - {"ca-cert", len(a.Cfg.CACerts) > 0}, - {"dns-server", a.Cfg.DNSServer != nil}, - {"retry", a.Cfg.Retry != nil}, - {"retry-delay", a.Cfg.RetryDelay != nil}, - {"grpc", a.GRPC}, - {"query", len(a.Cfg.QueryParams) > 0}, - } - - if a.URL != nil { +func validateFromCurlExclusives(app *App, cli *CLI, long map[string]Flag) error { + if app.URL != nil { return fromCurlExclusiveError{flag: "URL", positional: true} } - for _, c := range conflicts { - if c.isSet { - return fromCurlExclusiveError{flag: c.name} + for _, name := range cli.FromCurlExclusiveFlags { + if flag, ok := long[name]; ok && flag.IsSet() { + return fromCurlExclusiveError{flag: name} } } return nil @@ -514,30 +467,15 @@ func (a *App) validateFromCurlExclusives() error { // applyFromCurl maps a parsed curl Result onto the App fields. func (a *App) applyFromCurl(r *curl.Result) error { - // Parse the URL using the same normalization logic as ArgFn. - rawURL := r.URL - if rawURL == "" { + // Parse the URL using the shared normalization logic. + if r.URL == "" { return fmt.Errorf("no URL provided") } - if !strings.Contains(rawURL, "://") && rawURL[0] != '/' { - rawURL = "//" + rawURL - } - u, err := url.Parse(rawURL) + u, isWS, err := parseURL(r.URL) if err != nil { - return fmt.Errorf("invalid url: %w", err) - } - u.Scheme = strings.ToLower(u.Scheme) - switch u.Scheme { - case "", "http", "https": - case "ws": - u.Scheme = "http" - a.WS = true - case "wss": - u.Scheme = "https" - a.WS = true - default: - return fmt.Errorf("unsupported url scheme: %s", u.Scheme) + return err } + a.WS = a.WS || isWS // Apply --proto restrictions. if r.AllowedProto != "" { @@ -667,20 +605,11 @@ func (a *App) applyFromCurl(r *curl.Result) error { if err != nil { return err } - accessKey := os.Getenv("AWS_ACCESS_KEY_ID") - if accessKey == "" { - return missingEnvVarErr("AWS_ACCESS_KEY_ID", "aws-sigv4") - } - secretKey := os.Getenv("AWS_SECRET_ACCESS_KEY") - if secretKey == "" { - return missingEnvVarErr("AWS_SECRET_ACCESS_KEY", "aws-sigv4") - } - a.AWSSigv4 = &aws.Config{ - Region: region, - Service: service, - AccessKey: accessKey, - SecretKey: secretKey, + cfg, err := buildAWSConfig(region, service) + if err != nil { + return err } + a.AWSSigv4 = cfg } // Output. diff --git a/internal/cli/errors.go b/internal/cli/errors.go index 1205953..d43cd3d 100644 --- a/internal/cli/errors.go +++ b/internal/cli/errors.go @@ -130,6 +130,53 @@ func (err fromCurlExclusiveError) PrintTo(p *core.Printer) { } } +type fileIsDirError string + +func (err fileIsDirError) Error() string { + return fmt.Sprintf("file '%s' is a directory", string(err)) +} + +func (err fileIsDirError) PrintTo(p *core.Printer) { + p.WriteString("file '") + p.Set(core.Dim) + p.WriteString(string(err)) + p.Reset() + p.WriteString("' is a directory") +} + +// MissingEnvVarError is returned when a required environment variable is not +// set for a given flag. +type MissingEnvVarError struct { + EnvVar string + Flag string +} + +func missingEnvVarErr(envVar, flag string) *MissingEnvVarError { + return &MissingEnvVarError{ + EnvVar: envVar, + Flag: flag, + } +} + +func (err *MissingEnvVarError) Error() string { + return fmt.Sprintf("missing environment variable '%s' required for option '--%s'", err.EnvVar, err.Flag) +} + +func (err *MissingEnvVarError) PrintTo(p *core.Printer) { + p.WriteString("missing environment variable '") + p.Set(core.Yellow) + p.WriteString(err.EnvVar) + p.Reset() + + p.WriteString("' required for option '") + p.Set(core.Bold) + p.WriteString("--") + p.WriteString(err.Flag) + p.Reset() + + p.WriteString("'") +} + type requiredFlagError struct { flag string required []string diff --git a/internal/cli/flags.go b/internal/cli/flags.go new file mode 100644 index 0000000..5000286 --- /dev/null +++ b/internal/cli/flags.go @@ -0,0 +1,100 @@ +package cli + +import "github.com/ryanfowler/fetch/internal/core" + +// boolFlag creates a Flag that sets a bool to true when present. +func boolFlag(target *bool, long, short, desc string) Flag { + return Flag{ + Long: long, + Short: short, + Description: desc, + IsSet: func() bool { + return *target + }, + Fn: func(string) error { + *target = true + return nil + }, + } +} + +// ptrBoolFlag creates a Flag that sets a *bool pointer to &true when present. +func ptrBoolFlag(target **bool, long, short, desc string) Flag { + return Flag{ + Long: long, + Short: short, + Description: desc, + IsSet: func() bool { + return *target != nil + }, + Fn: func(string) error { + *target = core.PointerTo(true) + return nil + }, + } +} + +// stringFlag creates a Flag that stores a string value. +func stringFlag(target *string, long, short, args, desc string) Flag { + return Flag{ + Long: long, + Short: short, + Args: args, + Description: desc, + IsSet: func() bool { + return *target != "" + }, + Fn: func(value string) error { + *target = value + return nil + }, + } +} + +// cfgFlag creates a Flag that delegates to an isSet check and a parse function. +func cfgFlag(long, short, args, desc string, isSet func() bool, parse func(string) error) Flag { + return Flag{ + Long: long, + Short: short, + Args: args, + Description: desc, + IsSet: isSet, + Fn: parse, + } +} + +// WithAliases adds aliases to the Flag. +func (f Flag) WithAliases(aliases ...string) Flag { + f.Aliases = aliases + return f +} + +// WithValues sets the accepted values for the Flag. +func (f Flag) WithValues(values []core.KeyVal[string]) Flag { + f.Values = values + return f +} + +// WithHideValues hides the accepted values from help output. +func (f Flag) WithHideValues() Flag { + f.HideValues = true + return f +} + +// WithDefault sets the default value shown in help. +func (f Flag) WithDefault(def string) Flag { + f.Default = def + return f +} + +// WithHidden marks the flag as hidden from help output. +func (f Flag) WithHidden(hidden bool) Flag { + f.IsHidden = hidden + return f +} + +// WithOS restricts the flag to specific operating systems. +func (f Flag) WithOS(os []string) Flag { + f.OS = os + return f +} diff --git a/internal/config/config.go b/internal/config/config.go index 44f9561..dd52ca6 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -359,7 +359,7 @@ func (c *Config) ParseFormat(value string) error { } func (c *Config) ParseHeader(value string) error { - key, val, _ := cut(value, ":") + key, val, _ := core.CutTrimmed(value, ":") c.Headers = append(c.Headers, core.KeyVal[string]{Key: key, Val: val}) return nil @@ -471,7 +471,7 @@ func (c *Config) ParseProxy(value string) error { } func (c *Config) ParseQuery(value string) error { - key, val, _ := cut(value, "=") + key, val, _ := core.CutTrimmed(value, "=") c.QueryParams = append(c.QueryParams, core.KeyVal[string]{Key: key, Val: val}) return nil } @@ -595,12 +595,6 @@ func (c *Config) ClientCert() (*tls.Certificate, error) { return nil, missingClientKeyError{certPath: c.CertPath, err: err} } -func cut(s, sep string) (string, string, bool) { - key, val, ok := strings.Cut(s, sep) - key, val = strings.TrimSpace(key), strings.TrimSpace(val) - return key, val, ok -} - type invalidOptionError string func (err invalidOptionError) Error() string { diff --git a/internal/core/core.go b/internal/core/core.go index 5a62b21..a486f6e 100644 --- a/internal/core/core.go +++ b/internal/core/core.go @@ -1,5 +1,7 @@ package core +import "strings" + // Color represents the options for enabling or disabling color output. type Color int @@ -75,3 +77,10 @@ type KeyVal[T any] struct { func PointerTo[T any](t T) *T { return &t } + +// CutTrimmed splits s around the first instance of sep, returning the +// trimmed text before and after sep. +func CutTrimmed(s, sep string) (string, string, bool) { + key, val, ok := strings.Cut(s, sep) + return strings.TrimSpace(key), strings.TrimSpace(val), ok +}