From bdc88813ed2dc180f321321e0587d6d58d68d453 Mon Sep 17 00:00:00 2001 From: asimfarooq5 Date: Sat, 27 Sep 2025 17:49:45 +0500 Subject: [PATCH 1/3] Add wamp procedure to sync keys with daemon --- cmd/wshd/main.go | 77 ++++++++++++++++++++++++++++++++++++++++++++++++ keystore.go | 21 +++++++++++++ 2 files changed, 98 insertions(+) diff --git a/cmd/wshd/main.go b/cmd/wshd/main.go index dd33cf1..9b43815 100644 --- a/cmd/wshd/main.go +++ b/cmd/wshd/main.go @@ -29,6 +29,7 @@ const ( procedureExec = "wampshell.shell.exec" procedureFileUpload = "wampshell.shell.upload" procedureFileDownload = "wampshell.shell.download" + procedureSyncKeys = "wampshell.shell.keys.sync" procedureWebRTCOffer = "wampshell.webrtc.offer" topicOffererOnCandidate = "wampshell.webrtc.offerer.on_candidate" topicAnswererOnCandidate = "wampshell.webrtc.answerer.on_candidate" @@ -294,6 +295,72 @@ func addRealm(router *xconn.Router, realm string) { log.Printf("Adding realm: %s", realm) } +func SyncAuthorizedKeys(session *xconn.Session, keys *wampshell.KeyPair, keyStore *wampshell.KeyStore) error { + lines, err := keyStore.AuthorizedKeys() + if err != nil { + return fmt.Errorf("failed to get authorized keys: %w", err) + } + + plaintext := strings.Join(lines, "\n") + if plaintext == "" { + return fmt.Errorf("no keys to sync") + } + + ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305([]byte(plaintext), keys.Send) + if err != nil { + return fmt.Errorf("failed to encrypt keys: %w", err) + } + payload := append(nonce, ciphertext...) + + callResponse := session.Call(procedureSyncKeys).Args(payload).Do() + if callResponse.Err != nil { + return fmt.Errorf("sync keys call failed: %w", callResponse.Err) + } + + return nil +} + +func handleSyncKeys(keyStore *wampshell.KeyStore, e *wampshell.EncryptionManager) xconn.InvocationHandler { + return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { + encryptedPayload, err := inv.ArgBytes(0) + if err != nil { + return xconn.NewInvocationError("wamp.error.invalid_argument", err.Error()) + } + + key, ok := e.Key(inv.Caller()) + if !ok { + return xconn.NewInvocationError("wamp.error.unavailable", "no encryption key for caller") + } + + if len(encryptedPayload) < 12 { + return xconn.NewInvocationError("wamp.error.invalid_argument", "payload too short") + } + plaintext, err := berncrypt.DecryptChaCha20Poly1305(encryptedPayload[12:], encryptedPayload[:12], key.Receive) + if err != nil { + return xconn.NewInvocationError("wamp.error.internal_error", err.Error()) + } + + allKeys := strings.Split(string(plaintext), "\n") + newKeys := make(map[string][]string) + for _, line := range allKeys { + parts := strings.Fields(strings.TrimSpace(line)) + if len(parts) == 0 { + continue + } + k := parts[0] + realm := "wampshell" + if len(parts) > 1 { + realm = parts[1] + } + newKeys[realm] = append(newKeys[realm], k) + } + keyStore.Update(newKeys) + + log.Printf("Synced %d keys from caller %d", len(allKeys), inv.Caller()) + return xconn.NewInvocationResult("ok") + } +} + func main() { loadConfig, err := wampshell.LoadConfig() if err != nil { @@ -342,6 +409,7 @@ func main() { {procedureExec, handleRunCommand(encryption)}, {procedureFileUpload, handleFileUpload(encryption)}, {procedureFileDownload, handleFileDownload(encryption)}, + {procedureSyncKeys, handleSyncKeys(keyStore, encryption)}, } server := xconn.NewServer(router, authenticator, nil) @@ -374,6 +442,15 @@ func main() { } sessions = append(sessions, sess) + + keys, err := wampshell.ExchangeKeys(session) + if err != nil { + log.Fatalf("Failed to exchange keys: %v", err) + } + + if err := SyncAuthorizedKeys(sess, keys, keyStore); err != nil { + log.Printf("failed to sync authorized keys with %s: %v", p.URL, err) + } } for _, sess := range sessions { diff --git a/keystore.go b/keystore.go index 537cd28..b64f6a3 100644 --- a/keystore.go +++ b/keystore.go @@ -142,3 +142,24 @@ func (k *KeyStore) watch(filePath string, watcher *fsnotify.Watcher) { } } } + +func (k *KeyStore) AuthorizedKeys() ([]string, error) { + k.RLock() + defer k.RUnlock() + + if len(k.keys) == 0 { + return nil, fmt.Errorf("no keys in KeyStore") + } + + var keys []string + for realm, ks := range k.keys { + for _, key := range ks { + if realm == "wampshell" { + keys = append(keys, key) + } else { + keys = append(keys, fmt.Sprintf("%s %s", key, realm)) + } + } + } + return keys, nil +} From 1a5511cb2119be94df5c8e3d512feca0dbe4deb6 Mon Sep 17 00:00:00 2001 From: Omer Akram Date: Sat, 27 Sep 2025 19:26:27 +0500 Subject: [PATCH 2/3] fix stuff --- cmd/wshd/main.go | 23 +++++++++++++++++------ 1 file changed, 17 insertions(+), 6 deletions(-) diff --git a/cmd/wshd/main.go b/cmd/wshd/main.go index 9b43815..22fef8c 100644 --- a/cmd/wshd/main.go +++ b/cmd/wshd/main.go @@ -320,7 +320,7 @@ func SyncAuthorizedKeys(session *xconn.Session, keys *wampshell.KeyPair, keyStor return nil } -func handleSyncKeys(keyStore *wampshell.KeyStore, e *wampshell.EncryptionManager) xconn.InvocationHandler { +func handleSyncKeys(realm string, keyStore *wampshell.KeyStore, e *wampshell.EncryptionManager) xconn.InvocationHandler { return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { encryptedPayload, err := inv.ArgBytes(0) if err != nil { @@ -348,10 +348,6 @@ func handleSyncKeys(keyStore *wampshell.KeyStore, e *wampshell.EncryptionManager continue } k := parts[0] - realm := "wampshell" - if len(parts) > 1 { - realm = parts[1] - } newKeys[realm] = append(newKeys[realm], k) } keyStore.Update(newKeys) @@ -409,7 +405,22 @@ func main() { {procedureExec, handleRunCommand(encryption)}, {procedureFileUpload, handleFileUpload(encryption)}, {procedureFileDownload, handleFileDownload(encryption)}, - {procedureSyncKeys, handleSyncKeys(keyStore, encryption)}, + } + + for realm := range authenticator.Realms() { + if realm != defaultRealm { + c, err := xconn.ConnectInMemory(router, realm) + if err != nil { + log.Fatalf("Error connecting to realm %s: %v", realm, err) + } + + r := c.Register(procedureSyncKeys, handleSyncKeys(realm, keyStore, encryption)).Do() + if r.Err != nil { + log.Fatalf("Error registering realm %s: %v", realm, r.Err) + } + + fmt.Printf("registered %s\n", procedureSyncKeys) + } } server := xconn.NewServer(router, authenticator, nil) From cf8b17b03527f3b60e7b47e83411e6439b683e88 Mon Sep 17 00:00:00 2001 From: asimfarooq5 Date: Sat, 27 Sep 2025 20:06:37 +0500 Subject: [PATCH 3/3] move encryption to per session --- authenticator.go | 2 +- cmd/wshd/main.go | 117 +++++++++++++++++++---------------------------- encryption.go | 17 +++---- keystore.go | 9 +++- 4 files changed, 61 insertions(+), 84 deletions(-) diff --git a/authenticator.go b/authenticator.go index 7b38db1..7b25f44 100644 --- a/authenticator.go +++ b/authenticator.go @@ -28,7 +28,7 @@ func (a *ServerAuthenticator) Authenticate(request auth.Request) (auth.Response, return auth.NewResponse("", "anonymous", 0) } - return nil, fmt.Errorf("unauthorized") + return nil, fmt.Errorf("invalid credentials") } func (a *ServerAuthenticator) Realms() map[string][]string { diff --git a/cmd/wshd/main.go b/cmd/wshd/main.go index 22fef8c..d9a9bec 100644 --- a/cmd/wshd/main.go +++ b/cmd/wshd/main.go @@ -295,24 +295,13 @@ func addRealm(router *xconn.Router, realm string) { log.Printf("Adding realm: %s", realm) } -func SyncAuthorizedKeys(session *xconn.Session, keys *wampshell.KeyPair, keyStore *wampshell.KeyStore) error { +func SyncAuthorizedKeys(session *xconn.Session, keyStore *wampshell.KeyStore) error { lines, err := keyStore.AuthorizedKeys() if err != nil { return fmt.Errorf("failed to get authorized keys: %w", err) } - plaintext := strings.Join(lines, "\n") - if plaintext == "" { - return fmt.Errorf("no keys to sync") - } - - ciphertext, nonce, err := berncrypt.EncryptChaCha20Poly1305([]byte(plaintext), keys.Send) - if err != nil { - return fmt.Errorf("failed to encrypt keys: %w", err) - } - payload := append(nonce, ciphertext...) - - callResponse := session.Call(procedureSyncKeys).Args(payload).Do() + callResponse := session.Call(procedureSyncKeys).Arg(lines).Do() if callResponse.Err != nil { return fmt.Errorf("sync keys call failed: %w", callResponse.Err) } @@ -320,39 +309,30 @@ func SyncAuthorizedKeys(session *xconn.Session, keys *wampshell.KeyPair, keyStor return nil } -func handleSyncKeys(realm string, keyStore *wampshell.KeyStore, e *wampshell.EncryptionManager) xconn.InvocationHandler { +func handleSyncKeys(realm string, keyStore *wampshell.KeyStore) xconn.InvocationHandler { return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { - encryptedPayload, err := inv.ArgBytes(0) + argList, err := inv.ArgList(0) if err != nil { return xconn.NewInvocationError("wamp.error.invalid_argument", err.Error()) } - key, ok := e.Key(inv.Caller()) - if !ok { - return xconn.NewInvocationError("wamp.error.unavailable", "no encryption key for caller") + keys := make([]string, 0, len(argList)) + for _, item := range argList { + if s, ok := item.(string); ok && s != "" { + keys = append(keys, s) + } } - if len(encryptedPayload) < 12 { - return xconn.NewInvocationError("wamp.error.invalid_argument", "payload too short") - } - plaintext, err := berncrypt.DecryptChaCha20Poly1305(encryptedPayload[12:], encryptedPayload[:12], key.Receive) - if err != nil { - return xconn.NewInvocationError("wamp.error.internal_error", err.Error()) + if len(keys) == 0 { + return xconn.NewInvocationError("wamp.error.invalid_argument", "no valid keys provided") } - allKeys := strings.Split(string(plaintext), "\n") - newKeys := make(map[string][]string) - for _, line := range allKeys { - parts := strings.Fields(strings.TrimSpace(line)) - if len(parts) == 0 { - continue - } - k := parts[0] - newKeys[realm] = append(newKeys[realm], k) + newKeys := map[string][]string{ + realm: keys, } keyStore.Update(newKeys) - log.Printf("Synced %d keys from caller %d", len(allKeys), inv.Caller()) + log.Printf("Synced %d keys for realm %s from caller %d", len(keys), realm, inv.Caller()) return xconn.NewInvocationResult("ok") } } @@ -384,45 +364,25 @@ func main() { addRealm(router, defaultRealm) for realm := range authenticator.Realms() { addRealm(router, realm) - } - - keyStore.OnUpdate(func(keys map[string][]string) { - for realm := range keys { - addRealm(router, realm) - } - }) - - encryption := wampshell.NewEncryptionManager(router) - if err = encryption.Setup(); err != nil { - log.Fatal(err) - } - - procedures := []struct { - name string - handler xconn.InvocationHandler - }{ - {procedureInteractive, newInteractiveShellSession().handleShell(encryption)}, - {procedureExec, handleRunCommand(encryption)}, - {procedureFileUpload, handleFileUpload(encryption)}, - {procedureFileDownload, handleFileDownload(encryption)}, - } - - for realm := range authenticator.Realms() { if realm != defaultRealm { c, err := xconn.ConnectInMemory(router, realm) if err != nil { log.Fatalf("Error connecting to realm %s: %v", realm, err) } - r := c.Register(procedureSyncKeys, handleSyncKeys(realm, keyStore, encryption)).Do() + r := c.Register(procedureSyncKeys, handleSyncKeys(realm, keyStore)).Do() if r.Err != nil { log.Fatalf("Error registering realm %s: %v", realm, r.Err) } - - fmt.Printf("registered %s\n", procedureSyncKeys) } } + keyStore.OnUpdate(func(keys map[string][]string) { + for realm := range keys { + addRealm(router, realm) + } + }) + server := xconn.NewServer(router, authenticator, nil) if server == nil { log.Fatal("failed to create server") @@ -453,15 +413,6 @@ func main() { } sessions = append(sessions, sess) - - keys, err := wampshell.ExchangeKeys(session) - if err != nil { - log.Fatalf("Failed to exchange keys: %v", err) - } - - if err := SyncAuthorizedKeys(sess, keys, keyStore); err != nil { - log.Printf("failed to sync authorized keys with %s: %v", p.URL, err) - } } for _, sess := range sessions { @@ -481,6 +432,21 @@ func main() { return } + encryption := wampshell.NewEncryptionManager(sess) + if err = encryption.Setup(); err != nil { + log.Fatal(err) + } + + procedures := []struct { + name string + handler xconn.InvocationHandler + }{ + {procedureInteractive, newInteractiveShellSession().handleShell(encryption)}, + {procedureExec, handleRunCommand(encryption)}, + {procedureFileUpload, handleFileUpload(encryption)}, + {procedureFileDownload, handleFileDownload(encryption)}, + } + for _, proc := range procedures { registerResponse := sess.Register(proc.name, proc.handler).Do() if registerResponse.Err != nil { @@ -488,6 +454,17 @@ func main() { } log.Printf("Procedure registered: %s", proc.name) } + + if sess.Details().Realm() != defaultRealm { + _, err := wampshell.ExchangeKeys(sess) + if err != nil { + log.Fatalf("Failed to exchange keys: %v", err) + } + + if err := SyncAuthorizedKeys(sess, keyStore); err != nil { + log.Printf("failed to sync authorized keys with : %v", err) + } + } } log.Printf("listening on rs://%s", address) diff --git a/encryption.go b/encryption.go index b108deb..8f38da2 100644 --- a/encryption.go +++ b/encryption.go @@ -14,32 +14,27 @@ type KeyPair struct { } type EncryptionManager struct { - router *xconn.Router + session *xconn.Session keys map[uint64]*KeyPair sync.Mutex } -func NewEncryptionManager(router *xconn.Router) *EncryptionManager { +func NewEncryptionManager(session *xconn.Session) *EncryptionManager { return &EncryptionManager{ - router: router, - keys: make(map[uint64]*KeyPair), + session: session, + keys: make(map[uint64]*KeyPair), } } func (e *EncryptionManager) Setup() error { - session, err := xconn.ConnectInMemory(e.router, "wampshell") - if err != nil { - return err - } - - response := session.Register("wampshell.key.exchange", e.HandleKeyExchange).Do() + response := e.session.Register("wampshell.key.exchange", e.HandleKeyExchange).Do() if response.Err != nil { return response.Err } - response = session.Register("wampshell.payload.echo", e.TestEcho).Do() + response = e.session.Register("wampshell.payload.echo", e.TestEcho).Do() if response.Err != nil { return response.Err } diff --git a/keystore.go b/keystore.go index b64f6a3..6d9b7b3 100644 --- a/keystore.go +++ b/keystore.go @@ -45,9 +45,14 @@ func (k *KeyStore) OnUpdate(cb func(map[string][]string)) { func (k *KeyStore) Update(keys map[string][]string) { k.Lock() defer k.Unlock() - k.keys = keys + + for realm, newList := range keys { + existing := k.keys[realm] + k.keys[realm] = append(existing[:0:0], newList...) + } + if k.onUpdate != nil { - go k.onUpdate(keys) + go k.onUpdate(k.keys) } }