From 22625954e940dcf714e57d65971a92ea6d6f6252 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Victor=20Anderss=C3=A9n?= Date: Fri, 13 Feb 2026 11:58:56 +0200 Subject: [PATCH 1/6] test: add comprehensive tests for voice daemon functionality and IPC commands --- cmd/client/voice_daemon_test.go | 377 ++++++++++++++++++++++++++++++++ cmd/voiced/daemon_test.go | 302 +++++++++++++++++++++++++ 2 files changed, 679 insertions(+) create mode 100644 cmd/voiced/daemon_test.go diff --git a/cmd/client/voice_daemon_test.go b/cmd/client/voice_daemon_test.go index ead8839..fbeabd2 100644 --- a/cmd/client/voice_daemon_test.go +++ b/cmd/client/voice_daemon_test.go @@ -1,10 +1,387 @@ package main import ( + "errors" + "fmt" + "os" + "os/exec" + "path/filepath" + "runtime" "strings" + "syscall" "testing" + "time" ) +func makeExecutableFile(t *testing.T, path string, content string) { + t.Helper() + if err := os.WriteFile(path, []byte(content), 0o700); err != nil { + t.Fatalf("write executable %s: %v", path, err) + } +} + +func waitForFile(t *testing.T, path string, timeout time.Duration) string { + t.Helper() + deadline := time.Now().Add(timeout) + for { + data, err := os.ReadFile(path) + if err == nil { + return string(data) + } + if !errors.Is(err, os.ErrNotExist) { + t.Fatalf("read %s: %v", path, err) + } + if time.Now().After(deadline) { + t.Fatalf("timed out waiting for %s", path) + } + time.Sleep(20 * time.Millisecond) + } +} + +func hasLine(content, want string) bool { + normalized := strings.ReplaceAll(content, "\r\n", "\n") + for _, line := range strings.Split(normalized, "\n") { + if line == want { + return true + } + } + return false +} + +func TestStartVoiceDaemonValidation(t *testing.T) { + tests := []struct { + name string + m *chatModel + want string + }{ + { + name: "autostart disabled", + m: &chatModel{}, + want: "voice auto-start disabled", + }, + { + name: "missing auth context", + m: &chatModel{ + voiceAutoStart: true, + }, + want: "missing auth", + }, + { + name: "missing ipc address", + m: &chatModel{ + voiceAutoStart: true, + api: &APIClient{serverURL: "http://server"}, + auth: newTestAuth(), + }, + want: "voice ipc address is required", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + err := tt.m.startVoiceDaemon() + if err == nil || !strings.Contains(err.Error(), tt.want) { + t.Fatalf("startVoiceDaemon() error = %v, want contains %q", err, tt.want) + } + }) + } +} + +func TestStartVoiceDaemonAlreadyRunningNoop(t *testing.T) { + m := &chatModel{ + voiceAutoStart: true, + voiceProc: &voiceAutoProcess{}, + api: &APIClient{serverURL: "http://server"}, + auth: newTestAuth(), + voiceIPCAddr: "/tmp/dialtone-voice.sock", + } + if err := m.startVoiceDaemon(); err != nil { + t.Fatalf("startVoiceDaemon should no-op when already running: %v", err) + } +} + +func TestStartVoiceDaemonLaunchesProcessAndStops(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("shell-script process test is unix-only") + } + + tmp := t.TempDir() + argsPath := filepath.Join(tmp, "args.txt") + envPath := filepath.Join(tmp, "env.txt") + logPath := filepath.Join(tmp, "voiced.log") + binPath := filepath.Join(tmp, "dialtone-voiced") + + script := fmt.Sprintf("#!/bin/sh\nprintf '%%s\\n' \"$@\" > %q\nenv > %q\nsleep 60\n", argsPath, envPath) + makeExecutableFile(t, binPath, script) + + t.Setenv("XDG_ACTIVATION_TOKEN", "launcher-token") + t.Setenv("DESKTOP_STARTUP_ID", "launcher-id") + t.Setenv("DIALTONE_KEEP_ME", "ok") + + m := &chatModel{ + voiceAutoStart: true, + api: &APIClient{serverURL: "http://dialtone.test"}, + auth: &AuthResponse{Token: "secret-token"}, + voiceIPCAddr: "/tmp/dialtone-test.sock", + voicedPath: binPath, + voiceArgs: []string{"--meter"}, + voiceLogPath: logPath, + } + + if err := m.startVoiceDaemon(); err != nil { + t.Fatalf("startVoiceDaemon: %v", err) + } + t.Cleanup(func() { + m.stopVoiceDaemon() + }) + + if m.voiceProc == nil { + t.Fatalf("expected voice process state after start") + } + if !m.voiceAutoStarting { + t.Fatalf("expected voiceAutoStarting=true after start") + } + + argsOut := waitForFile(t, argsPath, 2*time.Second) + for _, arg := range []string{"-server", "http://dialtone.test", "-token", "secret-token", "-ipc", "/tmp/dialtone-test.sock", "--meter"} { + if !hasLine(argsOut, arg) { + t.Fatalf("expected daemon arg %q in:\n%s", arg, argsOut) + } + } + + envOut := waitForFile(t, envPath, 2*time.Second) + if strings.Contains(envOut, "XDG_ACTIVATION_TOKEN=") { + t.Fatalf("expected launcher token removed from daemon environment") + } + if strings.Contains(envOut, "DESKTOP_STARTUP_ID=") { + t.Fatalf("expected desktop startup id removed from daemon environment") + } + if !strings.Contains(envOut, "DIALTONE_KEEP_ME=ok") { + t.Fatalf("expected unrelated env var to be preserved") + } + + m.stopVoiceDaemon() + if m.voiceProc != nil { + t.Fatalf("expected voice process cleared after stop") + } + if m.voiceAutoStarting { + t.Fatalf("expected voiceAutoStarting=false after stop") + } +} + +func TestResolveVoicedPathCandidateOrder(t *testing.T) { + tmp := t.TempDir() + hintPath := filepath.Join(tmp, "hint-voiced") + envPrimary := filepath.Join(tmp, "env-primary-voiced") + envSecondary := filepath.Join(tmp, "env-secondary-voiced") + + makeExecutableFile(t, hintPath, "#!/bin/sh\nexit 0\n") + makeExecutableFile(t, envPrimary, "#!/bin/sh\nexit 0\n") + makeExecutableFile(t, envSecondary, "#!/bin/sh\nexit 0\n") + + t.Setenv("DIALTONE_VOICE_DAEMON", envPrimary) + t.Setenv("DIALTONE_VOICED", envSecondary) + + path, err := resolveVoicedPath(hintPath) + if err != nil { + t.Fatalf("resolveVoicedPath hint: %v", err) + } + if path != hintPath { + t.Fatalf("expected hint path %q, got %q", hintPath, path) + } + + path, err = resolveVoicedPath("") + if err != nil { + t.Fatalf("resolveVoicedPath env primary: %v", err) + } + if path != envPrimary { + t.Fatalf("expected primary env path %q, got %q", envPrimary, path) + } + + if err := os.Remove(envPrimary); err != nil { + t.Fatalf("remove env primary: %v", err) + } + path, err = resolveVoicedPath("") + if err != nil { + t.Fatalf("resolveVoicedPath env secondary: %v", err) + } + if path != envSecondary { + t.Fatalf("expected secondary env path %q, got %q", envSecondary, path) + } +} + +func TestResolveVoicedPathNotFound(t *testing.T) { + tmp := t.TempDir() + oldWD, err := os.Getwd() + if err != nil { + t.Fatalf("getwd: %v", err) + } + if err := os.Chdir(tmp); err != nil { + t.Fatalf("chdir: %v", err) + } + t.Cleanup(func() { + _ = os.Chdir(oldWD) + }) + + t.Setenv("DIALTONE_VOICE_DAEMON", "") + t.Setenv("DIALTONE_VOICED", "") + t.Setenv("PATH", tmp) + + path, err := resolveVoicedPath("") + if err == nil { + t.Fatalf("expected not found error, got path %q", path) + } + if path != "" { + t.Fatalf("expected empty path on error, got %q", path) + } +} + +func TestOpenVoiceLogFileBehavior(t *testing.T) { + m := &chatModel{} + file, err := m.openVoiceLogFile() + if err != nil { + t.Fatalf("openVoiceLogFile: %v", err) + } + if file != nil { + t.Fatalf("expected no log file when debug disabled and path unset") + } + + t.Setenv("XDG_CACHE_HOME", t.TempDir()) + m.voiceDebug = true + file, err = m.openVoiceLogFile() + if err != nil { + t.Fatalf("openVoiceLogFile debug: %v", err) + } + if file == nil { + t.Fatalf("expected log file when debug is enabled") + } + _ = file.Close() + if !strings.HasSuffix(filepath.ToSlash(m.voiceLogPath), "dialtone/voiced.log") { + t.Fatalf("unexpected default log path: %q", m.voiceLogPath) + } + if _, err := os.Stat(m.voiceLogPath); err != nil { + t.Fatalf("stat default log path: %v", err) + } + + explicit := filepath.Join(t.TempDir(), "logs", "voice.log") + m2 := &chatModel{voiceLogPath: explicit} + file, err = m2.openVoiceLogFile() + if err != nil { + t.Fatalf("openVoiceLogFile explicit: %v", err) + } + if file == nil { + t.Fatalf("expected explicit log file to open") + } + _ = file.Close() + if m2.voiceLogPath != explicit { + t.Fatalf("expected explicit log path preserved, got %q", m2.voiceLogPath) + } +} + +func TestResolveExecutableCandidate(t *testing.T) { + tmp := t.TempDir() + path := filepath.Join(tmp, "voiced") + makeExecutableFile(t, path, "#!/bin/sh\nexit 0\n") + + if got := resolveExecutableCandidate(path); got != path { + t.Fatalf("expected executable candidate %q, got %q", path, got) + } + if got := resolveExecutableCandidate(filepath.Join(tmp, "missing")); got != "" { + t.Fatalf("expected missing candidate to resolve empty string, got %q", got) + } +} + +func TestFileIsExecutable(t *testing.T) { + tmp := t.TempDir() + path := filepath.Join(tmp, "plain-file") + if err := os.WriteFile(path, []byte("x"), 0o600); err != nil { + t.Fatalf("write file: %v", err) + } + + if runtime.GOOS == "windows" { + if !fileIsExecutable(path) { + t.Fatalf("expected windows file to be treated as executable") + } + return + } + + if fileIsExecutable(path) { + t.Fatalf("expected non-executable mode to be false") + } + if err := os.Chmod(path, 0o755); err != nil { + t.Fatalf("chmod executable: %v", err) + } + if !fileIsExecutable(path) { + t.Fatalf("expected executable mode to be true") + } +} + +func TestIsVoiceIPCNotRunning(t *testing.T) { + tests := []struct { + name string + err error + want bool + }{ + {name: "nil", err: nil, want: false}, + {name: "os err not exist", err: os.ErrNotExist, want: true}, + {name: "enoent", err: syscall.ENOENT, want: true}, + {name: "conn refused", err: syscall.ECONNREFUSED, want: true}, + {name: "string no such file", err: errors.New("No such file or directory"), want: true}, + {name: "string conn refused", err: errors.New("connection refused by peer"), want: true}, + {name: "other", err: errors.New("permission denied"), want: false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := isVoiceIPCNotRunning(tt.err) + if got != tt.want { + t.Fatalf("isVoiceIPCNotRunning(%v) = %v, want %v", tt.err, got, tt.want) + } + }) + } +} + +func TestSignalCommandShutdownNilAndNoProcess(t *testing.T) { + if signalCommandShutdown(nil) { + t.Fatalf("expected nil command to return false") + } + cmd := exec.Command(os.Args[0]) + if signalCommandShutdown(cmd) { + t.Fatalf("expected command without process to return false") + } +} + +func TestStopCommandGracefullyKillsUnresponsiveProcess(t *testing.T) { + if runtime.GOOS == "windows" { + t.Skip("signal behavior test is unix-only") + } + + tmp := t.TempDir() + binPath := filepath.Join(tmp, "ignore-int") + makeExecutableFile(t, binPath, "#!/bin/sh\ntrap '' INT\nwhile true; do sleep 1; done\n") + + cmd := exec.Command(binPath) + if err := cmd.Start(); err != nil { + t.Fatalf("start helper process: %v", err) + } + + start := time.Now() + stopCommandGracefully(cmd, 150*time.Millisecond) + if cmd.ProcessState == nil { + t.Fatalf("expected process state after graceful stop") + } + status, ok := cmd.ProcessState.Sys().(syscall.WaitStatus) + if !ok { + t.Fatalf("expected unix wait status, got %T", cmd.ProcessState.Sys()) + } + if !status.Exited() && !status.Signaled() { + t.Fatalf("expected process to terminate (exit or signal), status=%v", status) + } + if elapsed := time.Since(start); elapsed > 2*time.Second { + t.Fatalf("expected forced shutdown quickly, took %v", elapsed) + } +} + func TestVoiceDaemonEnvFiltersLauncherActivationVars(t *testing.T) { t.Setenv("XDG_ACTIVATION_TOKEN", "token-from-launcher") t.Setenv("DESKTOP_STARTUP_ID", "launcher-startup-id") diff --git a/cmd/voiced/daemon_test.go b/cmd/voiced/daemon_test.go new file mode 100644 index 0000000..2101095 --- /dev/null +++ b/cmd/voiced/daemon_test.go @@ -0,0 +1,302 @@ +package main + +import ( + "context" + "errors" + "path/filepath" + "testing" + "time" + + "github.com/Avicted/dialtone/internal/ipc" + "github.com/pion/webrtc/v4" +) + +func TestNewVoiceDaemonDefaultsVADThreshold(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, 0, false) + if d.vadThreshold != defaultVADThreshold { + t.Fatalf("expected default VAD threshold %d, got %d", defaultVADThreshold, d.vadThreshold) + } + + d = newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, 1234, false) + if d.vadThreshold != 1234 { + t.Fatalf("expected explicit VAD threshold to be preserved, got %d", d.vadThreshold) + } +} + +func TestWSBackoffClamp(t *testing.T) { + tests := []struct { + attempt int + want time.Duration + }{ + {attempt: 0, want: 2 * time.Second}, + {attempt: 1, want: 2 * time.Second}, + {attempt: 2, want: 4 * time.Second}, + {attempt: 5, want: 32 * time.Second}, + {attempt: 99, want: 32 * time.Second}, + } + + for _, tt := range tests { + if got := wsBackoff(tt.attempt); got != tt.want { + t.Fatalf("wsBackoff(%d) = %v, want %v", tt.attempt, got, tt.want) + } + } +} + +func TestHandleIPCCommandVoiceJoinDeferredWithoutWebsocket(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + d.local = "alice" + + resp, err := d.handleIPCCommand(context.Background(), ipc.Message{Cmd: ipc.CommandVoiceJoin, Room: " room-1 "}) + if err != nil { + t.Fatalf("handleIPCCommand voice join: %v", err) + } + if resp.Event != ipc.EventVoiceConnected { + t.Fatalf("expected %q event, got %q", ipc.EventVoiceConnected, resp.Event) + } + if resp.Room != "room-1" { + t.Fatalf("expected joined room room-1, got %q", resp.Room) + } + if got := d.currentRoom(); got != "room-1" { + t.Fatalf("expected daemon room to be set, got %q", got) + } + + d.mu.Lock() + _, hasLocal := d.memb["alice"] + d.mu.Unlock() + if !hasLocal { + t.Fatalf("expected local user to be included in voice members after join") + } +} + +func TestHandleIPCCommandVoiceJoinClearsRoomOnSendError(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + d.setWS(&WSClient{closed: true}) + + resp, err := d.handleIPCCommand(context.Background(), ipc.Message{Cmd: ipc.CommandVoiceJoin, Room: "room-1"}) + if err == nil { + t.Fatalf("expected voice join send failure") + } + if resp.Event != "" { + t.Fatalf("expected empty response event on join failure, got %q", resp.Event) + } + if got := d.currentRoom(); got != "" { + t.Fatalf("expected room to be cleared on join failure, got %q", got) + } + + d.mu.Lock() + membersLen := len(d.memb) + d.mu.Unlock() + if membersLen != 0 { + t.Fatalf("expected voice members cleared on join failure, got %d members", membersLen) + } +} + +func TestHandleIPCCommandValidationAndBasics(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + + if _, err := d.handleIPCCommand(context.Background(), ipc.Message{Cmd: ipc.CommandVoiceJoin, Room: " "}); err == nil { + t.Fatalf("expected join with blank room to fail") + } + if _, err := d.handleIPCCommand(context.Background(), ipc.Message{Cmd: ipc.CommandVoiceLeave}); err == nil { + t.Fatalf("expected leave without room to fail") + } + if _, err := d.handleIPCCommand(context.Background(), ipc.Message{Cmd: ipc.CommandIdentify, User: ""}); err == nil { + t.Fatalf("expected identify without user to fail") + } + + if _, err := d.handleIPCCommand(context.Background(), ipc.Message{Cmd: ipc.CommandIdentify, User: "alice"}); err != nil { + t.Fatalf("identify failed: %v", err) + } + if got := d.localUser(); got != "alice" { + t.Fatalf("expected local user alice, got %q", got) + } + + resp, err := d.handleIPCCommand(context.Background(), ipc.Message{Cmd: ipc.CommandPing}) + if err != nil { + t.Fatalf("ping failed: %v", err) + } + if resp.Event != ipc.EventPong { + t.Fatalf("expected pong event, got %q", resp.Event) + } + + if _, err := d.handleIPCCommand(context.Background(), ipc.Message{Cmd: "unknown"}); err == nil { + t.Fatalf("expected unknown command to fail") + } +} + +func TestMuteUnmuteAndSpeakingState(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + d.local = "alice" + d.setSpeaking(true) + if !d.isSpeaking() { + t.Fatalf("expected speaking=true after setSpeaking(true)") + } + + if _, err := d.handleIPCCommand(context.Background(), ipc.Message{Cmd: ipc.CommandMute}); err != nil { + t.Fatalf("mute failed: %v", err) + } + if d.isSpeaking() { + t.Fatalf("expected muted daemon to report not speaking") + } + + if _, err := d.handleIPCCommand(context.Background(), ipc.Message{Cmd: ipc.CommandUnmute}); err != nil { + t.Fatalf("unmute failed: %v", err) + } + if d.isSpeaking() { + t.Fatalf("expected speaking to remain false after unmute") + } +} + +func TestUpdateVADAndDisablePTT(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "ctrl+v", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + d.local = "alice" + + d.updateVAD(true) + if d.isSpeaking() { + t.Fatalf("expected VAD to be ignored while PTT binding is active") + } + + d.disablePTT() + if d.hasPTTBinding() { + t.Fatalf("expected PTT binding to be cleared") + } + + d.updateVAD(true) + if !d.isSpeaking() { + t.Fatalf("expected VAD to control speaking after PTT is disabled") + } +} + +func TestSendSignalWithoutWebsocket(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + err := d.sendSignal(VoiceSignal{Type: "voice_join", ChannelID: "room-1"}) + if !errors.Is(err, errWebsocketUnavailable) { + t.Fatalf("expected errWebsocketUnavailable, got %v", err) + } +} + +func TestOnWSDisconnectResetsVoiceState(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + d.local = "alice" + d.room = "room-1" + d.ws = &WSClient{closed: true} + d.rem = map[string]bool{"bob": true} + d.memb = map[string]struct{}{"alice": {}, "bob": {}} + + d.onWSDisconnect() + + if d.currentWS() != nil { + t.Fatalf("expected websocket reference to be cleared") + } + + d.mu.Lock() + remLen := len(d.rem) + _, hasAlice := d.memb["alice"] + _, hasBob := d.memb["bob"] + membersLen := len(d.memb) + d.mu.Unlock() + + if remLen != 0 { + t.Fatalf("expected remote speaking state cleared, got %d entries", remLen) + } + if !hasAlice || hasBob || membersLen != 1 { + t.Fatalf("expected members reset to only local user, got alice=%v bob=%v len=%d", hasAlice, hasBob, membersLen) + } +} + +func TestHandleWSMessageVoiceRosterUpdatesMembersForCurrentRoom(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + d.local = "alice" + d.room = "room-1" + + d.handleWSMessage(VoiceSignal{Type: "voice_roster", ChannelID: "room-1", Users: []string{"bob", "", "bob"}}) + + d.mu.Lock() + _, hasAlice := d.memb["alice"] + _, hasBob := d.memb["bob"] + membersLen := len(d.memb) + d.mu.Unlock() + + if !hasAlice || !hasBob || membersLen != 2 { + t.Fatalf("expected roster to include local+remote members once, got alice=%v bob=%v len=%d", hasAlice, hasBob, membersLen) + } +} + +func TestHandleICECandidateAndPeerStateFailure(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + + d.setRemoteSpeaking("peer-1", true) + d.handleICECandidate("peer-1", "candidate") + d.room = "room-1" + d.handleICECandidate("peer-1", "candidate") + d.handlePeerState("peer-1", webrtc.PeerConnectionStateFailed) + + d.mu.Lock() + active := d.rem["peer-1"] + d.mu.Unlock() + if active { + t.Fatalf("expected failed peer state to clear remote speaking") + } +} + +func TestRunWSLoopAndConnectWSHonorCanceledContext(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + out := make(chan VoiceSignal) + if err := d.runWSLoop(ctx, out); err != nil { + t.Fatalf("runWSLoop canceled context: %v", err) + } + if _, ok := <-out; ok { + t.Fatalf("expected runWSLoop output channel to be closed") + } + + if _, err := d.connectWSWithRetry(ctx, 0); !errors.Is(err, context.Canceled) { + t.Fatalf("expected canceled context from connectWSWithRetry, got %v", err) + } +} + +func TestRunPTTInvalidBackendReturnsError(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "ctrl+v", "invalid-backend", webrtc.Configuration{}, defaultVADThreshold, false) + if err := d.runPTT(context.Background()); err == nil { + t.Fatalf("expected runPTT to fail for unsupported backend") + } +} + +func TestPlaybackHelpersNoopWithoutPlayback(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + d.writePlayback(nil) + d.writePlayback([]int16{1, 2, 3}) + d.closePlayback() +} + +func TestStartPlaybackReturnsPromptly(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + done := make(chan struct{}) + go func() { + d.startPlayback(ctx) + close(done) + }() + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatalf("startPlayback did not return promptly") + } +} + +func TestRunReturnsNilWhenContextCanceled(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + ipcAddr := filepath.Join(t.TempDir(), "voice.sock") + if err := d.Run(ctx, ipcAddr); err != nil { + t.Fatalf("Run canceled context: %v", err) + } +} From b28b0b67f514c75a35e4c333c5925743d49d7c26 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Victor=20Anderss=C3=A9n?= Date: Fri, 13 Feb 2026 12:22:45 +0200 Subject: [PATCH 2/6] Add tests for login model, voice daemon, IPC, audio playback, and WebRTC manager - Implement TestLoginConfirmPasswordGetter to verify confirm password retrieval. - Add TestAddAndRemoveVoiceMember to ensure proper member management in voice daemon. - Introduce TestPTTControllerRun to validate PTT controller behavior. - Create voice IPC tests including connection handling and message sending. - Add audio playback tests for buffer management and safety on nil instances. - Implement IPC tests for message encoding/decoding and connection lifecycle. - Add WebRTC manager tests for peer validation, offer/answer flow, and ICE candidate handling. --- cmd/client/chat_test.go | 205 ++++++++++++++++++++++++++++++ cmd/client/login_test.go | 8 ++ cmd/client/voice_ipc_test.go | 176 +++++++++++++++++++++++++ cmd/voiced/audio_pipeline_test.go | 88 +++++++++++++ cmd/voiced/daemon_test.go | 33 +++++ cmd/voiced/ipc_server_test.go | 187 +++++++++++++++++++++++++++ cmd/voiced/main_test.go | 93 ++++++++++++++ cmd/voiced/ptt_test.go | 12 ++ cmd/voiced/stats_test.go | 51 ++++++++ cmd/voiced/ws_client_test.go | 129 +++++++++++++++++++ internal/audio/playback.go | 25 ++-- internal/audio/playback_test.go | 67 ++++++++++ internal/ipc/ipc_test.go | 32 +++++ internal/ipc/ipc_unix_test.go | 71 +++++++++++ internal/webrtc/manager_test.go | 103 +++++++++++++++ internal/ws/ws_test.go | 107 ++++++++++++++++ scripts/coverage.sh | 23 +++- 17 files changed, 1393 insertions(+), 17 deletions(-) create mode 100644 cmd/client/voice_ipc_test.go create mode 100644 cmd/voiced/audio_pipeline_test.go create mode 100644 cmd/voiced/ipc_server_test.go create mode 100644 cmd/voiced/main_test.go create mode 100644 cmd/voiced/stats_test.go create mode 100644 cmd/voiced/ws_client_test.go create mode 100644 internal/audio/playback_test.go create mode 100644 internal/ipc/ipc_test.go create mode 100644 internal/ipc/ipc_unix_test.go create mode 100644 internal/webrtc/manager_test.go diff --git a/cmd/client/chat_test.go b/cmd/client/chat_test.go index d974f09..f1fed37 100644 --- a/cmd/client/chat_test.go +++ b/cmd/client/chat_test.go @@ -1358,3 +1358,208 @@ func TestChatModelDispatchVoiceLeaveDoesNotClearRoomBeforeAck(t *testing.T) { t.Fatalf("timed out waiting for leave command on ipc") } } + +func TestWaitForVoiceMsg(t *testing.T) { + ch := make(chan ipc.Message, 1) + ch <- ipc.Message{Event: ipc.EventPong} + + msg := waitForVoiceMsg(ch)() + voiceMsg, ok := msg.(voiceIPCMsg) + if !ok { + t.Fatalf("expected voiceIPCMsg, got %T", msg) + } + if ipc.Message(voiceMsg).Event != ipc.EventPong { + t.Fatalf("expected pong event, got %#v", ipc.Message(voiceMsg)) + } + + closed := make(chan ipc.Message) + close(closed) + msg = waitForVoiceMsg(closed)() + errMsg, ok := msg.(voiceIPCErrorMsg) + if !ok { + t.Fatalf("expected voiceIPCErrorMsg, got %T", msg) + } + if !strings.Contains(errMsg.err.Error(), "voice daemon disconnected") { + t.Fatalf("unexpected disconnect error: %v", errMsg.err) + } +} + +func TestChatModelConnectVoiceIPC(t *testing.T) { + m := newChatForTest(t, &APIClient{serverURL: "http://server", httpClient: http.DefaultClient}) + m.voiceIPC = nil + if cmd := m.connectVoiceIPC(); cmd != nil { + t.Fatalf("expected nil connect command when voice IPC is not configured") + } + + m.voiceIPC = newVoiceIPC("") + cmd := m.connectVoiceIPC() + if cmd == nil { + t.Fatalf("expected connect command") + } + msg := cmd() + errMsg, ok := msg.(voiceIPCErrorMsg) + if !ok { + t.Fatalf("expected voiceIPCErrorMsg for invalid address, got %T", msg) + } + if !strings.Contains(errMsg.err.Error(), "voice ipc address is empty") { + t.Fatalf("unexpected connect error: %v", errMsg.err) + } + + addr := filepath.Join(t.TempDir(), "voice.sock") + listener, err := ipc.Listen(addr) + if err != nil { + t.Fatalf("listen voice ipc: %v", err) + } + defer listener.Close() + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + _ = conn.Close() + }() + + m.voiceIPC = newVoiceIPC(addr) + cmd = m.connectVoiceIPC() + msg = cmd() + connected, ok := msg.(voiceIPCConnectedMsg) + if !ok { + t.Fatalf("expected voiceIPCConnectedMsg, got %T", msg) + } + if connected.ch == nil { + t.Fatalf("expected non-nil voice IPC channel") + } +} + +func TestChatModelHandleVoiceCommandPaths(t *testing.T) { + m := newChatForTest(t, &APIClient{serverURL: "http://server", httpClient: http.DefaultClient}) + if cmd := m.handleVoiceCommand("/voice", []string{"/voice"}); cmd != nil { + t.Fatalf("expected no command for usage help") + } + if !strings.Contains(lastSystemMessage(m), "voice commands") { + t.Fatalf("expected voice help message") + } + + m.voiceIPC = nil + if cmd := m.handleVoiceCommand("/voice mute", []string{"/voice", "mute"}); cmd != nil { + t.Fatalf("expected no command when voice IPC is nil") + } + if !strings.Contains(lastSystemMessage(m), "voice daemon not configured") { + t.Fatalf("expected missing daemon message") + } + + m.voiceIPC = &voiceIPC{addr: ""} + m.voiceRoom = "room-1" + _ = m.handleVoiceCommand("/voice leave", []string{"/voice", "leave"}) + if !strings.Contains(lastSystemMessage(m), "voice leave failed") { + t.Fatalf("expected leave failure message") + } + _ = m.handleVoiceCommand("/voice mute", []string{"/voice", "mute"}) + if !strings.Contains(lastSystemMessage(m), "voice mute failed") { + t.Fatalf("expected mute failure message") + } + _ = m.handleVoiceCommand("/voice unmute", []string{"/voice", "unmute"}) + if !strings.Contains(lastSystemMessage(m), "voice unmute failed") { + t.Fatalf("expected unmute failure message") + } + + addr := filepath.Join(t.TempDir(), "voice.sock") + listener, err := ipc.Listen(addr) + if err != nil { + t.Fatalf("listen voice ipc: %v", err) + } + defer listener.Close() + + recv := make(chan ipc.Message, 1) + go func() { + conn, err := listener.Accept() + if err != nil { + return + } + defer conn.Close() + var msg ipc.Message + if err := ipc.NewDecoder(conn).Decode(&msg); err == nil { + recv <- msg + } + }() + + m.voiceIPC = newVoiceIPC(addr) + m.voiceCh = make(chan ipc.Message, 1) + m.channels["ch-1"] = channelInfo{ID: "ch-1", Name: "general"} + m.activeChannel = "ch-1" + + if cmd := m.handleVoiceCommand("/voice join", []string{"/voice", "join"}); cmd != nil { + t.Fatalf("expected no reconnect cmd when voice channel already connected") + } + if m.voiceRoom != "ch-1" { + t.Fatalf("expected voice room set to joined channel, got %q", m.voiceRoom) + } + if !strings.Contains(lastSystemMessage(m), "voice join requested") { + t.Fatalf("expected join notice") + } + + select { + case sent := <-recv: + if sent.Cmd != ipc.CommandVoiceJoin || sent.Room != "ch-1" { + t.Fatalf("unexpected join IPC message: %#v", sent) + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for join message") + } +} + +func TestChatModelVoiceHelperMethods(t *testing.T) { + m := newChatForTest(t, &APIClient{serverURL: "http://server", httpClient: http.DefaultClient}) + + pending := ipc.Message{Cmd: ipc.CommandVoiceJoin, Room: "room-1"} + m.queueVoiceCommand(pending, "room-1", "queued") + m.clearPendingVoiceCommand() + if m.voicePendingCmd != nil || m.voicePendingRoom != "" || m.voicePendingNotice != "" { + t.Fatalf("expected pending voice command state to be cleared") + } + + if cmd := m.scheduleVoicePing(); cmd == nil { + t.Fatalf("expected non-nil voice ping schedule command") + } + + m.voiceMembers = nil + m.clearVoiceMembers() + if m.voiceMembers == nil { + t.Fatalf("expected voiceMembers map initialized") + } + m.voiceMembers["user-1"] = true + m.clearVoiceMembers() + if len(m.voiceMembers) != 0 { + t.Fatalf("expected voiceMembers map cleared") + } +} + +func TestChatModelShareDirectoryKeyCmdGuards(t *testing.T) { + m := newChatForTest(t, &APIClient{serverURL: "http://server", httpClient: http.DefaultClient}) + m.auth.IsTrusted = false + if cmd := m.shareDirectoryKeyCmd(); cmd != nil { + t.Fatalf("expected nil cmd when user is not trusted") + } + + m.auth.IsTrusted = true + m.directoryKey = nil + if cmd := m.shareDirectoryKeyCmd(); cmd != nil { + t.Fatalf("expected nil cmd without directory key") + } + + m.directoryKey = bytes.Repeat([]byte{7}, crypto.KeySize) + m.api = nil + m.kp = nil + cmd := m.shareDirectoryKeyCmd() + if cmd == nil { + t.Fatalf("expected cmd when trusted and directory key exists") + } + msg := cmd() + result, ok := msg.(shareDirectoryMsg) + if !ok { + t.Fatalf("expected shareDirectoryMsg, got %T", msg) + } + if result.err != nil { + t.Fatalf("expected nil error when shareDirectoryKey short-circuits missing deps, got %v", result.err) + } +} diff --git a/cmd/client/login_test.go b/cmd/client/login_test.go index ac25e27..07ac36b 100644 --- a/cmd/client/login_test.go +++ b/cmd/client/login_test.go @@ -146,3 +146,11 @@ func TestLoginViewWithErrorsAndLoading(t *testing.T) { t.Fatalf("expected view") } } + +func TestLoginConfirmPasswordGetter(t *testing.T) { + m := newLoginModel("http://server") + m.confirmInput.SetValue("secret-confirm") + if got := m.confirmPassword(); got != "secret-confirm" { + t.Fatalf("confirmPassword() = %q, want %q", got, "secret-confirm") + } +} diff --git a/cmd/client/voice_ipc_test.go b/cmd/client/voice_ipc_test.go new file mode 100644 index 0000000..b07b7f1 --- /dev/null +++ b/cmd/client/voice_ipc_test.go @@ -0,0 +1,176 @@ +package main + +import ( + "encoding/json" + "net" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/Avicted/dialtone/internal/ipc" +) + +func TestVoiceIPCEnsureConnAndSend(t *testing.T) { + addr := filepath.Join(t.TempDir(), "voice.sock") + ln, err := ipc.Listen(addr) + if err != nil { + t.Fatalf("listen ipc: %v", err) + } + defer ln.Close() + + received := make(chan ipc.Message, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + dec := ipc.NewDecoder(conn) + var msg ipc.Message + if err := dec.Decode(&msg); err != nil { + return + } + received <- msg + }() + + v := newVoiceIPC(addr) + if err := v.ensureConn(); err != nil { + t.Fatalf("ensureConn: %v", err) + } + if err := v.send(ipc.Message{Cmd: ipc.CommandPing}); err != nil { + t.Fatalf("send: %v", err) + } + if err := v.ensureConn(); err != nil { + t.Fatalf("ensureConn second call: %v", err) + } + + select { + case msg := <-received: + if msg.Cmd != ipc.CommandPing { + t.Fatalf("expected ping command, got %#v", msg) + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for IPC message") + } +} + +func TestVoiceIPCReadLoopReportsEnsureConnError(t *testing.T) { + v := newVoiceIPC("") + ch := make(chan ipc.Message, 2) + go v.readLoop(ch) + + select { + case msg, ok := <-ch: + if !ok { + t.Fatalf("expected error message before channel close") + } + if msg.Event != ipc.EventError { + t.Fatalf("expected error event, got %#v", msg) + } + if !strings.Contains(msg.Error, "voice ipc address is empty") { + t.Fatalf("unexpected error payload: %q", msg.Error) + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for error event") + } + + select { + case _, ok := <-ch: + if ok { + t.Fatalf("expected channel to be closed") + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for channel close") + } +} + +func TestVoiceIPCReadLoopReceivesMessageThenResetsOnDisconnect(t *testing.T) { + addr := filepath.Join(t.TempDir(), "voice.sock") + ln, err := ipc.Listen(addr) + if err != nil { + t.Fatalf("listen ipc: %v", err) + } + defer ln.Close() + + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + enc := ipc.NewEncoder(conn) + _ = enc.Encode(ipc.Message{Event: ipc.EventPong}) + _ = conn.Close() + }() + + v := newVoiceIPC(addr) + ch := make(chan ipc.Message, 4) + go v.readLoop(ch) + + select { + case msg := <-ch: + if msg.Event != ipc.EventPong { + t.Fatalf("expected pong event, got %#v", msg) + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for pong") + } + + select { + case msg := <-ch: + if msg.Event != ipc.EventError { + t.Fatalf("expected error event after disconnect, got %#v", msg) + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for disconnect error") + } + + select { + case _, ok := <-ch: + if ok { + t.Fatalf("expected read loop channel to be closed") + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for channel close") + } + + v.mu.Lock() + defer v.mu.Unlock() + if v.conn != nil || v.enc != nil || v.dec != nil { + t.Fatalf("expected IPC connection state reset after read failure") + } +} + +func TestVoiceIPCEnsureConnFailsWithoutCodecState(t *testing.T) { + client, server := net.Pipe() + defer client.Close() + defer server.Close() + + v := &voiceIPC{addr: "in-memory", conn: client} + err := v.ensureConn() + if err == nil || !strings.Contains(err.Error(), "voice ipc encoder not available") { + t.Fatalf("expected codec availability error, got %v", err) + } +} + +func TestVoiceIPCSendEncodeErrorResetsConnection(t *testing.T) { + client, server := net.Pipe() + _ = server.Close() + + v := &voiceIPC{ + addr: "in-memory", + conn: client, + enc: json.NewEncoder(client), + dec: json.NewDecoder(client), + } + err := v.send(ipc.Message{Cmd: ipc.CommandPing}) + if err == nil { + t.Fatalf("expected encode error when peer is closed") + } + + v.mu.Lock() + defer v.mu.Unlock() + if v.conn != nil || v.enc != nil || v.dec != nil { + t.Fatalf("expected connection state reset after encode failure") + } +} diff --git a/cmd/voiced/audio_pipeline_test.go b/cmd/voiced/audio_pipeline_test.go new file mode 100644 index 0000000..0a434ba --- /dev/null +++ b/cmd/voiced/audio_pipeline_test.go @@ -0,0 +1,88 @@ +//go:build linux + +package main + +import ( + "testing" + "time" + + "github.com/pion/webrtc/v4" +) + +func TestAudioBackoffClamp(t *testing.T) { + tests := []struct { + attempt int + want time.Duration + }{ + {attempt: 0, want: 2 * time.Second}, + {attempt: 1, want: 2 * time.Second}, + {attempt: 3, want: 8 * time.Second}, + {attempt: 5, want: 32 * time.Second}, + {attempt: 100, want: 32 * time.Second}, + } + + for _, tt := range tests { + if got := audioBackoff(tt.attempt); got != tt.want { + t.Fatalf("audioBackoff(%d) = %v, want %v", tt.attempt, got, tt.want) + } + } +} + +func TestVoiceLevelAndIsVoiceActive(t *testing.T) { + if got := voiceLevel(nil); got != 0 { + t.Fatalf("voiceLevel(nil) = %d, want 0", got) + } + if got := voiceLevel([]int16{-10, 20, -30, 40}); got != 25 { + t.Fatalf("voiceLevel abs average mismatch: got %d, want 25", got) + } + + if !isVoiceActive(25, 25) { + t.Fatalf("expected voice active at threshold") + } + if isVoiceActive(24, 25) { + t.Fatalf("expected voice inactive below threshold") + } +} + +func TestUpdateVADFromFrameAndMeterBehavior(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, 20, true) + d.local = "alice" + + active := d.updateVADFromFrame([]int16{30, -30, 30, -30}) + if !active { + t.Fatalf("expected frame to be active") + } + if !d.isSpeaking() { + t.Fatalf("expected speaking after active VAD frame") + } + + d.muted = true + inactive := d.updateVADFromFrame([]int16{5, -5, 5, -5}) + if inactive { + t.Fatalf("expected low-level frame to be inactive") + } + if d.isSpeaking() { + t.Fatalf("expected muted daemon to report not speaking") + } + + d.muted = false + d.meter = true + d.vadThreshold = 15 + d.meterNext = time.Time{} + d.maybeLogMeter(16) + firstNext := d.meterNext + if firstNext.IsZero() { + t.Fatalf("expected meterNext to be scheduled when meter is enabled") + } + d.maybeLogMeter(16) + if !d.meterNext.Equal(firstNext) { + t.Fatalf("expected meterNext unchanged inside interval") + } + + d.pttBind = "ctrl+v" + d.setSpeaking(false) + d.updateVADFromFrame([]int16{100, 100, 100, 100}) + if d.isSpeaking() { + t.Fatalf("expected VAD updates ignored while PTT binding is active") + } +} diff --git a/cmd/voiced/daemon_test.go b/cmd/voiced/daemon_test.go index 2101095..1924eb5 100644 --- a/cmd/voiced/daemon_test.go +++ b/cmd/voiced/daemon_test.go @@ -300,3 +300,36 @@ func TestRunReturnsNilWhenContextCanceled(t *testing.T) { t.Fatalf("Run canceled context: %v", err) } } + +func TestAddAndRemoveVoiceMember(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + d.room = "room-1" + d.local = "alice" + d.memb = map[string]struct{}{"alice": {}} + + d.addVoiceMember("") + d.addVoiceMember("bob") + d.addVoiceMember("bob") + + d.mu.Lock() + _, hasAlice := d.memb["alice"] + _, hasBob := d.memb["bob"] + membersLen := len(d.memb) + d.mu.Unlock() + if !hasAlice || !hasBob || membersLen != 2 { + t.Fatalf("expected alice and bob members only, got alice=%v bob=%v len=%d", hasAlice, hasBob, membersLen) + } + + d.removeVoiceMember("charlie") + d.removeVoiceMember("") + d.removeVoiceMember("bob") + + d.mu.Lock() + _, hasAlice = d.memb["alice"] + _, hasBob = d.memb["bob"] + membersLen = len(d.memb) + d.mu.Unlock() + if !hasAlice || hasBob || membersLen != 1 { + t.Fatalf("expected only local member to remain, got alice=%v bob=%v len=%d", hasAlice, hasBob, membersLen) + } +} diff --git a/cmd/voiced/ipc_server_test.go b/cmd/voiced/ipc_server_test.go new file mode 100644 index 0000000..0112eb5 --- /dev/null +++ b/cmd/voiced/ipc_server_test.go @@ -0,0 +1,187 @@ +package main + +import ( + "context" + "encoding/json" + "fmt" + "net" + "testing" + "time" + + "github.com/Avicted/dialtone/internal/ipc" +) + +func TestIPCConnSend(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + state := &ipcConn{conn: serverConn, enc: json.NewEncoder(serverConn)} + errCh := make(chan error, 1) + go func() { + errCh <- state.send(ipc.Message{Event: ipc.EventVoiceReady}) + }() + + var msg ipc.Message + if err := json.NewDecoder(clientConn).Decode(&msg); err != nil { + t.Fatalf("decode sent message: %v", err) + } + if msg.Event != ipc.EventVoiceReady { + t.Fatalf("unexpected sent message: %#v", msg) + } + if err := <-errCh; err != nil { + t.Fatalf("ipcConn.send failed: %v", err) + } +} + +func TestIPCServerHandleCommandPaths(t *testing.T) { + t.Run("missing handler emits error", func(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + s := &ipcServer{} + state := &ipcConn{conn: serverConn, enc: json.NewEncoder(serverConn)} + go s.handleCommand(context.Background(), ipc.Message{Cmd: ipc.CommandPing}, state) + + var msg ipc.Message + if err := json.NewDecoder(clientConn).Decode(&msg); err != nil { + t.Fatalf("decode error payload: %v", err) + } + if msg.Event != ipc.EventError || msg.Error != "ipc handler unavailable" { + t.Fatalf("unexpected error payload: %#v", msg) + } + }) + + t.Run("handler response emits event", func(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + s := &ipcServer{h: func(context.Context, ipc.Message) (ipc.Message, error) { + return ipc.Message{Event: ipc.EventPong}, nil + }} + state := &ipcConn{conn: serverConn, enc: json.NewEncoder(serverConn)} + go s.handleCommand(context.Background(), ipc.Message{Cmd: ipc.CommandPing}, state) + + var msg ipc.Message + if err := json.NewDecoder(clientConn).Decode(&msg); err != nil { + t.Fatalf("decode handler response: %v", err) + } + if msg.Event != ipc.EventPong { + t.Fatalf("unexpected response payload: %#v", msg) + } + }) + + t.Run("handler error emits error", func(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + s := &ipcServer{h: func(context.Context, ipc.Message) (ipc.Message, error) { + return ipc.Message{}, fmt.Errorf("boom") + }} + state := &ipcConn{conn: serverConn, enc: json.NewEncoder(serverConn)} + go s.handleCommand(context.Background(), ipc.Message{Cmd: ipc.CommandPing}, state) + + var msg ipc.Message + if err := json.NewDecoder(clientConn).Decode(&msg); err != nil { + t.Fatalf("decode handler error: %v", err) + } + if msg.Event != ipc.EventError || msg.Error != "boom" { + t.Fatalf("unexpected handler error payload: %#v", msg) + } + }) +} + +func TestIPCServerTrackUntrackAndBroadcast(t *testing.T) { + serverConn, clientConn := net.Pipe() + defer serverConn.Close() + defer clientConn.Close() + + s := &ipcServer{} + state := &ipcConn{conn: serverConn, enc: json.NewEncoder(serverConn)} + s.trackConn(state) + if len(s.conns) != 1 { + t.Fatalf("expected tracked conn count 1, got %d", len(s.conns)) + } + + errCh := make(chan error, 1) + go func() { + var msg ipc.Message + err := json.NewDecoder(clientConn).Decode(&msg) + if err == nil && msg.Event != ipc.EventInfo { + err = fmt.Errorf("unexpected broadcast payload: %#v", msg) + } + errCh <- err + }() + + s.Broadcast(ipc.Message{Event: ipc.EventInfo, Error: "hello"}) + if err := <-errCh; err != nil { + t.Fatalf("broadcast receive: %v", err) + } + + s.untrackConn(serverConn) + if len(s.conns) != 0 { + t.Fatalf("expected tracked conn count 0, got %d", len(s.conns)) + } +} + +func TestIPCServerHandleConnLifecycle(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + s := newIPCServer("", func(_ context.Context, msg ipc.Message) (ipc.Message, error) { + if msg.Cmd == ipc.CommandPing { + return ipc.Message{Event: ipc.EventPong}, nil + } + return ipc.Message{}, nil + }) + + serverConn, clientConn := net.Pipe() + defer clientConn.Close() + + done := make(chan struct{}) + go func() { + s.handleConn(ctx, serverConn) + close(done) + }() + + dec := json.NewDecoder(clientConn) + enc := json.NewEncoder(clientConn) + + var ready ipc.Message + if err := dec.Decode(&ready); err != nil { + t.Fatalf("decode ready event: %v", err) + } + if ready.Event != ipc.EventVoiceReady { + t.Fatalf("expected ready event, got %#v", ready) + } + + if err := enc.Encode(ipc.Message{Cmd: ipc.CommandPing}); err != nil { + t.Fatalf("encode ping: %v", err) + } + var resp ipc.Message + if err := dec.Decode(&resp); err != nil { + t.Fatalf("decode ping response: %v", err) + } + if resp.Event != ipc.EventPong { + t.Fatalf("expected pong response, got %#v", resp) + } + + if err := enc.Encode(ipc.Message{Event: ipc.EventInfo}); err != nil { + t.Fatalf("encode no-op message: %v", err) + } + + _ = clientConn.Close() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for handleConn to exit") + } + + if len(s.conns) != 0 { + t.Fatalf("expected all conns untracked after close, got %d", len(s.conns)) + } +} diff --git a/cmd/voiced/main_test.go b/cmd/voiced/main_test.go new file mode 100644 index 0000000..bd87a50 --- /dev/null +++ b/cmd/voiced/main_test.go @@ -0,0 +1,93 @@ +package main + +import ( + "os" + "runtime" + "strings" + "testing" +) + +func TestDefaultIPCAddr(t *testing.T) { + addr := defaultIPCAddr() + if runtime.GOOS == "windows" { + if addr != `\\.\pipe\dialtone-voice` { + t.Fatalf("unexpected windows IPC addr: %q", addr) + } + return + } + if addr != "/tmp/dialtone-voice.sock" { + t.Fatalf("unexpected unix IPC addr: %q", addr) + } +} + +func TestSplitCSV(t *testing.T) { + if got := splitCSV(""); got != nil { + t.Fatalf("expected nil for empty CSV, got %#v", got) + } + got := splitCSV(" stun1.example, ,stun2.example ,, turn.example ") + if len(got) != 3 || got[0] != "stun1.example" || got[1] != "stun2.example" || got[2] != "turn.example" { + t.Fatalf("unexpected splitCSV result: %#v", got) + } +} + +func TestNormalizeICEURLs(t *testing.T) { + urls := normalizeICEURLs([]string{"stun:one", "two", "turn:three", "turns:four"}, "stun:") + if len(urls) != 4 { + t.Fatalf("unexpected normalized URL count: %d", len(urls)) + } + if urls[0] != "stun:one" || urls[1] != "stun:two" || urls[2] != "turn:three" || urls[3] != "turns:four" { + t.Fatalf("unexpected normalized URLs: %#v", urls) + } +} + +func TestBuildICEConfig(t *testing.T) { + config := buildICEConfig("stun1.example,stun:stun2.example", "turn.example", "user", "pass") + if len(config.ICEServers) != 2 { + t.Fatalf("expected 2 ICE server entries, got %d", len(config.ICEServers)) + } + stun := config.ICEServers[0] + if len(stun.URLs) != 2 || stun.URLs[0] != "stun:stun1.example" || stun.URLs[1] != "stun:stun2.example" { + t.Fatalf("unexpected STUN config: %#v", stun) + } + turn := config.ICEServers[1] + if len(turn.URLs) != 1 || turn.URLs[0] != "turn:turn.example" || turn.Username != "user" || turn.Credential != "pass" { + t.Fatalf("unexpected TURN config: %#v", turn) + } + + empty := buildICEConfig("", "", "", "") + if len(empty.ICEServers) != 0 { + t.Fatalf("expected empty ICE config, got %#v", empty) + } +} + +func TestRunValidationErrors(t *testing.T) { + originalArgs := os.Args + t.Cleanup(func() { + os.Args = originalArgs + }) + + tests := []struct { + name string + args []string + wantErr string + }{ + {name: "missing server", args: []string{"dialtone-voiced", "-token", "t"}, wantErr: "server address is required"}, + {name: "missing token", args: []string{"dialtone-voiced", "-server", "http://s"}, wantErr: "auth token is required"}, + {name: "empty ipc", args: []string{"dialtone-voiced", "-server", "http://s", "-token", "t", "-ipc", ""}, wantErr: "ipc address is required"}, + {name: "invalid vad", args: []string{"dialtone-voiced", "-server", "http://s", "-token", "t", "-vad-threshold", "0"}, wantErr: "vad-threshold must be > 0"}, + {name: "invalid backend", args: []string{"dialtone-voiced", "-server", "http://s", "-token", "t", "-ptt-backend", "bad"}, wantErr: "invalid ptt-backend"}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + os.Args = append([]string(nil), tt.args...) + err := run() + if err == nil || err.Error() == "" { + t.Fatalf("expected validation error containing %q", tt.wantErr) + } + if tt.wantErr != "" && !strings.Contains(err.Error(), tt.wantErr) { + t.Fatalf("run() error = %q, want contains %q", err.Error(), tt.wantErr) + } + }) + } +} diff --git a/cmd/voiced/ptt_test.go b/cmd/voiced/ptt_test.go index 93c3850..e66b824 100644 --- a/cmd/voiced/ptt_test.go +++ b/cmd/voiced/ptt_test.go @@ -134,3 +134,15 @@ func TestPTTStartupDiagnostic(t *testing.T) { t.Fatalf("expected reason field in diagnostic: %q", msg) } } + +func TestPTTControllerRun(t *testing.T) { + var nilController *pttController + if err := nilController.Run(context.Background(), func() {}, func() {}); err == nil { + t.Fatalf("expected nil controller run to fail") + } + + controller := &pttController{backend: testPTTBackend{}} + if err := controller.Run(context.Background(), func() {}, func() {}); err != nil { + t.Fatalf("expected backend run to succeed: %v", err) + } +} diff --git a/cmd/voiced/stats_test.go b/cmd/voiced/stats_test.go new file mode 100644 index 0000000..cb4d969 --- /dev/null +++ b/cmd/voiced/stats_test.go @@ -0,0 +1,51 @@ +package main + +import ( + "context" + "testing" + "time" +) + +func TestVoiceStatsRecordSentAndDrop(t *testing.T) { + var nilStats *voiceStats + nilStats.RecordSent(100) + nilStats.RecordDrop() + + s := newVoiceStats() + s.RecordSent(0) + s.RecordSent(-1) + s.RecordSent(120) + s.RecordDrop() + s.RecordDrop() + + if got := s.bytesSent.Load(); got != 120 { + t.Fatalf("bytesSent = %d, want 120", got) + } + if got := s.framesSent.Load(); got != 1 { + t.Fatalf("framesSent = %d, want 1", got) + } + if got := s.framesDropped.Load(); got != 2 { + t.Fatalf("framesDropped = %d, want 2", got) + } +} + +func TestVoiceStatsLogLoopStopsOnContextCancel(t *testing.T) { + var nilStats *voiceStats + nilStats.LogLoop(context.Background()) + + s := newVoiceStats() + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + s.LogLoop(ctx) + close(done) + }() + + cancel() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("LogLoop did not stop after cancel") + } +} diff --git a/cmd/voiced/ws_client_test.go b/cmd/voiced/ws_client_test.go new file mode 100644 index 0000000..1b3636c --- /dev/null +++ b/cmd/voiced/ws_client_test.go @@ -0,0 +1,129 @@ +package main + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "testing" + "time" + + "nhooyr.io/websocket" +) + +func TestVoicedWSClientConnectSendReadClose(t *testing.T) { + recv := make(chan VoiceSignal, 1) + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/ws" { + t.Errorf("expected /ws, got %s", r.URL.Path) + return + } + if got := r.Header.Get("Authorization"); got != "Bearer token" { + t.Errorf("expected auth header, got %q", got) + return + } + conn, err := websocket.Accept(w, r, nil) + if err != nil { + t.Errorf("accept websocket: %v", err) + return + } + defer conn.Close(websocket.StatusNormalClosure, "done") + + _, data, err := conn.Read(context.Background()) + if err != nil { + t.Errorf("read client signal: %v", err) + return + } + var in VoiceSignal + if err := json.Unmarshal(data, &in); err != nil { + t.Errorf("decode client signal: %v", err) + return + } + recv <- in + + _ = conn.Write(context.Background(), websocket.MessageText, []byte(`{"type":"voice_join","channel_id":"room-1"}`)) + _ = conn.Write(context.Background(), websocket.MessageText, []byte(`{"type":" "}`)) + _ = conn.Write(context.Background(), websocket.MessageText, []byte(`not-json`)) + })) + defer server.Close() + + client, err := ConnectWS(server.URL, "token") + if err != nil { + t.Fatalf("ConnectWS: %v", err) + } + defer client.Close() + + if err := client.Send(VoiceSignal{Type: "ping", ChannelID: "room-1"}); err != nil { + t.Fatalf("Send: %v", err) + } + select { + case msg := <-recv: + if msg.Type != "ping" || msg.ChannelID != "room-1" { + t.Fatalf("unexpected sent signal: %#v", msg) + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for sent signal") + } + + ch := make(chan VoiceSignal, 2) + go client.ReadLoop(ch) + select { + case msg := <-ch: + if msg.Type != "voice_join" || msg.ChannelID != "room-1" { + t.Fatalf("unexpected read signal: %#v", msg) + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for server signal") + } +} + +func TestVoicedWSClientSendClosedAndCloseIdempotent(t *testing.T) { + client := &WSClient{closed: true} + if err := client.Send(VoiceSignal{Type: "ping"}); err == nil { + t.Fatalf("expected Send to fail when closed") + } + + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "done") + })) + defer server.Close() + + realClient, err := ConnectWS(server.URL, "token") + if err != nil { + t.Fatalf("ConnectWS: %v", err) + } + realClient.Close() + realClient.Close() +} + +func TestVoicedWSClientReadLoopClosesChannelOnSocketClose(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + return + } + _ = conn.Close(websocket.StatusNormalClosure, "done") + })) + defer server.Close() + + client, err := ConnectWS(server.URL, "token") + if err != nil { + t.Fatalf("ConnectWS: %v", err) + } + defer client.Close() + + ch := make(chan VoiceSignal, 1) + go client.ReadLoop(ch) + select { + case _, ok := <-ch: + if ok { + t.Fatalf("expected closed read channel") + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for read loop close") + } +} diff --git a/internal/audio/playback.go b/internal/audio/playback.go index 525ac45..83adf68 100644 --- a/internal/audio/playback.go +++ b/internal/audio/playback.go @@ -17,9 +17,10 @@ type Playback struct { ctx *malgo.AllocatedContext device *malgo.Device - mu sync.Mutex - buf []int16 - maxBuf int + mu sync.Mutex + buf []int16 + maxBuf int + closeOnce sync.Once } func StartPlayback(ctx context.Context) (*Playback, error) { @@ -115,13 +116,15 @@ func (p *Playback) Close() error { if p == nil { return nil } - if p.device != nil { - p.device.Uninit() - p.device = nil - } - if p.ctx != nil { - p.ctx.Uninit() - p.ctx = nil - } + p.closeOnce.Do(func() { + if p.device != nil { + p.device.Uninit() + p.device = nil + } + if p.ctx != nil { + p.ctx.Uninit() + p.ctx = nil + } + }) return nil } diff --git a/internal/audio/playback_test.go b/internal/audio/playback_test.go new file mode 100644 index 0000000..699b20b --- /dev/null +++ b/internal/audio/playback_test.go @@ -0,0 +1,67 @@ +//go:build linux + +package audio + +import ( + "encoding/binary" + "testing" +) + +func decodeSample(t *testing.T, out []byte, idx int) int16 { + t.Helper() + return int16(binary.LittleEndian.Uint16(out[idx*2:])) +} + +func TestPlaybackWriteAndFillOutput(t *testing.T) { + p := &Playback{maxBuf: 4} + p.Write([]int16{1, 2, 3}) + p.Write([]int16{4, 5}) + + if len(p.buf) != 4 { + t.Fatalf("expected bounded playback buffer length 4, got %d", len(p.buf)) + } + if p.buf[0] != 2 || p.buf[3] != 5 { + t.Fatalf("unexpected buffer contents after overflow write: %#v", p.buf) + } + + out := make([]byte, 6) + p.fillOutput(out) + if got := decodeSample(t, out, 0); got != 2 { + t.Fatalf("sample0 = %d, want 2", got) + } + if got := decodeSample(t, out, 1); got != 3 { + t.Fatalf("sample1 = %d, want 3", got) + } + if got := decodeSample(t, out, 2); got != 4 { + t.Fatalf("sample2 = %d, want 4", got) + } + + out = make([]byte, 4) + p.fillOutput(out) + if got := decodeSample(t, out, 0); got != 5 { + t.Fatalf("sample0 = %d, want 5", got) + } + if got := decodeSample(t, out, 1); got != 0 { + t.Fatalf("sample1 = %d, want zero padding", got) + } + + if len(p.buf) != 0 { + t.Fatalf("expected buffer fully drained, got %d", len(p.buf)) + } +} + +func TestPlaybackNilAndCloseSafety(t *testing.T) { + var p *Playback + p.Write([]int16{1, 2, 3}) + p.fillOutput(make([]byte, 4)) + if err := p.Close(); err != nil { + t.Fatalf("nil playback close: %v", err) + } + + instance := &Playback{} + instance.Write(nil) + instance.fillOutput(nil) + if err := instance.Close(); err != nil { + t.Fatalf("zero playback close: %v", err) + } +} diff --git a/internal/ipc/ipc_test.go b/internal/ipc/ipc_test.go new file mode 100644 index 0000000..1992963 --- /dev/null +++ b/internal/ipc/ipc_test.go @@ -0,0 +1,32 @@ +package ipc + +import ( + "bytes" + "testing" +) + +func TestNewEncoderDecoderRoundTrip(t *testing.T) { + var buf bytes.Buffer + enc := NewEncoder(&buf) + if enc == nil { + t.Fatalf("expected non-nil encoder") + } + + want := Message{Cmd: CommandPing, Room: "room-1"} + if err := enc.Encode(want); err != nil { + t.Fatalf("encode message: %v", err) + } + + dec := NewDecoder(&buf) + if dec == nil { + t.Fatalf("expected non-nil decoder") + } + + var got Message + if err := dec.Decode(&got); err != nil { + t.Fatalf("decode message: %v", err) + } + if got.Cmd != want.Cmd || got.Room != want.Room { + t.Fatalf("unexpected round-trip payload: %#v", got) + } +} diff --git a/internal/ipc/ipc_unix_test.go b/internal/ipc/ipc_unix_test.go new file mode 100644 index 0000000..05c4ae7 --- /dev/null +++ b/internal/ipc/ipc_unix_test.go @@ -0,0 +1,71 @@ +//go:build !windows + +package ipc + +import ( + "errors" + "os" + "path/filepath" + "testing" + "time" +) + +func TestListenDialRoundTrip(t *testing.T) { + addr := filepath.Join(t.TempDir(), "ipc.sock") + ln, err := Listen(addr) + if err != nil { + t.Fatalf("Listen: %v", err) + } + defer ln.Close() + + recv := make(chan Message, 1) + go func() { + conn, err := ln.Accept() + if err != nil { + return + } + defer conn.Close() + var msg Message + if err := NewDecoder(conn).Decode(&msg); err != nil { + return + } + recv <- msg + _ = NewEncoder(conn).Encode(Message{Event: EventPong}) + }() + + conn, err := Dial(addr) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + + if err := NewEncoder(conn).Encode(Message{Cmd: CommandPing, Room: "room-1"}); err != nil { + t.Fatalf("encode to server: %v", err) + } + + select { + case got := <-recv: + if got.Cmd != CommandPing || got.Room != "room-1" { + t.Fatalf("unexpected server payload: %#v", got) + } + case <-time.After(2 * time.Second): + t.Fatalf("timeout waiting for server receive") + } + + var pong Message + if err := NewDecoder(conn).Decode(&pong); err != nil { + t.Fatalf("decode server reply: %v", err) + } + if pong.Event != EventPong { + t.Fatalf("unexpected server reply: %#v", pong) + } +} + +func TestListenDialEmptyAddr(t *testing.T) { + if _, err := Listen(""); !errors.Is(err, os.ErrInvalid) { + t.Fatalf("Listen empty addr error = %v, want %v", err, os.ErrInvalid) + } + if _, err := Dial(""); !errors.Is(err, os.ErrInvalid) { + t.Fatalf("Dial empty addr error = %v, want %v", err, os.ErrInvalid) + } +} diff --git a/internal/webrtc/manager_test.go b/internal/webrtc/manager_test.go new file mode 100644 index 0000000..c0fea91 --- /dev/null +++ b/internal/webrtc/manager_test.go @@ -0,0 +1,103 @@ +package webrtc + +import ( + "strings" + "testing" + "time" + + pionwebrtc "github.com/pion/webrtc/v4" +) + +func TestManagerEnsurePeerValidationAndReuse(t *testing.T) { + m, err := NewManager(pionwebrtc.Configuration{}, nil, nil, nil) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + t.Cleanup(func() { + m.CloseAll() + }) + + if _, err := m.ensurePeer(""); err == nil || !strings.Contains(err.Error(), "peer id is required") { + t.Fatalf("expected peer id validation error, got %v", err) + } + + p1, err := m.ensurePeer("peer-1") + if err != nil { + t.Fatalf("ensurePeer create: %v", err) + } + p2, err := m.ensurePeer("peer-1") + if err != nil { + t.Fatalf("ensurePeer reuse: %v", err) + } + if p1 != p2 { + t.Fatalf("expected ensurePeer to reuse existing peer instance") + } +} + +func TestManagerOfferAnswerFlow(t *testing.T) { + offerer, err := NewManager(pionwebrtc.Configuration{}, nil, nil, nil) + if err != nil { + t.Fatalf("NewManager offerer: %v", err) + } + t.Cleanup(func() { offerer.CloseAll() }) + + answerer, err := NewManager(pionwebrtc.Configuration{}, nil, nil, nil) + if err != nil { + t.Fatalf("NewManager answerer: %v", err) + } + t.Cleanup(func() { answerer.CloseAll() }) + + offer, err := offerer.CreateOffer("peer-a") + if err != nil { + t.Fatalf("CreateOffer: %v", err) + } + if strings.TrimSpace(offer) == "" { + t.Fatalf("expected non-empty SDP offer") + } + + answer, err := answerer.HandleOffer("peer-a", offer) + if err != nil { + t.Fatalf("HandleOffer: %v", err) + } + if strings.TrimSpace(answer) == "" { + t.Fatalf("expected non-empty SDP answer") + } + + if err := offerer.HandleAnswer("peer-a", answer); err != nil { + t.Fatalf("HandleAnswer: %v", err) + } + + if err := offerer.HandleAnswer("peer-a", "invalid-sdp"); err == nil { + t.Fatalf("expected invalid answer SDP to fail") + } +} + +func TestManagerAddICECandidateWriteAndClose(t *testing.T) { + m, err := NewManager(pionwebrtc.Configuration{}, nil, nil, nil) + if err != nil { + t.Fatalf("NewManager: %v", err) + } + t.Cleanup(func() { + m.CloseAll() + }) + + if err := m.WriteSample([]byte{1, 2, 3}, 20*time.Millisecond); err != nil { + t.Fatalf("WriteSample with no peers should be nil, got %v", err) + } + + if _, err := m.CreateOffer("peer-ws"); err != nil { + t.Fatalf("CreateOffer peer-ws: %v", err) + } + + if err := m.WriteSample([]byte{1, 2, 3}, 20*time.Millisecond); err != nil { + t.Fatalf("WriteSample with peer: %v", err) + } + + if err := m.AddICECandidate("peer-ws", "not-a-valid-candidate"); err == nil { + t.Fatalf("expected invalid ICE candidate to fail") + } + + m.ClosePeer("peer-ws") + m.ClosePeer("peer-ws") + m.CloseAll() +} diff --git a/internal/ws/ws_test.go b/internal/ws/ws_test.go index 33cdd09..e848700 100644 --- a/internal/ws/ws_test.go +++ b/internal/ws/ws_test.go @@ -1724,3 +1724,110 @@ func TestHub_HandleVoiceLeave_BroadcastsPresenceToAllClients(t *testing.T) { t.Fatalf("unexpected presence event: %+v", event) } } + +func TestHub_HandleVoiceSignal_ValidationErrors(t *testing.T) { + hub := NewHub(nil, nil, nil) + + sender := &Client{send: make(chan []byte, 8), userID: "user-1"} + hub.handleVoiceSignal(sender, inboundMessage{Type: "webrtc_offer", ChannelID: ""}) + errEvent := readEvent[errorEvent](t, sender.send) + if errEvent.Type != "error" || errEvent.Code != "invalid_message" { + t.Fatalf("unexpected missing channel error event: %+v", errEvent) + } + + hub.handleVoiceSignal(sender, inboundMessage{Type: "webrtc_offer", ChannelID: "ch-1"}) + errEvent = readEvent[errorEvent](t, sender.send) + if errEvent.Type != "error" || errEvent.Code != "voice_not_joined" { + t.Fatalf("unexpected not-joined error event: %+v", errEvent) + } +} + +func TestHub_HandleVoiceSignal_BroadcastToRoomPeers(t *testing.T) { + hub := NewHub(nil, nil, nil) + + sender := &Client{send: make(chan []byte, 8), userID: "user-1"} + peerA := &Client{send: make(chan []byte, 8), userID: "user-2"} + peerB := &Client{send: make(chan []byte, 8), userID: "user-3"} + + hub.mu.Lock() + hub.voiceRooms["ch-1"] = map[*Client]struct{}{ + sender: {}, + peerA: {}, + peerB: {}, + } + hub.voiceRoom[sender] = "ch-1" + hub.voiceRoom[peerA] = "ch-1" + hub.voiceRoom[peerB] = "ch-1" + hub.mu.Unlock() + + hub.handleVoiceSignal(sender, inboundMessage{ + Type: "webrtc_offer", + ChannelID: "ch-1", + SDP: "offer-sdp", + }) + + msgA := readEvent[voiceSignalEvent](t, peerA.send) + msgB := readEvent[voiceSignalEvent](t, peerB.send) + + if msgA.Type != "webrtc_offer" || msgA.Sender != "user-1" || msgA.ChannelID != "ch-1" || msgA.SDP != "offer-sdp" { + t.Fatalf("unexpected peerA signal event: %+v", msgA) + } + if msgB.Type != "webrtc_offer" || msgB.Sender != "user-1" || msgB.ChannelID != "ch-1" || msgB.SDP != "offer-sdp" { + t.Fatalf("unexpected peerB signal event: %+v", msgB) + } + + select { + case data := <-sender.send: + t.Fatalf("sender should not receive own signal, got %s", string(data)) + default: + } +} + +func TestHub_HandleVoiceSignal_RecipientRouting(t *testing.T) { + hub := NewHub(nil, nil, nil) + + sender := &Client{send: make(chan []byte, 8), userID: "user-1"} + recipientInRoom := &Client{send: make(chan []byte, 8), userID: "user-2"} + recipientOutOfRoom := &Client{send: make(chan []byte, 8), userID: "user-2"} + unrelatedInRoom := &Client{send: make(chan []byte, 8), userID: "user-3"} + + hub.mu.Lock() + hub.voiceRooms["ch-1"] = map[*Client]struct{}{ + sender: {}, + recipientInRoom: {}, + unrelatedInRoom: {}, + } + hub.voiceRoom[sender] = "ch-1" + hub.voiceRoom[recipientInRoom] = "ch-1" + hub.voiceRoom[unrelatedInRoom] = "ch-1" + hub.voiceRooms["ch-2"] = map[*Client]struct{}{recipientOutOfRoom: {}} + hub.voiceRoom[recipientOutOfRoom] = "ch-2" + hub.byUser["user-2"] = map[*Client]struct{}{ + recipientInRoom: {}, + recipientOutOfRoom: {}, + } + hub.mu.Unlock() + + hub.handleVoiceSignal(sender, inboundMessage{ + Type: "ice_candidate", + ChannelID: "ch-1", + Recipient: "user-2", + Candidate: "candidate:1 1 udp 2122260223 127.0.0.1 40000 typ host", + }) + + msg := readEvent[voiceSignalEvent](t, recipientInRoom.send) + if msg.Type != "ice_candidate" || msg.Sender != "user-1" || msg.Recipient != "user-2" || msg.Candidate == "" { + t.Fatalf("unexpected routed signal event: %+v", msg) + } + + select { + case data := <-recipientOutOfRoom.send: + t.Fatalf("out-of-room recipient should not receive signal, got %s", string(data)) + default: + } + select { + case data := <-unrelatedInRoom.send: + t.Fatalf("unrelated in-room peer should not receive recipient-targeted signal, got %s", string(data)) + default: + } +} diff --git a/scripts/coverage.sh b/scripts/coverage.sh index eab3ee9..c39adff 100755 --- a/scripts/coverage.sh +++ b/scripts/coverage.sh @@ -47,14 +47,20 @@ while [[ $# -gt 0 ]]; do done # EXCLUDE_REGEX='internal|cmd/server' -# PKGS=$(go list ./... | grep -vE "$EXCLUDE_REGEX") -PKGS=$(go list ./...) +# mapfile -t PKGS < <(go list ./... | grep -vE "$EXCLUDE_REGEX") +mapfile -t PKGS < <(go list ./...) -# Silence go test output when we only care about zero coverage +# Run tests with coverage and preserve failure output for diagnosability. if [[ "$SHOW_ZERO_ONLY" == true || "$FAIL_ON_ZERO" == true ]]; then - go test $PKGS -cover -coverprofile=coverage.out > /dev/null + TEST_LOG=$(mktemp) + if ! go test "${PKGS[@]}" -cover -coverprofile=coverage.out >"$TEST_LOG" 2>&1; then + cat "$TEST_LOG" >&2 + rm -f "$TEST_LOG" + exit 1 + fi + rm -f "$TEST_LOG" else - go test $PKGS -cover -coverprofile=coverage.out + go test "${PKGS[@]}" -cover -coverprofile=coverage.out fi if [[ "$FAIL_ON_ZERO" == true ]]; then @@ -63,7 +69,12 @@ if [[ "$FAIL_ON_ZERO" == true ]]; then | awk '$NF=="0.0%" { found=1; print } END { exit found }' elif [[ "$SHOW_ZERO_ONLY" == true ]]; then # Just print zero-coverage functions - go tool cover -func=coverage.out | awk '$NF=="0.0%"' + ZERO_OUTPUT=$(go tool cover -func=coverage.out | awk '$NF=="0.0%"') + if [[ -n "$ZERO_OUTPUT" ]]; then + printf '%s\n' "$ZERO_OUTPUT" + else + echo "No functions at 0.0% coverage" + fi else # Full coverage output go tool cover -func=coverage.out From 10d0397828ba1a400740749ebc5c316a6ceffa86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Victor=20Anderss=C3=A9n?= Date: Fri, 13 Feb 2026 12:27:55 +0200 Subject: [PATCH 3/6] test: add error handling tests for server and client main functions --- cmd/client/main_entry_test.go | 35 +++++++++++++++++++++++++++++++++++ cmd/server/main_test.go | 32 ++++++++++++++++++++++++++++++++ cmd/voiced/main_test.go | 28 ++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 cmd/client/main_entry_test.go diff --git a/cmd/client/main_entry_test.go b/cmd/client/main_entry_test.go new file mode 100644 index 0000000..d160339 --- /dev/null +++ b/cmd/client/main_entry_test.go @@ -0,0 +1,35 @@ +package main + +import ( + "bytes" + "errors" + "os" + "os/exec" + "strings" + "testing" +) + +func TestClientMainExitsOnRunError(t *testing.T) { + if os.Getenv("DIALTONE_TEST_CLIENT_MAIN_HELPER") == "1" { + os.Args = []string{"dialtone", "-voice-vad", "0"} + main() + os.Exit(0) + } + + cmd := exec.Command(os.Args[0], "-test.run=TestClientMainExitsOnRunError") + cmd.Env = append(os.Environ(), "DIALTONE_TEST_CLIENT_MAIN_HELPER=1") + var stderr bytes.Buffer + cmd.Stderr = &stderr + + err := cmd.Run() + var exitErr *exec.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected subprocess exit error, got %v", err) + } + if exitErr.ExitCode() != 1 { + t.Fatalf("expected exit code 1, got %d", exitErr.ExitCode()) + } + if !strings.Contains(stderr.String(), "error: voice-vad must be > 0") { + t.Fatalf("expected main stderr to include run error, got %q", stderr.String()) + } +} diff --git a/cmd/server/main_test.go b/cmd/server/main_test.go index 4844655..ee7d175 100644 --- a/cmd/server/main_test.go +++ b/cmd/server/main_test.go @@ -1,6 +1,7 @@ package main import ( + "bytes" "context" "crypto/ecdsa" "crypto/elliptic" @@ -16,6 +17,7 @@ import ( "net/http" "net/http/httptest" "os" + "os/exec" "path/filepath" "strings" "testing" @@ -238,6 +240,36 @@ func TestRun_FailsWithBadDBURL(t *testing.T) { } } +func TestServerMainExitsOnRunError(t *testing.T) { + if os.Getenv("DIALTONE_TEST_SERVER_MAIN_HELPER") == "1" { + _ = os.Unsetenv("DIALTONE_LISTEN_ADDR") + _ = os.Unsetenv("DIALTONE_DB_URL") + _ = os.Unsetenv("DIALTONE_USERNAME_PEPPER") + _ = os.Unsetenv("DIALTONE_ADMIN_TOKEN") + _ = os.Unsetenv("DIALTONE_TLS_CERT") + _ = os.Unsetenv("DIALTONE_TLS_KEY") + main() + os.Exit(0) + } + + cmd := exec.Command(os.Args[0], "-test.run=TestServerMainExitsOnRunError") + cmd.Env = append(os.Environ(), "DIALTONE_TEST_SERVER_MAIN_HELPER=1") + var stderr bytes.Buffer + cmd.Stderr = &stderr + + err := cmd.Run() + var exitErr *exec.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected subprocess exit error, got %v", err) + } + if exitErr.ExitCode() != 1 { + t.Fatalf("expected exit code 1, got %d", exitErr.ExitCode()) + } + if !strings.Contains(stderr.String(), "fatal: config invalid") { + t.Fatalf("expected fatal config error in stderr, got %q", stderr.String()) + } +} + // --------------------------------------------------------------------------- // serve() tests – exercise everything after config/store init // --------------------------------------------------------------------------- diff --git a/cmd/voiced/main_test.go b/cmd/voiced/main_test.go index bd87a50..120087f 100644 --- a/cmd/voiced/main_test.go +++ b/cmd/voiced/main_test.go @@ -1,7 +1,10 @@ package main import ( + "bytes" + "errors" "os" + "os/exec" "runtime" "strings" "testing" @@ -91,3 +94,28 @@ func TestRunValidationErrors(t *testing.T) { }) } } + +func TestVoicedMainExitsOnRunError(t *testing.T) { + if os.Getenv("DIALTONE_TEST_VOICED_MAIN_HELPER") == "1" { + os.Args = []string{"dialtone-voiced"} + main() + os.Exit(0) + } + + cmd := exec.Command(os.Args[0], "-test.run=TestVoicedMainExitsOnRunError") + cmd.Env = append(os.Environ(), "DIALTONE_TEST_VOICED_MAIN_HELPER=1") + var stderr bytes.Buffer + cmd.Stderr = &stderr + + err := cmd.Run() + var exitErr *exec.ExitError + if !errors.As(err, &exitErr) { + t.Fatalf("expected subprocess exit error, got %v", err) + } + if exitErr.ExitCode() != 1 { + t.Fatalf("expected exit code 1, got %d", exitErr.ExitCode()) + } + if !strings.Contains(stderr.String(), "fatal: server address is required") { + t.Fatalf("expected fatal server-address error in stderr, got %q", stderr.String()) + } +} From c1e5766243076b487b92847287e5cc17074eef37 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Victor=20Anderss=C3=A9n?= Date: Fri, 13 Feb 2026 12:38:35 +0200 Subject: [PATCH 4/6] test: add unit tests for Postgres store and migration functionality --- go.mod | 1 + go.sum | 3 + internal/storage/migrate_unit_test.go | 215 +++++++++++++++++++++++++ internal/storage/postgres_unit_test.go | 92 +++++++++++ 4 files changed, 311 insertions(+) create mode 100644 internal/storage/migrate_unit_test.go create mode 100644 internal/storage/postgres_unit_test.go diff --git a/go.mod b/go.mod index 748ae55..7f1c58e 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( require ( dario.cat/mergo v1.0.0 // indirect github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 // indirect + github.com/DATA-DOG/go-sqlmock v1.5.2 // indirect github.com/cenkalti/backoff/v4 v4.2.1 // indirect github.com/containerd/containerd v1.7.18 // indirect github.com/containerd/log v0.1.0 // indirect diff --git a/go.sum b/go.sum index 2522be4..e897f50 100644 --- a/go.sum +++ b/go.sum @@ -4,6 +4,8 @@ github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9 github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1 h1:UQHMgLO+TxOElx5B5HZ4hJQsoJ/PvUvKRhJHDQXO8P8= github.com/Azure/go-ansiterm v0.0.0-20210617225240-d185dfc1b5a1/go.mod h1:xomTg63KZ2rFqZQzSB4Vz2SUXa1BpHTVz9L5PTmPC4E= +github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU= +github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU= github.com/Microsoft/go-winio v0.6.2 h1:F2VQgta7ecxGYO8k3ZZz3RS8fVIXVxONVUPlNERoyfY= github.com/Microsoft/go-winio v0.6.2/go.mod h1:yd8OoFMLzJbo9gZq8j5qaps8bJ9aShtEA8Ipt1oGCvU= github.com/atotto/clipboard v0.1.4 h1:EH0zSVneZPSuFR11BlR9YppQTVDbh5+16AmcJi4g1z4= @@ -91,6 +93,7 @@ github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= +github.com/kisielk/sqlstruct v0.0.0-20201105191214-5f3e10d3ab46/go.mod h1:yyMNCyc/Ib3bDTKd379tNMpB/7/H5TjM2Y9QJ5THLbE= github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4= github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM= github.com/kr/pretty v0.3.0 h1:WgNl7dwNpEZ6jJ9k1snq4pZsg7DOEN8hP9Xw0Tsjwk0= diff --git a/internal/storage/migrate_unit_test.go b/internal/storage/migrate_unit_test.go new file mode 100644 index 0000000..325da5d --- /dev/null +++ b/internal/storage/migrate_unit_test.go @@ -0,0 +1,215 @@ +package storage + +import ( + "context" + "errors" + "io/fs" + "path/filepath" + "strings" + "testing" + "testing/fstest" + + "github.com/DATA-DOG/go-sqlmock" +) + +func newSQLMockDB(t *testing.T) (sqlmock.Sqlmock, func()) { + t.Helper() + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherRegexp)) + if err != nil { + t.Fatalf("sqlmock.New: %v", err) + } + cleanup := func() { + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("sqlmock expectations: %v", err) + } + _ = db.Close() + } + return mock, cleanup +} + +func newMigratorWithMock(t *testing.T, migrationFS fs.FS) (*Migrator, sqlmock.Sqlmock, func()) { + t.Helper() + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherRegexp)) + if err != nil { + t.Fatalf("sqlmock.New: %v", err) + } + cleanup := func() { + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("sqlmock expectations: %v", err) + } + _ = db.Close() + } + return NewMigrator(db, migrationFS), mock, cleanup +} + +func TestMigratorUpRequiresDB(t *testing.T) { + m := NewMigrator(nil, fstest.MapFS{}) + err := m.Up(context.Background()) + if err == nil || !strings.Contains(err.Error(), "db is required") { + t.Fatalf("expected db required error, got %v", err) + } +} + +func TestMigratorUpNoMigrations(t *testing.T) { + m, mock, cleanup := newMigratorWithMock(t, fstest.MapFS{}) + defer cleanup() + + mock.ExpectExec(`CREATE TABLE IF NOT EXISTS schema_migrations`).WillReturnResult(sqlmock.NewResult(0, 0)) + + if err := m.Up(context.Background()); err != nil { + t.Fatalf("Up() no migrations: %v", err) + } +} + +func TestMigratorUpAppliesAndRecordsCommentOnly(t *testing.T) { + migrationFS := fstest.MapFS{ + "migrations/0002_apply.sql": &fstest.MapFile{Data: []byte("CREATE TABLE demo_table (id INT);\n")}, + "migrations/0001_already.sql": &fstest.MapFile{Data: []byte("CREATE TABLE already_table (id INT);\n")}, + "migrations/0003_comment.sql": &fstest.MapFile{Data: []byte("-- comment only\n -- still comment\n")}, + } + m, mock, cleanup := newMigratorWithMock(t, migrationFS) + defer cleanup() + + mock.ExpectExec(`CREATE TABLE IF NOT EXISTS schema_migrations`).WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery(`SELECT id FROM schema_migrations`).WillReturnRows( + sqlmock.NewRows([]string{"id"}).AddRow("0001_already.sql"), + ) + + mock.ExpectBegin() + mock.ExpectExec(`CREATE TABLE demo_table`).WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`INSERT INTO schema_migrations`).WithArgs("0002_apply.sql", sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit() + + mock.ExpectExec(`INSERT INTO schema_migrations`).WithArgs("0003_comment.sql", sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1)) + + if err := m.Up(context.Background()); err != nil { + t.Fatalf("Up() apply + comment-only record: %v", err) + } +} + +func TestMigratorUpApplyExecErrorRollsBack(t *testing.T) { + migrationFS := fstest.MapFS{ + "migrations/0001_fail.sql": &fstest.MapFile{Data: []byte("CREATE TABLE broken_table (id INT);")}, + } + m, mock, cleanup := newMigratorWithMock(t, migrationFS) + defer cleanup() + + mock.ExpectExec(`CREATE TABLE IF NOT EXISTS schema_migrations`).WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery(`SELECT id FROM schema_migrations`).WillReturnRows(sqlmock.NewRows([]string{"id"})) + mock.ExpectBegin() + mock.ExpectExec(`CREATE TABLE broken_table`).WillReturnError(errors.New("exec boom")) + mock.ExpectRollback() + + err := m.Up(context.Background()) + if err == nil || !strings.Contains(err.Error(), "exec migration 0001_fail.sql") { + t.Fatalf("expected exec migration error, got %v", err) + } +} + +func TestMigratorEnsureTableAndAppliedErrors(t *testing.T) { + t.Run("ensureTable exec error", func(t *testing.T) { + m, mock, cleanup := newMigratorWithMock(t, fstest.MapFS{}) + defer cleanup() + + mock.ExpectExec(`CREATE TABLE IF NOT EXISTS schema_migrations`).WillReturnError(errors.New("create failed")) + err := m.ensureTable(context.Background()) + if err == nil || !strings.Contains(err.Error(), "create schema_migrations") { + t.Fatalf("expected ensureTable error, got %v", err) + } + }) + + t.Run("appliedMigrations query error", func(t *testing.T) { + m, mock, cleanup := newMigratorWithMock(t, fstest.MapFS{}) + defer cleanup() + + mock.ExpectQuery(`SELECT id FROM schema_migrations`).WillReturnError(errors.New("query failed")) + _, err := m.appliedMigrations(context.Background()) + if err == nil || !strings.Contains(err.Error(), "list schema_migrations") { + t.Fatalf("expected appliedMigrations query error, got %v", err) + } + }) + + t.Run("appliedMigrations scan error", func(t *testing.T) { + m, mock, cleanup := newMigratorWithMock(t, fstest.MapFS{}) + defer cleanup() + + rows := sqlmock.NewRows([]string{"id", "extra"}).AddRow("0001", "x") + mock.ExpectQuery(`SELECT id FROM schema_migrations`).WillReturnRows(rows) + _, err := m.appliedMigrations(context.Background()) + if err == nil || !strings.Contains(err.Error(), "scan schema_migrations") { + t.Fatalf("expected appliedMigrations scan error, got %v", err) + } + }) +} + +func TestMigratorApplyOneAndRecordAppliedErrors(t *testing.T) { + t.Run("applyOne record insert error rolls back", func(t *testing.T) { + m, mock, cleanup := newMigratorWithMock(t, fstest.MapFS{}) + defer cleanup() + + mock.ExpectBegin() + mock.ExpectExec(`CREATE TABLE t`).WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`INSERT INTO schema_migrations`).WithArgs("0001.sql", sqlmock.AnyArg()).WillReturnError(errors.New("insert failed")) + mock.ExpectRollback() + + err := m.applyOne(context.Background(), "0001.sql", "CREATE TABLE t (id INT)") + if err == nil || !strings.Contains(err.Error(), "record migration 0001.sql") { + t.Fatalf("expected applyOne record error, got %v", err) + } + }) + + t.Run("applyOne commit error", func(t *testing.T) { + m, mock, cleanup := newMigratorWithMock(t, fstest.MapFS{}) + defer cleanup() + + mock.ExpectBegin() + mock.ExpectExec(`CREATE TABLE t2`).WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectExec(`INSERT INTO schema_migrations`).WithArgs("0002.sql", sqlmock.AnyArg()).WillReturnResult(sqlmock.NewResult(0, 1)) + mock.ExpectCommit().WillReturnError(errors.New("commit failed")) + + err := m.applyOne(context.Background(), "0002.sql", "CREATE TABLE t2 (id INT)") + if err == nil || !strings.Contains(err.Error(), "commit migration 0002.sql") { + t.Fatalf("expected applyOne commit error, got %v", err) + } + }) + + t.Run("recordApplied exec error", func(t *testing.T) { + m, mock, cleanup := newMigratorWithMock(t, fstest.MapFS{}) + defer cleanup() + + mock.ExpectExec(`INSERT INTO schema_migrations`).WithArgs("0003.sql", sqlmock.AnyArg()).WillReturnError(errors.New("insert failed")) + err := m.recordApplied(context.Background(), "0003.sql") + if err == nil || !strings.Contains(err.Error(), "record migration 0003.sql") { + t.Fatalf("expected recordApplied error, got %v", err) + } + }) +} + +func TestStripLineComments(t *testing.T) { + input := strings.Join([]string{ + "-- top comment", + "CREATE TABLE x (id INT);", + " -- indented comment", + "INSERT INTO x VALUES (1);", + }, "\n") + + out := stripLineComments(input) + if strings.Contains(out, "--") { + t.Fatalf("expected line comments removed, got %q", out) + } + if !strings.Contains(out, "CREATE TABLE x") || !strings.Contains(out, "INSERT INTO x") { + t.Fatalf("expected SQL statements preserved, got %q", out) + } +} + +func TestEmbeddedMigrationsBasenameSanity(t *testing.T) { + files, err := fs.Glob(migrationsFS, "migrations/*.sql") + if err != nil { + t.Fatalf("glob embedded migrations: %v", err) + } + for _, file := range files { + if base := filepath.Base(file); base == "" || base == "." || base == "/" { + t.Fatalf("invalid migration basename for %q", file) + } + } +} diff --git a/internal/storage/postgres_unit_test.go b/internal/storage/postgres_unit_test.go new file mode 100644 index 0000000..2a26945 --- /dev/null +++ b/internal/storage/postgres_unit_test.go @@ -0,0 +1,92 @@ +package storage + +import ( + "context" + "io/fs" + "path/filepath" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" +) + +func TestNewPostgresStoreValidationAndPingFailure(t *testing.T) { + ctx := context.Background() + + store, err := NewPostgresStore(ctx, "") + if err == nil || !strings.Contains(err.Error(), "db url is required") { + t.Fatalf("expected db url required error, got store=%v err=%v", store, err) + } + + expiredCtx, cancel := context.WithDeadline(context.Background(), time.Unix(0, 0)) + defer cancel() + store, err = NewPostgresStore(expiredCtx, "postgres://dialtone:dialtone@127.0.0.1:1/dialtone?sslmode=disable") + if err == nil || !strings.Contains(err.Error(), "ping db") { + t.Fatalf("expected ping db failure, got store=%v err=%v", store, err) + } +} + +func TestPostgresStoreMigrateAndAccessors(t *testing.T) { + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherRegexp)) + if err != nil { + t.Fatalf("sqlmock.New: %v", err) + } + defer db.Close() + + store := &PostgresStore{ + db: db, + users: &userRepo{db: db}, + devices: &deviceRepo{db: db}, + broadcasts: &broadcastRepo{db: db}, + channels: &channelRepo{db: db}, + serverInvites: &serverInviteRepo{db: db}, + } + + files, err := fs.Glob(migrationsFS, "migrations/*.sql") + if err != nil { + t.Fatalf("glob migrations: %v", err) + } + rows := sqlmock.NewRows([]string{"id"}) + for _, file := range files { + rows.AddRow(filepath.Base(file)) + } + + mock.ExpectExec(`CREATE TABLE IF NOT EXISTS schema_migrations`).WillReturnResult(sqlmock.NewResult(0, 0)) + mock.ExpectQuery(`SELECT id FROM schema_migrations`).WillReturnRows(rows) + mock.ExpectClose() + + if err := store.Migrate(context.Background()); err != nil { + t.Fatalf("Migrate() error: %v", err) + } + if got := store.Users(); got != store.users { + t.Fatalf("Users() returned unexpected repository pointer") + } + if got := store.Devices(); got != store.devices { + t.Fatalf("Devices() returned unexpected repository pointer") + } + if got := store.Broadcasts(); got != store.broadcasts { + t.Fatalf("Broadcasts() returned unexpected repository pointer") + } + if got := store.Channels(); got != store.channels { + t.Fatalf("Channels() returned unexpected repository pointer") + } + if got := store.ServerInvites(); got != store.serverInvites { + t.Fatalf("ServerInvites() returned unexpected repository pointer") + } + if err := store.Close(context.Background()); err != nil { + t.Fatalf("Close() error: %v", err) + } + + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("sqlmock expectations: %v", err) + } +} + +func TestPostgresStoreMigrateNilDB(t *testing.T) { + store := &PostgresStore{} + err := store.Migrate(context.Background()) + if err == nil || !strings.Contains(err.Error(), "db is required") { + t.Fatalf("expected nil-db migrate error, got %v", err) + } +} From 1401654d8d89654bbecf21846bfc3b2f8a2da5c3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Victor=20Anderss=C3=A9n?= Date: Fri, 13 Feb 2026 13:01:11 +0200 Subject: [PATCH 5/6] Add tests for audio capture and playback, implement safety checks on close - Introduced unit tests for audio capture and playback functionalities, ensuring proper handling of nil and zero values during close operations. - Enhanced the Capture struct with a sync.Once to ensure safe closure of audio devices. - Added tests for hotkey parsing and server invite repository, covering various scenarios including validation and success cases. - Implemented mock database interactions for user, device, broadcast, channel, and server invite repositories to validate SQL operations without a real database. --- cmd/voiced/audio_pipeline_test.go | 18 + cmd/voiced/ptt_hotkey_cgo_test.go | 95 +++ cmd/voiced/ptt_portal_linux.go | 25 + cmd/voiced/ptt_portal_linux_test.go | 135 ++- cmd/voiced/remote_track_cgo_test.go | 54 ++ internal/audio/capture.go | 21 +- internal/audio/capture_test.go | 41 + internal/audio/playback_test.go | 23 + internal/storage/postgres_repos_unit_test.go | 830 +++++++++++++++++++ 9 files changed, 1233 insertions(+), 9 deletions(-) create mode 100644 cmd/voiced/ptt_hotkey_cgo_test.go create mode 100644 cmd/voiced/remote_track_cgo_test.go create mode 100644 internal/audio/capture_test.go create mode 100644 internal/storage/postgres_repos_unit_test.go diff --git a/cmd/voiced/audio_pipeline_test.go b/cmd/voiced/audio_pipeline_test.go index 0a434ba..de802a5 100644 --- a/cmd/voiced/audio_pipeline_test.go +++ b/cmd/voiced/audio_pipeline_test.go @@ -3,6 +3,7 @@ package main import ( + "context" "testing" "time" @@ -86,3 +87,20 @@ func TestUpdateVADFromFrameAndMeterBehavior(t *testing.T) { t.Fatalf("expected VAD updates ignored while PTT binding is active") } } + +func TestRunAudioSessionCanceledContextReturns(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, 20, false) + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + done := make(chan error, 1) + go func() { + done <- d.runAudioSession(ctx) + }() + + select { + case <-time.After(3 * time.Second): + t.Fatal("runAudioSession did not return promptly for canceled context") + case <-done: + } +} diff --git a/cmd/voiced/ptt_hotkey_cgo_test.go b/cmd/voiced/ptt_hotkey_cgo_test.go new file mode 100644 index 0000000..8829121 --- /dev/null +++ b/cmd/voiced/ptt_hotkey_cgo_test.go @@ -0,0 +1,95 @@ +package main + +import ( + "context" + "testing" + "time" +) + +func TestParseHotkeyVariants(t *testing.T) { + tests := []struct { + name string + binding string + wantErr bool + }{ + {name: "ctrl+v", binding: "ctrl+v"}, + {name: "control+v", binding: "control+v"}, + {name: "shift+space", binding: "shift+space"}, + {name: "caps", binding: "caps"}, + {name: "missing key", binding: "ctrl", wantErr: true}, + {name: "unsupported key", binding: "alt+v", wantErr: true}, + {name: "empty", binding: "", wantErr: true}, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + mods, key, err := parseHotkey(tc.binding) + if tc.wantErr { + if err == nil { + t.Fatalf("parseHotkey(%q) expected error", tc.binding) + } + return + } + if err != nil { + t.Fatalf("parseHotkey(%q) error: %v", tc.binding, err) + } + if key == 0 { + t.Fatalf("parseHotkey(%q) returned empty key", tc.binding) + } + if tc.binding == "ctrl+v" && len(mods) == 0 { + t.Fatalf("expected ctrl modifier for %q", tc.binding) + } + }) + } +} + +func TestHotkeyCodeHelpers(t *testing.T) { + if got := hotkeyModifierFromCode(123); uint32(got) != 123 { + t.Fatalf("hotkeyModifierFromCode mismatch: got %d", got) + } + if got := hotkeyKeyFromCode(456); uint32(got) != 456 { + t.Fatalf("hotkeyKeyFromCode mismatch: got %d", got) + } + + if _, err := hotkeyModifierCtrl(); err != nil { + t.Fatalf("hotkeyModifierCtrl error: %v", err) + } + if _, err := hotkeyModifierShift(); err != nil { + t.Fatalf("hotkeyModifierShift error: %v", err) + } + if _, err := hotkeySpaceKey(); err != nil { + t.Fatalf("hotkeySpaceKey error: %v", err) + } + if _, err := hotkeyVKey(); err != nil { + t.Fatalf("hotkeyVKey error: %v", err) + } + if _, err := capsLockHotkeyKey(); err != nil { + t.Fatalf("capsLockHotkeyKey error: %v", err) + } +} + +func TestNewHotkeyPTTBackendAndRunCanceled(t *testing.T) { + b, err := newHotkeyPTTBackend("ctrl+v") + if err != nil { + t.Fatalf("newHotkeyPTTBackend error: %v", err) + } + + hkBackend, ok := b.(*hotkeyPTTBackend) + if !ok || hkBackend.hk == nil { + t.Fatalf("unexpected backend type/value: %#v", b) + } + + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + done := make(chan error, 1) + go func() { + done <- hkBackend.Run(ctx, nil, nil) + }() + + select { + case <-time.After(3 * time.Second): + t.Fatal("hotkey backend Run did not return promptly") + case <-done: + } +} diff --git a/cmd/voiced/ptt_portal_linux.go b/cmd/voiced/ptt_portal_linux.go index 3617036..a75662d 100644 --- a/cmd/voiced/ptt_portal_linux.go +++ b/cmd/voiced/ptt_portal_linux.go @@ -144,6 +144,10 @@ func (p *portalPTTBackend) Run(ctx context.Context, onDown, onUp func()) error { } func createPortalSession(ctx context.Context, conn *dbus.Conn) (dbus.ObjectPath, error) { + if conn == nil { + return "", fmt.Errorf("portal dbus connection is required") + } + handleToken := portalToken("request") options := map[string]dbus.Variant{ "handle_token": dbus.MakeVariant(handleToken), @@ -202,6 +206,13 @@ func portalSessionPath(raw dbus.Variant) (dbus.ObjectPath, error) { } func bindPortalShortcut(ctx context.Context, conn *dbus.Conn, sessionPath dbus.ObjectPath, binding string) error { + if !sessionPath.IsValid() { + return fmt.Errorf("portal session path is required") + } + if conn == nil { + return fmt.Errorf("portal dbus connection is required") + } + trigger := portalPreferredTrigger(binding) if trigger != "" { if err := bindPortalShortcutWithTrigger(ctx, conn, sessionPath, trigger); err == nil { @@ -212,6 +223,13 @@ func bindPortalShortcut(ctx context.Context, conn *dbus.Conn, sessionPath dbus.O } func bindPortalShortcutWithTrigger(ctx context.Context, conn *dbus.Conn, sessionPath dbus.ObjectPath, trigger string) error { + if !sessionPath.IsValid() { + return fmt.Errorf("portal session path is required") + } + if conn == nil { + return fmt.Errorf("portal dbus connection is required") + } + details := map[string]dbus.Variant{ "description": dbus.MakeVariant("Dialtone push-to-talk"), } @@ -256,6 +274,13 @@ func bindPortalShortcutWithTrigger(ctx context.Context, conn *dbus.Conn, session } func waitPortalResponse(ctx context.Context, conn *dbus.Conn, requestPath dbus.ObjectPath) (uint32, map[string]dbus.Variant, error) { + if !requestPath.IsValid() { + return 0, nil, fmt.Errorf("portal request path is required") + } + if conn == nil { + return 0, nil, fmt.Errorf("portal dbus connection is required") + } + signals := make(chan *dbus.Signal, 8) conn.Signal(signals) defer conn.RemoveSignal(signals) diff --git a/cmd/voiced/ptt_portal_linux_test.go b/cmd/voiced/ptt_portal_linux_test.go index 05c56dd..1b17cc8 100644 --- a/cmd/voiced/ptt_portal_linux_test.go +++ b/cmd/voiced/ptt_portal_linux_test.go @@ -2,7 +2,14 @@ package main -import "testing" +import ( + "context" + "strings" + "testing" + "time" + + "github.com/godbus/dbus/v5" +) func TestPortalSessionHandleTokenStable(t *testing.T) { if portalSessionHandleToken != "dialtone_session" { @@ -81,3 +88,129 @@ func TestPortalPreferredTrigger(t *testing.T) { }) } } + +func TestPortalSessionPath(t *testing.T) { + validPath := dbus.ObjectPath("/org/freedesktop/portal/desktop/session/1") + + t.Run("object path variant", func(t *testing.T) { + got, err := portalSessionPath(dbus.MakeVariant(validPath)) + if err != nil { + t.Fatalf("portalSessionPath(object path) error: %v", err) + } + if got != validPath { + t.Fatalf("unexpected path: %q", got) + } + }) + + t.Run("string variant", func(t *testing.T) { + got, err := portalSessionPath(dbus.MakeVariant(string(validPath))) + if err != nil { + t.Fatalf("portalSessionPath(string) error: %v", err) + } + if got != validPath { + t.Fatalf("unexpected path: %q", got) + } + }) + + t.Run("invalid string", func(t *testing.T) { + _, err := portalSessionPath(dbus.MakeVariant("not/a/valid/path")) + if err == nil { + t.Fatal("expected invalid path error") + } + }) + + t.Run("unexpected type", func(t *testing.T) { + _, err := portalSessionPath(dbus.MakeVariant(123)) + if err == nil { + t.Fatal("expected unexpected type error") + } + }) +} + +func TestPortalToken(t *testing.T) { + t1 := portalToken("request") + time.Sleep(time.Microsecond) + t2 := portalToken("request") + + if !strings.HasPrefix(t1, "dialtone_request_") { + t.Fatalf("unexpected token prefix: %q", t1) + } + if !strings.HasPrefix(t2, "dialtone_request_") { + t.Fatalf("unexpected token prefix: %q", t2) + } + if t1 == t2 { + t.Fatalf("expected unique tokens, got %q and %q", t1, t2) + } +} + +func TestClosePortalSessionNilSafe(t *testing.T) { + if err := closePortalSession(nil, ""); err != nil { + t.Fatalf("closePortalSession(nil, empty) error: %v", err) + } + if err := closePortalSession(nil, dbus.ObjectPath("/org/freedesktop/portal/desktop/session/1")); err != nil { + t.Fatalf("closePortalSession(nil, path) error: %v", err) + } +} + +func TestPortalBackendGuards(t *testing.T) { + ctx := context.Background() + validPath := dbus.ObjectPath("/org/freedesktop/portal/desktop/request/1") + + t.Run("newPortalPTTBackend requires binding", func(t *testing.T) { + backend, err := newPortalPTTBackend(" ") + if err == nil { + t.Fatalf("expected binding validation error, got backend=%v", backend) + } + }) + + t.Run("Run requires initialized conn", func(t *testing.T) { + p := &portalPTTBackend{} + err := p.Run(ctx, nil, nil) + if err == nil { + t.Fatal("expected Run() init error") + } + }) + + t.Run("createPortalSession nil conn", func(t *testing.T) { + _, err := createPortalSession(ctx, nil) + if err == nil { + t.Fatal("expected createPortalSession nil-conn error") + } + }) + + t.Run("bindPortalShortcut guards", func(t *testing.T) { + err := bindPortalShortcut(ctx, nil, validPath, "ctrl+v") + if err == nil { + t.Fatal("expected bindPortalShortcut nil-conn error") + } + + err = bindPortalShortcut(ctx, nil, dbus.ObjectPath(""), "ctrl+v") + if err == nil { + t.Fatal("expected bindPortalShortcut invalid-path error") + } + }) + + t.Run("bindPortalShortcutWithTrigger guards", func(t *testing.T) { + err := bindPortalShortcutWithTrigger(ctx, nil, validPath, "v") + if err == nil { + t.Fatal("expected bindPortalShortcutWithTrigger nil-conn error") + } + + err = bindPortalShortcutWithTrigger(ctx, nil, dbus.ObjectPath(""), "v") + if err == nil { + t.Fatal("expected bindPortalShortcutWithTrigger invalid-path error") + } + }) + + t.Run("waitPortalResponse guards", func(t *testing.T) { + _, _, err := waitPortalResponse(ctx, nil, validPath) + if err == nil { + t.Fatal("expected waitPortalResponse nil-conn error") + } + + _, _, err = waitPortalResponse(ctx, nil, dbus.ObjectPath("")) + if err == nil { + t.Fatal("expected waitPortalResponse invalid-path error") + } + }) +} diff --git a/cmd/voiced/remote_track_cgo_test.go b/cmd/voiced/remote_track_cgo_test.go new file mode 100644 index 0000000..e7682f8 --- /dev/null +++ b/cmd/voiced/remote_track_cgo_test.go @@ -0,0 +1,54 @@ +//go:build linux + +package main + +import ( + "testing" + "time" + + "github.com/pion/webrtc/v4" +) + +func TestHandleRemoteTrackNilAndNonAudioNoPanic(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, 20, false) + + d.handleRemoteTrack("peer-1", nil) +} + +func TestTrackRemoteSpeakingPulseAndTimeout(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, 20, false) + pulse := make(chan struct{}, 1) + done := make(chan struct{}) + + finished := make(chan struct{}) + go func() { + d.trackRemoteSpeaking("peer-1", pulse, done) + close(finished) + }() + + pulse <- struct{}{} + time.Sleep(50 * time.Millisecond) + + d.mu.Lock() + active := d.rem["peer-1"] + d.mu.Unlock() + if !active { + t.Fatalf("expected peer to become active after pulse") + } + + time.Sleep(remoteSpeakingTimeout + 150*time.Millisecond) + + d.mu.Lock() + active = d.rem["peer-1"] + d.mu.Unlock() + if active { + t.Fatalf("expected peer to become inactive after timeout") + } + + close(done) + select { + case <-time.After(2 * time.Second): + t.Fatal("trackRemoteSpeaking did not stop after done") + case <-finished: + } +} diff --git a/internal/audio/capture.go b/internal/audio/capture.go index 42cce31..16430bd 100644 --- a/internal/audio/capture.go +++ b/internal/audio/capture.go @@ -6,6 +6,7 @@ import ( "context" "encoding/binary" "fmt" + "sync" "github.com/gen2brain/malgo" ) @@ -18,6 +19,8 @@ const ( type Capture struct { ctx *malgo.AllocatedContext device *malgo.Device + + closeOnce sync.Once } func StartCapture(ctx context.Context) (*Capture, <-chan []int16, error) { @@ -73,13 +76,15 @@ func (c *Capture) Close() error { if c == nil { return nil } - if c.device != nil { - c.device.Uninit() - c.device = nil - } - if c.ctx != nil { - c.ctx.Uninit() - c.ctx = nil - } + c.closeOnce.Do(func() { + if c.device != nil { + c.device.Uninit() + c.device = nil + } + if c.ctx != nil { + c.ctx.Uninit() + c.ctx = nil + } + }) return nil } diff --git a/internal/audio/capture_test.go b/internal/audio/capture_test.go new file mode 100644 index 0000000..d77f05b --- /dev/null +++ b/internal/audio/capture_test.go @@ -0,0 +1,41 @@ +//go:build linux + +package audio + +import ( + "context" + "testing" + "time" +) + +func TestStartCaptureSmokeAndCloseSafety(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + capture, samples, err := StartCapture(ctx) + if err != nil { + return + } + if capture == nil || samples == nil { + t.Fatalf("StartCapture() returned capture=%v samples=%v", capture, samples) + } + + cancel() + time.Sleep(20 * time.Millisecond) + + if err := capture.Close(); err != nil { + t.Fatalf("capture.Close() error: %v", err) + } +} + +func TestCaptureCloseNilAndZeroSafety(t *testing.T) { + var capture *Capture + if err := capture.Close(); err != nil { + t.Fatalf("nil capture close: %v", err) + } + + empty := &Capture{} + if err := empty.Close(); err != nil { + t.Fatalf("empty capture close: %v", err) + } +} diff --git a/internal/audio/playback_test.go b/internal/audio/playback_test.go index 699b20b..4fb1394 100644 --- a/internal/audio/playback_test.go +++ b/internal/audio/playback_test.go @@ -3,8 +3,10 @@ package audio import ( + "context" "encoding/binary" "testing" + "time" ) func decodeSample(t *testing.T, out []byte, idx int) int16 { @@ -65,3 +67,24 @@ func TestPlaybackNilAndCloseSafety(t *testing.T) { t.Fatalf("zero playback close: %v", err) } } + +func TestStartPlaybackSmoke(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + playback, err := StartPlayback(ctx) + if err != nil { + return + } + if playback == nil { + t.Fatal("StartPlayback() returned nil playback without error") + } + + playback.Write([]int16{1, 2, 3}) + cancel() + time.Sleep(20 * time.Millisecond) + + if err := playback.Close(); err != nil { + t.Fatalf("playback.Close() error: %v", err) + } +} diff --git a/internal/storage/postgres_repos_unit_test.go b/internal/storage/postgres_repos_unit_test.go new file mode 100644 index 0000000..24e4a7f --- /dev/null +++ b/internal/storage/postgres_repos_unit_test.go @@ -0,0 +1,830 @@ +package storage + +import ( + "context" + "database/sql" + "errors" + "strings" + "testing" + "time" + + "github.com/DATA-DOG/go-sqlmock" + + "github.com/Avicted/dialtone/internal/channel" + "github.com/Avicted/dialtone/internal/device" + "github.com/Avicted/dialtone/internal/message" + "github.com/Avicted/dialtone/internal/serverinvite" + "github.com/Avicted/dialtone/internal/user" +) + +func newRepoSQLMock(t *testing.T) (*sql.DB, sqlmock.Sqlmock, func()) { + t.Helper() + db, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherRegexp)) + if err != nil { + t.Fatalf("sqlmock.New: %v", err) + } + cleanup := func() { + if err := mock.ExpectationsWereMet(); err != nil { + t.Fatalf("sqlmock expectations: %v", err) + } + _ = db.Close() + } + return db, mock, cleanup +} + +func TestUserRepoSQL(t *testing.T) { + ctx := context.Background() + now := time.Now().UTC() + + t.Run("Create validation", func(t *testing.T) { + repo := &userRepo{} + err := repo.Create(ctx, user.User{}) + if err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected validation error, got %v", err) + } + }) + + t.Run("Create success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + u := user.User{ID: "u1", UsernameHash: "h1", PasswordHash: "pw", CreatedAt: now} + mock.ExpectExec(`INSERT INTO users`). + WithArgs(u.ID, u.UsernameHash, u.PasswordHash, false, false, u.CreatedAt). + WillReturnResult(sqlmock.NewResult(1, 1)) + + if err := repo.Create(ctx, u); err != nil { + t.Fatalf("Create() error: %v", err) + } + }) + + t.Run("GetByID success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "username_hash", "password_hash", "is_admin", "is_trusted", "created_at"}). + AddRow("u1", "h1", "pw", true, false, now) + mock.ExpectQuery(`FROM users WHERE id = \$1`).WithArgs(user.ID("u1")).WillReturnRows(rows) + + u, err := repo.GetByID(ctx, "u1") + if err != nil { + t.Fatalf("GetByID() error: %v", err) + } + if u.UsernameHash != "h1" || u.PasswordHash != "pw" { + t.Fatalf("unexpected user: %+v", u) + } + }) + + t.Run("GetByID not found", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "username_hash", "password_hash", "is_admin", "is_trusted", "created_at"}) + mock.ExpectQuery(`FROM users WHERE id = \$1`).WithArgs(user.ID("missing")).WillReturnRows(rows) + + _, err := repo.GetByID(ctx, "missing") + if !errors.Is(err, ErrNotFound) { + t.Fatalf("expected ErrNotFound, got %v", err) + } + }) + + t.Run("GetByUsernameHash success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "username_hash", "password_hash", "is_admin", "is_trusted", "created_at"}). + AddRow("u1", "h1", "pw", false, true, now) + mock.ExpectQuery(`FROM users WHERE username_hash = \$1`).WithArgs("h1").WillReturnRows(rows) + + u, err := repo.GetByUsernameHash(ctx, "h1") + if err != nil { + t.Fatalf("GetByUsernameHash() error: %v", err) + } + if u.ID != "u1" { + t.Fatalf("unexpected user: %+v", u) + } + }) + + t.Run("GetByUsernameHash not found", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "username_hash", "password_hash", "is_admin", "is_trusted", "created_at"}) + mock.ExpectQuery(`FROM users WHERE username_hash = \$1`).WithArgs("missing").WillReturnRows(rows) + + _, err := repo.GetByUsernameHash(ctx, "missing") + if !errors.Is(err, ErrNotFound) { + t.Fatalf("expected ErrNotFound, got %v", err) + } + }) + + t.Run("Count success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`).WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(3)) + + count, err := repo.Count(ctx) + if err != nil { + t.Fatalf("Count() error: %v", err) + } + if count != 3 { + t.Fatalf("count=%d, want 3", count) + } + }) + + t.Run("UpsertProfile validation", func(t *testing.T) { + repo := &userRepo{} + err := repo.UpsertProfile(ctx, user.Profile{}) + if err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected validation error, got %v", err) + } + }) + + t.Run("UpsertProfile success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + profile := user.Profile{UserID: "u1", NameEnc: "name", UpdatedAt: now} + mock.ExpectExec(`INSERT INTO user_profiles`).WithArgs(profile.UserID, profile.NameEnc, profile.UpdatedAt). + WillReturnResult(sqlmock.NewResult(1, 1)) + + if err := repo.UpsertProfile(ctx, profile); err != nil { + t.Fatalf("UpsertProfile() error: %v", err) + } + }) + + t.Run("ListProfiles success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + rows := sqlmock.NewRows([]string{"user_id", "name_enc", "updated_at"}).AddRow("u1", "name", now) + mock.ExpectQuery(`SELECT user_id, name_enc, updated_at FROM user_profiles`).WillReturnRows(rows) + + profiles, err := repo.ListProfiles(ctx) + if err != nil { + t.Fatalf("ListProfiles() error: %v", err) + } + if len(profiles) != 1 || profiles[0].UserID != "u1" { + t.Fatalf("unexpected profiles: %+v", profiles) + } + }) + + t.Run("UpsertDirectoryKeyEnvelope validation", func(t *testing.T) { + repo := &userRepo{} + err := repo.UpsertDirectoryKeyEnvelope(ctx, user.DirectoryKeyEnvelope{}) + if err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected validation error, got %v", err) + } + }) + + t.Run("UpsertDirectoryKeyEnvelope success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + env := user.DirectoryKeyEnvelope{DeviceID: "d1", SenderDeviceID: "d2", SenderPublicKey: "pk", Envelope: "env", CreatedAt: now} + mock.ExpectExec(`INSERT INTO directory_key_envelopes`). + WithArgs(env.DeviceID, env.SenderDeviceID, env.SenderPublicKey, env.Envelope, env.CreatedAt). + WillReturnResult(sqlmock.NewResult(1, 1)) + + if err := repo.UpsertDirectoryKeyEnvelope(ctx, env); err != nil { + t.Fatalf("UpsertDirectoryKeyEnvelope() error: %v", err) + } + }) + + t.Run("GetDirectoryKeyEnvelope success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + rows := sqlmock.NewRows([]string{"device_id", "sender_device_id", "sender_public_key", "envelope", "created_at"}). + AddRow("d1", "d2", "pk", "env", now) + mock.ExpectQuery(`FROM directory_key_envelopes WHERE device_id = \$1`).WithArgs("d1").WillReturnRows(rows) + + env, err := repo.GetDirectoryKeyEnvelope(ctx, "d1") + if err != nil { + t.Fatalf("GetDirectoryKeyEnvelope() error: %v", err) + } + if env.Envelope != "env" { + t.Fatalf("unexpected envelope: %+v", env) + } + }) + + t.Run("GetDirectoryKeyEnvelope not found", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + rows := sqlmock.NewRows([]string{"device_id", "sender_device_id", "sender_public_key", "envelope", "created_at"}) + mock.ExpectQuery(`FROM directory_key_envelopes WHERE device_id = \$1`).WithArgs("missing").WillReturnRows(rows) + + _, err := repo.GetDirectoryKeyEnvelope(ctx, "missing") + if !errors.Is(err, ErrNotFound) { + t.Fatalf("expected ErrNotFound, got %v", err) + } + }) +} + +func TestDeviceRepoSQL(t *testing.T) { + ctx := context.Background() + now := time.Now().UTC() + + t.Run("Create validation", func(t *testing.T) { + repo := &deviceRepo{} + err := repo.Create(ctx, device.Device{}) + if err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected validation error, got %v", err) + } + }) + + t.Run("Create success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + d := device.Device{ID: "d1", UserID: "u1", PublicKey: "pk", CreatedAt: now} + mock.ExpectExec(`INSERT INTO devices`).WithArgs(d.ID, d.UserID, d.PublicKey, d.CreatedAt, nil). + WillReturnResult(sqlmock.NewResult(1, 1)) + + if err := repo.Create(ctx, d); err != nil { + t.Fatalf("Create() error: %v", err) + } + }) + + t.Run("GetByID success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + lastSeen := now.Add(-time.Minute) + rows := sqlmock.NewRows([]string{"id", "user_id", "public_key", "created_at", "last_seen_at"}). + AddRow("d1", "u1", "pk", now, lastSeen) + mock.ExpectQuery(`FROM devices WHERE id = \$1`).WithArgs(device.ID("d1")).WillReturnRows(rows) + + d, err := repo.GetByID(ctx, "d1") + if err != nil { + t.Fatalf("GetByID() error: %v", err) + } + if d.PublicKey != "pk" || d.LastSeenAt == nil { + t.Fatalf("unexpected device: %+v", d) + } + }) + + t.Run("GetByID not found", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "user_id", "public_key", "created_at", "last_seen_at"}) + mock.ExpectQuery(`FROM devices WHERE id = \$1`).WithArgs(device.ID("missing")).WillReturnRows(rows) + + _, err := repo.GetByID(ctx, "missing") + if !errors.Is(err, device.ErrNotFound) { + t.Fatalf("expected device.ErrNotFound, got %v", err) + } + }) + + t.Run("GetByUserAndPublicKey success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "user_id", "public_key", "created_at", "last_seen_at"}). + AddRow("d1", "u1", "pk", now, nil) + mock.ExpectQuery(`FROM devices WHERE user_id = \$1 AND public_key = \$2`).WithArgs(user.ID("u1"), "pk").WillReturnRows(rows) + + d, err := repo.GetByUserAndPublicKey(ctx, "u1", "pk") + if err != nil { + t.Fatalf("GetByUserAndPublicKey() error: %v", err) + } + if d.ID != "d1" { + t.Fatalf("unexpected device: %+v", d) + } + }) + + t.Run("GetByUserAndPublicKey not found", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "user_id", "public_key", "created_at", "last_seen_at"}) + mock.ExpectQuery(`FROM devices WHERE user_id = \$1 AND public_key = \$2`).WithArgs(user.ID("u1"), "missing").WillReturnRows(rows) + + _, err := repo.GetByUserAndPublicKey(ctx, "u1", "missing") + if !errors.Is(err, device.ErrNotFound) { + t.Fatalf("expected device.ErrNotFound, got %v", err) + } + }) + + t.Run("ListByUser filters empty public key", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "user_id", "public_key", "created_at", "last_seen_at"}). + AddRow("d1", "u1", "pk", now, nil). + AddRow("d2", "u1", nil, now, nil) + mock.ExpectQuery(`FROM devices WHERE user_id = \$1 ORDER BY created_at`).WithArgs(user.ID("u1")).WillReturnRows(rows) + + devices, err := repo.ListByUser(ctx, "u1") + if err != nil { + t.Fatalf("ListByUser() error: %v", err) + } + if len(devices) != 1 || devices[0].ID != "d1" { + t.Fatalf("unexpected devices: %+v", devices) + } + }) + + t.Run("ListAll success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "user_id", "public_key", "created_at", "last_seen_at"}). + AddRow("d1", "u1", "pk", now, nil) + mock.ExpectQuery(`FROM devices ORDER BY created_at`).WillReturnRows(rows) + + devices, err := repo.ListAll(ctx) + if err != nil { + t.Fatalf("ListAll() error: %v", err) + } + if len(devices) != 1 { + t.Fatalf("unexpected devices len=%d", len(devices)) + } + }) + + t.Run("UpdateLastSeen validation", func(t *testing.T) { + repo := &deviceRepo{} + err := repo.UpdateLastSeen(ctx, "", time.Time{}) + if err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected validation error, got %v", err) + } + }) + + t.Run("UpdateLastSeen not found", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + mock.ExpectExec(`UPDATE devices SET last_seen_at = \$2 WHERE id = \$1`).WithArgs(device.ID("d1"), now). + WillReturnResult(sqlmock.NewResult(0, 0)) + + err := repo.UpdateLastSeen(ctx, "d1", now) + if !errors.Is(err, device.ErrNotFound) { + t.Fatalf("expected device.ErrNotFound, got %v", err) + } + }) + + t.Run("UpdateLastSeen success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + mock.ExpectExec(`UPDATE devices SET last_seen_at = \$2 WHERE id = \$1`).WithArgs(device.ID("d1"), now). + WillReturnResult(sqlmock.NewResult(0, 1)) + + if err := repo.UpdateLastSeen(ctx, "d1", now); err != nil { + t.Fatalf("UpdateLastSeen() error: %v", err) + } + }) +} + +func TestBroadcastRepoSQL(t *testing.T) { + ctx := context.Background() + now := time.Now().UTC() + + t.Run("Save validation", func(t *testing.T) { + repo := &broadcastRepo{} + err := repo.Save(ctx, message.BroadcastMessage{}) + if err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected validation error, got %v", err) + } + }) + + t.Run("Save success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &broadcastRepo{db: db} + msg := message.BroadcastMessage{ + ID: "b1", + SenderID: "u1", + SenderPublicKey: "pk", + SenderNameEnc: "name", + Body: "body", + Envelopes: map[string]string{"d1": "env"}, + SentAt: now, + } + mock.ExpectExec(`INSERT INTO broadcast_messages`). + WithArgs(msg.ID, msg.SenderID, msg.SenderPublicKey, msg.SenderNameEnc, msg.Body, sqlmock.AnyArg(), msg.SentAt). + WillReturnResult(sqlmock.NewResult(1, 1)) + + if err := repo.Save(ctx, msg); err != nil { + t.Fatalf("Save() error: %v", err) + } + }) + + t.Run("ListRecent validation", func(t *testing.T) { + repo := &broadcastRepo{} + msgs, err := repo.ListRecent(ctx, 0) + if err == nil || msgs != nil { + t.Fatalf("expected validation error with nil messages, got msgs=%v err=%v", msgs, err) + } + }) + + t.Run("ListRecent success and chronological order", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &broadcastRepo{db: db} + newer := now + older := now.Add(-time.Minute) + rows := sqlmock.NewRows([]string{"id", "sender_id", "sender_public_key", "sender_name_enc", "body", "key_envelopes", "sent_at"}). + AddRow("new", "u1", "pk", "n1", "b1", `{"d1":"e1"}`, newer). + AddRow("old", "u1", "pk", "n1", "b0", `{"d2":"e2"}`, older) + mock.ExpectQuery(`FROM broadcast_messages ORDER BY sent_at DESC LIMIT \$1`).WithArgs(2).WillReturnRows(rows) + + msgs, err := repo.ListRecent(ctx, 2) + if err != nil { + t.Fatalf("ListRecent() error: %v", err) + } + if len(msgs) != 2 || msgs[0].ID != "old" || msgs[1].ID != "new" { + t.Fatalf("unexpected order/messages: %+v", msgs) + } + if msgs[0].Envelopes["d2"] != "e2" { + t.Fatalf("expected decoded envelopes, got %+v", msgs[0].Envelopes) + } + }) +} + +func TestChannelRepoSQL(t *testing.T) { + ctx := context.Background() + now := time.Now().UTC() + + t.Run("CreateChannel validation", func(t *testing.T) { + repo := &channelRepo{} + err := repo.CreateChannel(ctx, channel.Channel{}) + if err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected validation error, got %v", err) + } + }) + + t.Run("CreateChannel success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + ch := channel.Channel{ID: "c1", NameEnc: "name", CreatedBy: "u1", CreatedAt: now} + mock.ExpectExec(`INSERT INTO channels`).WithArgs(ch.ID, ch.NameEnc, ch.CreatedBy, ch.CreatedAt). + WillReturnResult(sqlmock.NewResult(1, 1)) + + if err := repo.CreateChannel(ctx, ch); err != nil { + t.Fatalf("CreateChannel() error: %v", err) + } + }) + + t.Run("GetChannel not found", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "name_enc", "created_by", "created_at"}) + mock.ExpectQuery(`FROM channels WHERE id = \$1`).WithArgs(channel.ID("missing")).WillReturnRows(rows) + + _, err := repo.GetChannel(ctx, "missing") + if !errors.Is(err, ErrNotFound) { + t.Fatalf("expected ErrNotFound, got %v", err) + } + }) + + t.Run("GetChannel success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "name_enc", "created_by", "created_at"}).AddRow("c1", "name", "u1", now) + mock.ExpectQuery(`FROM channels WHERE id = \$1`).WithArgs(channel.ID("c1")).WillReturnRows(rows) + + ch, err := repo.GetChannel(ctx, "c1") + if err != nil { + t.Fatalf("GetChannel() error: %v", err) + } + if ch.ID != "c1" { + t.Fatalf("unexpected channel: %+v", ch) + } + }) + + t.Run("ListChannels success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "name_enc", "created_by", "created_at"}).AddRow("c1", "n1", "u1", now) + mock.ExpectQuery(`FROM channels ORDER BY created_at DESC`).WillReturnRows(rows) + + channels, err := repo.ListChannels(ctx) + if err != nil { + t.Fatalf("ListChannels() error: %v", err) + } + if len(channels) != 1 { + t.Fatalf("unexpected channels len=%d", len(channels)) + } + }) + + t.Run("DeleteChannel validation", func(t *testing.T) { + repo := &channelRepo{} + err := repo.DeleteChannel(ctx, "") + if err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected validation error, got %v", err) + } + }) + + t.Run("DeleteChannel not found", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + mock.ExpectExec(`DELETE FROM channels WHERE id = \$1`).WithArgs(channel.ID("c1")). + WillReturnResult(sqlmock.NewResult(0, 0)) + + err := repo.DeleteChannel(ctx, "c1") + if !errors.Is(err, ErrNotFound) { + t.Fatalf("expected ErrNotFound, got %v", err) + } + }) + + t.Run("DeleteChannel success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + mock.ExpectExec(`DELETE FROM channels WHERE id = \$1`).WithArgs(channel.ID("c1")). + WillReturnResult(sqlmock.NewResult(0, 1)) + + if err := repo.DeleteChannel(ctx, "c1"); err != nil { + t.Fatalf("DeleteChannel() error: %v", err) + } + }) + + t.Run("UpdateChannelName validation", func(t *testing.T) { + repo := &channelRepo{} + err := repo.UpdateChannelName(ctx, "", "") + if err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected validation error, got %v", err) + } + }) + + t.Run("UpdateChannelName not found", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + mock.ExpectExec(`UPDATE channels SET name_enc = \$2 WHERE id = \$1`).WithArgs(channel.ID("c1"), "name"). + WillReturnResult(sqlmock.NewResult(0, 0)) + + err := repo.UpdateChannelName(ctx, "c1", "name") + if !errors.Is(err, ErrNotFound) { + t.Fatalf("expected ErrNotFound, got %v", err) + } + }) + + t.Run("UpdateChannelName success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + mock.ExpectExec(`UPDATE channels SET name_enc = \$2 WHERE id = \$1`).WithArgs(channel.ID("c1"), "name"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + if err := repo.UpdateChannelName(ctx, "c1", "name"); err != nil { + t.Fatalf("UpdateChannelName() error: %v", err) + } + }) + + t.Run("SaveMessage validation", func(t *testing.T) { + repo := &channelRepo{} + err := repo.SaveMessage(ctx, channel.Message{}) + if err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected validation error, got %v", err) + } + }) + + t.Run("SaveMessage success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + msg := channel.Message{ID: "m1", ChannelID: "c1", SenderID: "u1", SenderNameEnc: "name", Body: "body", SentAt: now} + mock.ExpectExec(`INSERT INTO channel_messages`).WithArgs(msg.ID, msg.ChannelID, msg.SenderID, msg.SenderNameEnc, msg.Body, msg.SentAt). + WillReturnResult(sqlmock.NewResult(1, 1)) + + if err := repo.SaveMessage(ctx, msg); err != nil { + t.Fatalf("SaveMessage() error: %v", err) + } + }) + + t.Run("ListRecentMessages validation", func(t *testing.T) { + repo := &channelRepo{} + msgs, err := repo.ListRecentMessages(ctx, "", 0) + if err == nil || msgs != nil { + t.Fatalf("expected validation error with nil messages, got msgs=%v err=%v", msgs, err) + } + }) + + t.Run("ListRecentMessages success and chronological order", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + newer := now + older := now.Add(-time.Minute) + rows := sqlmock.NewRows([]string{"id", "channel_id", "sender_id", "sender_name_enc", "body", "sent_at"}). + AddRow("new", "c1", "u1", "n1", "b1", newer). + AddRow("old", "c1", "u1", "n1", "b0", older) + mock.ExpectQuery(`FROM channel_messages WHERE channel_id = \$1 ORDER BY sent_at DESC LIMIT \$2`).WithArgs(channel.ID("c1"), 2). + WillReturnRows(rows) + + msgs, err := repo.ListRecentMessages(ctx, "c1", 2) + if err != nil { + t.Fatalf("ListRecentMessages() error: %v", err) + } + if len(msgs) != 2 || msgs[0].ID != "old" || msgs[1].ID != "new" { + t.Fatalf("unexpected order/messages: %+v", msgs) + } + }) + + t.Run("UpsertKeyEnvelope validation", func(t *testing.T) { + repo := &channelRepo{} + err := repo.UpsertKeyEnvelope(ctx, channel.KeyEnvelope{}) + if err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected validation error, got %v", err) + } + }) + + t.Run("UpsertKeyEnvelope success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + env := channel.KeyEnvelope{ChannelID: "c1", DeviceID: "d1", SenderDeviceID: "d2", SenderPublicKey: "pk", Envelope: "env", CreatedAt: now} + mock.ExpectExec(`INSERT INTO channel_key_envelopes`). + WithArgs(env.ChannelID, env.DeviceID, env.SenderDeviceID, env.SenderPublicKey, env.Envelope, env.CreatedAt). + WillReturnResult(sqlmock.NewResult(1, 1)) + + if err := repo.UpsertKeyEnvelope(ctx, env); err != nil { + t.Fatalf("UpsertKeyEnvelope() error: %v", err) + } + }) + + t.Run("GetKeyEnvelope validation", func(t *testing.T) { + repo := &channelRepo{} + _, err := repo.GetKeyEnvelope(ctx, "", "") + if err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected validation error, got %v", err) + } + }) + + t.Run("GetKeyEnvelope not found", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + rows := sqlmock.NewRows([]string{"channel_id", "device_id", "sender_device_id", "sender_public_key", "envelope", "created_at"}) + mock.ExpectQuery(`FROM channel_key_envelopes WHERE channel_id = \$1 AND device_id = \$2`).WithArgs(channel.ID("c1"), device.ID("d1")).WillReturnRows(rows) + + _, err := repo.GetKeyEnvelope(ctx, "c1", "d1") + if !errors.Is(err, ErrNotFound) { + t.Fatalf("expected ErrNotFound, got %v", err) + } + }) + + t.Run("GetKeyEnvelope success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + rows := sqlmock.NewRows([]string{"channel_id", "device_id", "sender_device_id", "sender_public_key", "envelope", "created_at"}). + AddRow("c1", "d1", "d2", "pk", "env", now) + mock.ExpectQuery(`FROM channel_key_envelopes WHERE channel_id = \$1 AND device_id = \$2`).WithArgs(channel.ID("c1"), device.ID("d1")).WillReturnRows(rows) + + env, err := repo.GetKeyEnvelope(ctx, "c1", "d1") + if err != nil { + t.Fatalf("GetKeyEnvelope() error: %v", err) + } + if env.Envelope != "env" { + t.Fatalf("unexpected envelope: %+v", env) + } + }) +} + +func TestServerInviteRepoSQL(t *testing.T) { + ctx := context.Background() + now := time.Now().UTC() + + t.Run("Create validation", func(t *testing.T) { + repo := &serverInviteRepo{} + err := repo.Create(ctx, serverinvite.Invite{}) + if err == nil || !strings.Contains(err.Error(), "required") { + t.Fatalf("expected validation error, got %v", err) + } + }) + + t.Run("Create success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &serverInviteRepo{db: db} + invite := serverinvite.Invite{Token: "tok", CreatedAt: now, ExpiresAt: now.Add(time.Hour)} + mock.ExpectExec(`INSERT INTO server_invites`).WithArgs(invite.Token, invite.CreatedAt, invite.ExpiresAt). + WillReturnResult(sqlmock.NewResult(1, 1)) + + if err := repo.Create(ctx, invite); err != nil { + t.Fatalf("Create() error: %v", err) + } + }) + + t.Run("Consume invalid input", func(t *testing.T) { + repo := &serverInviteRepo{} + _, err := repo.Consume(ctx, "", "", time.Time{}) + if !errors.Is(err, serverinvite.ErrInvalidInput) { + t.Fatalf("expected ErrInvalidInput, got %v", err) + } + }) + + t.Run("Consume success", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &serverInviteRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "created_at", "expires_at", "consumed_at", "consumed_by"}). + AddRow("tok", now.Add(-time.Minute), now.Add(time.Hour), now, nil) + mock.ExpectQuery(`UPDATE server_invites`).WithArgs("tok", now).WillReturnRows(rows) + + invite, err := repo.Consume(ctx, "tok", "u1", now) + if err != nil { + t.Fatalf("Consume() error: %v", err) + } + if invite.Token != "tok" || invite.ConsumedAt == nil { + t.Fatalf("unexpected invite: %+v", invite) + } + }) + + t.Run("Consume not found after fallback", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &serverInviteRepo{db: db} + emptyUpdate := sqlmock.NewRows([]string{"id", "created_at", "expires_at", "consumed_at", "consumed_by"}) + emptySelect := sqlmock.NewRows([]string{"id", "created_at", "expires_at", "consumed_at", "consumed_by"}) + mock.ExpectQuery(`UPDATE server_invites`).WithArgs("missing", now).WillReturnRows(emptyUpdate) + mock.ExpectQuery(`FROM server_invites WHERE id = \$1`).WithArgs("missing").WillReturnRows(emptySelect) + + _, err := repo.Consume(ctx, "missing", "u1", now) + if !errors.Is(err, serverinvite.ErrNotFound) { + t.Fatalf("expected ErrNotFound, got %v", err) + } + }) + + t.Run("Consume consumed after fallback", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &serverInviteRepo{db: db} + emptyUpdate := sqlmock.NewRows([]string{"id", "created_at", "expires_at", "consumed_at", "consumed_by"}) + consumed := sqlmock.NewRows([]string{"id", "created_at", "expires_at", "consumed_at", "consumed_by"}). + AddRow("tok", now.Add(-time.Hour), now.Add(time.Hour), now.Add(-time.Minute), "u1") + mock.ExpectQuery(`UPDATE server_invites`).WithArgs("tok", now).WillReturnRows(emptyUpdate) + mock.ExpectQuery(`FROM server_invites WHERE id = \$1`).WithArgs("tok").WillReturnRows(consumed) + + _, err := repo.Consume(ctx, "tok", "u1", now) + if !errors.Is(err, serverinvite.ErrConsumed) { + t.Fatalf("expected ErrConsumed, got %v", err) + } + }) + + t.Run("Consume expired after fallback", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &serverInviteRepo{db: db} + emptyUpdate := sqlmock.NewRows([]string{"id", "created_at", "expires_at", "consumed_at", "consumed_by"}) + expired := sqlmock.NewRows([]string{"id", "created_at", "expires_at", "consumed_at", "consumed_by"}). + AddRow("tok", now.Add(-2*time.Hour), now.Add(-time.Hour), nil, nil) + mock.ExpectQuery(`UPDATE server_invites`).WithArgs("tok", now).WillReturnRows(emptyUpdate) + mock.ExpectQuery(`FROM server_invites WHERE id = \$1`).WithArgs("tok").WillReturnRows(expired) + + _, err := repo.Consume(ctx, "tok", "u1", now) + if !errors.Is(err, serverinvite.ErrExpired) { + t.Fatalf("expected ErrExpired, got %v", err) + } + }) +} From 2da777ce455a6e866bbb9264c339bc65f5921cf0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Victor=20Anderss=C3=A9n?= Date: Fri, 13 Feb 2026 14:27:19 +0200 Subject: [PATCH 6/6] Add tests for hotkey parsing, PTT controller modes, and voice stats - Enhance hotkey parsing tests with additional cases for capslock aliases and error scenarios. - Introduce tests for PTT controller in portal and hotkey modes, including validation error handling. - Add tests for voice stats log loop to ensure counters are flushed correctly on tick. - Implement tests for WebSocket client connection handling and read loop cancellation. - Expand channel service tests to verify default service initialization. - Add error handling tests for user and device repositories in PostgreSQL. - Introduce tests for voice members management in the voice daemon, ensuring proper member handling and broadcasting. - Create tests for CPU usage logging to validate context cancellation behavior. --- .github/workflows/badge.yml | 2 +- cmd/voiced/cpu_stats_test.go | 25 ++ cmd/voiced/daemon_test.go | 121 +++++++++ cmd/voiced/ptt_hotkey_cgo_test.go | 39 +++ cmd/voiced/ptt_test.go | 205 +++++++++++++++ cmd/voiced/stats_test.go | 34 +++ cmd/voiced/voice_members_test.go | 124 ++++++++++ cmd/voiced/ws_client_test.go | 42 ++++ internal/channel/service_test.go | 22 ++ internal/storage/postgres_repos_unit_test.go | 193 +++++++++++++++ internal/user/service_test.go | 21 ++ internal/ws/ws_test.go | 248 +++++++++++++++++++ 12 files changed, 1075 insertions(+), 1 deletion(-) create mode 100644 cmd/voiced/cpu_stats_test.go create mode 100644 cmd/voiced/voice_members_test.go diff --git a/.github/workflows/badge.yml b/.github/workflows/badge.yml index f9df0ca..ae04902 100644 --- a/.github/workflows/badge.yml +++ b/.github/workflows/badge.yml @@ -47,7 +47,7 @@ jobs: fi color="red" if awk "BEGIN {exit !($total >= 80)}"; then - color="brightgreen" + color="green" elif awk "BEGIN {exit !($total >= 60)}"; then color="yellow" fi diff --git a/cmd/voiced/cpu_stats_test.go b/cmd/voiced/cpu_stats_test.go new file mode 100644 index 0000000..f86e540 --- /dev/null +++ b/cmd/voiced/cpu_stats_test.go @@ -0,0 +1,25 @@ +package main + +import ( + "context" + "testing" + "time" +) + +func TestLogCPUUsageStopsOnContextCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + + go func() { + logCPUUsage(ctx) + close(done) + }() + + cancel() + + select { + case <-done: + case <-time.After(3 * time.Second): + t.Fatalf("logCPUUsage did not return after context cancellation") + } +} diff --git a/cmd/voiced/daemon_test.go b/cmd/voiced/daemon_test.go index 1924eb5..0ae51e3 100644 --- a/cmd/voiced/daemon_test.go +++ b/cmd/voiced/daemon_test.go @@ -2,13 +2,18 @@ package main import ( "context" + "encoding/json" "errors" + "net/http" + "net/http/httptest" "path/filepath" + "sync/atomic" "testing" "time" "github.com/Avicted/dialtone/internal/ipc" "github.com/pion/webrtc/v4" + "nhooyr.io/websocket" ) func TestNewVoiceDaemonDefaultsVADThreshold(t *testing.T) { @@ -258,6 +263,122 @@ func TestRunWSLoopAndConnectWSHonorCanceledContext(t *testing.T) { } } +func TestConnectWSWithRetrySuccess(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + t.Errorf("accept websocket: %v", err) + return + } + defer conn.Close(websocket.StatusNormalClosure, "done") + for { + if _, _, err := conn.Read(context.Background()); err != nil { + return + } + } + })) + defer server.Close() + + d := newVoiceDaemon(server.URL, "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + ws, err := d.connectWSWithRetry(ctx, 0) + if err != nil { + t.Fatalf("connectWSWithRetry() error: %v", err) + } + if ws == nil { + t.Fatalf("expected websocket client") + } + ws.Close() +} + +func TestRunWSLoopConnectsAndStopsOnCancel(t *testing.T) { + var joinSignals atomic.Int32 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + t.Errorf("accept websocket: %v", err) + return + } + defer conn.Close(websocket.StatusNormalClosure, "done") + for { + _, data, err := conn.Read(context.Background()) + if err != nil { + return + } + var msg VoiceSignal + if json.Unmarshal(data, &msg) == nil && msg.Type == "voice_join" { + joinSignals.Add(1) + } + } + })) + defer server.Close() + + d := newVoiceDaemon(server.URL, "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + d.room = "room-1" + + ctx, cancel := context.WithCancel(context.Background()) + out := make(chan VoiceSignal, 1) + done := make(chan error, 1) + go func() { + done <- d.runWSLoop(ctx, out) + }() + + deadline := time.Now().Add(3 * time.Second) + for d.currentWS() == nil { + if time.Now().After(deadline) { + cancel() + t.Fatalf("runWSLoop did not establish websocket connection") + } + time.Sleep(10 * time.Millisecond) + } + + cancel() + select { + case err := <-done: + if err != nil { + t.Fatalf("runWSLoop() error: %v", err) + } + case <-time.After(3 * time.Second): + t.Fatalf("runWSLoop did not return after cancel") + } + + if _, ok := <-out; ok { + t.Fatalf("expected output channel to be closed") + } + if joinSignals.Load() == 0 { + t.Fatalf("expected runWSLoop to resend room join after websocket connect") + } +} + +func TestHandleWSMessageGuardPaths(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + d.local = "alice" + d.room = "room-1" + d.memb = map[string]struct{}{"alice": {}} + + d.handleWSMessage(VoiceSignal{Type: "voice_join", ChannelID: "room-1", Sender: "alice"}) + d.handleWSMessage(VoiceSignal{Type: "voice_join", ChannelID: "room-1", Sender: "bob"}) + d.handleWSMessage(VoiceSignal{Type: "voice_leave", ChannelID: "room-1", Sender: "bob"}) + d.handleWSMessage(VoiceSignal{Type: "voice_roster", ChannelID: ""}) + d.handleWSMessage(VoiceSignal{Type: "voice.presence", ChannelID: "room-1", Users: []string{"bob"}}) + d.handleWSMessage(VoiceSignal{Type: "voice.presence.snapshot", ChannelID: "room-1", Users: []string{"bob"}}) + d.handleWSMessage(VoiceSignal{Type: "webrtc_offer", ChannelID: "room-1", Sender: "bob", SDP: "offer"}) + d.handleWSMessage(VoiceSignal{Type: "webrtc_answer", ChannelID: "room-1", Sender: "bob", SDP: "answer"}) + d.handleWSMessage(VoiceSignal{Type: "ice_candidate", ChannelID: "room-1", Sender: "bob", Candidate: "cand"}) + d.handleWSMessage(VoiceSignal{Type: "unknown", ChannelID: "room-1", Sender: "bob"}) + + d.mu.Lock() + _, hasAlice := d.memb["alice"] + membersLen := len(d.memb) + d.mu.Unlock() + + if !hasAlice || membersLen != 1 { + t.Fatalf("expected guard-path messages to keep members unchanged, got alice=%v len=%d", hasAlice, membersLen) + } +} + func TestRunPTTInvalidBackendReturnsError(t *testing.T) { d := newVoiceDaemon("http://server", "token", "ctrl+v", "invalid-backend", webrtc.Configuration{}, defaultVADThreshold, false) if err := d.runPTT(context.Background()); err == nil { diff --git a/cmd/voiced/ptt_hotkey_cgo_test.go b/cmd/voiced/ptt_hotkey_cgo_test.go index 8829121..248c254 100644 --- a/cmd/voiced/ptt_hotkey_cgo_test.go +++ b/cmd/voiced/ptt_hotkey_cgo_test.go @@ -16,7 +16,12 @@ func TestParseHotkeyVariants(t *testing.T) { {name: "control+v", binding: "control+v"}, {name: "shift+space", binding: "shift+space"}, {name: "caps", binding: "caps"}, + {name: "capslock alias", binding: "capslock"}, + {name: "caps_lock alias", binding: "caps_lock"}, + {name: "trim and case normalize", binding: " ConTRol + Shift + V "}, {name: "missing key", binding: "ctrl", wantErr: true}, + {name: "modifiers only", binding: "ctrl+shift", wantErr: true}, + {name: "empty segment", binding: "ctrl++v", wantErr: true}, {name: "unsupported key", binding: "alt+v", wantErr: true}, {name: "empty", binding: "", wantErr: true}, } @@ -43,6 +48,30 @@ func TestParseHotkeyVariants(t *testing.T) { } } +func TestParseHotkeyModifierCounts(t *testing.T) { + mods, key, err := parseHotkey("control+shift+space") + if err != nil { + t.Fatalf("parseHotkey(control+shift+space) error: %v", err) + } + if key == 0 { + t.Fatalf("expected non-zero key") + } + if len(mods) != 2 { + t.Fatalf("expected two modifiers, got %d", len(mods)) + } + + mods, key, err = parseHotkey("capslock") + if err != nil { + t.Fatalf("parseHotkey(capslock) error: %v", err) + } + if key == 0 { + t.Fatalf("expected non-zero key for capslock") + } + if len(mods) != 0 { + t.Fatalf("expected no modifiers for capslock, got %d", len(mods)) + } +} + func TestHotkeyCodeHelpers(t *testing.T) { if got := hotkeyModifierFromCode(123); uint32(got) != 123 { t.Fatalf("hotkeyModifierFromCode mismatch: got %d", got) @@ -93,3 +122,13 @@ func TestNewHotkeyPTTBackendAndRunCanceled(t *testing.T) { case <-done: } } + +func TestNewHotkeyPTTBackendValidationError(t *testing.T) { + backend, err := newHotkeyPTTBackend("alt+v") + if err == nil { + t.Fatalf("expected backend validation to fail") + } + if backend != nil { + t.Fatalf("expected nil backend on validation failure, got %#v", backend) + } +} diff --git a/cmd/voiced/ptt_test.go b/cmd/voiced/ptt_test.go index e66b824..177fcde 100644 --- a/cmd/voiced/ptt_test.go +++ b/cmd/voiced/ptt_test.go @@ -116,6 +116,211 @@ func TestNewPTTControllerIncludesStartupInfo(t *testing.T) { } } +func TestNewPTTControllerPortalMode(t *testing.T) { + portalErr := errors.New("portal unavailable") + + t.Run("success", func(t *testing.T) { + portalCalled := false + hotkeyCalled := false + prevPortal := newPortalBackend + prevHotkey := newHotkeyBackend + newPortalBackend = func(binding string) (pttBackend, error) { + portalCalled = true + if binding != "caps" { + t.Fatalf("expected binding caps, got %q", binding) + } + return testPTTBackend{}, nil + } + newHotkeyBackend = func(string) (pttBackend, error) { + hotkeyCalled = true + return testPTTBackend{}, nil + } + t.Cleanup(func() { + newPortalBackend = prevPortal + newHotkeyBackend = prevHotkey + }) + + controller, err := newPTTController("caps", pttBackendPortal) + if err != nil { + t.Fatalf("newPTTController portal mode success error: %v", err) + } + if controller == nil || controller.backend == nil { + t.Fatalf("expected configured controller") + } + if !portalCalled || hotkeyCalled { + t.Fatalf("expected only portal backend to be used") + } + if !strings.Contains(controller.startupInfo, "selected=portal") { + t.Fatalf("expected portal startup diagnostic, got %q", controller.startupInfo) + } + }) + + t.Run("failure wraps", func(t *testing.T) { + prevPortal := newPortalBackend + newPortalBackend = func(string) (pttBackend, error) { + return nil, portalErr + } + t.Cleanup(func() { + newPortalBackend = prevPortal + }) + + controller, err := newPTTController("caps", pttBackendPortal) + if err == nil { + t.Fatalf("expected portal mode failure") + } + if controller != nil { + t.Fatalf("expected nil controller on portal mode failure") + } + if !errors.Is(err, portalErr) { + t.Fatalf("expected wrapped portal error, got %v", err) + } + if !strings.Contains(err.Error(), "ptt portal backend unavailable") { + t.Fatalf("unexpected portal failure error: %v", err) + } + }) +} + +func TestNewPTTControllerHotkeyModeFailure(t *testing.T) { + hotkeyErr := errors.New("hotkey unavailable") + portalCalled := false + prevPortal := newPortalBackend + prevHotkey := newHotkeyBackend + newPortalBackend = func(string) (pttBackend, error) { + portalCalled = true + return testPTTBackend{}, nil + } + newHotkeyBackend = func(string) (pttBackend, error) { + return nil, hotkeyErr + } + t.Cleanup(func() { + newPortalBackend = prevPortal + newHotkeyBackend = prevHotkey + }) + + controller, err := newPTTController("caps", pttBackendHotkey) + if err == nil { + t.Fatalf("expected hotkey mode failure") + } + if controller != nil { + t.Fatalf("expected nil controller on hotkey mode failure") + } + if !errors.Is(err, hotkeyErr) { + t.Fatalf("expected wrapped hotkey error, got %v", err) + } + if portalCalled { + t.Fatalf("did not expect portal backend call in hotkey mode") + } +} + +func TestNewPTTControllerAutoLinuxPortalAndFallbackFailures(t *testing.T) { + t.Setenv("WAYLAND_DISPLAY", "") + t.Setenv("XDG_SESSION_TYPE", "x11") + t.Setenv("DIALTONE_PTT_WAYLAND_HOTKEY_FALLBACK", "") + + t.Run("portal success preferred", func(t *testing.T) { + portalCalled := false + hotkeyCalled := false + prevPortal := newPortalBackend + prevHotkey := newHotkeyBackend + newPortalBackend = func(string) (pttBackend, error) { + portalCalled = true + return testPTTBackend{}, nil + } + newHotkeyBackend = func(string) (pttBackend, error) { + hotkeyCalled = true + return testPTTBackend{}, nil + } + t.Cleanup(func() { + newPortalBackend = prevPortal + newHotkeyBackend = prevHotkey + }) + + controller, err := newPTTController("caps", pttBackendAuto) + if err != nil { + t.Fatalf("auto mode portal success error: %v", err) + } + if controller == nil || controller.backend == nil { + t.Fatalf("expected configured controller") + } + if !portalCalled || hotkeyCalled { + t.Fatalf("expected portal preferred without hotkey fallback") + } + if !strings.Contains(controller.startupInfo, "selected=portal") { + t.Fatalf("expected portal startup diagnostic, got %q", controller.startupInfo) + } + }) + + t.Run("both backends fail", func(t *testing.T) { + portalErr := errors.New("portal unavailable") + hotkeyErr := errors.New("hotkey unavailable") + portalCalled := false + hotkeyCalled := false + prevPortal := newPortalBackend + prevHotkey := newHotkeyBackend + newPortalBackend = func(string) (pttBackend, error) { + portalCalled = true + return nil, portalErr + } + newHotkeyBackend = func(string) (pttBackend, error) { + hotkeyCalled = true + return nil, hotkeyErr + } + t.Cleanup(func() { + newPortalBackend = prevPortal + newHotkeyBackend = prevHotkey + }) + + controller, err := newPTTController("caps", pttBackendAuto) + if err == nil { + t.Fatalf("expected auto mode failure when both backends fail") + } + if controller != nil { + t.Fatalf("expected nil controller when both backends fail") + } + if !errors.Is(err, hotkeyErr) { + t.Fatalf("expected hotkey failure to be returned, got %v", err) + } + if !portalCalled || !hotkeyCalled { + t.Fatalf("expected both portal and hotkey backends attempted") + } + }) +} + +func TestNormalizePTTBackendMode(t *testing.T) { + mode, err := normalizePTTBackendMode("") + if err != nil || mode != pttBackendAuto { + t.Fatalf("normalize empty mode = %q, err=%v; want %q", mode, err, pttBackendAuto) + } + + mode, err = normalizePTTBackendMode(" HOTKEY ") + if err != nil || mode != pttBackendHotkey { + t.Fatalf("normalize HOTKEY mode = %q, err=%v; want %q", mode, err, pttBackendHotkey) + } + + if _, err = normalizePTTBackendMode("invalid"); err == nil { + t.Fatalf("expected invalid backend mode to fail") + } +} + +func TestIsWaylandSessionLinuxEnvCombos(t *testing.T) { + t.Setenv("WAYLAND_DISPLAY", "") + t.Setenv("XDG_SESSION_TYPE", "x11") + if isWaylandSession() { + t.Fatalf("expected non-wayland session when env indicates x11") + } + + t.Setenv("XDG_SESSION_TYPE", "wayland") + if !isWaylandSession() { + t.Fatalf("expected wayland session from XDG_SESSION_TYPE") + } + + t.Setenv("WAYLAND_DISPLAY", "wayland-1") + t.Setenv("XDG_SESSION_TYPE", "") + if !isWaylandSession() { + t.Fatalf("expected wayland session from WAYLAND_DISPLAY") + } +} + func TestPTTStartupDiagnostic(t *testing.T) { msg := pttStartupDiagnostic(pttBackendAuto, true, "unavailable", "none", errors.New("portal unavailable")) if !strings.Contains(msg, "mode=auto") { diff --git a/cmd/voiced/stats_test.go b/cmd/voiced/stats_test.go index cb4d969..841d1c6 100644 --- a/cmd/voiced/stats_test.go +++ b/cmd/voiced/stats_test.go @@ -49,3 +49,37 @@ func TestVoiceStatsLogLoopStopsOnContextCancel(t *testing.T) { t.Fatalf("LogLoop did not stop after cancel") } } + +func TestVoiceStatsLogLoopFlushesCountersOnTick(t *testing.T) { + s := newVoiceStats() + s.RecordSent(1600) + s.RecordDrop() + + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + s.LogLoop(ctx) + close(done) + }() + + deadline := time.Now().Add(12 * time.Second) + for time.Now().Before(deadline) { + if s.bytesSent.Load() == 0 && s.framesSent.Load() == 0 && s.framesDropped.Load() == 0 { + break + } + time.Sleep(50 * time.Millisecond) + } + + if s.bytesSent.Load() != 0 || s.framesSent.Load() != 0 || s.framesDropped.Load() != 0 { + cancel() + <-done + t.Fatalf("expected counters to be reset after log loop tick") + } + + cancel() + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("LogLoop did not stop after cancel") + } +} diff --git a/cmd/voiced/voice_members_test.go b/cmd/voiced/voice_members_test.go new file mode 100644 index 0000000..f8df021 --- /dev/null +++ b/cmd/voiced/voice_members_test.go @@ -0,0 +1,124 @@ +package main + +import ( + "reflect" + "testing" + + "github.com/Avicted/dialtone/internal/ipc" + "github.com/pion/webrtc/v4" +) + +func TestVoiceMembersPayloadLocked(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + d.memb = map[string]struct{}{"bob": {}, "alice": {}} + + if _, _, ok := d.voiceMembersPayloadLocked(); ok { + t.Fatalf("expected no payload when ipc server is unavailable") + } + + d.ipc = &ipcServer{} + if _, _, ok := d.voiceMembersPayloadLocked(); ok { + t.Fatalf("expected no payload when room is empty") + } + + d.room = "room-1" + server, payload, ok := d.voiceMembersPayloadLocked() + if !ok { + t.Fatalf("expected payload to be produced") + } + if server != d.ipc { + t.Fatalf("expected payload server to match daemon ipc server") + } + if payload.Event != ipc.EventVoiceMembers || payload.Room != "room-1" { + t.Fatalf("unexpected payload metadata: %+v", payload) + } + if want := []string{"alice", "bob"}; !reflect.DeepEqual(payload.Users, want) { + t.Fatalf("unexpected sorted users: got=%v want=%v", payload.Users, want) + } +} + +func TestResetVoiceMembersForCurrentRoom(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + d.local = " alice " + d.memb = map[string]struct{}{"bob": {}} + + d.resetVoiceMembersForCurrentRoom() + d.mu.Lock() + emptyRoomMembers := len(d.memb) + d.mu.Unlock() + if emptyRoomMembers != 0 { + t.Fatalf("expected no members when room is empty, got %d", emptyRoomMembers) + } + + d.room = "room-1" + d.memb["bob"] = struct{}{} + d.resetVoiceMembersForCurrentRoom() + + d.mu.Lock() + _, hasAlice := d.memb["alice"] + _, hasBob := d.memb["bob"] + membersLen := len(d.memb) + d.mu.Unlock() + + if !hasAlice || hasBob || membersLen != 1 { + t.Fatalf("expected only local member after reset, got alice=%v bob=%v len=%d", hasAlice, hasBob, membersLen) + } +} + +func TestSetVoiceMembersForRoomGuardsAndPopulate(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + d.room = "room-1" + d.local = "alice" + d.memb = map[string]struct{}{"stale": {}} + + d.setVoiceMembersForRoom(" ", []string{"bob"}) + d.setVoiceMembersForRoom("other-room", []string{"bob"}) + + d.mu.Lock() + _, stillStale := d.memb["stale"] + d.mu.Unlock() + if !stillStale { + t.Fatalf("expected stale members unchanged for ignored rooms") + } + + d.setVoiceMembersForRoom("room-1", []string{" bob ", "", "bob", "carol"}) + d.mu.Lock() + _, hasAlice := d.memb["alice"] + _, hasBob := d.memb["bob"] + _, hasCarol := d.memb["carol"] + _, hasStale := d.memb["stale"] + membersLen := len(d.memb) + d.mu.Unlock() + + if !hasAlice || !hasBob || !hasCarol || hasStale || membersLen != 3 { + t.Fatalf("unexpected members after roster apply: alice=%v bob=%v carol=%v stale=%v len=%d", hasAlice, hasBob, hasCarol, hasStale, membersLen) + } +} + +func TestBroadcastVoiceMembersGuards(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + payload := ipc.Message{Event: ipc.EventVoiceMembers, Room: "room-1", Users: []string{"alice"}} + + d.broadcastVoiceMembers(nil, payload, true) + d.broadcastVoiceMembers(&ipcServer{}, ipc.Message{}, true) + d.broadcastVoiceMembers(&ipcServer{}, payload, false) + d.broadcastVoiceMembers(&ipcServer{}, payload, true) +} + +func TestAddRemoveVoiceMemberGuardPaths(t *testing.T) { + d := newVoiceDaemon("http://server", "token", "", pttBackendAuto, webrtc.Configuration{}, defaultVADThreshold, false) + + d.addVoiceMember("bob") + d.removeVoiceMember("bob") + + d.room = "room-1" + d.removeVoiceMember("bob") + + d.mu.Lock() + membersLen := len(d.memb) + d.mu.Unlock() + + if membersLen != 0 { + t.Fatalf("expected no members after guard-path operations, got %d", membersLen) + } +} diff --git a/cmd/voiced/ws_client_test.go b/cmd/voiced/ws_client_test.go index 1b3636c..a7965d4 100644 --- a/cmd/voiced/ws_client_test.go +++ b/cmd/voiced/ws_client_test.go @@ -127,3 +127,45 @@ func TestVoicedWSClientReadLoopClosesChannelOnSocketClose(t *testing.T) { t.Fatalf("timeout waiting for read loop close") } } + +func TestVoicedWSClientConnectInvalidURL(t *testing.T) { + client, err := ConnectWS("://bad-url", "token") + if err == nil || client != nil { + t.Fatalf("expected malformed URL dial error, got client=%v err=%v", client, err) + } +} + +func TestVoicedWSClientReadLoopStopsOnCancelWhenChannelBlocked(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + conn, err := websocket.Accept(w, r, nil) + if err != nil { + return + } + defer conn.Close(websocket.StatusNormalClosure, "done") + _ = conn.Write(context.Background(), websocket.MessageText, []byte(`{"type":"voice_join","channel_id":"room-1"}`)) + <-r.Context().Done() + })) + defer server.Close() + + client, err := ConnectWS(server.URL, "token") + if err != nil { + t.Fatalf("ConnectWS: %v", err) + } + defer client.Close() + + ch := make(chan VoiceSignal) + done := make(chan struct{}) + go func() { + client.ReadLoop(ch) + close(done) + }() + + time.Sleep(100 * time.Millisecond) + client.cancel() + + select { + case <-done: + case <-time.After(2 * time.Second): + t.Fatalf("ReadLoop did not stop after cancel") + } +} diff --git a/internal/channel/service_test.go b/internal/channel/service_test.go index 4bd804a..ffa05d5 100644 --- a/internal/channel/service_test.go +++ b/internal/channel/service_test.go @@ -108,6 +108,28 @@ func newServiceWithAdmin(isAdmin bool) (*Service, *fakeRepo) { return svc, repo } +func TestNewServiceDefaults(t *testing.T) { + repo := &fakeRepo{} + users := user.NewService(&fakeUserRepo{}, "pepper") + svc := NewService(repo, users) + + if svc == nil { + t.Fatalf("expected service instance") + } + if svc.repo != repo { + t.Fatalf("expected repository to be stored on service") + } + if svc.users != users { + t.Fatalf("expected user service to be stored on service") + } + if svc.idGen == nil || svc.now == nil { + t.Fatalf("expected idGen and now defaults to be initialized") + } + if id := svc.idGen(); id == "" { + t.Fatalf("expected non-empty default generated channel id") + } +} + func TestCreateChannel_Admin(t *testing.T) { svc, repo := newServiceWithAdmin(true) diff --git a/internal/storage/postgres_repos_unit_test.go b/internal/storage/postgres_repos_unit_test.go index 24e4a7f..0bb3e16 100644 --- a/internal/storage/postgres_repos_unit_test.go +++ b/internal/storage/postgres_repos_unit_test.go @@ -139,6 +139,19 @@ func TestUserRepoSQL(t *testing.T) { } }) + t.Run("Count query error", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + mock.ExpectQuery(`SELECT COUNT\(\*\) FROM users`).WillReturnError(errors.New("boom")) + + _, err := repo.Count(ctx) + if err == nil || !strings.Contains(err.Error(), "count users") { + t.Fatalf("expected wrapped count users error, got %v", err) + } + }) + t.Run("UpsertProfile validation", func(t *testing.T) { repo := &userRepo{} err := repo.UpsertProfile(ctx, user.Profile{}) @@ -178,6 +191,35 @@ func TestUserRepoSQL(t *testing.T) { } }) + t.Run("ListProfiles query error", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + mock.ExpectQuery(`SELECT user_id, name_enc, updated_at FROM user_profiles`).WillReturnError(errors.New("boom")) + + _, err := repo.ListProfiles(ctx) + if err == nil || !strings.Contains(err.Error(), "list user profiles") { + t.Fatalf("expected wrapped list user profiles error, got %v", err) + } + }) + + t.Run("ListProfiles iterate error", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &userRepo{db: db} + rows := sqlmock.NewRows([]string{"user_id", "name_enc", "updated_at"}). + AddRow("u1", "name", now). + RowError(0, errors.New("boom")) + mock.ExpectQuery(`SELECT user_id, name_enc, updated_at FROM user_profiles`).WillReturnRows(rows) + + _, err := repo.ListProfiles(ctx) + if err == nil || !strings.Contains(err.Error(), "iterate user profiles") { + t.Fatalf("expected wrapped iterate user profiles error, got %v", err) + } + }) + t.Run("UpsertDirectoryKeyEnvelope validation", func(t *testing.T) { repo := &userRepo{} err := repo.UpsertDirectoryKeyEnvelope(ctx, user.DirectoryKeyEnvelope{}) @@ -186,6 +228,15 @@ func TestUserRepoSQL(t *testing.T) { } }) + t.Run("UpsertDirectoryKeyEnvelope validation missing envelope payload", func(t *testing.T) { + repo := &userRepo{} + env := user.DirectoryKeyEnvelope{DeviceID: "d1", SenderDeviceID: "d2", CreatedAt: now} + err := repo.UpsertDirectoryKeyEnvelope(ctx, env) + if err == nil || !strings.Contains(err.Error(), "sender_public_key and envelope are required") { + t.Fatalf("expected payload validation error, got %v", err) + } + }) + t.Run("UpsertDirectoryKeyEnvelope success", func(t *testing.T) { db, mock, cleanup := newRepoSQLMock(t) defer cleanup() @@ -260,6 +311,21 @@ func TestDeviceRepoSQL(t *testing.T) { } }) + t.Run("Create success with last_seen_at", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + lastSeen := now.Add(-time.Minute) + d := device.Device{ID: "d1", UserID: "u1", PublicKey: "pk", CreatedAt: now, LastSeenAt: &lastSeen} + mock.ExpectExec(`INSERT INTO devices`).WithArgs(d.ID, d.UserID, d.PublicKey, d.CreatedAt, lastSeen). + WillReturnResult(sqlmock.NewResult(1, 1)) + + if err := repo.Create(ctx, d); err != nil { + t.Fatalf("Create() error: %v", err) + } + }) + t.Run("GetByID success", func(t *testing.T) { db, mock, cleanup := newRepoSQLMock(t) defer cleanup() @@ -325,6 +391,21 @@ func TestDeviceRepoSQL(t *testing.T) { } }) + t.Run("GetByUserAndPublicKey empty stored key treated as not found", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + rows := sqlmock.NewRows([]string{"id", "user_id", "public_key", "created_at", "last_seen_at"}). + AddRow("d1", "u1", nil, now, nil) + mock.ExpectQuery(`FROM devices WHERE user_id = \$1 AND public_key = \$2`).WithArgs(user.ID("u1"), "pk").WillReturnRows(rows) + + _, err := repo.GetByUserAndPublicKey(ctx, "u1", "pk") + if !errors.Is(err, device.ErrNotFound) { + t.Fatalf("expected device.ErrNotFound, got %v", err) + } + }) + t.Run("ListByUser filters empty public key", func(t *testing.T) { db, mock, cleanup := newRepoSQLMock(t) defer cleanup() @@ -344,6 +425,19 @@ func TestDeviceRepoSQL(t *testing.T) { } }) + t.Run("ListByUser query error", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + mock.ExpectQuery(`FROM devices WHERE user_id = \$1 ORDER BY created_at`).WithArgs(user.ID("u1")).WillReturnError(errors.New("boom")) + + _, err := repo.ListByUser(ctx, "u1") + if err == nil || !strings.Contains(err.Error(), "list devices by user") { + t.Fatalf("expected wrapped list devices by user error, got %v", err) + } + }) + t.Run("ListAll success", func(t *testing.T) { db, mock, cleanup := newRepoSQLMock(t) defer cleanup() @@ -362,6 +456,19 @@ func TestDeviceRepoSQL(t *testing.T) { } }) + t.Run("ListAll query error", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + mock.ExpectQuery(`FROM devices ORDER BY created_at`).WillReturnError(errors.New("boom")) + + _, err := repo.ListAll(ctx) + if err == nil || !strings.Contains(err.Error(), "list devices") { + t.Fatalf("expected wrapped list devices error, got %v", err) + } + }) + t.Run("UpdateLastSeen validation", func(t *testing.T) { repo := &deviceRepo{} err := repo.UpdateLastSeen(ctx, "", time.Time{}) @@ -396,6 +503,20 @@ func TestDeviceRepoSQL(t *testing.T) { t.Fatalf("UpdateLastSeen() error: %v", err) } }) + + t.Run("UpdateLastSeen rows affected error", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &deviceRepo{db: db} + mock.ExpectExec(`UPDATE devices SET last_seen_at = \$2 WHERE id = \$1`).WithArgs(device.ID("d1"), now). + WillReturnResult(sqlmock.NewErrorResult(errors.New("boom"))) + + err := repo.UpdateLastSeen(ctx, "d1", now) + if err == nil || !strings.Contains(err.Error(), "rows affected") { + t.Fatalf("expected wrapped rows affected error, got %v", err) + } + }) } func TestBroadcastRepoSQL(t *testing.T) { @@ -540,6 +661,19 @@ func TestChannelRepoSQL(t *testing.T) { } }) + t.Run("ListChannels query error", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + mock.ExpectQuery(`FROM channels ORDER BY created_at DESC`).WillReturnError(errors.New("boom")) + + _, err := repo.ListChannels(ctx) + if err == nil || !strings.Contains(err.Error(), "list channels") { + t.Fatalf("expected wrapped list channels error, got %v", err) + } + }) + t.Run("DeleteChannel validation", func(t *testing.T) { repo := &channelRepo{} err := repo.DeleteChannel(ctx, "") @@ -575,6 +709,20 @@ func TestChannelRepoSQL(t *testing.T) { } }) + t.Run("DeleteChannel rows affected error", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + mock.ExpectExec(`DELETE FROM channels WHERE id = \$1`).WithArgs(channel.ID("c1")). + WillReturnResult(sqlmock.NewErrorResult(errors.New("boom"))) + + err := repo.DeleteChannel(ctx, "c1") + if err == nil || !strings.Contains(err.Error(), "rows affected") { + t.Fatalf("expected wrapped rows affected error, got %v", err) + } + }) + t.Run("UpdateChannelName validation", func(t *testing.T) { repo := &channelRepo{} err := repo.UpdateChannelName(ctx, "", "") @@ -610,6 +758,19 @@ func TestChannelRepoSQL(t *testing.T) { } }) + t.Run("UpdateChannelName trims input", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &channelRepo{db: db} + mock.ExpectExec(`UPDATE channels SET name_enc = \$2 WHERE id = \$1`).WithArgs(channel.ID("c1"), "name"). + WillReturnResult(sqlmock.NewResult(0, 1)) + + if err := repo.UpdateChannelName(ctx, "c1", " name "); err != nil { + t.Fatalf("UpdateChannelName() error: %v", err) + } + }) + t.Run("SaveMessage validation", func(t *testing.T) { repo := &channelRepo{} err := repo.SaveMessage(ctx, channel.Message{}) @@ -827,4 +988,36 @@ func TestServerInviteRepoSQL(t *testing.T) { t.Fatalf("expected ErrExpired, got %v", err) } }) + + t.Run("Consume fallback select error", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &serverInviteRepo{db: db} + emptyUpdate := sqlmock.NewRows([]string{"id", "created_at", "expires_at", "consumed_at", "consumed_by"}) + mock.ExpectQuery(`UPDATE server_invites`).WithArgs("tok", now).WillReturnRows(emptyUpdate) + mock.ExpectQuery(`FROM server_invites WHERE id = \$1`).WithArgs("tok").WillReturnError(errors.New("boom")) + + _, err := repo.Consume(ctx, "tok", "u1", now) + if err == nil || !strings.Contains(err.Error(), "select server invite") { + t.Fatalf("expected wrapped select server invite error, got %v", err) + } + }) + + t.Run("Consume fallback still not found", func(t *testing.T) { + db, mock, cleanup := newRepoSQLMock(t) + defer cleanup() + + repo := &serverInviteRepo{db: db} + emptyUpdate := sqlmock.NewRows([]string{"id", "created_at", "expires_at", "consumed_at", "consumed_by"}) + available := sqlmock.NewRows([]string{"id", "created_at", "expires_at", "consumed_at", "consumed_by"}). + AddRow("tok", now.Add(-time.Hour), now.Add(time.Hour), nil, nil) + mock.ExpectQuery(`UPDATE server_invites`).WithArgs("tok", now).WillReturnRows(emptyUpdate) + mock.ExpectQuery(`FROM server_invites WHERE id = \$1`).WithArgs("tok").WillReturnRows(available) + + _, err := repo.Consume(ctx, "tok", "u1", now) + if !errors.Is(err, serverinvite.ErrNotFound) { + t.Fatalf("expected ErrNotFound, got %v", err) + } + }) } diff --git a/internal/user/service_test.go b/internal/user/service_test.go index a4b9b2c..c8501ab 100644 --- a/internal/user/service_test.go +++ b/internal/user/service_test.go @@ -84,6 +84,27 @@ func newTestService() (*Service, *fakeRepo) { return svc, repo } +func TestNewServiceDefaults(t *testing.T) { + repo := newFakeRepo() + svc := NewService(repo, "pepper") + + if svc == nil { + t.Fatalf("expected service instance") + } + if svc.repo != repo { + t.Fatalf("expected repository to be stored on service") + } + if svc.idGen == nil || svc.now == nil { + t.Fatalf("expected idGen and now defaults to be initialized") + } + if len(svc.pepper) == 0 { + t.Fatalf("expected pepper bytes to be initialized") + } + if id := svc.NewID(); id == "" { + t.Fatalf("expected NewID to return a non-empty value") + } +} + func TestCreate_Success(t *testing.T) { svc, _ := newTestService() diff --git a/internal/ws/ws_test.go b/internal/ws/ws_test.go index e848700..b1ba808 100644 --- a/internal/ws/ws_test.go +++ b/internal/ws/ws_test.go @@ -1831,3 +1831,251 @@ func TestHub_HandleVoiceSignal_RecipientRouting(t *testing.T) { default: } } + +func TestDecodeIncoming_ChannelMessageValidationBranches(t *testing.T) { + tooLongBody := strings.Repeat("a", maxMessageLen+1) + + tests := []struct { + name string + payload map[string]any + }{ + { + name: "missing channel id", + payload: map[string]any{ + "type": "channel.message.send", + "body": "hello", + "sender_name_enc": "enc", + }, + }, + { + name: "missing body", + payload: map[string]any{ + "type": "channel.message.send", + "channel_id": "ch-1", + "body": " ", + "sender_name_enc": "enc", + }, + }, + { + name: "message too long", + payload: map[string]any{ + "type": "channel.message.send", + "channel_id": "ch-1", + "body": tooLongBody, + "sender_name_enc": "enc", + }, + }, + { + name: "missing sender name", + payload: map[string]any{ + "type": "channel.message.send", + "channel_id": "ch-1", + "body": "hello", + "sender_name_enc": " ", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.payload) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + _, err = decodeIncoming(data) + if err == nil { + t.Fatalf("expected decodeIncoming to fail for %s", tt.name) + } + }) + } +} + +func TestDecodeIncoming_VoiceSignalValidationBranches(t *testing.T) { + tests := []struct { + name string + payload map[string]any + }{ + { + name: "voice_join missing channel", + payload: map[string]any{ + "type": "voice_join", + }, + }, + { + name: "voice_leave missing channel", + payload: map[string]any{ + "type": "voice_leave", + }, + }, + { + name: "webrtc_offer missing recipient", + payload: map[string]any{ + "type": "webrtc_offer", + "channel_id": "ch-1", + "sdp": "offer", + }, + }, + { + name: "webrtc_answer missing sdp", + payload: map[string]any{ + "type": "webrtc_answer", + "channel_id": "ch-1", + "recipient": "user-2", + }, + }, + { + name: "ice_candidate missing candidate", + payload: map[string]any{ + "type": "ice_candidate", + "channel_id": "ch-1", + "recipient": "user-2", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + data, err := json.Marshal(tt.payload) + if err != nil { + t.Fatalf("json.Marshal() error = %v", err) + } + _, err = decodeIncoming(data) + if err == nil { + t.Fatalf("expected decodeIncoming to fail for %s", tt.name) + } + }) + } +} + +func TestHub_HandleVoiceJoin_ValidationAndStorageErrors(t *testing.T) { + t.Run("missing channel id", func(t *testing.T) { + hub := NewHub(nil, nil, &fakeChannelRepo{ch: channel.Channel{ID: "ch-1"}}) + client := &Client{send: make(chan []byte, 1), userID: "user-1"} + + hub.handleVoiceJoin(context.Background(), client, inboundMessage{Type: "voice_join"}) + + event := readEvent[errorEvent](t, client.send) + if event.Code != "invalid_message" { + t.Fatalf("code = %q, want %q", event.Code, "invalid_message") + } + }) + + t.Run("channel repo unavailable", func(t *testing.T) { + hub := NewHub(nil, nil, nil) + client := &Client{send: make(chan []byte, 1), userID: "user-1"} + + hub.handleVoiceJoin(context.Background(), client, inboundMessage{Type: "voice_join", ChannelID: "ch-1"}) + + event := readEvent[errorEvent](t, client.send) + if event.Code != "server_error" { + t.Fatalf("code = %q, want %q", event.Code, "server_error") + } + }) + + t.Run("channel not found", func(t *testing.T) { + hub := NewHub(nil, nil, &fakeChannelRepo{getErr: storage.ErrNotFound}) + client := &Client{send: make(chan []byte, 1), userID: "user-1"} + + hub.handleVoiceJoin(context.Background(), client, inboundMessage{Type: "voice_join", ChannelID: "missing"}) + + event := readEvent[errorEvent](t, client.send) + if event.Code != "channel_not_found" { + t.Fatalf("code = %q, want %q", event.Code, "channel_not_found") + } + }) + + t.Run("channel lookup failure", func(t *testing.T) { + hub := NewHub(nil, nil, &fakeChannelRepo{getErr: errors.New("boom")}) + client := &Client{send: make(chan []byte, 1), userID: "user-1"} + + hub.handleVoiceJoin(context.Background(), client, inboundMessage{Type: "voice_join", ChannelID: "ch-1"}) + + event := readEvent[errorEvent](t, client.send) + if event.Code != "server_error" { + t.Fatalf("code = %q, want %q", event.Code, "server_error") + } + }) +} + +func TestHub_HandleVoiceLeave_ValidationAndPeerEvent(t *testing.T) { + t.Run("missing channel id", func(t *testing.T) { + hub := NewHub(nil, nil, nil) + client := &Client{send: make(chan []byte, 1), userID: "user-1"} + + hub.handleVoiceLeave(client, inboundMessage{Type: "voice_leave"}) + + event := readEvent[errorEvent](t, client.send) + if event.Code != "invalid_message" { + t.Fatalf("code = %q, want %q", event.Code, "invalid_message") + } + }) + + t.Run("peer receives leave signal", func(t *testing.T) { + hub := NewHub(nil, nil, nil) + leaver := &Client{send: make(chan []byte, 2), userID: "user-1"} + peer := &Client{send: make(chan []byte, 2), userID: "user-2"} + + hub.mu.Lock() + hub.voiceRooms["ch-1"] = map[*Client]struct{}{leaver: {}, peer: {}} + hub.voiceRoom[leaver] = "ch-1" + hub.voiceRoom[peer] = "ch-1" + hub.mu.Unlock() + + hub.handleVoiceLeave(leaver, inboundMessage{Type: "voice_leave", ChannelID: "ch-1"}) + + event := readEvent[voiceSignalEvent](t, peer.send) + if event.Type != "voice_leave" || event.ChannelID != "ch-1" || event.Sender != "user-1" { + t.Fatalf("unexpected peer leave event: %+v", event) + } + }) +} + +func TestHub_RemoveFromVoiceRoomLocked(t *testing.T) { + hub := NewHub(nil, nil, nil) + missing := &Client{userID: "missing"} + + hub.mu.Lock() + if peers := hub.removeFromVoiceRoomLocked(missing); peers != nil { + hub.mu.Unlock() + t.Fatalf("expected nil peers for client without room, got %d", len(peers)) + } + hub.mu.Unlock() + + owner := &Client{userID: "owner"} + hub.mu.Lock() + hub.voiceRooms["room-a"] = map[*Client]struct{}{owner: {}} + hub.voiceRoom[owner] = "room-a" + peers := hub.removeFromVoiceRoomLocked(owner) + _, hasRoom := hub.voiceRooms["room-a"] + _, tracked := hub.voiceRoom[owner] + hub.mu.Unlock() + + if peers != nil { + t.Fatalf("expected nil peers when room is emptied, got %d", len(peers)) + } + if hasRoom { + t.Fatalf("expected empty room to be removed") + } + if tracked { + t.Fatalf("expected client room tracking to be removed") + } + + target := &Client{userID: "target"} + peerA := &Client{userID: "peer-a"} + peerB := &Client{userID: "peer-b"} + hub.mu.Lock() + hub.voiceRooms["room-b"] = map[*Client]struct{}{target: {}, peerA: {}, peerB: {}} + hub.voiceRoom[target] = "room-b" + hub.voiceRoom[peerA] = "room-b" + hub.voiceRoom[peerB] = "room-b" + peers = hub.removeFromVoiceRoomLocked(target) + roomSize := len(hub.voiceRooms["room-b"]) + hub.mu.Unlock() + + if len(peers) != 2 { + t.Fatalf("expected 2 remaining peers, got %d", len(peers)) + } + if roomSize != 2 { + t.Fatalf("expected room to retain 2 peers, got %d", roomSize) + } +}