diff --git a/README.md b/README.md index 1c32006e..76e2db20 100644 --- a/README.md +++ b/README.md @@ -86,6 +86,8 @@ The URL format looks like `https://github.com//bazel/releases/download//` to the base URL instead of using the official release server. Bazelisk will read file [`~/.netrc`](https://everything.curl.dev/usingcurl/netrc) for credentials for Basic authentication. +If you want to use the releases stored on the local disk, set the URL as `file://` followed by the local disk path. On Windows, escape `\` in the path by `%5C`. + If for any reason none of this works, you can also override the URL format altogether by setting the environment variable `$BAZELISK_FORMAT_URL`. This variable takes a format-like string with placeholders and performs the following replacements to compute the download URL: - `%e`: Extension suffix, such as the empty string or `.exe`. diff --git a/httputil/httputil.go b/httputil/httputil.go index 0103d29b..2f7fc482 100644 --- a/httputil/httputil.go +++ b/httputil/httputil.go @@ -14,6 +14,7 @@ import ( "path/filepath" "regexp" "strconv" + "strings" "time" netrc "github.com/bgentry/go-netrc/netrc" @@ -78,7 +79,41 @@ func ReadRemoteFile(url string, auth string) ([]byte, http.Header, error) { return body, res.Header, nil } +type LocalFileError struct{ err error } + +func (e *LocalFileError) Error() string { return e.err.Error() } +func (e *LocalFileError) Unwrap() error { return e.err } + +// Handles file:// URLs by reading files from disk. +func readLocalFile(urlStr string) (*http.Response, error) { + urlStr = strings.TrimPrefix(urlStr, "file://") + path, err := url.PathUnescape(urlStr) + if err != nil { + return nil, &LocalFileError{err: fmt.Errorf("invalid file url %q: %w", urlStr, err)} + } + f, err := os.Open(path) + if err != nil { + return nil, &LocalFileError{err: fmt.Errorf("could not open %q: %w", path, err)} + } + var size int64 = -1 + if fi, statErr := f.Stat(); statErr == nil { + size = fi.Size() + } + return &http.Response{ + StatusCode: 200, + Status: "200 OK", + Header: make(http.Header), + Body: f, + ContentLength: size, + Request: &http.Request{Method: "GET", URL: &url.URL{Scheme: "file", Path: path}}, + }, nil +} + func get(url, auth string) (*http.Response, error) { + if strings.HasPrefix(url, "file://") { + return readLocalFile(url) + } + req, err := http.NewRequest("GET", url, nil) if err != nil { return nil, fmt.Errorf("could not create request: %v", err) @@ -127,6 +162,10 @@ func get(url, auth string) (*http.Response, error) { func shouldRetry(res *http.Response, err error) bool { // Retry if the client failed to speak HTTP. if err != nil { + var nre *LocalFileError + if errors.As(err, &nre) { + return false + } return true } // For HTTP: only retry on non-permanent/fatal errors. diff --git a/httputil/httputil_test.go b/httputil/httputil_test.go index 1a3f2281..d82e826d 100644 --- a/httputil/httputil_test.go +++ b/httputil/httputil_test.go @@ -3,6 +3,9 @@ package httputil import ( "errors" "net/http" + "net/url" + "os" + "path/filepath" "strconv" "strings" "testing" @@ -251,3 +254,42 @@ func TestNoRetryOnPermanentError(t *testing.T) { t.Fatalf("Expected no retries for permanent error, but got %d", clock.TimesSlept()) } } + +func TestReadLocalFile(t *testing.T) { + tmpDir := t.TempDir() + path := filepath.Join(tmpDir, "payload.txt") + want := "hello from disk" + if err := os.WriteFile(path, []byte(want), 0644); err != nil { + t.Fatalf("failed to write temp file: %v", err) + } + fileURL := (&url.URL{Scheme: "file", Path: path}).String() + + body, _, err := ReadRemoteFile(fileURL, "") + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if got := string(body); got != want { + t.Fatalf("expected body %q, but got %q", want, got) + } +} + +func TestReadLocalFileNotFound(t *testing.T) { + clock := newFakeClock() + RetryClock = clock + MaxRetries = 10 + MaxRequestDuration = time.Hour + + missingPath := filepath.Join(t.TempDir(), "does-not-exist.txt") + fileURL := (&url.URL{Scheme: "file", Path: missingPath}).String() + + _, _, err := ReadRemoteFile(fileURL, "") + if err == nil { + t.Fatal("expected error for missing file") + } + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("expected os.ErrNotExist, got %v", err) + } + if clock.TimesSlept() != 0 { + t.Fatalf("expected no retries for file:// error, but slept %d times", clock.TimesSlept()) + } +}