From 41348469ea6388b8b39858b00e52f7161e9e9ec7 Mon Sep 17 00:00:00 2001 From: fujiwara Date: Sat, 22 Feb 2025 03:56:00 +0900 Subject: [PATCH] add --http-headers option. --- http.go | 3 +-- manifest.go | 4 ++-- stretcher.go | 16 ++++++++++------ 3 files changed, 13 insertions(+), 10 deletions(-) diff --git a/http.go b/http.go index 44e96fc..01ef8d6 100644 --- a/http.go +++ b/http.go @@ -1,6 +1,5 @@ package stretcher type HTTPOptions struct { - Headers map[string]string `yaml:"headers"` - RetryMax int `yaml:"retry_max"` + Headers map[string]string `help:"HTTP request headers(key=value) for download src archives with HTTP or HTTPS"` } diff --git a/manifest.go b/manifest.go index f242775..0d98df5 100644 --- a/manifest.go +++ b/manifest.go @@ -155,13 +155,13 @@ func (m *Manifest) Deploy(ctx context.Context, conf *Config) error { func (m *Manifest) fetchSrc(ctx context.Context, conf *Config, tmp *os.File) error { begin := time.Now() - src, err := getURL(ctx, m.Src) + src, err := getURL(ctx, m.Src, conf) if err != nil { for i := 0; i < conf.Retry; i++ { log.Printf("Get src failed: %s", err) log.Printf("Try again. Waiting: %s", conf.RetryWait) time.Sleep(conf.RetryWait) - src, err = getURL(ctx, m.Src) + src, err = getURL(ctx, m.Src, conf) if err == nil { break } diff --git a/stretcher.go b/stretcher.go index aeb14f9..0503a7b 100644 --- a/stretcher.go +++ b/stretcher.go @@ -36,6 +36,7 @@ type Config struct { RetryWait time.Duration `help:"wait for retry download src archives"` RsyncVerbose string `help:"rsync verbose option (default: -v)" default:"-v"` Version kong.VersionFlag `short:"v" help:"Show version and exit."` + HTTP HTTPOptions `embed prefix:"http-"` maxbw uint64 initSleep time.Duration @@ -86,7 +87,7 @@ func Run(ctx context.Context, conf *Config) error { } log.Println("Loading manifest:", manifestURL) - m, err := getManifest(ctx, manifestURL) + m, err := getManifest(ctx, manifestURL, conf) if err != nil { return fmt.Errorf("load manifest failed: %w", err) } @@ -144,12 +145,15 @@ func getFile(_ context.Context, u *url.URL) (io.ReadCloser, error) { return os.Open(u.Path) } -func getHTTP(ctx context.Context, u *url.URL) (io.ReadCloser, error) { +func getHTTP(ctx context.Context, u *url.URL, opt *HTTPOptions) (io.ReadCloser, error) { req, err := http.NewRequestWithContext(ctx, "GET", u.String(), nil) if err != nil { return nil, err } req.Header.Add("User-Agent", "Stretcher/"+Version) + for k, v := range opt.Headers { + req.Header.Add(k, v) + } resp, err := http.DefaultClient.Do(req) if err != nil { @@ -158,7 +162,7 @@ func getHTTP(ctx context.Context, u *url.URL) (io.ReadCloser, error) { return resp.Body, nil } -func getURL(ctx context.Context, urlStr string) (io.ReadCloser, error) { +func getURL(ctx context.Context, urlStr string, conf *Config) (io.ReadCloser, error) { log.Println("Loading URL", urlStr) u, err := url.Parse(urlStr) if err != nil { @@ -170,7 +174,7 @@ func getURL(ctx context.Context, urlStr string) (io.ReadCloser, error) { case "gs": return getGS(ctx, u) case "http", "https": - return getHTTP(ctx, u) + return getHTTP(ctx, u, &conf.HTTP) case "file": return getFile(ctx, u) default: @@ -178,8 +182,8 @@ func getURL(ctx context.Context, urlStr string) (io.ReadCloser, error) { } } -func getManifest(ctx context.Context, manifestURL string) (*Manifest, error) { - rc, err := getURL(ctx, manifestURL) +func getManifest(ctx context.Context, manifestURL string, conf *Config) (*Manifest, error) { + rc, err := getURL(ctx, manifestURL, conf) if err != nil { return nil, err }