diff --git a/cmd/wsh/main.go b/cmd/wsh/main.go index 5b406c7..f124500 100644 --- a/cmd/wsh/main.go +++ b/cmd/wsh/main.go @@ -8,6 +8,7 @@ import ( "strings" "github.com/jessevdk/go-flags" + "golang.org/x/term" "github.com/xconnio/berncrypt/go" "github.com/xconnio/wamp-webrtc-go" @@ -18,6 +19,8 @@ import ( const ( defaultRealm = "wampshell" + procedureInteractive = "wampshell.shell.interactive" + procedureExec = "wampshell.shell.exec" procedureWebRTCOffer = "wampshell.webrtc.offer" topicOffererOnCandidate = "wampshell.webrtc.offerer.on_candidate" topicAnswererOnCandidate = "wampshell.webrtc.answerer.on_candidate" @@ -65,11 +68,103 @@ func exchangeKeys(session *xconn.Session) (*keyPair, error) { }, nil } +func startInteractiveShell(session *xconn.Session, keys *keyPair) error { + fd := int(os.Stdin.Fd()) + oldState, err := term.MakeRaw(fd) + if err != nil { + return fmt.Errorf("failed to set raw mode: %w", err) + } + defer func() { _ = term.Restore(fd, oldState) }() + + firstProgress := true + + call := session.Call(procedureInteractive). + ProgressSender(func(ctx context.Context) *xconn.Progress { + if firstProgress { + firstProgress = false + return xconn.NewProgress() + } + + buf := make([]byte, 1024) + n, err := os.Stdin.Read(buf) + if err != nil { + return xconn.NewFinalProgress() + } + + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(buf[:n], keys.send) + if err != nil { + fmt.Printf("encryption error: %s", err) + os.Exit(1) + } + payload := append(nonce, ciphertext...) + + return xconn.NewProgress(payload) + }). + ProgressReceiver(func(result *xconn.InvocationResult) { + if len(result.Args) > 0 { + encData := result.Args[0].([]byte) + + if len(encData) < 12 { + fmt.Fprintln(os.Stderr, "invalid payload from server") + os.Exit(1) + } + + plain, err := berncrypt.DecryptChaCha20Poly1305(encData[12:], encData[:12], keys.receive) + if err != nil { + _ = fmt.Errorf("decryption error: %w", err) + } + + os.Stdout.Write(plain) + } else { + err = term.Restore(fd, oldState) + if err != nil { + return + } + os.Exit(0) + } + }).Do() + + if call.Err != nil { + log.Fatalf("Shell error: %s", call.Err) + } + return nil +} + +func runCommand(session *xconn.Session, keys *keyPair, args []string) error { + b := []byte(strings.Join(args, " ")) + + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(b, keys.send) + if err != nil { + return fmt.Errorf("encryption error: %w", err) + } + + payload := append(nonce, ciphertext...) + + cmdResponse := session.Call(procedureExec).Args(payload).Do() + if cmdResponse.Err != nil { + return fmt.Errorf("command execution error: %w", cmdResponse.Err) + } + + output, err := cmdResponse.Args.Bytes(0) + if err != nil { + fmt.Printf("Output parsing error: %v", err) + os.Exit(1) + } + + plain, err := berncrypt.DecryptChaCha20Poly1305(output[12:], output[:12], keys.receive) + if err != nil { + return fmt.Errorf("decryption error: %w", err) + } + fmt.Print(string(plain)) + return nil +} + type Options struct { - PeerToPeer bool `long:"p2p" description:"Use WebRTC for peer-to-peer connection"` - Args struct { + Interactive bool `short:"i" long:"interactive" description:"Force interactive shell"` + PeerToPeer bool `long:"p2p" description:"Use WebRTC for peer-to-peer connection"` + Args struct { Target string `positional-arg-name:"host" required:"true"` - Cmd []string `positional-arg-name:"command" required:"true"` + Cmd []string `positional-arg-name:"command"` } `positional-args:"yes"` } @@ -105,11 +200,6 @@ func main() { port = "8022" } - anyArgs := make([]any, len(args)) - for i, a := range args { - anyArgs[i] = a - } - privateKey, err := wampshell.ReadPrivateKeyFromFile() if err != nil { fmt.Printf("Error reading private key: %v\n", err) @@ -118,8 +208,7 @@ func main() { authenticator, err := auth.NewCryptoSignAuthenticator("", privateKey, nil) if err != nil { - fmt.Printf("Error creating crypto sign authenticator: %v\n", err) - os.Exit(1) + log.Fatal("Error creating crypto sign authenticator:", err) } client := xconn.Client{ @@ -156,33 +245,15 @@ func main() { panic(err) } - b := []byte(strings.Join(args, " ")) - - ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305(b, keys.send) - if err != nil { - panic(err) - } - - payload := make([]byte, len(nonce)+len(ciphertext)) - copy(payload, nonce) - copy(payload[len(nonce):], ciphertext) - - cmdResponse := session.Call("wampshell.shell.exec").Args(payload).Do() - if cmdResponse.Err != nil { - fmt.Printf("Command execution error: %v\n", cmdResponse.Err) - os.Exit(1) - } - - output, err := cmdResponse.Args.Bytes(0) - if err != nil { - fmt.Printf("Output parsing error: %v\n", err) - os.Exit(1) + if opts.Interactive || len(args) == 0 { + err := startInteractiveShell(session, keys) + if err != nil { + log.Fatal(err) + } } - a, err := berncrypt.DecryptChaCha20Poly1305(output[12:], output[:12], keys.receive) + err = runCommand(session, keys, args) if err != nil { - panic(err) + log.Fatal(err) } - fmt.Print(string(a)) - } diff --git a/cmd/wshd/main.go b/cmd/wshd/main.go index edb37b6..d398046 100644 --- a/cmd/wshd/main.go +++ b/cmd/wshd/main.go @@ -1,4 +1,3 @@ -// wshd.go package main import ( @@ -11,9 +10,9 @@ import ( "os/signal" "path/filepath" "strings" + "sync" "github.com/creack/pty" - "github.com/jessevdk/go-flags" "github.com/xconnio/berncrypt/go" "github.com/xconnio/wamp-webrtc-go" @@ -26,6 +25,7 @@ const ( defaultRealm = "wampshell" defaultPort = 8022 defaultHost = "0.0.0.0" + procedureInteractive = "wampshell.shell.interactive" procedureExec = "wampshell.shell.exec" procedureFileUpload = "wampshell.shell.upload" procedureFileDownload = "wampshell.shell.download" @@ -34,6 +34,124 @@ const ( topicAnswererOnCandidate = "wampshell.webrtc.answerer.on_candidate" ) +type interactiveShellSession struct { + ptmx map[uint64]*os.File + sync.Mutex +} + +func newInteractiveShellSession() *interactiveShellSession { + return &interactiveShellSession{ + ptmx: make(map[uint64]*os.File), + } +} + +func (p *interactiveShellSession) startPtySession(inv *xconn.Invocation, sendKey []byte) (*os.File, error) { + cmd := exec.Command("bash") + ptmx, err := pty.Start(cmd) + if err != nil { + return nil, fmt.Errorf("failed to start PTY: %w", err) + } + p.Lock() + p.ptmx[inv.Caller()] = ptmx + p.Unlock() + + go p.startOutputReader(inv, ptmx, sendKey) + + return ptmx, nil +} + +func (p *interactiveShellSession) startOutputReader(inv *xconn.Invocation, ptmx *os.File, sendKey []byte) { + caller := inv.Caller() + defer func() { + p.Lock() + delete(p.ptmx, caller) + p.Unlock() + if err := ptmx.Close(); err != nil { + log.Printf("Error closing PTY for caller %d: %v", caller, err) + } + }() + buf := make([]byte, 4096) + for { + n, err := ptmx.Read(buf) + if n > 0 { + ciphertext, nonce, errEnc := berncrypt.EncryptChaCha20Poly1305(buf[:n], sendKey) + if errEnc != nil { + log.Printf("Encryption failed in shell output for caller %d: %v", caller, errEnc) + return + } + payload := append(nonce, ciphertext...) + _ = inv.SendProgress([]any{payload}, nil) + } + if err != nil { + _ = inv.SendProgress(nil, nil) + return + } + } +} + +func (p *interactiveShellSession) handleShell(e *wampshell.EncryptionManager) func(_ context.Context, + inv *xconn.Invocation) *xconn.InvocationResult { + return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { + caller := inv.Caller() + + e.Lock() + key, ok := e.Keys()[inv.Caller()] + e.Unlock() + if !ok { + return xconn.NewInvocationError("wamp.error.unavailable", "unavailable") + } + + p.Lock() + ptmx, ok := p.ptmx[caller] + p.Unlock() + + if !ok { + _, err := p.startPtySession(inv, key.Send) + if err != nil { + return xconn.NewInvocationError("io.xconn.error", err.Error()) + } + return xconn.NewInvocationError(xconn.ErrNoResult) + } + + if inv.Progress() { + payload, err := inv.ArgBytes(0) + if err != nil { + return xconn.NewInvocationError("wamp.error.invalid_argument", err.Error()) + } + if len(payload) < 12 { + return xconn.NewInvocationError("wamp.error.invalid_argument", "payload too short") + } + + decrypted, err := berncrypt.DecryptChaCha20Poly1305(payload[12:], payload[:12], key.Receive) + if err != nil { + p.Lock() + if storedPtmx, exists := p.ptmx[caller]; exists { + storedPtmx.Close() + delete(p.ptmx, caller) + } + p.Unlock() + return xconn.NewInvocationError("io.xconn.error", err.Error()) + } + + _, err = ptmx.Write(decrypted) + if err != nil { + log.Printf("Failed to write to PTY for caller %d: %v", caller, err) + return xconn.NewInvocationError("io.xconn.error", err.Error()) + } + return xconn.NewInvocationError(xconn.ErrNoResult) + } + + p.Lock() + delete(p.ptmx, caller) + p.Unlock() + if ok { + ptmx.Close() + } + + return xconn.NewInvocationResult() + } +} + func runCommand(cmd string, args ...string) ([]byte, error) { fullCmd := cmd if len(args) > 0 { @@ -78,7 +196,6 @@ func handleRunCommand(e *wampshell.EncryptionManager) func(_ context.Context, newStrs := strings.Split(s, " ") cmd := newStrs[0] - rawArgs := newStrs[1:] output, err := runCommand(cmd, rawArgs...) @@ -88,7 +205,8 @@ func handleRunCommand(e *wampshell.EncryptionManager) func(_ context.Context, ciphertext1, nonce1, err1 := berncrypt.EncryptChaCha20Poly1305(output, key.Send) if err1 != nil { - panic(err) + log.Printf("Encryption failed in runCommand: %v", err1) + return xconn.NewInvocationError("wamp.error.internal_error", err1.Error()) } payload1 := make([]byte, len(nonce1)+len(ciphertext1)) @@ -102,6 +220,8 @@ func handleRunCommand(e *wampshell.EncryptionManager) func(_ context.Context, func handleFileUpload(e *wampshell.EncryptionManager) func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { + log.Printf("handleFileUpload called for caller: %d", inv.Caller()) + if len(inv.Args()) < 2 { return xconn.NewInvocationError("wamp.error.invalid_argument", "expected filename + encrypted data") } @@ -146,6 +266,8 @@ func handleFileUpload(e *wampshell.EncryptionManager) func(_ context.Context, func handleFileDownload(e *wampshell.EncryptionManager) func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { + log.Printf("handleFileDownload called for caller: %d", inv.Caller()) + filename, err := inv.ArgString(0) if err != nil { return xconn.NewInvocationError("wamp.error.invalid_argument", err.Error()) @@ -196,18 +318,7 @@ func addRealm(router *xconn.Router, realm string) { log.Printf("Adding realm: %s", realm) } -type Options struct { -} - func main() { - var opts Options - parser := flags.NewParser(&opts, flags.Default) - - _, err := parser.Parse() - if err != nil { - os.Exit(1) - } - address := fmt.Sprintf("%s:%d", defaultHost, defaultPort) path := os.ExpandEnv("$HOME/.wampshell/authorized_keys") @@ -240,6 +351,7 @@ func main() { name string handler xconn.InvocationHandler }{ + {procedureInteractive, newInteractiveShellSession().handleShell(encryption)}, {procedureExec, handleRunCommand(encryption)}, {procedureFileUpload, handleFileUpload(encryption)}, {procedureFileDownload, handleFileDownload(encryption)}, @@ -277,6 +389,7 @@ func main() { } err = webRtcManager.Setup(cfg) if err != nil { + log.Printf("Failed to setup WebRTC: %v", err) return } diff --git a/go.mod b/go.mod index 9d4c430..0dde3e1 100644 --- a/go.mod +++ b/go.mod @@ -11,6 +11,7 @@ require ( github.com/xconnio/wampproto-capnproto/go v0.0.0-20250921183631-6decd38ce372 github.com/xconnio/wampproto-go v0.0.0-20250915142018-1ae321b40fec github.com/xconnio/xconn-go v0.0.0-20250918124058-95e16bcd2454 + golang.org/x/term v0.35.0 ) require ( diff --git a/go.sum b/go.sum index ae4d4ee..7fd7d34 100644 --- a/go.sum +++ b/go.sum @@ -126,6 +126,8 @@ golang.org/x/sys v0.0.0-20220715151400-c0bba94af5f8/go.mod h1:oPkhp1MJrh7nUepCBc golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.36.0 h1:KVRy2GtZBrk1cBYA7MKu5bEZFxQk4NIDV6RLVcC8o0k= golang.org/x/sys v0.36.0/go.mod h1:OgkHotnGiDImocRcuBABYBEXf8A9a87e/uXjp9XT3ks= +golang.org/x/term v0.35.0 h1:bZBVKBudEyhRcajGcNc3jIfWPqV4y/Kt2XcoigOWtDQ= +golang.org/x/term v0.35.0/go.mod h1:TPGtkTLesOwf2DE8CgVYiZinHAOuy5AYUYT1lENIZnA= golang.org/x/time v0.12.0 h1:ScB/8o8olJvc+CQPWrK3fPZNfh7qgwCrY0zJmoEQLSE= golang.org/x/time v0.12.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=