diff --git a/internal/collector/http_test.go b/internal/collector/http_test.go index 663ec93..9a8f88e 100644 --- a/internal/collector/http_test.go +++ b/internal/collector/http_test.go @@ -357,7 +357,89 @@ func TestHTTPCollector_MultipleEndpoints(t *testing.T) { } // --------------------------------------------------------------------------- -// 9. splitLabels — quoted values containing commas +// 9. HTTPCollector.Name +// --------------------------------------------------------------------------- + +func TestHTTPCollector_Name(t *testing.T) { + c := NewHTTPCollector(nil, 5*time.Second) + if c.Name() != "http" { + t.Errorf("expected Name() == \"http\", got %q", c.Name()) + } +} + +// --------------------------------------------------------------------------- +// 10. parseMetrics — all type-switch branches +// --------------------------------------------------------------------------- + +// TestParseMetrics_AllTypes exercises the int, int64, int32 and float32 +// branches of parseMetrics, which are not reached through the HTTP path +// because JSON always deserialises numbers as float64. +func TestParseMetrics_AllTypes(t *testing.T) { + col := NewHTTPCollector(nil, 5*time.Second) + + rawMetrics := map[string]interface{}{ + "int_val": int(7), + "int64_val": int64(8), + "int32_val": int32(9), + "float32_val": float32(3.14), + "float64_val": float64(2.71), + "string_val": "skip_me", // should be ignored + } + + metrics := col.parseMetrics("test_ep", rawMetrics) + + // 5 numeric keys → 5 metrics (the string should be skipped). + if len(metrics) != 5 { + t.Fatalf("expected 5 metrics, got %d", len(metrics)) + } + + find := func(name string) *Metric { + for i := range metrics { + if metrics[i].Name == "app_"+name { + return &metrics[i] + } + } + return nil + } + + cases := []struct { + key string + want float64 + }{ + {"int_val", 7}, + {"int64_val", 8}, + {"int32_val", 9}, + } + for _, tc := range cases { + m := find(tc.key) + if m == nil { + t.Errorf("metric app_%s not found", tc.key) + continue + } + if m.Value != tc.want { + t.Errorf("app_%s: expected %v, got %v", tc.key, tc.want, m.Value) + } + if m.Labels["endpoint"] != "test_ep" { + t.Errorf("app_%s: expected endpoint label 'test_ep', got %q", tc.key, m.Labels["endpoint"]) + } + } + + // float32 loses some precision when converted; just check it's close. + f32m := find("float32_val") + if f32m == nil { + t.Error("metric app_float32_val not found") + } else if f32m.Value < 3.0 || f32m.Value > 4.0 { + t.Errorf("app_float32_val: unexpected value %v", f32m.Value) + } + + // string key must NOT appear. + if find("string_val") != nil { + t.Error("app_string_val should have been skipped") + } +} + +// --------------------------------------------------------------------------- +// 11. splitLabels — quoted values containing commas // --------------------------------------------------------------------------- func TestSplitLabels(t *testing.T) { @@ -408,3 +490,35 @@ func TestSplitLabels(t *testing.T) { }) } } + +func TestParsePrometheusLine_EdgeCases(t *testing.T) { + c := &HTTPCollector{} + + t.Run("no value field", func(t *testing.T) { + result := c.parsePrometheusLine("test", "metric_name_only") + if result != nil { + t.Error("expected nil for line with no value") + } + }) + + t.Run("unparseable value", func(t *testing.T) { + result := c.parsePrometheusLine("test", "metric_name notanumber") + if result != nil { + t.Error("expected nil for non-numeric value") + } + }) + + t.Run("malformed labels no closing brace", func(t *testing.T) { + result := c.parsePrometheusLine("test", "metric{label=\"value\" 123") + if result != nil { + t.Error("expected nil for malformed labels") + } + }) + + t.Run("labels with no value after brace", func(t *testing.T) { + result := c.parsePrometheusLine("test", "metric{label=\"value\"}") + if result != nil { + t.Error("expected nil for labels with no value") + } + }) +} diff --git a/internal/config/config.go b/internal/config/config.go index b30bbd1..99ea03b 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -10,10 +10,10 @@ import ( // Config represents the application configuration type Config struct { - Server ServerConfig `json:"server"` - Collector CollectorConfig `json:"collector"` - Shipper ShipperConfig `json:"shipper"` - Endpoints []EndpointConfig `json:"endpoints"` + Server ServerConfig `json:"server"` + Collector CollectorConfig `json:"collector"` + Shipper ShipperConfig `json:"shipper"` + Endpoints []EndpointConfig `json:"endpoints"` } // ServerConfig contains HTTP server settings diff --git a/internal/orchestrator/orchestrator_test.go b/internal/orchestrator/orchestrator_test.go index 3495d28..9ef49e4 100644 --- a/internal/orchestrator/orchestrator_test.go +++ b/internal/orchestrator/orchestrator_test.go @@ -2,6 +2,7 @@ package orchestrator import ( "context" + "fmt" "sync" "testing" "time" @@ -177,3 +178,165 @@ func TestStop(t *testing.T) { t.Fatal("Start did not return after Stop()") } } + +// retryShipper is a shipper that fails for the first N calls then succeeds. +type retryShipper struct { + mu sync.Mutex + callCount int + failUntil int // fail for first N calls + shipped [][]collector.Metric +} + +func (r *retryShipper) Ship(_ context.Context, metrics []collector.Metric) error { + r.mu.Lock() + defer r.mu.Unlock() + r.callCount++ + if r.callCount <= r.failUntil { + return fmt.Errorf("ship error on call %d", r.callCount) + } + r.shipped = append(r.shipped, metrics) + return nil +} + +func (r *retryShipper) Close() error { return nil } + +func (r *retryShipper) calls() int { + r.mu.Lock() + defer r.mu.Unlock() + return r.callCount +} + +func (r *retryShipper) successCount() int { + r.mu.Lock() + defer r.mu.Unlock() + return len(r.shipped) +} + +// TestCollectAndShip_ShipRetry verifies the path where the first Ship call +// fails but the retry succeeds — metrics must still be delivered. +func TestCollectAndShip_ShipRetry(t *testing.T) { + userMetrics := []collector.Metric{ + {Name: "cpu", Value: 1.0, Type: "gauge", Labels: map[string]string{}}, + } + reg := collector.NewRegistry() + reg.Register(&mockCollector{name: "test", metrics: userMetrics}) + + shpr := &retryShipper{failUntil: 1} // fail first call, succeed on retry + + o := NewOrchestrator(reg, shpr, 10*time.Minute) + o.collectAndShip(context.Background()) + + // Ship should have been called exactly twice (original + retry). + if shpr.calls() != 2 { + t.Errorf("expected 2 Ship calls (1 fail + 1 retry), got %d", shpr.calls()) + } + // The retry succeeded, so metrics should have been delivered once. + if shpr.successCount() != 1 { + t.Errorf("expected 1 successful delivery, got %d", shpr.successCount()) + } + // lastShipDuration should be set after a successful retry. + if o.lastShipDuration == 0 { + t.Error("expected lastShipDuration to be non-zero after successful retry") + } +} + +// TestCollectAndShip_ShipRetryFails verifies the path where both the original +// Ship call and the retry fail — no panic, lastShipDuration still set. +func TestCollectAndShip_ShipRetryFails(t *testing.T) { + userMetrics := []collector.Metric{ + {Name: "cpu", Value: 2.0, Type: "gauge", Labels: map[string]string{}}, + } + reg := collector.NewRegistry() + reg.Register(&mockCollector{name: "test", metrics: userMetrics}) + + shpr := &retryShipper{failUntil: 999} // always fail + + o := NewOrchestrator(reg, shpr, 10*time.Minute) + o.collectAndShip(context.Background()) + + // Ship should have been called exactly twice (original + retry). + if shpr.calls() != 2 { + t.Errorf("expected 2 Ship calls (original + retry), got %d", shpr.calls()) + } + // No metrics should have been successfully delivered. + if shpr.successCount() != 0 { + t.Errorf("expected 0 successful deliveries, got %d", shpr.successCount()) + } + // lastShipDuration is still updated even on full failure. + if o.lastShipDuration == 0 { + t.Error("expected lastShipDuration to be non-zero even after failed retry") + } +} + +// TestCollectAndShip_DeadlineWarning verifies that collectAndShip completes +// without panic when the collection duration exceeds 80 % of the interval. +// We can't assert on the log output but we can ensure the cycle still ships. +func TestCollectAndShip_DeadlineWarning(t *testing.T) { + // Use a very short interval so even a trivial collection duration exceeds 80 %. + interval := 1 * time.Nanosecond + + reg := collector.NewRegistry() + reg.Register(&mockCollector{ + name: "slow", + metrics: []collector.Metric{{Name: "m", Value: 1, Type: "gauge", Labels: map[string]string{}}}, + }) + + shpr := &mockShipper{} + o := NewOrchestrator(reg, shpr, interval) + + // collectAndShip must not panic even when the deadline warning fires. + o.collectAndShip(context.Background()) + + if shpr.calls() < 1 { + t.Error("expected at least one Ship call despite deadline warning") + } +} + +// TestCollectAndShip_LastShipDurationIncluded verifies that on the second call +// to collectAndShip the metricsd_ship_duration_seconds internal metric is +// present in the shipped batch (because lastShipDuration was set by the first). +func TestCollectAndShip_LastShipDurationIncluded(t *testing.T) { + reg := collector.NewRegistry() + reg.Register(&mockCollector{ + name: "test", + metrics: []collector.Metric{{Name: "cpu", Value: 1.0, Type: "gauge", Labels: map[string]string{}}}, + }) + + shpr := &mockShipper{} + o := NewOrchestrator(reg, shpr, 10*time.Minute) + + // First cycle — sets lastShipDuration but does NOT include it in shipped metrics. + o.collectAndShip(context.Background()) + + if shpr.calls() != 1 { + t.Fatalf("expected 1 Ship call after first cycle, got %d", shpr.calls()) + } + firstBatch := shpr.firstBatch() + for _, m := range firstBatch { + if m.Name == "metricsd_ship_duration_seconds" { + t.Error("metricsd_ship_duration_seconds should NOT appear in first cycle batch") + } + } + + // Second cycle — lastShipDuration > 0, so ship metric must be included. + o.collectAndShip(context.Background()) + + if shpr.calls() != 2 { + t.Fatalf("expected 2 Ship calls after second cycle, got %d", shpr.calls()) + } + + shpr.mu.Lock() + secondBatch := shpr.shipped[1] + shpr.mu.Unlock() + + found := false + for _, m := range secondBatch { + if m.Name == "metricsd_ship_duration_seconds" { + found = true + break + } + } + if !found { + t.Error("expected metricsd_ship_duration_seconds in second cycle batch") + } +} diff --git a/internal/plugin/config.go b/internal/plugin/config.go index 33f1d3b..544353e 100644 --- a/internal/plugin/config.go +++ b/internal/plugin/config.go @@ -11,10 +11,10 @@ type PluginConfig struct { Name string `json:"name"` Path string `json:"-"` // Set by discovery, not from JSON Args []string `json:"args,omitempty"` - Timeout int `json:"timeout,omitempty"` // Seconds + Timeout int `json:"timeout,omitempty"` // Seconds Env []string `json:"env,omitempty"` WorkingDir string `json:"working_dir,omitempty"` - Enabled *bool `json:"enabled,omitempty"` // Pointer to distinguish unset from false + Enabled *bool `json:"enabled,omitempty"` // Pointer to distinguish unset from false Interval int `json:"interval_seconds,omitempty"` } @@ -45,14 +45,14 @@ type PluginMetric struct { // PluginHealth tracks the runtime health state of a single plugin. // Owned by the Manager, not the plugin itself. type PluginHealth struct { - Name string - Status string // "ok", "failing", "circuit_open" - ConsecutiveFails int - LastError string - LastSuccess time.Time - LastCollect time.Time - LastMetricCount int - CircuitOpenUntil time.Time // Zero means circuit closed + Name string + Status string // "ok", "failing", "circuit_open" + ConsecutiveFails int + LastError string + LastSuccess time.Time + LastCollect time.Time + LastMetricCount int + CircuitOpenUntil time.Time // Zero means circuit closed } // DefaultTimeout is the fallback plugin timeout. diff --git a/internal/plugin/config_test.go b/internal/plugin/config_test.go new file mode 100644 index 0000000..24273b8 --- /dev/null +++ b/internal/plugin/config_test.go @@ -0,0 +1,60 @@ +package plugin + +import ( + "testing" + "time" +) + +func TestPluginConfig_GetTimeout(t *testing.T) { + fallback := 30 * time.Second + + t.Run("returns timeout when set", func(t *testing.T) { + cfg := PluginConfig{Timeout: 10} + got := cfg.GetTimeout(fallback) + want := 10 * time.Second + if got != want { + t.Errorf("GetTimeout: got %v, want %v", got, want) + } + }) + + t.Run("returns fallback when timeout is zero", func(t *testing.T) { + cfg := PluginConfig{Timeout: 0} + got := cfg.GetTimeout(fallback) + if got != fallback { + t.Errorf("GetTimeout: got %v, want fallback %v", got, fallback) + } + }) + + t.Run("returns fallback when timeout is negative", func(t *testing.T) { + cfg := PluginConfig{Timeout: -5} + got := cfg.GetTimeout(fallback) + if got != fallback { + t.Errorf("GetTimeout: got %v, want fallback %v", got, fallback) + } + }) +} + +func TestPluginConfig_IsEnabled(t *testing.T) { + t.Run("nil pointer defaults to true", func(t *testing.T) { + cfg := PluginConfig{Enabled: nil} + if !cfg.IsEnabled() { + t.Error("IsEnabled: expected true when Enabled is nil") + } + }) + + t.Run("explicit true", func(t *testing.T) { + v := true + cfg := PluginConfig{Enabled: &v} + if !cfg.IsEnabled() { + t.Error("IsEnabled: expected true when Enabled is &true") + } + }) + + t.Run("explicit false", func(t *testing.T) { + v := false + cfg := PluginConfig{Enabled: &v} + if cfg.IsEnabled() { + t.Error("IsEnabled: expected false when Enabled is &false") + } + }) +} diff --git a/internal/plugin/exec_plugin.go b/internal/plugin/exec_plugin.go index a77d634..1135059 100644 --- a/internal/plugin/exec_plugin.go +++ b/internal/plugin/exec_plugin.go @@ -18,7 +18,7 @@ import ( ) const defaultMaxOutputBytes = 5 * 1024 * 1024 // 5MB -const maxStderrCapture = 4096 // 4KB +const maxStderrCapture = 4096 // 4KB // ExecPlugin executes a shell script and parses its JSON output. type ExecPlugin struct { diff --git a/internal/plugin/exec_plugin_test.go b/internal/plugin/exec_plugin_test.go index 6a8dc26..1057397 100644 --- a/internal/plugin/exec_plugin_test.go +++ b/internal/plugin/exec_plugin_test.go @@ -151,3 +151,94 @@ func TestExecPlugin_Name(t *testing.T) { // Verify ExecPlugin satisfies collector.Collector interface var _ collector.Collector = (*ExecPlugin)(nil) + +// TestExecPlugin_LastStderr verifies that stderr output is captured and +// returned by LastStderr even when the plugin succeeds. +func TestExecPlugin_LastStderr(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "exec_plugin_stderr_test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + // Plugin that writes to stderr AND stdout so it succeeds. + path := writeTestPlugin(t, tmpDir, "stderr_plugin", + "#!/bin/bash\necho 'diagnostic info' >&2\necho '[{\"name\":\"ok\",\"value\":1}]'\n") + + ep := NewExecPlugin(PluginConfig{Name: "stderr_test", Path: path, Timeout: 5}) + + metrics, err := ep.Collect(context.Background()) + if err != nil { + t.Fatalf("Collect failed: %v", err) + } + if len(metrics) != 1 { + t.Fatalf("expected 1 metric, got %d", len(metrics)) + } + + stderr := ep.LastStderr() + if stderr == "" { + t.Error("expected non-empty LastStderr()") + } + if !containsSubstring(stderr, "diagnostic info") { + t.Errorf("expected 'diagnostic info' in stderr, got %q", stderr) + } +} + +// TestExecPlugin_LastStderr_OnFailure verifies that stderr is also captured +// when the plugin exits with a non-zero status. +func TestExecPlugin_LastStderr_OnFailure(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "exec_plugin_stderr_fail_test") + if err != nil { + t.Fatalf("failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + path := writeTestPlugin(t, tmpDir, "fail_with_stderr", + "#!/bin/bash\necho 'error message' >&2\nexit 1\n") + + ep := NewExecPlugin(PluginConfig{Name: "fail_stderr", Path: path, Timeout: 5}) + + _, err = ep.Collect(context.Background()) + if err == nil { + t.Fatal("expected error from failing plugin") + } + + stderr := ep.LastStderr() + if !containsSubstring(stderr, "error message") { + t.Errorf("expected 'error message' in LastStderr(), got %q", stderr) + } +} + +func containsSubstring(s, sub string) bool { + return len(s) >= len(sub) && (s == sub || len(sub) == 0 || + func() bool { + for i := 0; i <= len(s)-len(sub); i++ { + if s[i:i+len(sub)] == sub { + return true + } + } + return false + }()) +} + +func TestTruncate(t *testing.T) { + tests := []struct { + name string + input string + maxLen int + expected string + }{ + {"shorter than max", "hello", 10, "hello"}, + {"equal to max", "hello", 5, "hello"}, + {"longer than max", "hello world", 5, "hello..."}, + {"empty string", "", 5, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := truncate(tt.input, tt.maxLen) + if result != tt.expected { + t.Errorf("truncate(%q, %d) = %q, want %q", tt.input, tt.maxLen, result, tt.expected) + } + }) + } +} diff --git a/internal/plugin/manager_test.go b/internal/plugin/manager_test.go index 504b281..141c69f 100644 --- a/internal/plugin/manager_test.go +++ b/internal/plugin/manager_test.go @@ -126,3 +126,35 @@ func TestManager_GetHealth(t *testing.T) { } var _ collector.Collector = (*Manager)(nil) + +// TestManager_AddExecPlugin verifies that AddExecPlugin registers the plugin +// and increments the count. +func TestManager_AddExecPlugin(t *testing.T) { + m := NewManager() + ep := NewExecPlugin(PluginConfig{Name: "my_exec_plugin", Path: "/bin/true", Timeout: 5}) + m.AddExecPlugin(ep) + + if m.PluginCount() != 1 { + t.Errorf("expected PluginCount 1 after AddExecPlugin, got %d", m.PluginCount()) + } + + health := m.GetHealth() + if _, ok := health["my_exec_plugin"]; !ok { + t.Error("expected health entry for 'my_exec_plugin'") + } +} + +// TestManager_PluginCount verifies that PluginCount reflects the total number +// of registered plugins (mix of Go and exec plugins). +func TestManager_PluginCount(t *testing.T) { + m := NewManager() + + m.AddGoPlugin("p1", &mockCollector{name: "p1"}) + m.AddGoPlugin("p2", &mockCollector{name: "p2"}) + ep := NewExecPlugin(PluginConfig{Name: "p3", Path: "/bin/true", Timeout: 5}) + m.AddExecPlugin(ep) + + if m.PluginCount() != 3 { + t.Errorf("expected PluginCount 3, got %d", m.PluginCount()) + } +} diff --git a/internal/plugin/security_test.go b/internal/plugin/security_test.go index 09ce4c7..8d6de19 100644 --- a/internal/plugin/security_test.go +++ b/internal/plugin/security_test.go @@ -154,3 +154,35 @@ func TestValidateMetricOutput(t *testing.T) { } }) } + +func TestValidatePluginPath_BadPluginsDir(t *testing.T) { + // Plugins dir that doesn't exist — EvalSymlinks fails + _, err := ValidatePluginPath("/tmp/nonexistent/plugin", "/tmp/nonexistent") + if err == nil { + t.Error("expected error for non-existent plugins dir") + } +} + +func TestValidateMetricOutput_MultipleLabels(t *testing.T) { + // Test with valid metric having multiple labels + metrics := []PluginMetric{ + {Name: "test_metric", Value: 1, Labels: map[string]string{"a": "1", "b": "2", "c": "3"}}, + } + result := ValidateMetricOutput(metrics, "test") + if len(result) != 1 { + t.Errorf("expected 1 metric, got %d", len(result)) + } + if len(result[0].Labels) != 3 { + t.Errorf("expected 3 labels, got %d", len(result[0].Labels)) + } +} + +func TestValidateMetricOutput_NilLabels(t *testing.T) { + metrics := []PluginMetric{ + {Name: "test_metric", Value: 1, Labels: nil}, + } + result := ValidateMetricOutput(metrics, "test") + if len(result) != 1 { + t.Errorf("expected 1 metric, got %d", len(result)) + } +} diff --git a/internal/server/server_test.go b/internal/server/server_test.go index 8a2e8c9..b4a1ea1 100644 --- a/internal/server/server_test.go +++ b/internal/server/server_test.go @@ -213,3 +213,12 @@ func TestNewServer_NilProvider(t *testing.T) { t.Error("expected nil collectors when provider is nil") } } + +func TestServer_ShutdownWithoutStart(t *testing.T) { + srv := NewServer("localhost", 0, nil) + // Shutdown without Start — server field is nil + err := srv.Shutdown(context.Background()) + if err != nil { + t.Errorf("expected nil error, got %v", err) + } +} diff --git a/internal/shipper/file_test.go b/internal/shipper/file_test.go index f489f64..28a2a88 100644 --- a/internal/shipper/file_test.go +++ b/internal/shipper/file_test.go @@ -440,3 +440,30 @@ func TestFileShipper_DefaultFormat(t *testing.T) { t.Errorf("Expected 2 lines for single format, got %d", lineCount) } } + +func TestNewFileShipper_Defaults(t *testing.T) { + tmpDir, err := os.MkdirTemp("", "metricsd-test") + if err != nil { + t.Fatalf("Failed to create temp dir: %v", err) + } + defer os.RemoveAll(tmpDir) + + filePath := filepath.Join(tmpDir, "metrics.json") + + // Pass 0/0/"" to trigger all default branches + shipper, err := NewFileShipper(filePath, 0, 0, "") + if err != nil { + t.Fatalf("Failed to create shipper: %v", err) + } + defer shipper.Close() + + if shipper.maxSizeBytes != 100*1024*1024 { + t.Errorf("expected default maxSizeBytes 100MB, got %d", shipper.maxSizeBytes) + } + if shipper.maxFiles != 5 { + t.Errorf("expected default maxFiles 5, got %d", shipper.maxFiles) + } + if shipper.format != "single" { + t.Errorf("expected default format 'single', got %s", shipper.format) + } +} diff --git a/internal/shipper/http_json.go b/internal/shipper/http_json.go index 5e03c74..04ba4b8 100644 --- a/internal/shipper/http_json.go +++ b/internal/shipper/http_json.go @@ -66,8 +66,8 @@ func NewHTTPJSONShipper(endpoint string, tlsEnabled bool, certFile, keyFile, caF // MetricPayload represents the JSON structure for shipping metrics type MetricPayload struct { - Timestamp int64 `json:"timestamp"` - Metrics []MetricData `json:"metrics"` + Timestamp int64 `json:"timestamp"` + Metrics []MetricData `json:"metrics"` } // MetricData represents a single metric in JSON format diff --git a/internal/shipper/http_json_test.go b/internal/shipper/http_json_test.go index 80fcc75..a9abbe4 100644 --- a/internal/shipper/http_json_test.go +++ b/internal/shipper/http_json_test.go @@ -186,3 +186,34 @@ func TestHTTPJSONShipper_Close(t *testing.T) { t.Errorf("Close() returned unexpected error: %v", err) } } + +// TestNewHTTPJSONShipper_NoTLS verifies that the constructor succeeds with TLS disabled. +func TestNewHTTPJSONShipper_NoTLS(t *testing.T) { + s, err := NewHTTPJSONShipper("http://localhost:9999", false, "", "", "", false, 5*time.Second) + if err != nil { + t.Fatal(err) + } + if s == nil { + t.Fatal("expected non-nil shipper") + } + s.Close() +} + +// TestNewHTTPJSONShipper_WithTLS verifies that the constructor succeeds with a valid self-signed cert. +func TestNewHTTPJSONShipper_WithTLS(t *testing.T) { + certFile, keyFile, cleanup := generateTestCert(t) + defer cleanup() + s, err := NewHTTPJSONShipper("https://localhost:9999", true, certFile, keyFile, "", false, 5*time.Second) + if err != nil { + t.Fatal(err) + } + s.Close() +} + +// TestNewHTTPJSONShipper_BadCert verifies that a missing cert/key pair causes an error. +func TestNewHTTPJSONShipper_BadCert(t *testing.T) { + _, err := NewHTTPJSONShipper("https://localhost:9999", true, "/nonexistent", "/nonexistent", "", false, 5*time.Second) + if err == nil { + t.Error("expected error for bad cert") + } +} diff --git a/internal/shipper/prometheus_test.go b/internal/shipper/prometheus_test.go index 0008697..98b66d7 100644 --- a/internal/shipper/prometheus_test.go +++ b/internal/shipper/prometheus_test.go @@ -272,3 +272,34 @@ func TestConvertToPrometheusMetrics(t *testing.T) { t.Error("expected non-empty result from ConvertToPrometheusMetrics") } } + +// TestNewPrometheusRemoteWriteShipper_NoTLS verifies that the constructor succeeds with TLS disabled. +func TestNewPrometheusRemoteWriteShipper_NoTLS(t *testing.T) { + s, err := NewPrometheusRemoteWriteShipper("http://localhost:9999", false, "", "", "", false, 5*time.Second) + if err != nil { + t.Fatal(err) + } + if s == nil { + t.Fatal("expected non-nil shipper") + } + s.Close() +} + +// TestNewPrometheusRemoteWriteShipper_WithTLS verifies that the constructor succeeds with a valid self-signed cert. +func TestNewPrometheusRemoteWriteShipper_WithTLS(t *testing.T) { + certFile, keyFile, cleanup := generateTestCert(t) + defer cleanup() + s, err := NewPrometheusRemoteWriteShipper("https://localhost:9999", true, certFile, keyFile, "", false, 5*time.Second) + if err != nil { + t.Fatal(err) + } + s.Close() +} + +// TestNewPrometheusRemoteWriteShipper_BadCert verifies that a missing cert/key pair causes an error. +func TestNewPrometheusRemoteWriteShipper_BadCert(t *testing.T) { + _, err := NewPrometheusRemoteWriteShipper("https://localhost:9999", true, "/nonexistent", "/nonexistent", "", false, 5*time.Second) + if err == nil { + t.Error("expected error for bad cert") + } +} diff --git a/internal/shipper/splunk_hec.go b/internal/shipper/splunk_hec.go index e2ff835..688e9bc 100644 --- a/internal/shipper/splunk_hec.go +++ b/internal/shipper/splunk_hec.go @@ -212,4 +212,3 @@ func (s *SplunkHECShipper) logPayloadToFile(payload string) { log.Error().Err(err).Msg("Failed to write to debug log file") } } - diff --git a/internal/shipper/splunk_hec_test.go b/internal/shipper/splunk_hec_test.go index d80bd5f..94fed12 100644 --- a/internal/shipper/splunk_hec_test.go +++ b/internal/shipper/splunk_hec_test.go @@ -38,10 +38,10 @@ func newTestSplunkShipper(t *testing.T, serverURL string) *SplunkHECShipper { // correct path, carries the expected headers, and contains well-formed JSON events. func TestSplunkHECShipper_ShipSuccess(t *testing.T) { var ( - capturedPath string - capturedAuth string - capturedCT string - capturedBody []byte + capturedPath string + capturedAuth string + capturedCT string + capturedBody []byte ) srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { @@ -219,6 +219,47 @@ func TestSplunkHECShipper_Close(t *testing.T) { } } +// TestNewSplunkHECShipper_NoTLS verifies that the constructor succeeds with TLS disabled. +func TestNewSplunkHECShipper_NoTLS(t *testing.T) { + s, err := NewSplunkHECShipper("http://localhost:9999", "token", false, "", "", "", false, 5*time.Second, "") + if err != nil { + t.Fatal(err) + } + if s == nil { + t.Fatal("expected non-nil shipper") + } + s.Close() +} + +// TestNewSplunkHECShipper_WithTLS verifies that the constructor succeeds with a valid self-signed cert. +func TestNewSplunkHECShipper_WithTLS(t *testing.T) { + certFile, keyFile, cleanup := generateTestCert(t) + defer cleanup() + s, err := NewSplunkHECShipper("https://localhost:9999", "token", true, certFile, keyFile, "", false, 5*time.Second, "") + if err != nil { + t.Fatal(err) + } + s.Close() +} + +// TestNewSplunkHECShipper_BadCert verifies that a missing cert/key pair causes an error. +func TestNewSplunkHECShipper_BadCert(t *testing.T) { + _, err := NewSplunkHECShipper("https://localhost:9999", "token", true, "/nonexistent", "/nonexistent", "", false, 5*time.Second, "") + if err == nil { + t.Error("expected error for bad cert") + } +} + +// TestNewSplunkHECShipper_BadCAFile verifies that a bad CA file path causes an error when TLS is enabled. +func TestNewSplunkHECShipper_BadCAFile(t *testing.T) { + certFile, keyFile, cleanup := generateTestCert(t) + defer cleanup() + _, err := NewSplunkHECShipper("https://localhost:9999", "token", true, certFile, keyFile, "/nonexistent/ca.pem", false, 5*time.Second, "") + if err == nil { + t.Error("expected error for bad CA file path") + } +} + // TestSplunkHECShipper_LogPayloadToFile verifies that, when debugLogFile is set, // Ship writes a payload entry that contains the timestamp header and the // metric_name of each shipped metric. diff --git a/internal/shipper/testhelper_test.go b/internal/shipper/testhelper_test.go new file mode 100644 index 0000000..8d315a1 --- /dev/null +++ b/internal/shipper/testhelper_test.go @@ -0,0 +1,53 @@ +package shipper + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "math/big" + "os" + "path/filepath" + "testing" + "time" +) + +func generateTestCert(t *testing.T) (certFile, keyFile string, cleanup func()) { + t.Helper() + tmpDir, err := os.MkdirTemp("", "shipper-tls-test") + if err != nil { + t.Fatal(err) + } + + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + + template := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "test"}, + NotBefore: time.Now(), + NotAfter: time.Now().Add(time.Hour), + } + + certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key) + if err != nil { + t.Fatal(err) + } + + certPath := filepath.Join(tmpDir, "cert.pem") + certOut, _ := os.Create(certPath) + pem.Encode(certOut, &pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + certOut.Close() + + keyDER, _ := x509.MarshalECPrivateKey(key) + keyPath := filepath.Join(tmpDir, "key.pem") + keyOut, _ := os.Create(keyPath) + pem.Encode(keyOut, &pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + keyOut.Close() + + return certPath, keyPath, func() { os.RemoveAll(tmpDir) } +}