diff --git a/authenticator.go b/authenticator.go index 7b38db1..7b25f44 100644 --- a/authenticator.go +++ b/authenticator.go @@ -28,7 +28,7 @@ func (a *ServerAuthenticator) Authenticate(request auth.Request) (auth.Response, return auth.NewResponse("", "anonymous", 0) } - return nil, fmt.Errorf("unauthorized") + return nil, fmt.Errorf("invalid credentials") } func (a *ServerAuthenticator) Realms() map[string][]string { diff --git a/cmd/wshd/main.go b/cmd/wshd/main.go index dd33cf1..d9a9bec 100644 --- a/cmd/wshd/main.go +++ b/cmd/wshd/main.go @@ -29,6 +29,7 @@ const ( procedureExec = "wampshell.shell.exec" procedureFileUpload = "wampshell.shell.upload" procedureFileDownload = "wampshell.shell.download" + procedureSyncKeys = "wampshell.shell.keys.sync" procedureWebRTCOffer = "wampshell.webrtc.offer" topicOffererOnCandidate = "wampshell.webrtc.offerer.on_candidate" topicAnswererOnCandidate = "wampshell.webrtc.answerer.on_candidate" @@ -294,6 +295,48 @@ func addRealm(router *xconn.Router, realm string) { log.Printf("Adding realm: %s", realm) } +func SyncAuthorizedKeys(session *xconn.Session, keyStore *wampshell.KeyStore) error { + lines, err := keyStore.AuthorizedKeys() + if err != nil { + return fmt.Errorf("failed to get authorized keys: %w", err) + } + + callResponse := session.Call(procedureSyncKeys).Arg(lines).Do() + if callResponse.Err != nil { + return fmt.Errorf("sync keys call failed: %w", callResponse.Err) + } + + return nil +} + +func handleSyncKeys(realm string, keyStore *wampshell.KeyStore) xconn.InvocationHandler { + return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult { + argList, err := inv.ArgList(0) + if err != nil { + return xconn.NewInvocationError("wamp.error.invalid_argument", err.Error()) + } + + keys := make([]string, 0, len(argList)) + for _, item := range argList { + if s, ok := item.(string); ok && s != "" { + keys = append(keys, s) + } + } + + if len(keys) == 0 { + return xconn.NewInvocationError("wamp.error.invalid_argument", "no valid keys provided") + } + + newKeys := map[string][]string{ + realm: keys, + } + keyStore.Update(newKeys) + + log.Printf("Synced %d keys for realm %s from caller %d", len(keys), realm, inv.Caller()) + return xconn.NewInvocationResult("ok") + } +} + func main() { loadConfig, err := wampshell.LoadConfig() if err != nil { @@ -321,6 +364,17 @@ func main() { addRealm(router, defaultRealm) for realm := range authenticator.Realms() { addRealm(router, realm) + if realm != defaultRealm { + c, err := xconn.ConnectInMemory(router, realm) + if err != nil { + log.Fatalf("Error connecting to realm %s: %v", realm, err) + } + + r := c.Register(procedureSyncKeys, handleSyncKeys(realm, keyStore)).Do() + if r.Err != nil { + log.Fatalf("Error registering realm %s: %v", realm, r.Err) + } + } } keyStore.OnUpdate(func(keys map[string][]string) { @@ -329,21 +383,6 @@ func main() { } }) - encryption := wampshell.NewEncryptionManager(router) - if err = encryption.Setup(); err != nil { - log.Fatal(err) - } - - procedures := []struct { - name string - handler xconn.InvocationHandler - }{ - {procedureInteractive, newInteractiveShellSession().handleShell(encryption)}, - {procedureExec, handleRunCommand(encryption)}, - {procedureFileUpload, handleFileUpload(encryption)}, - {procedureFileDownload, handleFileDownload(encryption)}, - } - server := xconn.NewServer(router, authenticator, nil) if server == nil { log.Fatal("failed to create server") @@ -393,6 +432,21 @@ func main() { return } + encryption := wampshell.NewEncryptionManager(sess) + if err = encryption.Setup(); err != nil { + log.Fatal(err) + } + + procedures := []struct { + name string + handler xconn.InvocationHandler + }{ + {procedureInteractive, newInteractiveShellSession().handleShell(encryption)}, + {procedureExec, handleRunCommand(encryption)}, + {procedureFileUpload, handleFileUpload(encryption)}, + {procedureFileDownload, handleFileDownload(encryption)}, + } + for _, proc := range procedures { registerResponse := sess.Register(proc.name, proc.handler).Do() if registerResponse.Err != nil { @@ -400,6 +454,17 @@ func main() { } log.Printf("Procedure registered: %s", proc.name) } + + if sess.Details().Realm() != defaultRealm { + _, err := wampshell.ExchangeKeys(sess) + if err != nil { + log.Fatalf("Failed to exchange keys: %v", err) + } + + if err := SyncAuthorizedKeys(sess, keyStore); err != nil { + log.Printf("failed to sync authorized keys with : %v", err) + } + } } log.Printf("listening on rs://%s", address) diff --git a/encryption.go b/encryption.go index b108deb..8f38da2 100644 --- a/encryption.go +++ b/encryption.go @@ -14,32 +14,27 @@ type KeyPair struct { } type EncryptionManager struct { - router *xconn.Router + session *xconn.Session keys map[uint64]*KeyPair sync.Mutex } -func NewEncryptionManager(router *xconn.Router) *EncryptionManager { +func NewEncryptionManager(session *xconn.Session) *EncryptionManager { return &EncryptionManager{ - router: router, - keys: make(map[uint64]*KeyPair), + session: session, + keys: make(map[uint64]*KeyPair), } } func (e *EncryptionManager) Setup() error { - session, err := xconn.ConnectInMemory(e.router, "wampshell") - if err != nil { - return err - } - - response := session.Register("wampshell.key.exchange", e.HandleKeyExchange).Do() + response := e.session.Register("wampshell.key.exchange", e.HandleKeyExchange).Do() if response.Err != nil { return response.Err } - response = session.Register("wampshell.payload.echo", e.TestEcho).Do() + response = e.session.Register("wampshell.payload.echo", e.TestEcho).Do() if response.Err != nil { return response.Err } diff --git a/keystore.go b/keystore.go index 537cd28..6d9b7b3 100644 --- a/keystore.go +++ b/keystore.go @@ -45,9 +45,14 @@ func (k *KeyStore) OnUpdate(cb func(map[string][]string)) { func (k *KeyStore) Update(keys map[string][]string) { k.Lock() defer k.Unlock() - k.keys = keys + + for realm, newList := range keys { + existing := k.keys[realm] + k.keys[realm] = append(existing[:0:0], newList...) + } + if k.onUpdate != nil { - go k.onUpdate(keys) + go k.onUpdate(k.keys) } } @@ -142,3 +147,24 @@ func (k *KeyStore) watch(filePath string, watcher *fsnotify.Watcher) { } } } + +func (k *KeyStore) AuthorizedKeys() ([]string, error) { + k.RLock() + defer k.RUnlock() + + if len(k.keys) == 0 { + return nil, fmt.Errorf("no keys in KeyStore") + } + + var keys []string + for realm, ks := range k.keys { + for _, key := range ks { + if realm == "wampshell" { + keys = append(keys, key) + } else { + keys = append(keys, fmt.Sprintf("%s %s", key, realm)) + } + } + } + return keys, nil +}