Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion authenticator.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ func (a *ServerAuthenticator) Authenticate(request auth.Request) (auth.Response,
return auth.NewResponse("", "anonymous", 0)
}

return nil, fmt.Errorf("unauthorized")
return nil, fmt.Errorf("invalid credentials")
}

func (a *ServerAuthenticator) Realms() map[string][]string {
Expand Down
95 changes: 80 additions & 15 deletions cmd/wshd/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ const (
procedureExec = "wampshell.shell.exec"
procedureFileUpload = "wampshell.shell.upload"
procedureFileDownload = "wampshell.shell.download"
procedureSyncKeys = "wampshell.shell.keys.sync"
procedureWebRTCOffer = "wampshell.webrtc.offer"
topicOffererOnCandidate = "wampshell.webrtc.offerer.on_candidate"
topicAnswererOnCandidate = "wampshell.webrtc.answerer.on_candidate"
Expand Down Expand Up @@ -294,6 +295,48 @@ func addRealm(router *xconn.Router, realm string) {
log.Printf("Adding realm: %s", realm)
}

func SyncAuthorizedKeys(session *xconn.Session, keyStore *wampshell.KeyStore) error {
lines, err := keyStore.AuthorizedKeys()
if err != nil {
return fmt.Errorf("failed to get authorized keys: %w", err)
}

callResponse := session.Call(procedureSyncKeys).Arg(lines).Do()
if callResponse.Err != nil {
return fmt.Errorf("sync keys call failed: %w", callResponse.Err)
}

return nil
}

func handleSyncKeys(realm string, keyStore *wampshell.KeyStore) xconn.InvocationHandler {
return func(_ context.Context, inv *xconn.Invocation) *xconn.InvocationResult {
argList, err := inv.ArgList(0)
if err != nil {
return xconn.NewInvocationError("wamp.error.invalid_argument", err.Error())
}

keys := make([]string, 0, len(argList))
for _, item := range argList {
if s, ok := item.(string); ok && s != "" {
keys = append(keys, s)
}
}

if len(keys) == 0 {
return xconn.NewInvocationError("wamp.error.invalid_argument", "no valid keys provided")
}

newKeys := map[string][]string{
realm: keys,
}
keyStore.Update(newKeys)

log.Printf("Synced %d keys for realm %s from caller %d", len(keys), realm, inv.Caller())
return xconn.NewInvocationResult("ok")
}
}

func main() {
loadConfig, err := wampshell.LoadConfig()
if err != nil {
Expand Down Expand Up @@ -321,6 +364,17 @@ func main() {
addRealm(router, defaultRealm)
for realm := range authenticator.Realms() {
addRealm(router, realm)
if realm != defaultRealm {
c, err := xconn.ConnectInMemory(router, realm)
if err != nil {
log.Fatalf("Error connecting to realm %s: %v", realm, err)
}

r := c.Register(procedureSyncKeys, handleSyncKeys(realm, keyStore)).Do()
if r.Err != nil {
log.Fatalf("Error registering realm %s: %v", realm, r.Err)
}
}
}

keyStore.OnUpdate(func(keys map[string][]string) {
Expand All @@ -329,21 +383,6 @@ func main() {
}
})

encryption := wampshell.NewEncryptionManager(router)
if err = encryption.Setup(); err != nil {
log.Fatal(err)
}

procedures := []struct {
name string
handler xconn.InvocationHandler
}{
{procedureInteractive, newInteractiveShellSession().handleShell(encryption)},
{procedureExec, handleRunCommand(encryption)},
{procedureFileUpload, handleFileUpload(encryption)},
{procedureFileDownload, handleFileDownload(encryption)},
}

server := xconn.NewServer(router, authenticator, nil)
if server == nil {
log.Fatal("failed to create server")
Expand Down Expand Up @@ -393,13 +432,39 @@ func main() {
return
}

encryption := wampshell.NewEncryptionManager(sess)
if err = encryption.Setup(); err != nil {
log.Fatal(err)
}

procedures := []struct {
name string
handler xconn.InvocationHandler
}{
{procedureInteractive, newInteractiveShellSession().handleShell(encryption)},
{procedureExec, handleRunCommand(encryption)},
{procedureFileUpload, handleFileUpload(encryption)},
{procedureFileDownload, handleFileDownload(encryption)},
}

for _, proc := range procedures {
registerResponse := sess.Register(proc.name, proc.handler).Do()
if registerResponse.Err != nil {
log.Fatalln(registerResponse.Err)
}
log.Printf("Procedure registered: %s", proc.name)
}

if sess.Details().Realm() != defaultRealm {
_, err := wampshell.ExchangeKeys(sess)
if err != nil {
log.Fatalf("Failed to exchange keys: %v", err)
}

if err := SyncAuthorizedKeys(sess, keyStore); err != nil {
log.Printf("failed to sync authorized keys with : %v", err)
}
}
}

log.Printf("listening on rs://%s", address)
Expand Down
17 changes: 6 additions & 11 deletions encryption.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,32 +14,27 @@ type KeyPair struct {
}

type EncryptionManager struct {
router *xconn.Router
session *xconn.Session

keys map[uint64]*KeyPair

sync.Mutex
}

func NewEncryptionManager(router *xconn.Router) *EncryptionManager {
func NewEncryptionManager(session *xconn.Session) *EncryptionManager {
return &EncryptionManager{
router: router,
keys: make(map[uint64]*KeyPair),
session: session,
keys: make(map[uint64]*KeyPair),
}
}

func (e *EncryptionManager) Setup() error {
session, err := xconn.ConnectInMemory(e.router, "wampshell")
if err != nil {
return err
}

response := session.Register("wampshell.key.exchange", e.HandleKeyExchange).Do()
response := e.session.Register("wampshell.key.exchange", e.HandleKeyExchange).Do()
if response.Err != nil {
return response.Err
}

response = session.Register("wampshell.payload.echo", e.TestEcho).Do()
response = e.session.Register("wampshell.payload.echo", e.TestEcho).Do()
if response.Err != nil {
return response.Err
}
Expand Down
30 changes: 28 additions & 2 deletions keystore.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,14 @@ func (k *KeyStore) OnUpdate(cb func(map[string][]string)) {
func (k *KeyStore) Update(keys map[string][]string) {
k.Lock()
defer k.Unlock()
k.keys = keys

for realm, newList := range keys {
existing := k.keys[realm]
k.keys[realm] = append(existing[:0:0], newList...)
}

if k.onUpdate != nil {
go k.onUpdate(keys)
go k.onUpdate(k.keys)
}
}

Expand Down Expand Up @@ -142,3 +147,24 @@ func (k *KeyStore) watch(filePath string, watcher *fsnotify.Watcher) {
}
}
}

func (k *KeyStore) AuthorizedKeys() ([]string, error) {
k.RLock()
defer k.RUnlock()

if len(k.keys) == 0 {
return nil, fmt.Errorf("no keys in KeyStore")
}

var keys []string
for realm, ks := range k.keys {
for _, key := range ks {
if realm == "wampshell" {
keys = append(keys, key)
} else {
keys = append(keys, fmt.Sprintf("%s %s", key, realm))
}
}
}
return keys, nil
}
Loading