diff --git a/internal/zip/fixtures/myproject.zip b/internal/zip/fixtures/myproject.zip new file mode 100644 index 00000000000..2fdf3f90c6e Binary files /dev/null and b/internal/zip/fixtures/myproject.zip differ diff --git a/pkg/cmd/run/download/zip.go b/internal/zip/zip.go similarity index 80% rename from pkg/cmd/run/download/zip.go rename to internal/zip/zip.go index bb504dde193..8cef5c30bfe 100644 --- a/pkg/cmd/run/download/zip.go +++ b/internal/zip/zip.go @@ -1,4 +1,4 @@ -package download +package zip import ( "archive/zip" @@ -17,7 +17,11 @@ const ( execMode os.FileMode = 0755 ) -func extractZip(zr *zip.Reader, destDir safepaths.Absolute) error { +// ExtractZip extracts the contents of a zip archive to destDir. +// Files that would result in path traversal are silently skipped. +// Files that would produce any other error cause the extraction to be aborted, +// and the error is returned. +func ExtractZip(zr *zip.Reader, destDir safepaths.Absolute) error { for _, zf := range zr.File { fpath, err := destDir.Join(zf.Name) if err != nil { diff --git a/pkg/cmd/run/download/zip_test.go b/internal/zip/zip_test.go similarity index 89% rename from pkg/cmd/run/download/zip_test.go rename to internal/zip/zip_test.go index 2584371b4f8..37e83661cf0 100644 --- a/pkg/cmd/run/download/zip_test.go +++ b/internal/zip/zip_test.go @@ -1,4 +1,4 @@ -package download +package zip import ( "archive/zip" @@ -19,7 +19,7 @@ func Test_extractZip(t *testing.T) { require.NoError(t, err) defer zipFile.Close() - err = extractZip(&zipFile.Reader, extractPath) + err = ExtractZip(&zipFile.Reader, extractPath) require.NoError(t, err) _, err = os.Stat(filepath.Join(extractPath.String(), "src", "main.go")) diff --git a/pkg/cmd/copilot/copilot.go b/pkg/cmd/copilot/copilot.go new file mode 100644 index 00000000000..1195accc488 --- /dev/null +++ b/pkg/cmd/copilot/copilot.go @@ -0,0 +1,455 @@ +package copilot + +import ( + "archive/tar" + "archive/zip" + "bufio" + "compress/gzip" + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "os" + "os/exec" + "path/filepath" + "runtime" + "slices" + "strings" + + "github.com/MakeNowJust/heredoc" + "github.com/cli/cli/v2/internal/config" + "github.com/cli/cli/v2/internal/prompter" + "github.com/cli/cli/v2/internal/safepaths" + "github.com/cli/cli/v2/internal/update" + ghzip "github.com/cli/cli/v2/internal/zip" + "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/iostreams" + "github.com/spf13/cobra" +) + +type CopilotOptions struct { + IO *iostreams.IOStreams + HttpClient func() (*http.Client, error) + Prompter prompter.Prompter + + CopilotArgs []string + Remove bool +} + +func NewCmdCopilot(f *cmdutil.Factory, runF func(*CopilotOptions) error) *cobra.Command { + opts := &CopilotOptions{ + IO: f.IOStreams, + HttpClient: f.HttpClient, + Prompter: f.Prompter, + } + + cmd := &cobra.Command{ + Use: "copilot [flags] [args]", + Short: "Run the GitHub Copilot CLI (preview)", + Long: heredoc.Docf(` + Runs the GitHub Copilot CLI. + + Executing the Copilot CLI through %[1]sgh%[1]s is currently in preview and subject to change. + + If already installed, %[1]sgh%[1]s will execute the Copilot CLI found in your %[1]sPATH%[1]s. + If the Copilot CLI is not installed, it will be downloaded to %[2]s. + + Use %[1]s--remove%[1]s to remove the downloaded Copilot CLI. + + This command is only supported on Windows, Linux, and Darwin, on amd64/x64 + or arm64 architectures. + + To prevent %[1]sgh%[1]s from interpreting flags intended for Copilot, + use %[1]s--%[1]s before Copilot flags and args. + + Learn more at https://gh.io/copilot-cli + `, "`", copilotInstallDir()), + Example: heredoc.Doc(` + # Download and run the Copilot CLI + $ gh copilot + + # Run the Copilot CLI + $ gh copilot -p "Summarize this week's commits" --allow-tool 'shell(git)' + + # Remove the Copilot CLI (if installed through gh) + $ gh copilot --remove + + # Run the Copilot CLI help command + $ gh copilot -- --help + `), + DisableFlagParsing: true, + RunE: func(cmd *cobra.Command, args []string) error { + stopParsePos := -1 + for i, arg := range args { + if arg == "--" { + stopParsePos = i + break + } + } + + ghArgs := args + opts.CopilotArgs = args + if stopParsePos >= 0 { + ghArgs = args[:stopParsePos] + opts.CopilotArgs = args[stopParsePos+1:] // +1 to skip the "--" itself + } + + if slices.Contains(ghArgs, "--help") || slices.Contains(ghArgs, "-h") { + return cmd.Help() + } + + if slices.Contains(ghArgs, "--remove") { + hasOtherArgs := len(ghArgs) > 1 + if stopParsePos >= 0 { + hasOtherArgs = hasOtherArgs || len(opts.CopilotArgs) > 0 + } + if hasOtherArgs { + return cmdutil.FlagErrorf("cannot use --remove with args") + } + opts.Remove = true + opts.CopilotArgs = nil + } + + if runF != nil { + return runF(opts) + } + + return runCopilot(opts) + }, + } + + cmdutil.DisableAuthCheck(cmd) + + // We add this flag, even though flag parsing is disabled for this command + // so the flag still appears in the help text. + cmd.Flags().Bool("remove", false, "Remove the downloaded Copilot CLI") + return cmd +} + +func runCopilot(opts *CopilotOptions) error { + if opts.Remove { + if err := removeCopilot(copilotInstallDir()); err != nil { + return err + } + + if opts.IO.IsStdoutTTY() { + fmt.Fprintln(opts.IO.ErrOut, "Copilot CLI removed successfully") + } + return nil + } + + copilotPath := findCopilotBinary() + if copilotPath == "" { + if opts.IO.CanPrompt() { + confirmed, err := opts.Prompter.Confirm("GitHub Copilot CLI is not installed. Would you like to install it?", true) + if err != nil { + return err + } + if !confirmed { + fmt.Fprintf(opts.IO.ErrOut, "%s Copilot CLI was not installed", opts.IO.ColorScheme().WarningIcon()) + return cmdutil.SilentError + } + } else if !update.IsCI() { + fmt.Fprintf(opts.IO.ErrOut, "%s Copilot CLI not installed", opts.IO.ColorScheme().WarningIcon()) + return cmdutil.SilentError + } + + httpClient, err := opts.HttpClient() + if err != nil { + return err + } + + copilotPath, err = downloadCopilot(httpClient, opts.IO, copilotInstallDir(), copilotBinaryPath()) + if err != nil { + return err + } + } + + externalCmd := exec.Command(copilotPath, opts.CopilotArgs...) + externalCmd.Stdin = opts.IO.In + externalCmd.Stdout = opts.IO.Out + externalCmd.Stderr = opts.IO.ErrOut + + if err := externalCmd.Run(); err != nil { + if exitErr, ok := err.(*exec.ExitError); ok { + // We terminate with os.Exit here, preserving the exit code from Copilot CLI, + // and also preventing stdio writes by callers up the stack. + os.Exit(exitErr.ExitCode()) + } + return err + } + return nil +} + +const copilotBinaryName = "copilot" + +func copilotInstallDir() string { + return filepath.Join(config.DataDir(), "copilot") +} + +func copilotBinaryPath() string { + binaryName := copilotBinaryName + if runtime.GOOS == "windows" { + binaryName += ".exe" + } + return filepath.Join(copilotInstallDir(), binaryName) +} + +// findCopilotBinary returns the path to the Copilot CLI binary, if installed, +// with the following order of precedence: +// 1. `copilot` in the PATH +// 2. `copilot` in gh's data directory +// +// If not installed, it returns an empty string. +func findCopilotBinary() string { + if path, err := exec.LookPath(copilotBinaryName); err == nil { + return path + } + + localPath := copilotBinaryPath() + if _, err := os.Stat(localPath); err != nil { + return "" + } + return localPath +} + +// downloadCopilot downloads and installs the Copilot CLI to installDir. +// It returns the path to the installed Copilot binary. +func downloadCopilot(httpClient *http.Client, ios *iostreams.IOStreams, installDir, localPath string) (string, error) { + platform := runtime.GOOS + if platform == "windows" { + platform = "win32" + } + + arch := runtime.GOARCH + if arch == "amd64" { + arch = "x64" + } + + if arch != "x64" && arch != "arm64" { + return "", fmt.Errorf("unsupported architecture: %s (supported: x64, arm64)", arch) + } + + var archiveURL string + var archiveName string + var isZip bool + switch platform { + case "win32": + archiveName = fmt.Sprintf("copilot-%s-%s.zip", platform, arch) + archiveURL = fmt.Sprintf("https://github.com/github/copilot-cli/releases/latest/download/%s", archiveName) + isZip = true + case "linux", "darwin": + archiveName = fmt.Sprintf("copilot-%s-%s.tar.gz", platform, arch) + archiveURL = fmt.Sprintf("https://github.com/github/copilot-cli/releases/latest/download/%s", archiveName) + default: + return "", fmt.Errorf("unsupported platform: %s (supported: linux, darwin, windows)", platform) + } + + checksumsURL := "https://github.com/github/copilot-cli/releases/latest/download/SHA256SUMS.txt" + + expectedChecksum, err := fetchExpectedChecksum(httpClient, checksumsURL, archiveName) + if err != nil { + return "", fmt.Errorf("failed to fetch checksums: %w", err) + } + + ios.StartProgressIndicatorWithLabel(fmt.Sprintf("Downloading Copilot CLI from %s", archiveURL)) + defer ios.StopProgressIndicator() + + resp, err := httpClient.Get(archiveURL) + if err != nil { + return "", fmt.Errorf("failed to download: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("download failed with status: %s", resp.Status) + } + + // Download to temp file while calculating checksum + tmpFile, err := os.CreateTemp("", "copilot-download-*") + if err != nil { + return "", fmt.Errorf("failed to create temp file: %w", err) + } + defer os.Remove(tmpFile.Name()) + defer tmpFile.Close() + + hasher := sha256.New() + if _, err := io.Copy(tmpFile, io.TeeReader(resp.Body, hasher)); err != nil { + return "", fmt.Errorf("failed to download: %w", err) + } + + ios.StopProgressIndicator() + + // Validate checksum + actualChecksumHex := hex.EncodeToString(hasher.Sum(nil)) + if actualChecksumHex != expectedChecksum { + return "", fmt.Errorf("checksum mismatch: expected %s, got %s", expectedChecksum, actualChecksumHex) + } + + if _, err := tmpFile.Seek(0, io.SeekStart); err != nil { + return "", fmt.Errorf("failed to seek temp file: %w", err) + } + + if err := os.MkdirAll(installDir, 0755); err != nil { + return "", fmt.Errorf("failed to create install directory: %w", err) + } + + // Extract from the downloaded data + if isZip { + err = extractZip(tmpFile.Name(), installDir) + } else { + err = extractTarGz(tmpFile, installDir) + } + if err != nil { + return "", err + } + + if _, err := os.Stat(localPath); err != nil { + return "", fmt.Errorf("copilot binary unavailable: %w", err) + } + + fmt.Fprintf(ios.ErrOut, "%s Copilot CLI installed successfully\n", ios.ColorScheme().SuccessIcon()) + return localPath, nil +} + +// fetchExpectedChecksum downloads the SHA256SUMS.txt file and returns the expected checksum for the given archive name. +func fetchExpectedChecksum(httpClient *http.Client, checksumsURL, archiveName string) (string, error) { + resp, err := httpClient.Get(checksumsURL) + if err != nil { + return "", err + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("failed to download checksums: %s", resp.Status) + } + + // Parse the checksums file. Possible formats are: + // - " " (two whitespaces) + // - " " + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + fields := strings.Fields(line) + if len(fields) >= 2 { + checksum := fields[0] + filename := fields[1] + if filename == archiveName { + return checksum, nil + } + } + } + if err := scanner.Err(); err != nil { + return "", fmt.Errorf("failed to read checksums: %w", err) + } + + return "", fmt.Errorf("checksum not found for %s", archiveName) +} + +// extractZip reads a ZIP archive at path and extracts its contents into destDir. +// It returns an error if the archive cannot be read, +// or if any file or directory within the archive cannot be created or written. +func extractZip(path, destDir string) error { + zipReader, err := zip.OpenReader(path) + if err != nil { + return fmt.Errorf("failed to open zip: %w", err) + } + defer zipReader.Close() + + absPath, err := safepaths.ParseAbsolute(destDir) + if err != nil { + return err + } + + // As of the time of writing, ghzip.ExtractZip will safely skip files that + // would result in path traversal. This is an issue for our use-case because + // we want to error out before extracting if there's any such file. + // To avoid breaking the shared ghzip.ExtractZip code that expects unsafe + // paths to be ignored and no error produced, we pre-validate here, + // producing an error if any such file is found. + for _, f := range zipReader.File { + _, err := absPath.Join(f.Name) + if err != nil { + return err + } + } + + if err := ghzip.ExtractZip(&zipReader.Reader, absPath); err != nil { + return err + } + + return nil +} + +// extractTarGz reads a TAR.GZ archive from r and extracts its contents into destDir. +// It returns an error if the archive cannot be read, +// or if any file or directory within the archive cannot be created or written. +func extractTarGz(r io.Reader, destDir string) error { + gzr, err := gzip.NewReader(r) + if err != nil { + return fmt.Errorf("failed to create gzip reader: %w", err) + } + defer gzr.Close() + + absDestDirPath, err := safepaths.ParseAbsolute(destDir) + if err != nil { + return err + } + + tr := tar.NewReader(gzr) + for { + header, err := tr.Next() + if err == io.EOF { + break + } + if err != nil { + return fmt.Errorf("failed to read tar: %w", err) + } + + absFilePath, err := absDestDirPath.Join(header.Name) + if err != nil { + return err + } + target := absFilePath.String() + + if header.Typeflag == tar.TypeReg { + if err := os.MkdirAll(filepath.Dir(target), 0755); err != nil { + return fmt.Errorf("failed to create parent directory: %w", err) + } + if err := extractFile(target, os.FileMode(header.Mode)&0777, tr); err != nil { + return err + } + } + } + return nil +} + +// extractFile creates a file at target with the given mode and copies content from r. +func extractFile(target string, mode os.FileMode, r io.Reader) (err error) { + out, err := os.OpenFile(target, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, mode) + if err != nil { + return fmt.Errorf("failed to create file: %w", err) + } + defer func() { + if cerr := out.Close(); err == nil && cerr != nil { + err = fmt.Errorf("failed to close file: %w", cerr) + } + }() + if _, err := io.Copy(out, r); err != nil { + return fmt.Errorf("failed to write file: %w", err) + } + return nil +} + +func removeCopilot(installDir string) error { + if _, err := os.Stat(installDir); os.IsNotExist(err) { + return fmt.Errorf("failed to remove Copilot CLI: Copilot CLI not installed through `gh`") + } + + if err := os.RemoveAll(installDir); err != nil { + return fmt.Errorf("failed to remove Copilot CLI: %w", err) + } + + return nil +} diff --git a/pkg/cmd/copilot/copilot_test.go b/pkg/cmd/copilot/copilot_test.go new file mode 100644 index 00000000000..f377f171dc0 --- /dev/null +++ b/pkg/cmd/copilot/copilot_test.go @@ -0,0 +1,588 @@ +package copilot + +import ( + "archive/tar" + "archive/zip" + "bytes" + "compress/gzip" + "crypto/sha256" + "encoding/hex" + "fmt" + "net/http" + "os" + "path/filepath" + "runtime" + "testing" + + "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/httpmock" + "github.com/cli/cli/v2/pkg/iostreams" + "github.com/google/shlex" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCmdCopilot(t *testing.T) { + tests := []struct { + name string + args string + wantOpts CopilotOptions + wantErrString string + wantHelp bool + }{ + { + name: "no argument", + args: "", + wantOpts: CopilotOptions{ + CopilotArgs: []string{}, + }, + wantErrString: "", + }, + { + name: "with arguments", + args: "some-arg some-other-arg", + wantOpts: CopilotOptions{ + CopilotArgs: []string{"some-arg", "some-other-arg"}, + }, + }, + { + name: "with --remove alone", + args: "--remove", + wantOpts: CopilotOptions{ + Remove: true, + }, + }, + { + name: "with non-gh flags passed to copilot", + args: "-p testing --something-flag", + wantOpts: CopilotOptions{ + CopilotArgs: []string{"-p", "testing", "--something-flag"}, + }, + }, + { + name: "with --remove and arguments", + args: "--remove some-arg", + wantErrString: "cannot use --remove with args", + }, + { + name: "with --remove passed to copilot using --", + args: "-- --remove", + wantOpts: CopilotOptions{ + CopilotArgs: []string{"--remove"}, + }, + }, + { + name: "with --remove and -- alone", + args: "--remove --", + wantOpts: CopilotOptions{ + Remove: true, + }, + }, + { + name: "with --remove, some invalid arg, and --", + args: "--remove invalid-arg --", + wantErrString: "cannot use --remove with args", + }, + { + name: "with --remove and -- and random arguments", + args: "--remove -- some-arg", + wantErrString: "cannot use --remove with args", + }, + { + name: "with --help, shows gh help", + args: "--help", + wantErrString: "", + wantHelp: true, + }, + { + name: "with --help and --, shows copilot help", + args: "-- --help", + wantOpts: CopilotOptions{ + CopilotArgs: []string{"--help"}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &cmdutil.Factory{} + + argv, err := shlex.Split(tt.args) + assert.NoError(t, err) + + var gotOpts *CopilotOptions + cmd := NewCmdCopilot(f, func(opts *CopilotOptions) error { + gotOpts = opts + return nil + }) + + cmd.SetArgs(argv) + cmd.SetIn(&bytes.Buffer{}) + cmd.SetOut(&bytes.Buffer{}) + cmd.SetErr(&bytes.Buffer{}) + + _, err = cmd.ExecuteC() + if tt.wantErrString != "" { + require.EqualError(t, err, tt.wantErrString) + return + } + + if tt.wantHelp { + require.NoError(t, err) + return + } + + require.NoError(t, err) + assert.Equal(t, tt.wantOpts.CopilotArgs, gotOpts.CopilotArgs, "opts.CopilotArgs not as expected") + assert.Equal(t, tt.wantOpts.Remove, gotOpts.Remove, "opts.Remove not as expected") + }) + } +} + +func TestRemoveCopilot(t *testing.T) { + t.Run("removes existing install directory", func(t *testing.T) { + // Create a temporary directory to simulate the install directory + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + require.NoError(t, os.MkdirAll(installDir, 0755), "failed to create test directory") + // Create a dummy file in the directory + dummyFile := filepath.Join(installDir, "copilot") + require.NoError(t, os.WriteFile(dummyFile, []byte("test"), 0755), "failed to create test file") + + err := removeCopilot(installDir) + require.NoError(t, err, "unexpected error") + + _, err = os.Stat(installDir) + require.True(t, os.IsNotExist(err), "expected install directory to be removed") + }) + + t.Run("handles non-existent directory", func(t *testing.T) { + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + + require.ErrorContains(t, removeCopilot(installDir), "failed to remove Copilot CLI") + }) +} + +// createTarGzBuffer creates a tar.gz archive in memory with the given files. +func createTarGzBuffer(t *testing.T, files map[string][]byte) []byte { + t.Helper() + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + for name, content := range files { + hdr := &tar.Header{ + Name: name, + Mode: 0755, + Size: int64(len(content)), + } + require.NoError(t, tw.WriteHeader(hdr), "failed to write tar header") + _, err := tw.Write(content) + require.NoError(t, err, "failed to write tar content") + } + + require.NoError(t, tw.Close(), "failed to close tar writer") + require.NoError(t, gw.Close(), "failed to close gzip writer") + return buf.Bytes() +} + +// createZipBuffer creates a zip archive in memory with the given files. +func createZipBuffer(t *testing.T, files map[string][]byte) []byte { + t.Helper() + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + for name, content := range files { + fw, err := zw.Create(name) + require.NoError(t, err, "failed to create zip entry") + _, err = fw.Write(content) + require.NoError(t, err, "failed to write zip content") + } + + require.NoError(t, zw.Close(), "failed to close zip writer") + return buf.Bytes() +} + +func TestExtractTarGz(t *testing.T) { + t.Run("extracts files correctly", func(t *testing.T) { + content := []byte("hello world") + archive := createTarGzBuffer(t, map[string][]byte{ + "copilot": content, + }) + + destDir := t.TempDir() + + err := extractTarGz(bytes.NewReader(archive), destDir) + require.NoError(t, err, "extractTarGz() error") + + extracted, err := os.ReadFile(filepath.Join(destDir, "copilot")) + require.NoError(t, err, "failed to read extracted file") + require.Equal(t, content, extracted, "extracted content mismatch") + }) + + t.Run("extracts nested files", func(t *testing.T) { + content := []byte("nested content") + archive := createTarGzBuffer(t, map[string][]byte{ + "subdir/file.txt": content, + }) + + destDir := t.TempDir() + + err := extractTarGz(bytes.NewReader(archive), destDir) + require.NoError(t, err, "extractTarGz() error") + + extracted, err := os.ReadFile(filepath.Join(destDir, "subdir", "file.txt")) + require.NoError(t, err, "failed to read extracted file") + require.Equal(t, content, extracted, "extracted content mismatch") + }) + + t.Run("rejects path traversal", func(t *testing.T) { + // Manually create a malicious tar.gz with path traversal + var buf bytes.Buffer + gw := gzip.NewWriter(&buf) + tw := tar.NewWriter(gw) + + hdr := &tar.Header{ + Name: "../evil.txt", + Mode: 0755, + Size: 4, + } + _ = tw.WriteHeader(hdr) + _, _ = tw.Write([]byte("evil")) + _ = tw.Close() + _ = gw.Close() + + destDir := t.TempDir() + + err := extractTarGz(bytes.NewReader(buf.Bytes()), destDir) + require.Error(t, err, "expected error for path traversal, got nil") + }) + + t.Run("handles invalid gzip", func(t *testing.T) { + destDir := t.TempDir() + + err := extractTarGz(bytes.NewReader([]byte("not valid gzip")), destDir) + require.Error(t, err, "expected error for invalid gzip, got nil") + }) +} + +func TestExtractZip(t *testing.T) { + t.Run("extracts files correctly", func(t *testing.T) { + zipDir := t.TempDir() + zipPath := filepath.Join(zipDir, "archive.zip") + content := []byte("hello world") + archive := createZipBuffer(t, map[string][]byte{ + "copilot.exe": content, + }) + require.NoError(t, os.WriteFile(zipPath, archive, 0x755)) + + destDir := t.TempDir() + + err := extractZip(zipPath, destDir) + require.NoError(t, err, "extractZip() error") + + extracted, err := os.ReadFile(filepath.Join(destDir, "copilot.exe")) + require.NoError(t, err, "failed to read extracted file") + require.Equal(t, content, extracted, "extracted content mismatch") + }) + + t.Run("extracts nested files", func(t *testing.T) { + zipDir := t.TempDir() + zipPath := filepath.Join(zipDir, "archive.zip") + content := []byte("hello world") + archive := createZipBuffer(t, map[string][]byte{ + "subdir/file.txt": content, + }) + require.NoError(t, os.WriteFile(zipPath, archive, 0x755)) + + destDir := t.TempDir() + + err := extractZip(zipPath, destDir) + require.NoError(t, err, "extractZip() error") + + extracted, err := os.ReadFile(filepath.Join(destDir, "subdir", "file.txt")) + require.NoError(t, err, "failed to read extracted file") + require.Equal(t, content, extracted, "extracted content mismatch") + }) + + t.Run("rejects path traversal", func(t *testing.T) { + zipDir := t.TempDir() + zipPath := filepath.Join(zipDir, "archive.zip") + + var buf bytes.Buffer + zw := zip.NewWriter(&buf) + + fh := &zip.FileHeader{ + Name: "../evil.txt", + Method: zip.Store, + } + fw, _ := zw.CreateHeader(fh) + _, _ = fw.Write([]byte("evil")) + _ = zw.Close() + + require.NoError(t, os.WriteFile(zipPath, buf.Bytes(), 0x755)) + destDir := t.TempDir() + + err := extractZip(zipPath, destDir) + require.Error(t, err, "expected error for path traversal, got nil") + }) +} + +func TestFetchExpectedChecksum(t *testing.T) { + t.Run("parses checksums file correctly", func(t *testing.T) { + reg := &httpmock.Registry{} + checksums := "abc123def456 copilot-linux-x64.tar.gz\n789xyz copilot-darwin-arm64.tar.gz\n" + reg.Register( + httpmock.MatchAny, + httpmock.StringResponse(checksums), + ) + + client := &http.Client{Transport: reg} + checksum, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-linux-x64.tar.gz") + require.NoError(t, err, "unexpected error") + require.Equal(t, "abc123def456", checksum, "checksum mismatch") + }) + + t.Run("returns error for missing archive", func(t *testing.T) { + reg := &httpmock.Registry{} + checksums := "abc123 copilot-linux-x64.tar.gz\n" + reg.Register( + httpmock.MatchAny, + httpmock.StringResponse(checksums), + ) + + client := &http.Client{Transport: reg} + _, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-windows-x64.zip") + require.Error(t, err, "expected error for missing archive") + require.Equal(t, "checksum not found for copilot-windows-x64.zip", err.Error(), "unexpected error") + }) + + t.Run("handles single space separator", func(t *testing.T) { + reg := &httpmock.Registry{} + checksums := "abc123 copilot-darwin-x64.tar.gz\n" + reg.Register( + httpmock.MatchAny, + httpmock.StringResponse(checksums), + ) + + client := &http.Client{Transport: reg} + checksum, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-darwin-x64.tar.gz") + require.NoError(t, err, "unexpected error") + require.Equal(t, "abc123", checksum, "checksum mismatch") + }) + + t.Run("handles HTTP error", func(t *testing.T) { + reg := &httpmock.Registry{} + reg.Register( + httpmock.MatchAny, + httpmock.StatusStringResponse(http.StatusNotFound, "not found"), + ) + + client := &http.Client{Transport: reg} + _, err := fetchExpectedChecksum(client, "https://example.com/checksums", "copilot-linux-x64.tar.gz") + require.Error(t, err, "expected error for HTTP 404") + }) +} + +func archString() string { + arch := runtime.GOARCH + if arch == "amd64" { + return "x64" + } + return arch +} + +func TestDownloadCopilot(t *testing.T) { + // Skip on unsupported architectures + if runtime.GOARCH != "amd64" && runtime.GOARCH != "arm64" { + t.Skip("skipping test on unsupported architecture") + } + + t.Run("downloads and extracts tar.gz with valid checksum", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping tar.gz test on windows") + } + + ios, _, _, stderr := iostreams.Test() + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + localPath := filepath.Join(installDir, "copilot") + + // Create mock archive with copilot binary + binaryContent := []byte("#!/bin/sh\necho copilot") + archive := createTarGzBuffer(t, map[string][]byte{ + "copilot": binaryContent, + }) + + // Calculate checksum + checksum := sha256.Sum256(archive) + checksumHex := hex.EncodeToString(checksum[:]) + archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString()) + checksumFile := fmt.Sprintf("%s %s\n", checksumHex, archiveName) + + reg := &httpmock.Registry{} + // Register checksum endpoint + reg.Register( + httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"), + httpmock.StringResponse(checksumFile), + ) + // Register archive endpoint + reg.Register( + httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)), + httpmock.BinaryResponse(archive), + ) + + httpClient := &http.Client{Transport: reg} + + path, err := downloadCopilot(httpClient, ios, installDir, localPath) + require.NoError(t, err, "downloadCopilot() error") + require.Equal(t, localPath, path, "downloadCopilot() path mismatch") + + // Verify binary was extracted + extracted, err := os.ReadFile(localPath) + require.NoError(t, err, "failed to read extracted binary") + require.Equal(t, binaryContent, extracted, "extracted content mismatch") + + // Verify output messages + require.Contains(t, stderr.String(), "installed successfully", "expected success message in stderr") + }) + + t.Run("fails with checksum mismatch", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping tar.gz test on windows") + } + + ios, _, _, _ := iostreams.Test() + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + localPath := filepath.Join(installDir, "copilot") + + binaryContent := []byte("#!/bin/sh\necho copilot") + archive := createTarGzBuffer(t, map[string][]byte{ + "copilot": binaryContent, + }) + + // Use wrong checksum + archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString()) + checksumFile := fmt.Sprintf("%s %s\n", "0000000000000000000000000000000000000000000000000000000000000000", archiveName) + + reg := &httpmock.Registry{} + reg.Register( + httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"), + httpmock.StringResponse(checksumFile), + ) + reg.Register( + httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)), + httpmock.BinaryResponse(archive), + ) + + httpClient := &http.Client{Transport: reg} + + _, err := downloadCopilot(httpClient, ios, installDir, localPath) + require.Error(t, err, "expected error for checksum mismatch, got nil") + require.Contains(t, err.Error(), "checksum mismatch", "expected checksum mismatch error") + }) + + t.Run("handles HTTP error on archive download", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping tar.gz test on windows") + } + + ios, _, _, _ := iostreams.Test() + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + localPath := filepath.Join(installDir, "copilot") + + archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString()) + checksumFile := fmt.Sprintf("%s %s\n", "abc123", archiveName) + + reg := &httpmock.Registry{} + reg.Register( + httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"), + httpmock.StringResponse(checksumFile), + ) + reg.Register( + httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)), + httpmock.StatusStringResponse(http.StatusNotFound, "not found"), + ) + + httpClient := &http.Client{Transport: reg} + + _, err := downloadCopilot(httpClient, ios, installDir, localPath) + require.Error(t, err, "expected error for HTTP 404, got nil") + require.Contains(t, err.Error(), "download failed", "expected error to contain 'download failed'") + }) + + t.Run("handles missing binary after extraction", func(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("skipping tar.gz test on windows") + } + + ios, _, _, _ := iostreams.Test() + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + localPath := filepath.Join(installDir, "copilot") + + // Create archive without the expected binary name + archive := createTarGzBuffer(t, map[string][]byte{ + "wrong-name": []byte("content"), + }) + + checksum := sha256.Sum256(archive) + checksumHex := hex.EncodeToString(checksum[:]) + archiveName := fmt.Sprintf("copilot-%s-%s.tar.gz", runtime.GOOS, archString()) + checksumFile := fmt.Sprintf("%s %s\n", checksumHex, archiveName) + + reg := &httpmock.Registry{} + reg.Register( + httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"), + httpmock.StringResponse(checksumFile), + ) + reg.Register( + httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)), + httpmock.BinaryResponse(archive), + ) + + httpClient := &http.Client{Transport: reg} + + _, err := downloadCopilot(httpClient, ios, installDir, localPath) + assert.ErrorContains(t, err, "copilot binary unavailable") + }) + + t.Run("downloads and extracts zip on windows", func(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("skipping zip test on non-windows") + } + + ios, _, _, _ := iostreams.Test() + tmpDir := t.TempDir() + installDir := filepath.Join(tmpDir, "copilot") + localPath := filepath.Join(installDir, "copilot.exe") + + binaryContent := []byte("MZ fake exe content") + archive := createZipBuffer(t, map[string][]byte{ + "copilot.exe": binaryContent, + }) + + checksum := sha256.Sum256(archive) + checksumHex := hex.EncodeToString(checksum[:]) + archiveName := fmt.Sprintf("copilot-%s-%s.zip", "win32", archString()) + checksumFile := fmt.Sprintf("%s %s\n", checksumHex, archiveName) + + reg := &httpmock.Registry{} + reg.Register( + httpmock.REST("GET", "github/copilot-cli/releases/latest/download/SHA256SUMS.txt"), + httpmock.StringResponse(checksumFile), + ) + reg.Register( + httpmock.REST("GET", fmt.Sprintf("github/copilot-cli/releases/latest/download/%s", archiveName)), + httpmock.BinaryResponse(archive), + ) + + httpClient := &http.Client{Transport: reg} + + path, err := downloadCopilot(httpClient, ios, installDir, localPath) + require.NoError(t, err, "downloadCopilot() error") + require.Equal(t, localPath, path, "downloadCopilot() path mismatch") + }) +} diff --git a/pkg/cmd/root/extension_registration_test.go b/pkg/cmd/root/extension_registration_test.go new file mode 100644 index 00000000000..90b836e4a47 --- /dev/null +++ b/pkg/cmd/root/extension_registration_test.go @@ -0,0 +1,97 @@ +package root + +import ( + "testing" + + "github.com/cli/cli/v2/internal/browser" + "github.com/cli/cli/v2/internal/config" + "github.com/cli/cli/v2/internal/gh" + "github.com/cli/cli/v2/pkg/cmdutil" + "github.com/cli/cli/v2/pkg/extensions" + "github.com/cli/cli/v2/pkg/iostreams" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestNewCmdRoot_ExtensionRegistration(t *testing.T) { + tests := []struct { + name string + extensions []string + wantRegistered []string + wantSkipped []string + }{ + { + name: "extension conflicts with core command 'copilot'", + extensions: []string{"copilot"}, + wantSkipped: []string{"copilot"}, + wantRegistered: []string{}, + }, + { + name: "extension does not conflict with any core command", + extensions: []string{"my-custom-extension"}, + wantSkipped: []string{}, + wantRegistered: []string{"my-custom-extension"}, + }, + { + name: "extension that conflicts with a core command's alias", + extensions: []string{"agent"}, + wantSkipped: []string{"agent"}, + wantRegistered: []string{}, + }, + { + name: "multiple extensions with some conflicts", + extensions: []string{"pr", "custom-ext", "issue", "another-ext"}, + wantSkipped: []string{"pr", "issue"}, + wantRegistered: []string{"custom-ext", "another-ext"}, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ios, _, _, _ := iostreams.Test() + + var extMocks []extensions.Extension + for _, extName := range tt.extensions { + extMocks = append(extMocks, &extensions.ExtensionMock{ + NameFunc: func() string { + return extName + }, + }) + } + + em := &extensions.ExtensionManagerMock{ + ListFunc: func() []extensions.Extension { + return extMocks + }, + } + + f := &cmdutil.Factory{ + IOStreams: ios, + Config: func() (gh.Config, error) { + return config.NewBlankConfig(), nil + }, + Browser: &browser.Stub{}, + ExtensionManager: em, + } + + cmd, err := NewCmdRoot(f, "", "") + require.NoError(t, err) + + // Verify skipped extensions (should find core command registered, not extension) + for _, extName := range tt.wantSkipped { + foundCmd, _, findErr := cmd.Find([]string{extName}) + assert.NoError(t, findErr, "command %q should be found", extName) + assert.NotNil(t, foundCmd, "command %q should exist", extName) + assert.NotEqual(t, "extension", foundCmd.GroupID, "command %q should be core command, not extension", extName) + } + + // Verify registered extensions (should find extension command registered) + for _, extName := range tt.wantRegistered { + foundCmd, _, findErr := cmd.Find([]string{extName}) + assert.NoError(t, findErr, "extension %q should be found", extName) + assert.NotNil(t, foundCmd, "extension %q should exist", extName) + assert.Equal(t, "extension", foundCmd.GroupID, "command %q should be extension command", extName) + } + }) + } +} diff --git a/pkg/cmd/root/root.go b/pkg/cmd/root/root.go index 0a4f04e35e6..de359bd3eca 100644 --- a/pkg/cmd/root/root.go +++ b/pkg/cmd/root/root.go @@ -19,6 +19,7 @@ import ( codespaceCmd "github.com/cli/cli/v2/pkg/cmd/codespace" completionCmd "github.com/cli/cli/v2/pkg/cmd/completion" configCmd "github.com/cli/cli/v2/pkg/cmd/config" + copilotCmd "github.com/cli/cli/v2/pkg/cmd/copilot" extensionCmd "github.com/cli/cli/v2/pkg/cmd/extension" "github.com/cli/cli/v2/pkg/cmd/factory" gistCmd "github.com/cli/cli/v2/pkg/cmd/gist" @@ -131,7 +132,6 @@ func NewCmdRoot(f *cmdutil.Factory, version, buildDate string) (*cobra.Command, cmd.AddCommand(authCmd.NewCmdAuth(f)) cmd.AddCommand(attestationCmd.NewCmdAttestation(f)) cmd.AddCommand(configCmd.NewCmdConfig(f)) - cmd.AddCommand(creditsCmd.NewCmdCredits(f, nil)) cmd.AddCommand(gistCmd.NewCmdGist(f)) cmd.AddCommand(gpgKeyCmd.NewCmdGPGKey(f)) cmd.AddCommand(completionCmd.NewCmdCompletion(f.IOStreams)) @@ -140,11 +140,15 @@ func NewCmdRoot(f *cmdutil.Factory, version, buildDate string) (*cobra.Command, cmd.AddCommand(secretCmd.NewCmdSecret(f)) cmd.AddCommand(variableCmd.NewCmdVariable(f)) cmd.AddCommand(sshKeyCmd.NewCmdSSHKey(f)) - cmd.AddCommand(statusCmd.NewCmdStatus(f, nil)) cmd.AddCommand(codespaceCmd.NewCmdCodespace(f)) cmd.AddCommand(projectCmd.NewCmdProject(f)) cmd.AddCommand(previewCmd.NewCmdPreview(f)) + // Root commands with standalone functionality and no subcommands + cmd.AddCommand(copilotCmd.NewCmdCopilot(f, nil)) + cmd.AddCommand(statusCmd.NewCmdStatus(f, nil)) + cmd.AddCommand(creditsCmd.NewCmdCredits(f, nil)) + // below here at the commands that require the "intelligent" BaseRepo resolver repoResolvingCmdFactory := *f repoResolvingCmdFactory.BaseRepo = factory.SmartBaseRepoFunc(f) @@ -179,6 +183,13 @@ func NewCmdRoot(f *cmdutil.Factory, version, buildDate string) (*cobra.Command, em := f.ExtensionManager for _, e := range em.List() { extensionCmd := NewCmdExtension(io, em, e, nil) + // Don't register an extension command if it would + // conflict with a core command. + _, _, err := cmd.Find([]string{extensionCmd.Name()}) + if err == nil { + continue + } + cmd.AddCommand(extensionCmd) } diff --git a/pkg/cmd/run/download/http.go b/pkg/cmd/run/download/http.go index 783c8495e6d..09293b056d6 100644 --- a/pkg/cmd/run/download/http.go +++ b/pkg/cmd/run/download/http.go @@ -10,6 +10,7 @@ import ( "github.com/cli/cli/v2/api" "github.com/cli/cli/v2/internal/ghrepo" "github.com/cli/cli/v2/internal/safepaths" + ghzip "github.com/cli/cli/v2/internal/zip" "github.com/cli/cli/v2/pkg/cmd/run/shared" ) @@ -62,7 +63,7 @@ func downloadArtifact(httpClient *http.Client, url string, destDir safepaths.Abs if err != nil { return fmt.Errorf("error extracting zip archive: %w", err) } - if err := extractZip(zipfile, destDir); err != nil { + if err := ghzip.ExtractZip(zipfile, destDir); err != nil { return fmt.Errorf("error extracting zip archive: %w", err) }