diff --git a/pkg/dotc1z/file.go b/pkg/dotc1z/file.go index b4f9d72a0..4b604a3b2 100644 --- a/pkg/dotc1z/file.go +++ b/pkg/dotc1z/file.go @@ -7,7 +7,6 @@ import ( "os" "path/filepath" "runtime" - "syscall" "github.com/klauspost/compress/zstd" "go.uber.org/zap" @@ -70,24 +69,32 @@ func saveC1z(dbFilePath string, outputFilePath string, encoderConcurrency int) e } defer func() { if dbFile != nil { - err = dbFile.Close() - if err != nil { - zap.L().Error("failed to close db file", zap.Error(err)) + if closeErr := dbFile.Close(); closeErr != nil { + zap.L().Error("failed to close db file", zap.Error(closeErr)) } } }() - outFile, err := os.OpenFile(outputFilePath, os.O_RDWR|os.O_CREATE|syscall.O_TRUNC, 0644) + // Write to temp file first, then atomic rename on success. + // This ensures outputFilePath never contains partial/corrupt data. + // Use CreateTemp for unique filename to prevent concurrent writer races. + outputDir := filepath.Dir(outputFilePath) + outputBase := filepath.Base(outputFilePath) + outFile, err := os.CreateTemp(outputDir, outputBase+".tmp-*") if err != nil { return err } + tmpPath := outFile.Name() + + // Clean up temp file on any failure defer func() { if outFile != nil { - err = outFile.Close() - if err != nil { - zap.L().Error("failed to close out file", zap.Error(err)) + if closeErr := outFile.Close(); closeErr != nil { + zap.L().Error("failed to close temp file", zap.Error(closeErr)) } } + // Remove temp file if it exists (no-op if rename succeeded) + _ = os.Remove(tmpPath) }() // Write the magic file header @@ -125,20 +132,26 @@ func saveC1z(dbFilePath string, outputFilePath string, encoderConcurrency int) e err = outFile.Sync() if err != nil { - return fmt.Errorf("failed to sync out file: %w", err) + return fmt.Errorf("failed to sync temp file: %w", err) } err = outFile.Close() if err != nil { - return fmt.Errorf("failed to close out file: %w", err) + return fmt.Errorf("failed to close temp file: %w", err) } - outFile = nil + outFile = nil // Prevent double-close in defer err = dbFile.Close() if err != nil { return fmt.Errorf("failed to close db file: %w", err) } - dbFile = nil + dbFile = nil // Prevent double-close in defer + + // Atomic rename: outputFilePath now has complete, valid data + // This is the only point where outputFilePath is modified + if err = os.Rename(tmpPath, outputFilePath); err != nil { + return fmt.Errorf("failed to rename temp file to output: %w", err) + } return nil } diff --git a/pkg/dotc1z/file_test.go b/pkg/dotc1z/file_test.go index 3e6756f4b..346748616 100644 --- a/pkg/dotc1z/file_test.go +++ b/pkg/dotc1z/file_test.go @@ -1,47 +1,44 @@ package dotc1z import ( + "fmt" "io" "os" "path/filepath" "testing" "github.com/stretchr/testify/require" - "google.golang.org/grpc/codes" - "google.golang.org/grpc/status" ) -func TestLoadC1z(t *testing.T) { - tmpDir := t.TempDir() - - t.Run("temp directory cleanup on error", func(t *testing.T) { - // Create a file that will cause an error during decoding - invalidFile := filepath.Join(tmpDir, "invalid2.c1z") - err := os.WriteFile(invalidFile, []byte("invalid"), 0600) - require.NoError(t, err) - defer os.Remove(invalidFile) - - // Try to load it - should fail and clean up temp dir - dbPath, err := loadC1z(invalidFile, tmpDir) - require.Error(t, err) - require.Empty(t, dbPath) - }) - - t.Run("custom tmpDir", func(t *testing.T) { - customTmpDir := filepath.Join(tmpDir, "custom") - err := os.MkdirAll(customTmpDir, 0755) - require.NoError(t, err) - defer os.RemoveAll(customTmpDir) - - nonExistentPath := filepath.Join(tmpDir, "nonexistent2.c1z") - dbPath, err := loadC1z(nonExistentPath, customTmpDir) - require.NoError(t, err) - require.NotEmpty(t, dbPath) - require.FileExists(t, dbPath) - - // Verify it was created in the custom tmpDir - require.Contains(t, dbPath, customTmpDir) - }) +func BenchmarkSaveC1z(b *testing.B) { + tmpDir := b.TempDir() + + // Create test data of various sizes + sizes := []int{1024, 100 * 1024, 1024 * 1024} // 1KB, 100KB, 1MB + + for _, size := range sizes { + b.Run(fmt.Sprintf("size_%d", size), func(b *testing.B) { + testData := make([]byte, size) + for i := range testData { + testData[i] = byte(i % 256) + } + dbFile := filepath.Join(tmpDir, fmt.Sprintf("bench_%d.db", size)) + err := os.WriteFile(dbFile, testData, 0600) + if err != nil { + b.Fatal(err) + } + + outputFile := filepath.Join(tmpDir, fmt.Sprintf("bench_%d.c1z", size)) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + err := saveC1z(dbFile, outputFile, 1) + if err != nil { + b.Fatal(err) + } + } + }) + } } func TestSaveC1z(t *testing.T) { @@ -79,101 +76,69 @@ func TestSaveC1z(t *testing.T) { require.NoError(t, err) require.Equal(t, testData, decodedData) }) +} - t.Run("save with empty output path returns error", func(t *testing.T) { - dbFile := filepath.Join(tmpDir, "test.db") - err := os.WriteFile(dbFile, []byte(""), 0600) - require.NoError(t, err) - defer os.Remove(dbFile) - - err = saveC1z(dbFile, "", 1) - require.Error(t, err) - require.True(t, status.Code(err) == codes.InvalidArgument) - require.Contains(t, err.Error(), "output file path not configured") - }) - - t.Run("save with non-existent db file returns error", func(t *testing.T) { - nonExistentDb := filepath.Join(tmpDir, "nonexistent.db") - outputFile := filepath.Join(tmpDir, "output.c1z") - - err := saveC1z(nonExistentDb, outputFile, 1) - require.Error(t, err) - }) - - t.Run("save overwrites existing file", func(t *testing.T) { - testData1 := []byte("first content") - dbFile1 := filepath.Join(tmpDir, "overwrite1.db") - err := os.WriteFile(dbFile1, testData1, 0600) - require.NoError(t, err) - defer os.Remove(dbFile1) - - outputFile := filepath.Join(tmpDir, "overwrite.c1z") - err = saveC1z(dbFile1, outputFile, 1) - require.NoError(t, err) - defer os.Remove(outputFile) - - // Get the size of the first file - stat1, err := os.Stat(outputFile) - require.NoError(t, err) - size1 := stat1.Size() +// TestSaveC1zAtomicWrite verifies that saveC1z uses atomic writes: +// 1. Output file is never partially written (either old data or new data, never corrupt). +// 2. Temp files are cleaned up on failure. +// 3. Existing output file is preserved if saveC1z fails. +func TestSaveC1zAtomicWrite(t *testing.T) { + tmpDir := t.TempDir() - // Save different content to the same file - testData2 := []byte("second content - different") - dbFile2 := filepath.Join(tmpDir, "overwrite2.db") - err = os.WriteFile(dbFile2, testData2, 0600) + t.Run("existing output preserved on failure", func(t *testing.T) { + // Create initial valid c1z + initialData := []byte("initial database content") + dbFile := filepath.Join(tmpDir, "initial.db") + err := os.WriteFile(dbFile, initialData, 0600) require.NoError(t, err) - defer os.Remove(dbFile2) - err = saveC1z(dbFile2, outputFile, 1) + outputFile := filepath.Join(tmpDir, "output.c1z") + err = saveC1z(dbFile, outputFile, 1) require.NoError(t, err) - // Verify the file was overwritten - stat2, err := os.Stat(outputFile) + // Read the valid output + originalContent, err := os.ReadFile(outputFile) require.NoError(t, err) - // Size might be different due to compression, but file should exist and be valid - require.NotEqual(t, size1, stat2.Size()) + require.NotEmpty(t, originalContent) - // Verify the content is the new content - f, err := os.Open(outputFile) - require.NoError(t, err) - defer f.Close() + // Now try to saveC1z with non-existent source - should fail + nonExistentDb := filepath.Join(tmpDir, "does_not_exist.db") + err = saveC1z(nonExistentDb, outputFile, 1) + require.Error(t, err) - decoder, err := NewDecoder(f) + // Output file should be UNCHANGED (still has original content) + afterContent, err := os.ReadFile(outputFile) require.NoError(t, err) - defer decoder.Close() + require.Equal(t, originalContent, afterContent, "output file should be unchanged after failed saveC1z") - decodedData, err := io.ReadAll(decoder) - require.NoError(t, err) - require.Equal(t, testData2, decodedData) + // Verify it's still valid + _, err = loadC1z(outputFile, tmpDir) + require.NoError(t, err, "output file should still be loadable after failed saveC1z") }) - t.Run("save empty db file", func(t *testing.T) { - emptyDbFile := filepath.Join(tmpDir, "empty.db") - err := os.WriteFile(emptyDbFile, []byte{}, 0600) - require.NoError(t, err) + t.Run("no temp file left on failure", func(t *testing.T) { + // Try to save with non-existent source + nonExistentDb := filepath.Join(tmpDir, "does_not_exist.db") + outputFile := filepath.Join(tmpDir, "output2.c1z") - outputFile := filepath.Join(tmpDir, "empty.c1z") - err = saveC1z(emptyDbFile, outputFile, 1) - require.NoError(t, err) - require.FileExists(t, outputFile) + err := saveC1z(nonExistentDb, outputFile, 1) + require.Error(t, err) - // Verify the file has the correct header - fileData, err := os.ReadFile(outputFile) + // Check no temp files left behind + matches, err := filepath.Glob(filepath.Join(tmpDir, "*.tmp-*")) require.NoError(t, err) - require.True(t, len(fileData) >= len(C1ZFileHeader)) - require.Equal(t, C1ZFileHeader, fileData[:len(C1ZFileHeader)]) + require.Empty(t, matches, "no temp files should be left after failed saveC1z") + }) - // Verify we can decode it (should be empty) - f, err := os.Open(outputFile) - require.NoError(t, err) - defer f.Close() + t.Run("no output file created on failure", func(t *testing.T) { + nonExistentDb := filepath.Join(tmpDir, "does_not_exist.db") + outputFile := filepath.Join(tmpDir, "should_not_exist.c1z") - decoder, err := NewDecoder(f) - require.NoError(t, err) - defer decoder.Close() + err := saveC1z(nonExistentDb, outputFile, 1) + require.Error(t, err) - decodedData, err := io.ReadAll(decoder) - require.NoError(t, err) - require.Empty(t, decodedData) + // Output file should not exist + _, statErr := os.Stat(outputFile) + require.True(t, os.IsNotExist(statErr), "output file should not exist after saveC1z error") }) } diff --git a/pkg/dotc1z/manager/local/local.go b/pkg/dotc1z/manager/local/local.go index 3f3190795..807e9e64f 100644 --- a/pkg/dotc1z/manager/local/local.go +++ b/pkg/dotc1z/manager/local/local.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "os" + "path/filepath" "github.com/grpc-ecosystem/go-grpc-middleware/logging/zap/ctxzap" "go.opentelemetry.io/otel" @@ -162,23 +163,45 @@ func (l *localManager) SaveC1Z(ctx context.Context) error { } defer tmpFile.Close() - dstFile, err := os.Create(l.filePath) + // Write to temp file first, then atomic rename on success. + // This ensures filePath never contains partial/corrupt data. + dstDir := filepath.Dir(l.filePath) + dstBase := filepath.Base(l.filePath) + dstFile, err := os.CreateTemp(dstDir, dstBase+".tmp-*") if err != nil { return err } - defer dstFile.Close() + dstTmpPath := dstFile.Name() + + // Clean up temp file on any failure + defer func() { + if dstFile != nil { + dstFile.Close() + } + _ = os.Remove(dstTmpPath) + }() size, err := io.Copy(dstFile, tmpFile) if err != nil { return err } - // CRITICAL: Sync to ensure data is written before function returns. + // CRITICAL: Sync to ensure data is written before rename. // This is especially important on ZFS ARC where writes may be cached. if err := dstFile.Sync(); err != nil { return fmt.Errorf("failed to sync destination file: %w", err) } + if err := dstFile.Close(); err != nil { + return fmt.Errorf("failed to close destination file: %w", err) + } + dstFile = nil // Prevent double-close in defer + + // Atomic rename: filePath now has complete, valid data. + if err := os.Rename(dstTmpPath, l.filePath); err != nil { + return fmt.Errorf("failed to rename temp file to destination: %w", err) + } + log.Debug( "successfully saved c1z locally", zap.String("file_path", l.filePath),