diff --git a/cmd/prism-loadtest/cmd/serve.go b/cmd/prism-loadtest/cmd/serve.go new file mode 100644 index 00000000..acbc42db --- /dev/null +++ b/cmd/prism-loadtest/cmd/serve.go @@ -0,0 +1,103 @@ +package cmd + +import ( + "context" + "fmt" + "os" + "os/signal" + "syscall" + "time" + + "github.com/spf13/cobra" + + "github.com/jrepp/prism-data-layer/cmd/prism-loadtest/server" +) + +var ( + servePort int +) + +var serveCmd = &cobra.Command{ + Use: "serve", + Short: "Run prism-loadtest as a server with HTTP API and WebSocket streaming", + Long: `Start prism-loadtest in server mode, exposing: + - HTTP API for starting/stopping load tests + - WebSocket endpoint for real-time metrics streaming + - Embedded dashboard for visualization + +The server allows remote control of load tests and provides real-time +metrics via WebSocket for D3-based visualizations. + +Example: + # Start server on default port (8091) + prism-loadtest serve + + # Start server on custom port + prism-loadtest serve --port 9000 + + # Access dashboard + open http://localhost:8091/dashboard + +API Endpoints: + POST /api/loadtest/start - Start a new load test + POST /api/loadtest/stop/:id - Stop a running test + GET /api/loadtest/status/:id - Get test status + GET /api/loadtest/list - List all tests + WS /ws/metrics/:id - Stream metrics for test + GET /dashboard - Embedded dashboard UI +`, + RunE: runServe, +} + +func init() { + rootCmd.AddCommand(serveCmd) + serveCmd.Flags().IntVar(&servePort, "port", 8091, "HTTP server port") +} + +func runServe(cmd *cobra.Command, args []string) error { + fmt.Printf("Starting prism-loadtest server on port %d...\n", servePort) + fmt.Printf("Dashboard: http://localhost:%d/dashboard\n", servePort) + fmt.Printf("API: http://localhost:%d/api/loadtest/...\n", servePort) + fmt.Printf("WebSocket: ws://localhost:%d/ws/metrics/:testId\n\n", servePort) + + // Create backend configuration + backendConfig := server.BackendConfig{ + RedisAddr: redisAddr, + RedisPassword: redisPassword, + RedisDB: redisDB, + NATSServers: natsServers, + } + + // Create and start server + srv := server.NewServer(servePort, backendConfig) + + // Start server in background + errChan := make(chan error, 1) + go func() { + if err := srv.Start(); err != nil { + errChan <- err + } + }() + + // Wait for interrupt signal + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + + select { + case err := <-errChan: + return fmt.Errorf("server failed to start: %w", err) + case sig := <-sigChan: + fmt.Printf("\nReceived signal %v, shutting down...\n", sig) + } + + // Graceful shutdown + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + if err := srv.Shutdown(ctx); err != nil { + return fmt.Errorf("server shutdown failed: %w", err) + } + + fmt.Println("Server stopped gracefully") + return nil +} diff --git a/cmd/prism-loadtest/server/dashboard.go b/cmd/prism-loadtest/server/dashboard.go new file mode 100644 index 00000000..0748902e --- /dev/null +++ b/cmd/prism-loadtest/server/dashboard.go @@ -0,0 +1,393 @@ +package server + +// dashboardHTML is the embedded dashboard +const dashboardHTML = ` + + + + + Prism Load Test Dashboard + + + + +
+

🚀 Prism Load Test Dashboard

+
Real-Time Performance Monitoring
+
+ +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ + +
+
+ +
+
+ +
+
+ +
+
+
+ Disconnected +
+
+
+ +
+
+
Total Requests
+
0
+
+
+
Throughput
+
0
+
req/s
+
+
+
Success Rate
+
100%
+
+
+
P50 Latency
+
0
+
ms
+
+
+ +
+
+
Throughput Over Time
+
+
+
+
Latency Percentiles
+
+
+
+ + + +` diff --git a/cmd/prism-loadtest/server/executor.go b/cmd/prism-loadtest/server/executor.go new file mode 100644 index 00000000..40c7d7f0 --- /dev/null +++ b/cmd/prism-loadtest/server/executor.go @@ -0,0 +1,414 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "math/rand" + "sync" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + "golang.org/x/time/rate" + + "github.com/jrepp/prism-data-layer/patterns/multicast_registry" + "github.com/jrepp/prism-data-layer/patterns/multicast_registry/backends" +) + +// TestExecutor runs a load test and streams metrics to WebSocket clients +type TestExecutor struct { + testID string + config TestConfig + backendConfig BackendConfig + status string + startedAt time.Time + stoppedAt *time.Time + + // Metrics + registerMetrics *MetricsCollector + enumerateMetrics *MetricsCollector + multicastMetrics *MetricsCollector + + // Counters + registerCount atomic.Int64 + enumerateCount atomic.Int64 + multicastCount atomic.Int64 + identityCounter atomic.Int64 + + // WebSocket clients + clients []*websocket.Conn + clientsMu sync.RWMutex + + // Control + stopChan chan struct{} + stopped atomic.Bool + mu sync.RWMutex +} + +// MetricsMessage is sent to WebSocket clients +type MetricsMessage struct { + Timestamp time.Time `json:"timestamp"` + Throughput float64 `json:"throughput"` + LatencyP50 float64 `json:"latency_p50"` + LatencyP95 float64 `json:"latency_p95"` + LatencyP99 float64 `json:"latency_p99"` + SuccessRate float64 `json:"success_rate"` + TotalRequests int64 `json:"total_requests"` + FailedRequests int64 `json:"failed_requests"` +} + +// NewTestExecutor creates a new test executor +func NewTestExecutor(testID string, config TestConfig, backendConfig BackendConfig) *TestExecutor { + return &TestExecutor{ + testID: testID, + config: config, + backendConfig: backendConfig, + status: "running", + startedAt: time.Now(), + registerMetrics: NewMetricsCollector(), + enumerateMetrics: NewMetricsCollector(), + multicastMetrics: NewMetricsCollector(), + stopChan: make(chan struct{}), + clients: make([]*websocket.Conn, 0), + } +} + +// Run starts the load test +func (e *TestExecutor) Run() { + defer func() { + e.mu.Lock() + e.status = "completed" + now := time.Now() + e.stoppedAt = &now + e.mu.Unlock() + + // Close all WebSocket connections + e.closeAllClients() + }() + + // Parse duration + duration, err := time.ParseDuration(e.config.Duration) + if err != nil { + e.mu.Lock() + e.status = "failed" + e.mu.Unlock() + return + } + + // Setup coordinator + coordinator, err := e.setupCoordinator() + if err != nil { + fmt.Printf("Failed to setup coordinator: %v\n", err) + e.mu.Lock() + e.status = "failed" + e.mu.Unlock() + return + } + defer coordinator.Close() + + // Create context with timeout + ctx, cancel := context.WithTimeout(context.Background(), duration) + defer cancel() + + // Start metrics broadcaster + go e.broadcastMetrics(ctx) + + // Run load test + e.runLoadTest(ctx, coordinator) + + fmt.Printf("Test %s completed\n", e.testID) +} + +// runLoadTest executes the actual load test +func (e *TestExecutor) runLoadTest(ctx context.Context, coordinator *multicast_registry.Coordinator) { + limiter := rate.NewLimiter(rate.Limit(e.config.Rate), e.config.Rate) + rng := rand.New(rand.NewSource(time.Now().UnixNano())) + + var wg sync.WaitGroup + + for { + select { + case <-ctx.Done(): + wg.Wait() + return + case <-e.stopChan: + wg.Wait() + return + default: + } + + // Rate limit + if err := limiter.Wait(ctx); err != nil { + wg.Wait() + return + } + + // Select operation based on percentages + roll := rng.Intn(100) + var operation string + + if roll < e.config.RegisterPct { + operation = "register" + } else if roll < e.config.RegisterPct+e.config.EnumeratePct { + operation = "enumerate" + } else { + operation = "multicast" + } + + // Launch worker + wg.Add(1) + go func(op string) { + defer wg.Done() + + switch op { + case "register": + e.registerCount.Add(1) + e.executeRegister(ctx, coordinator) + case "enumerate": + e.enumerateCount.Add(1) + e.executeEnumerate(ctx, coordinator) + case "multicast": + e.multicastCount.Add(1) + e.executeMulticast(ctx, coordinator) + } + }(operation) + } +} + +// executeRegister performs a register operation +func (e *TestExecutor) executeRegister(ctx context.Context, coordinator *multicast_registry.Coordinator) { + idNum := e.identityCounter.Add(1) + identity := fmt.Sprintf("loadtest-user-%d", idNum) + + metadata := map[string]interface{}{ + "status": "online", + "loadtest": true, + "timestamp": time.Now().Unix(), + "worker_id": idNum % 100, + } + + start := time.Now() + err := coordinator.Register(ctx, identity, metadata, 300*time.Second) + latency := time.Since(start) + + if err != nil { + e.registerMetrics.RecordFailure() + } else { + e.registerMetrics.RecordSuccess(latency) + } +} + +// executeEnumerate performs an enumerate operation +func (e *TestExecutor) executeEnumerate(ctx context.Context, coordinator *multicast_registry.Coordinator) { + filter := multicast_registry.NewFilter(map[string]interface{}{ + "status": "online", + }) + + start := time.Now() + _, err := coordinator.Enumerate(ctx, filter) + latency := time.Since(start) + + if err != nil { + e.enumerateMetrics.RecordFailure() + } else { + e.enumerateMetrics.RecordSuccess(latency) + } +} + +// executeMulticast performs a multicast operation +func (e *TestExecutor) executeMulticast(ctx context.Context, coordinator *multicast_registry.Coordinator) { + filter := multicast_registry.NewFilter(map[string]interface{}{ + "status": "online", + }) + + payload := []byte(fmt.Sprintf(`{"type":"loadtest","timestamp":%d}`, time.Now().Unix())) + + start := time.Now() + _, err := coordinator.Multicast(ctx, filter, payload) + latency := time.Since(start) + + if err != nil { + e.multicastMetrics.RecordFailure() + } else { + e.multicastMetrics.RecordSuccess(latency) + } +} + +// setupCoordinator creates a coordinator with backend connections +func (e *TestExecutor) setupCoordinator() (*multicast_registry.Coordinator, error) { + // Create config + config := multicast_registry.DefaultConfig() + config.DefaultTTL = 300 * time.Second + config.MaxIdentities = 1000000 // Allow large number for load testing + + // Create Redis registry backend + registryBackend, err := backends.NewRedisRegistryBackend( + e.backendConfig.RedisAddr, + e.backendConfig.RedisPassword, + e.backendConfig.RedisDB, + "loadtest:", + ) + if err != nil { + return nil, fmt.Errorf("failed to create Redis backend: %w", err) + } + + // Create NATS messaging backend + messagingBackend, err := backends.NewNATSMessagingBackend(e.backendConfig.NATSServers) + if err != nil { + registryBackend.Close() + return nil, fmt.Errorf("failed to create NATS backend: %w", err) + } + + // Create coordinator + coordinator, err := multicast_registry.NewCoordinator(config, registryBackend, messagingBackend, nil) + if err != nil { + registryBackend.Close() + messagingBackend.Close() + return nil, fmt.Errorf("failed to create coordinator: %w", err) + } + + return coordinator, nil +} + +// broadcastMetrics periodically broadcasts metrics to WebSocket clients +func (e *TestExecutor) broadcastMetrics(ctx context.Context) { + ticker := time.NewTicker(1 * time.Second) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-e.stopChan: + return + case <-ticker.C: + metrics := e.collectMetrics() + e.broadcast(metrics) + } + } +} + +// collectMetrics aggregates current metrics +func (e *TestExecutor) collectMetrics() MetricsMessage { + // Combine metrics from all operations + combined := NewMetricsCollector() + + e.registerMetrics.Mu.Lock() + combined.TotalRequests += e.registerMetrics.TotalRequests + combined.SuccessfulReqs += e.registerMetrics.SuccessfulReqs + combined.FailedReqs += e.registerMetrics.FailedReqs + combined.TotalLatencyNs += e.registerMetrics.TotalLatencyNs + for bucket, count := range e.registerMetrics.LatencyBuckets { + combined.LatencyBuckets[bucket] += count + } + e.registerMetrics.Mu.Unlock() + + e.enumerateMetrics.Mu.Lock() + combined.TotalRequests += e.enumerateMetrics.TotalRequests + combined.SuccessfulReqs += e.enumerateMetrics.SuccessfulReqs + combined.FailedReqs += e.enumerateMetrics.FailedReqs + combined.TotalLatencyNs += e.enumerateMetrics.TotalLatencyNs + for bucket, count := range e.enumerateMetrics.LatencyBuckets { + combined.LatencyBuckets[bucket] += count + } + e.enumerateMetrics.Mu.Unlock() + + e.multicastMetrics.Mu.Lock() + combined.TotalRequests += e.multicastMetrics.TotalRequests + combined.SuccessfulReqs += e.multicastMetrics.SuccessfulReqs + combined.FailedReqs += e.multicastMetrics.FailedReqs + combined.TotalLatencyNs += e.multicastMetrics.TotalLatencyNs + for bucket, count := range e.multicastMetrics.LatencyBuckets { + combined.LatencyBuckets[bucket] += count + } + e.multicastMetrics.Mu.Unlock() + + // Calculate metrics + elapsed := time.Since(e.startedAt) + throughput := float64(combined.TotalRequests) / elapsed.Seconds() + successRate := float64(100) + if combined.TotalRequests > 0 { + successRate = float64(combined.SuccessfulReqs) / float64(combined.TotalRequests) * 100 + } + + p50, p95, p99 := combined.CalculatePercentiles() + + return MetricsMessage{ + Timestamp: time.Now(), + Throughput: throughput, + LatencyP50: float64(p50.Microseconds()) / 1000.0, // Convert to ms + LatencyP95: float64(p95.Microseconds()) / 1000.0, + LatencyP99: float64(p99.Microseconds()) / 1000.0, + SuccessRate: successRate, + TotalRequests: combined.TotalRequests, + FailedRequests: combined.FailedReqs, + } +} + +// broadcast sends metrics to all connected WebSocket clients +func (e *TestExecutor) broadcast(metrics MetricsMessage) { + e.clientsMu.Lock() + defer e.clientsMu.Unlock() + + if len(e.clients) == 0 { + return + } + + data, err := json.Marshal(metrics) + if err != nil { + return + } + + // Send to all clients, removing dead ones + activeClients := make([]*websocket.Conn, 0, len(e.clients)) + for _, client := range e.clients { + if err := client.WriteMessage(websocket.TextMessage, data); err != nil { + // Remove dead client + client.Close() + } else { + activeClients = append(activeClients, client) + } + } + e.clients = activeClients +} + +// AddClient adds a WebSocket client +func (e *TestExecutor) AddClient(conn *websocket.Conn) { + e.clientsMu.Lock() + defer e.clientsMu.Unlock() + e.clients = append(e.clients, conn) +} + +// closeAllClients closes all WebSocket connections +func (e *TestExecutor) closeAllClients() { + e.clientsMu.Lock() + defer e.clientsMu.Unlock() + + for _, client := range e.clients { + client.Close() + } + e.clients = nil +} + +// Stop stops the test +func (e *TestExecutor) Stop() { + if e.stopped.Swap(true) { + return // Already stopped + } + close(e.stopChan) +} + +// GetInfo returns test information +func (e *TestExecutor) GetInfo() TestInfo { + e.mu.RLock() + defer e.mu.RUnlock() + + return TestInfo{ + TestID: e.testID, + Status: e.status, + StartedAt: e.startedAt, + StoppedAt: e.stoppedAt, + Config: e.config, + } +} diff --git a/cmd/prism-loadtest/server/server.go b/cmd/prism-loadtest/server/server.go new file mode 100644 index 00000000..138126dc --- /dev/null +++ b/cmd/prism-loadtest/server/server.go @@ -0,0 +1,292 @@ +package server + +import ( + "context" + "encoding/json" + "fmt" + "net/http" + "sync" + "time" + + "github.com/gorilla/mux" + "github.com/gorilla/websocket" +) + +// Server manages HTTP API and WebSocket connections for load testing +type Server struct { + port int + backendConfig BackendConfig + httpServer *http.Server + router *mux.Router + tests map[string]*TestExecutor + testsMu sync.RWMutex + upgrader websocket.Upgrader +} + +// BackendConfig holds configuration for backend connections +type BackendConfig struct { + RedisAddr string + RedisPassword string + RedisDB int + NATSServers []string +} + +// TestConfig holds configuration for a load test +type TestConfig struct { + Mix string `json:"mix"` // "mixed", "register", etc. + Duration string `json:"duration"` // "60s", "5m" + Rate int `json:"rate"` // req/sec + RegisterPct int `json:"register_pct"` // 0-100 + EnumeratePct int `json:"enumerate_pct"` // 0-100 + MulticastPct int `json:"multicast_pct"` // 0-100 +} + +// TestInfo contains information about a test +type TestInfo struct { + TestID string `json:"test_id"` + Status string `json:"status"` + StartedAt time.Time `json:"started_at"` + StoppedAt *time.Time `json:"stopped_at,omitempty"` + Config TestConfig `json:"config"` +} + +// NewServer creates a new load test server +func NewServer(port int, backendConfig BackendConfig) *Server { + s := &Server{ + port: port, + backendConfig: backendConfig, + router: mux.NewRouter(), + tests: make(map[string]*TestExecutor), + upgrader: websocket.Upgrader{ + CheckOrigin: func(r *http.Request) bool { + return true // Allow all origins for now + }, + }, + } + + s.setupRoutes() + return s +} + +// setupRoutes configures HTTP routes +func (s *Server) setupRoutes() { + // Middleware (must be set before routes) + s.router.Use(corsMiddleware) + s.router.Use(loggingMiddleware) + + // API routes + api := s.router.PathPrefix("/api/loadtest").Subrouter() + api.HandleFunc("/start", s.handleStartTest).Methods("POST", "OPTIONS") + api.HandleFunc("/stop/{testId}", s.handleStopTest).Methods("POST", "OPTIONS") + api.HandleFunc("/status/{testId}", s.handleGetStatus).Methods("GET", "OPTIONS") + api.HandleFunc("/list", s.handleListTests).Methods("GET", "OPTIONS") + + // WebSocket route + s.router.HandleFunc("/ws/metrics/{testId}", s.handleWebSocket) + + // Dashboard route + s.router.HandleFunc("/dashboard", s.handleDashboard).Methods("GET") + s.router.HandleFunc("/", s.handleDashboard).Methods("GET") +} + +// handleStartTest starts a new load test +func (s *Server) handleStartTest(w http.ResponseWriter, r *http.Request) { + var config TestConfig + if err := json.NewDecoder(r.Body).Decode(&config); err != nil { + sendError(w, "invalid_request", "Failed to parse request body", http.StatusBadRequest) + return + } + + // Validate config + if config.Mix == "" { + config.Mix = "mixed" + } + if config.Duration == "" { + config.Duration = "60s" + } + if config.Rate == 0 { + config.Rate = 100 + } + + // Generate test ID with nanosecond precision to avoid collisions + testID := fmt.Sprintf("test-%d", time.Now().UnixNano()) + + // Create test executor + executor := NewTestExecutor(testID, config, s.backendConfig) + + // Store test + s.testsMu.Lock() + s.tests[testID] = executor + s.testsMu.Unlock() + + // Start test in background + go executor.Run() + + // Return test info + info := TestInfo{ + TestID: testID, + Status: "running", + StartedAt: time.Now(), + Config: config, + } + + sendJSON(w, info, http.StatusCreated) +} + +// handleStopTest stops a running test +func (s *Server) handleStopTest(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + testID := vars["testId"] + + s.testsMu.RLock() + executor, ok := s.tests[testID] + s.testsMu.RUnlock() + + if !ok { + sendError(w, "not_found", "Test not found", http.StatusNotFound) + return + } + + executor.Stop() + + info := executor.GetInfo() + sendJSON(w, info, http.StatusOK) +} + +// handleGetStatus gets the status of a test +func (s *Server) handleGetStatus(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + testID := vars["testId"] + + s.testsMu.RLock() + executor, ok := s.tests[testID] + s.testsMu.RUnlock() + + if !ok { + sendError(w, "not_found", "Test not found", http.StatusNotFound) + return + } + + info := executor.GetInfo() + sendJSON(w, info, http.StatusOK) +} + +// handleListTests lists all tests +func (s *Server) handleListTests(w http.ResponseWriter, r *http.Request) { + s.testsMu.RLock() + defer s.testsMu.RUnlock() + + tests := make([]TestInfo, 0) + for _, executor := range s.tests { + tests = append(tests, executor.GetInfo()) + } + + sendJSON(w, tests, http.StatusOK) +} + +// handleWebSocket handles WebSocket connections for metrics streaming +func (s *Server) handleWebSocket(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + testID := vars["testId"] + + s.testsMu.RLock() + executor, ok := s.tests[testID] + s.testsMu.RUnlock() + + if !ok { + http.Error(w, "Test not found", http.StatusNotFound) + return + } + + // Upgrade to WebSocket + conn, err := s.upgrader.Upgrade(w, r, nil) + if err != nil { + fmt.Printf("WebSocket upgrade failed: %v\n", err) + return + } + + // Subscribe to metrics + executor.AddClient(conn) + + fmt.Printf("WebSocket client connected for test %s\n", testID) +} + +// handleDashboard serves the embedded dashboard +func (s *Server) handleDashboard(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Content-Type", "text/html") + w.Write([]byte(dashboardHTML)) +} + +// Start starts the HTTP server +func (s *Server) Start() error { + addr := fmt.Sprintf(":%d", s.port) + + s.httpServer = &http.Server{ + Addr: addr, + Handler: s.router, + ReadTimeout: 15 * time.Second, + WriteTimeout: 15 * time.Second, + IdleTimeout: 60 * time.Second, + } + + fmt.Printf("Server listening on %s\n", addr) + return s.httpServer.ListenAndServe() +} + +// Shutdown gracefully shuts down the server +func (s *Server) Shutdown(ctx context.Context) error { + // Stop all running tests + s.testsMu.Lock() + for _, executor := range s.tests { + executor.Stop() + } + s.testsMu.Unlock() + + // Shutdown HTTP server if it was started + if s.httpServer != nil { + return s.httpServer.Shutdown(ctx) + } + return nil +} + +// sendJSON sends a JSON response +func sendJSON(w http.ResponseWriter, data interface{}, statusCode int) { + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(statusCode) + json.NewEncoder(w).Encode(data) +} + +// sendError sends an error response +func sendError(w http.ResponseWriter, errorType, message string, code int) { + resp := map[string]interface{}{ + "error": errorType, + "message": message, + "code": code, + } + sendJSON(w, resp, code) +} + +// corsMiddleware adds CORS headers +func corsMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, OPTIONS") + w.Header().Set("Access-Control-Allow-Headers", "Content-Type") + + if r.Method == "OPTIONS" { + w.WriteHeader(http.StatusOK) + return + } + + next.ServeHTTP(w, r) + }) +} + +// loggingMiddleware logs HTTP requests +func loggingMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + start := time.Now() + next.ServeHTTP(w, r) + fmt.Printf("%s %s %v\n", r.Method, r.URL.Path, time.Since(start)) + }) +}