diff --git a/utils/unarchive.go b/utils/unarchive.go index 4e729cae..fe359f45 100644 --- a/utils/unarchive.go +++ b/utils/unarchive.go @@ -4,6 +4,7 @@ import ( "archive/tar" "archive/zip" "compress/gzip" + "fmt" "io" "net/http" "os" @@ -72,49 +73,58 @@ func readData(name string) (buffer []byte, err error) { func ExtractZip(path, artifactPath string) error { zipReader, err := zip.OpenReader(artifactPath) - defer func() { - _ = zipReader.Close() - }() - if err != nil { return ErrExtractZip(err, path) } - buffer := make([]byte, 1<<4) + defer zipReader.Close() + destDir, err := filepath.Abs(path) + if err != nil { + return ErrExtractZip(err, path) + } for _, file := range zipReader.File { + err := func() error { + targetPath, err := filepath.Abs(filepath.Join(destDir, file.Name)) + if err != nil { + return err + } - fd, err := file.Open() - defer func() { - _ = fd.Close() - }() - - if err != nil { - return ErrExtractZip(err, path) - } + if !strings.HasPrefix(targetPath, destDir+string(os.PathSeparator)) && targetPath != destDir { + return fmt.Errorf("zipslip: illegal file path: %s", file.Name) + } - filePath := filepath.Join(path, file.Name) + // CHECK for files to skip (macOS metadata) + if strings.HasPrefix(filepath.Base(targetPath), "._") || filepath.Base(targetPath) == "__MACOSX" { + return nil + } - if file.FileInfo().IsDir() { - err := os.Mkdir(file.Name, file.Mode()) + fd, err := file.Open() if err != nil { - return ErrExtractZip(err, path) + return err } - } else { - openedFile, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode()) - if err != nil { - return ErrExtractZip(err, path) + defer fd.Close() + + if file.FileInfo().IsDir() { + return os.MkdirAll(targetPath, file.Mode()) + } + + if err := os.MkdirAll(filepath.Dir(targetPath), 0755); err != nil { + return err } - _, err = io.CopyBuffer(openedFile, fd, buffer) + + openedFile, err := os.OpenFile(targetPath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode()) if err != nil { - return ErrExtractZip(err, path) + return err } - defer func() { - _ = openedFile.Close() - }() - } + defer openedFile.Close() + _, err = io.Copy(openedFile, fd) + return err + }() + if err != nil { + return ErrExtractZip(err, path) + } } return nil - } func ExtractTarGz(path, downloadfilePath string) error { diff --git a/utils/zip_test.go b/utils/zip_test.go new file mode 100644 index 00000000..a827034c --- /dev/null +++ b/utils/zip_test.go @@ -0,0 +1,139 @@ +package utils + +import ( + "archive/zip" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestExtractZip(t *testing.T) { + tests := []struct { + name string + files map[string]string // fileName: content + wantErr bool + errMatch string + }{ + { + name: "Valid Extraction", + files: map[string]string{ + "test.txt": "hello world", + "subdir/file.txt": "nested content", + }, + wantErr: false, + }, + { + name: "Zip Slip Attack Attempt", + files: map[string]string{ + "../outside.txt": "malicious content", + }, + wantErr: true, + errMatch: "zipslip: illegal file path", + }, + { + name: "Skip macOS Metadata", + files: map[string]string{ + "__MACOSX/secret": "should skip", + "._metadata": "should skip", + "realfile.txt": "should keep", + }, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "extract-test-*") + if err != nil { + t.Fatal(err) + } + defer os.RemoveAll(tmpDir) + zipPath := filepath.Join(t.TempDir(), "test.zip") + f, err := os.Create(zipPath) + if err != nil { + t.Fatal(err) + } + + writer := zip.NewWriter(f) + for name, content := range tt.files { + w, err := writer.Create(name) + if err != nil { + continue + } + w.Write([]byte(content)) + } + writer.Close() + f.Close() + + err = ExtractZip(tmpDir, zipPath) + if (err != nil) != tt.wantErr { + t.Fatalf("ExtractZip() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr && tt.errMatch != "" { + if !strings.Contains(err.Error(), tt.errMatch) { + t.Errorf("error %q does not match %q", err.Error(), tt.errMatch) + } + return + } + + for name, expectedContent := range tt.files { + if strings.HasPrefix(filepath.Base(name), "._") || filepath.Base(name) == "__MACOSX" { + continue + } + + path := filepath.Join(tmpDir, name) + gotContent, err := os.ReadFile(path) + if err != nil { + t.Errorf("Expected file %s missing: %v", name, err) + continue + } + if string(gotContent) != expectedContent { + t.Errorf("File %s: got %q, want %q", name, string(gotContent), expectedContent) + } + } + + }) + } +} + +func TestExtractZip_Destination(t *testing.T) { + destDir, _ := os.MkdirTemp("", "correct-dest-*") + defer os.RemoveAll(destDir) + + cwd, _ := os.Getwd() + + subDirName := "target-subfolder" + zipFile, _ := os.CreateTemp("", "test-*.zip") + writer := zip.NewWriter(zipFile) + + header := &zip.FileHeader{ + Name: subDirName + "/", + Method: zip.Store, + } + header.SetMode(0755) + _, _ = writer.CreateHeader(header) + + f, _ := writer.Create(filepath.Join(subDirName, "file.txt")) + f.Write([]byte("content")) + + writer.Close() + zipFile.Close() + defer os.Remove(zipFile.Name()) + + extractionErr := ExtractZip(destDir, zipFile.Name()) + wrongPath := filepath.Join(cwd, subDirName) + if _, err := os.Stat(wrongPath); err == nil { + t.Errorf("BUG FOUND: Folder was created in CWD: %s", wrongPath) + os.RemoveAll(wrongPath) + } + + rightPath := filepath.Join(destDir, subDirName) + if _, err := os.Stat(rightPath); os.IsNotExist(err) { + t.Errorf("FAILURE: Folder was NOT created in destination: %s", rightPath) + } + if extractionErr != nil { + t.Fatalf("Extraction failed: %v", extractionErr) + } + +}