diff --git a/README.md b/README.md index acc21ab..9d3dc0d 100644 --- a/README.md +++ b/README.md @@ -1,5 +1,6 @@ -# EIPC — Embedded Inter-Process Communication +# eIPC — Embedded Inter-Process Communication +[![CI](https://github.com/embeddedos-org/eipc/actions/workflows/ci.yml/badge.svg)](https://github.com/embeddedos-org/eipc/actions/workflows/ci.yml) [![Go](https://img.shields.io/badge/Go-1.22+-00ADD8?logo=go&logoColor=white)](https://go.dev) [![License](https://img.shields.io/badge/License-MIT-blue.svg)](LICENSE) [![Platform](https://img.shields.io/badge/Platform-Linux%20%7C%20macOS%20%7C%20Windows-lightgrey)](https://github.com/embeddedos-org/eipc) diff --git a/cmd/eipc-cli/main.go b/cmd/eipc-cli/main.go index 5842493..865672c 100644 --- a/cmd/eipc-cli/main.go +++ b/cmd/eipc-cli/main.go @@ -1,207 +1,207 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -// eipc-cli is a debugging command-line tool for sending and receiving EIPC messages. -// -// Usage: -// -// eipc-cli send --addr HOST:PORT --type chat --payload '{"text":"hello"}' -// eipc-cli listen --addr HOST:PORT -// eipc-cli inspect --addr HOST:PORT -package main - -import ( - "bytes" - "encoding/json" - "flag" - "fmt" - "log" - "os" - "time" - - "github.com/embeddedos-org/eipc/config" - "github.com/embeddedos-org/eipc/core" - "github.com/embeddedos-org/eipc/protocol" - "github.com/embeddedos-org/eipc/transport/tcp" -) - -func main() { - if len(os.Args) < 2 { - printUsage() - os.Exit(1) - } - - switch os.Args[1] { - case "send": - cmdSend(os.Args[2:]) - case "listen": - cmdListen(os.Args[2:]) - case "ping": - cmdPing(os.Args[2:]) - case "help": - printUsage() - default: - fmt.Fprintf(os.Stderr, "unknown command: %s\n", os.Args[1]) - printUsage() - os.Exit(1) - } -} - -func printUsage() { - fmt.Println(`eipc-cli — EIPC debugging tool - -Usage: - eipc-cli [options] - -Commands: - send Send a single message to an EIPC server - listen Connect and print incoming messages - ping Send a heartbeat and wait for response - help Show this help message - -Environment: - EIPC_HMAC_KEY Shared HMAC key (required) - EIPC_LISTEN_ADDR Default server address (optional)`) -} - -func cmdSend(args []string) { - fs := flag.NewFlagSet("send", flag.ExitOnError) - addr := fs.String("addr", config.LoadListenAddr(), "server address") - msgType := fs.String("type", "chat", "message type (intent|chat|heartbeat|ack)") - payload := fs.String("payload", `{}`, "JSON payload") - source := fs.String("source", "eipc-cli", "source service ID") - capability := fs.String("cap", "", "capability header") - _ = fs.Parse(args) - - hmacKey, err := config.LoadHMACKey() - if err != nil { - log.Fatalf("HMAC key: %v", err) - } - - t := tcp.New() - conn, err := t.Dial(*addr) - if err != nil { - log.Fatalf("dial %s: %v", *addr, err) - } - defer conn.Close() - - ep := core.NewClientEndpoint(conn, protocol.DefaultCodec(), hmacKey, "") - - mt := core.MessageType(*msgType) - msg := core.Message{ - Version: core.ProtocolVersion, - Type: mt, - Source: *source, - Timestamp: time.Now().UTC(), - RequestID: fmt.Sprintf("cli-%d", time.Now().UnixNano()), - Priority: core.PriorityP1, - Capability: *capability, - Payload: []byte(*payload), - } - - if err := ep.Send(msg); err != nil { - log.Fatalf("send: %v", err) - } - - fmt.Printf("sent type=%s to=%s size=%d bytes\n", *msgType, *addr, len(*payload)) - - resp, err := ep.Receive() - if err != nil { - fmt.Printf("no response (connection closed or timeout)\n") - return - } - - printMessage("response", resp) -} - -func cmdListen(args []string) { - fs := flag.NewFlagSet("listen", flag.ExitOnError) - addr := fs.String("addr", config.LoadListenAddr(), "server address") - count := fs.Int("count", 0, "max messages to receive (0=unlimited)") - _ = fs.Parse(args) - - hmacKey, err := config.LoadHMACKey() - if err != nil { - log.Fatalf("HMAC key: %v", err) - } - - t := tcp.New() - conn, err := t.Dial(*addr) - if err != nil { - log.Fatalf("dial %s: %v", *addr, err) - } - defer conn.Close() - - ep := core.NewClientEndpoint(conn, protocol.DefaultCodec(), hmacKey, "") - - fmt.Printf("listening on %s ...\n", *addr) - received := 0 - for { - msg, err := ep.Receive() - if err != nil { - log.Fatalf("receive: %v", err) - } - received++ - printMessage(fmt.Sprintf("msg#%d", received), msg) - - if *count > 0 && received >= *count { - break - } - } -} - -func cmdPing(args []string) { - fs := flag.NewFlagSet("ping", flag.ExitOnError) - addr := fs.String("addr", config.LoadListenAddr(), "server address") - _ = fs.Parse(args) - - hmacKey, err := config.LoadHMACKey() - if err != nil { - log.Fatalf("HMAC key: %v", err) - } - - t := tcp.New() - conn, err := t.Dial(*addr) - if err != nil { - log.Fatalf("dial %s: %v", *addr, err) - } - defer conn.Close() - - ep := core.NewClientEndpoint(conn, protocol.DefaultCodec(), hmacKey, "") - - start := time.Now() - msg := core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeHeartbeat, - Source: "eipc-cli", - Timestamp: time.Now().UTC(), - RequestID: "ping", - Priority: core.PriorityP0, - Payload: []byte(`{"service":"eipc-cli","status":"ping"}`), - } - - if err := ep.Send(msg); err != nil { - log.Fatalf("send ping: %v", err) - } - - _, err = ep.Receive() - elapsed := time.Since(start) - - if err != nil { - fmt.Printf("ping %s: no response (%v)\n", *addr, err) - return - } - fmt.Printf("ping %s: rtt=%v\n", *addr, elapsed) -} - -func printMessage(label string, msg core.Message) { - var indented bytes.Buffer - if err := json.Indent(&indented, msg.Payload, " ", " "); err != nil { - fmt.Printf("[%s] type=%s source=%s req=%s priority=P%d cap=%s\n payload: %s\n", - label, msg.Type, msg.Source, msg.RequestID, msg.Priority, msg.Capability, string(msg.Payload)) - return - } - - fmt.Printf("[%s] type=%s source=%s req=%s priority=P%d cap=%s\n payload: %s\n", - label, msg.Type, msg.Source, msg.RequestID, msg.Priority, msg.Capability, indented.String()) -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +// eipc-cli is a debugging command-line tool for sending and receiving EIPC messages. +// +// Usage: +// +// eipc-cli send --addr HOST:PORT --type chat --payload '{"text":"hello"}' +// eipc-cli listen --addr HOST:PORT +// eipc-cli inspect --addr HOST:PORT +package main + +import ( + "bytes" + "encoding/json" + "flag" + "fmt" + "log" + "os" + "time" + + "github.com/embeddedos-org/eipc/config" + "github.com/embeddedos-org/eipc/core" + "github.com/embeddedos-org/eipc/protocol" + "github.com/embeddedos-org/eipc/transport/tcp" +) + +func main() { + if len(os.Args) < 2 { + printUsage() + os.Exit(1) + } + + switch os.Args[1] { + case "send": + cmdSend(os.Args[2:]) + case "listen": + cmdListen(os.Args[2:]) + case "ping": + cmdPing(os.Args[2:]) + case "help": + printUsage() + default: + fmt.Fprintf(os.Stderr, "unknown command: %s\n", os.Args[1]) + printUsage() + os.Exit(1) + } +} + +func printUsage() { + fmt.Println(`eipc-cli — EIPC debugging tool + +Usage: + eipc-cli [options] + +Commands: + send Send a single message to an EIPC server + listen Connect and print incoming messages + ping Send a heartbeat and wait for response + help Show this help message + +Environment: + EIPC_HMAC_KEY Shared HMAC key (required) + EIPC_LISTEN_ADDR Default server address (optional)`) +} + +func cmdSend(args []string) { + fs := flag.NewFlagSet("send", flag.ExitOnError) + addr := fs.String("addr", config.LoadListenAddr(), "server address") + msgType := fs.String("type", "chat", "message type (intent|chat|heartbeat|ack)") + payload := fs.String("payload", `{}`, "JSON payload") + source := fs.String("source", "eipc-cli", "source service ID") + capability := fs.String("cap", "", "capability header") + _ = fs.Parse(args) + + hmacKey, err := config.LoadHMACKey() + if err != nil { + log.Fatalf("HMAC key: %v", err) + } + + t := tcp.New() + conn, err := t.Dial(*addr) + if err != nil { + log.Fatalf("dial %s: %v", *addr, err) + } + defer conn.Close() + + ep := core.NewClientEndpoint(conn, protocol.DefaultCodec(), hmacKey, "") + + mt := core.MessageType(*msgType) + msg := core.Message{ + Version: core.ProtocolVersion, + Type: mt, + Source: *source, + Timestamp: time.Now().UTC(), + RequestID: fmt.Sprintf("cli-%d", time.Now().UnixNano()), + Priority: core.PriorityP1, + Capability: *capability, + Payload: []byte(*payload), + } + + if err := ep.Send(msg); err != nil { + log.Fatalf("send: %v", err) + } + + fmt.Printf("sent type=%s to=%s size=%d bytes\n", *msgType, *addr, len(*payload)) + + resp, err := ep.Receive() + if err != nil { + fmt.Printf("no response (connection closed or timeout)\n") + return + } + + printMessage("response", resp) +} + +func cmdListen(args []string) { + fs := flag.NewFlagSet("listen", flag.ExitOnError) + addr := fs.String("addr", config.LoadListenAddr(), "server address") + count := fs.Int("count", 0, "max messages to receive (0=unlimited)") + _ = fs.Parse(args) + + hmacKey, err := config.LoadHMACKey() + if err != nil { + log.Fatalf("HMAC key: %v", err) + } + + t := tcp.New() + conn, err := t.Dial(*addr) + if err != nil { + log.Fatalf("dial %s: %v", *addr, err) + } + defer conn.Close() + + ep := core.NewClientEndpoint(conn, protocol.DefaultCodec(), hmacKey, "") + + fmt.Printf("listening on %s ...\n", *addr) + received := 0 + for { + msg, err := ep.Receive() + if err != nil { + log.Fatalf("receive: %v", err) + } + received++ + printMessage(fmt.Sprintf("msg#%d", received), msg) + + if *count > 0 && received >= *count { + break + } + } +} + +func cmdPing(args []string) { + fs := flag.NewFlagSet("ping", flag.ExitOnError) + addr := fs.String("addr", config.LoadListenAddr(), "server address") + _ = fs.Parse(args) + + hmacKey, err := config.LoadHMACKey() + if err != nil { + log.Fatalf("HMAC key: %v", err) + } + + t := tcp.New() + conn, err := t.Dial(*addr) + if err != nil { + log.Fatalf("dial %s: %v", *addr, err) + } + defer conn.Close() + + ep := core.NewClientEndpoint(conn, protocol.DefaultCodec(), hmacKey, "") + + start := time.Now() + msg := core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeHeartbeat, + Source: "eipc-cli", + Timestamp: time.Now().UTC(), + RequestID: "ping", + Priority: core.PriorityP0, + Payload: []byte(`{"service":"eipc-cli","status":"ping"}`), + } + + if err := ep.Send(msg); err != nil { + log.Fatalf("send ping: %v", err) + } + + _, err = ep.Receive() + elapsed := time.Since(start) + + if err != nil { + fmt.Printf("ping %s: no response (%v)\n", *addr, err) + return + } + fmt.Printf("ping %s: rtt=%v\n", *addr, elapsed) +} + +func printMessage(label string, msg core.Message) { + var indented bytes.Buffer + if err := json.Indent(&indented, msg.Payload, " ", " "); err != nil { + fmt.Printf("[%s] type=%s source=%s req=%s priority=P%d cap=%s\n payload: %s\n", + label, msg.Type, msg.Source, msg.RequestID, msg.Priority, msg.Capability, string(msg.Payload)) + return + } + + fmt.Printf("[%s] type=%s source=%s req=%s priority=P%d cap=%s\n payload: %s\n", + label, msg.Type, msg.Source, msg.RequestID, msg.Priority, msg.Capability, indented.String()) +} diff --git a/cmd/eipc-client/main.go b/cmd/eipc-client/main.go index 84a6066..b6f9f82 100644 --- a/cmd/eipc-client/main.go +++ b/cmd/eipc-client/main.go @@ -1,235 +1,235 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package main - -import ( - "crypto/hmac" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "fmt" - "log" - "os" - "time" - - "github.com/embeddedos-org/eipc/config" - "github.com/embeddedos-org/eipc/core" - "github.com/embeddedos-org/eipc/protocol" - "github.com/embeddedos-org/eipc/transport/tcp" -) - -func main() { - addr := "127.0.0.1:9090" - if len(os.Args) > 1 { - addr = os.Args[1] - } - - sharedSecret, err := config.LoadHMACKey() - if err != nil { - log.Fatalf("[CONFIG] %v", err) - } - - serviceID := "nia.min" - codec := protocol.DefaultCodec() - - log.Printf("EIPC client connecting to %s as %s", addr, serviceID) - - tcpTransport := tcp.New() - if err := tcpTransport.SetupTLSFromEnv(); err != nil { - log.Fatalf("TLS setup: %v", err) - } - - conn, err := tcpTransport.Dial(addr) - if err != nil { - log.Fatalf("dial: %v", err) - } - defer conn.Close() - - endpoint := core.NewClientEndpoint(conn, codec, sharedSecret, "") - - // Step 1: Send authentication request - log.Println("[1] Sending authentication request...") - type authRequest struct { - ServiceID string `json:"service_id"` - } - authPayload, err := codec.Marshal(authRequest{ServiceID: serviceID}) - if err != nil { - log.Fatalf("marshal auth request: %v", err) - } - - if err := endpoint.Send(core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeAuth, - Source: serviceID, - Timestamp: time.Now().UTC(), - Payload: authPayload, - }); err != nil { - log.Fatalf("send auth request: %v", err) - } - - // Step 2: Receive challenge (nonce) - challengeMsg, err := endpoint.Receive() - if err != nil { - log.Fatalf("receive challenge: %v", err) - } - if challengeMsg.Type != core.TypeChallenge { - log.Fatalf("[AUTH] expected TypeChallenge, got %s", challengeMsg.Type) - } - - type challengeResponse struct { - Status string `json:"status"` - Nonce string `json:"nonce"` - Error string `json:"error,omitempty"` - } - var challenge challengeResponse - if err := json.Unmarshal(challengeMsg.Payload, &challenge); err != nil { - log.Fatalf("unmarshal challenge: %v", err) - } - - if challenge.Status == "denied" { - log.Fatalf("[AUTH] rejected: %s", challenge.Error) - } - if len(challenge.Nonce) >= 16 { - log.Printf("[2] Received challenge nonce: %s...%s", - challenge.Nonce[:8], challenge.Nonce[len(challenge.Nonce)-8:]) - } else { - log.Printf("[2] Received challenge nonce: %s", challenge.Nonce) - } - - // Step 3: Compute HMAC-SHA256(secret, nonce) and send response - nonceBytes, err := hex.DecodeString(challenge.Nonce) - if err != nil { - log.Fatalf("decode nonce: %v", err) - } - - mac := hmac.New(sha256.New, sharedSecret) - mac.Write(nonceBytes) - response := mac.Sum(nil) - - type authChallengeResponse struct { - ServiceID string `json:"service_id"` - Response string `json:"response"` - } - chalRespPayload, err := codec.Marshal(authChallengeResponse{ - ServiceID: serviceID, - Response: hex.EncodeToString(response), - }) - if err != nil { - log.Fatalf("marshal challenge response: %v", err) - } - - if err := endpoint.Send(core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeAuthResponse, - Source: serviceID, - Timestamp: time.Now().UTC(), - Payload: chalRespPayload, - }); err != nil { - log.Fatalf("send auth response: %v", err) - } - - // Step 4: Receive session token - authResp, err := endpoint.Receive() - if err != nil { - log.Fatalf("receive auth response: %v", err) - } - if authResp.Type != core.TypeAuthResponse { - log.Fatalf("[AUTH] expected TypeAuthResponse, got %s", authResp.Type) - } - - type authResult struct { - Status string `json:"status"` - SessionToken string `json:"session_token"` - Capabilities []string `json:"capabilities"` - Error string `json:"error,omitempty"` - } - var authRes authResult - if err := json.Unmarshal(authResp.Payload, &authRes); err != nil { - log.Fatalf("unmarshal auth response: %v", err) - } - - sessionToken := authRes.SessionToken - if len(sessionToken) >= 16 { - log.Printf("[3] Authenticated! token=%s...%s caps=%v", - sessionToken[:8], sessionToken[len(sessionToken)-8:], authRes.Capabilities) - } else { - log.Printf("[3] Authenticated! token=%s caps=%v", sessionToken, authRes.Capabilities) - } - - // Step 5: Send HMAC-protected intent - log.Println("[4] Sending intent: move_left (confidence=0.91)") - intentPayload, err := codec.Marshal(core.IntentEvent{ - Intent: "move_left", - Confidence: 0.91, - SessionID: sessionToken, - }) - if err != nil { - log.Fatalf("marshal intent: %v", err) - } - - if err := endpoint.Send(core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeIntent, - Source: serviceID, - Timestamp: time.Now().UTC(), - SessionID: sessionToken, - RequestID: "req-1", - Priority: core.PriorityP0, - Capability: "ui:control", - Payload: intentPayload, - }); err != nil { - log.Fatalf("send intent: %v", err) - } - - // Step 6: Receive ack - ackMsg, err := endpoint.Receive() - if err != nil { - log.Fatalf("receive ack: %v", err) - } - - var ack core.AckEvent - if err := json.Unmarshal(ackMsg.Payload, &ack); err != nil { - log.Fatalf("unmarshal ack: %v", err) - } - - log.Printf("[5] Received ACK: request_id=%s status=%s", ack.RequestID, ack.Status) - - // Step 7: Send heartbeat - log.Println("[6] Sending heartbeat...") - hbPayload, err := codec.Marshal(core.HeartbeatEvent{ - Service: serviceID, - Status: "ready", - }) - if err != nil { - log.Fatalf("marshal heartbeat: %v", err) - } - - if err := endpoint.Send(core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeHeartbeat, - Source: serviceID, - Timestamp: time.Now().UTC(), - SessionID: sessionToken, - RequestID: "hb-1", - Priority: core.PriorityP2, - Payload: hbPayload, - }); err != nil { - log.Fatalf("send heartbeat: %v", err) - } - - time.Sleep(100 * time.Millisecond) - - endpoint.Close() - fmt.Println() - fmt.Println("=== EIPC Demo Complete (Hardened) ===") - fmt.Println("End-to-end flow demonstrated:") - fmt.Println(" 1. Client connected to server (TLS if configured)") - fmt.Println(" 2. Server sent challenge nonce") - fmt.Println(" 3. Client proved secret via HMAC-SHA256 response") - fmt.Println(" 4. Server validated & issued session token") - fmt.Println(" 5. Client sent HMAC-protected intents") - fmt.Println(" 6. Server enforced capability (ui:control)") - fmt.Println(" 7. Audit events recorded") - fmt.Println(" 8. Acks returned to client") -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package main + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "log" + "os" + "time" + + "github.com/embeddedos-org/eipc/config" + "github.com/embeddedos-org/eipc/core" + "github.com/embeddedos-org/eipc/protocol" + "github.com/embeddedos-org/eipc/transport/tcp" +) + +func main() { + addr := "127.0.0.1:9090" + if len(os.Args) > 1 { + addr = os.Args[1] + } + + sharedSecret, err := config.LoadHMACKey() + if err != nil { + log.Fatalf("[CONFIG] %v", err) + } + + serviceID := "nia.min" + codec := protocol.DefaultCodec() + + log.Printf("EIPC client connecting to %s as %s", addr, serviceID) + + tcpTransport := tcp.New() + if err := tcpTransport.SetupTLSFromEnv(); err != nil { + log.Fatalf("TLS setup: %v", err) + } + + conn, err := tcpTransport.Dial(addr) + if err != nil { + log.Fatalf("dial: %v", err) + } + defer conn.Close() + + endpoint := core.NewClientEndpoint(conn, codec, sharedSecret, "") + + // Step 1: Send authentication request + log.Println("[1] Sending authentication request...") + type authRequest struct { + ServiceID string `json:"service_id"` + } + authPayload, err := codec.Marshal(authRequest{ServiceID: serviceID}) + if err != nil { + log.Fatalf("marshal auth request: %v", err) + } + + if err := endpoint.Send(core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeAuth, + Source: serviceID, + Timestamp: time.Now().UTC(), + Payload: authPayload, + }); err != nil { + log.Fatalf("send auth request: %v", err) + } + + // Step 2: Receive challenge (nonce) + challengeMsg, err := endpoint.Receive() + if err != nil { + log.Fatalf("receive challenge: %v", err) + } + if challengeMsg.Type != core.TypeChallenge { + log.Fatalf("[AUTH] expected TypeChallenge, got %s", challengeMsg.Type) + } + + type challengeResponse struct { + Status string `json:"status"` + Nonce string `json:"nonce"` + Error string `json:"error,omitempty"` + } + var challenge challengeResponse + if err := json.Unmarshal(challengeMsg.Payload, &challenge); err != nil { + log.Fatalf("unmarshal challenge: %v", err) + } + + if challenge.Status == "denied" { + log.Fatalf("[AUTH] rejected: %s", challenge.Error) + } + if len(challenge.Nonce) >= 16 { + log.Printf("[2] Received challenge nonce: %s...%s", + challenge.Nonce[:8], challenge.Nonce[len(challenge.Nonce)-8:]) + } else { + log.Printf("[2] Received challenge nonce: %s", challenge.Nonce) + } + + // Step 3: Compute HMAC-SHA256(secret, nonce) and send response + nonceBytes, err := hex.DecodeString(challenge.Nonce) + if err != nil { + log.Fatalf("decode nonce: %v", err) + } + + mac := hmac.New(sha256.New, sharedSecret) + mac.Write(nonceBytes) + response := mac.Sum(nil) + + type authChallengeResponse struct { + ServiceID string `json:"service_id"` + Response string `json:"response"` + } + chalRespPayload, err := codec.Marshal(authChallengeResponse{ + ServiceID: serviceID, + Response: hex.EncodeToString(response), + }) + if err != nil { + log.Fatalf("marshal challenge response: %v", err) + } + + if err := endpoint.Send(core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeAuthResponse, + Source: serviceID, + Timestamp: time.Now().UTC(), + Payload: chalRespPayload, + }); err != nil { + log.Fatalf("send auth response: %v", err) + } + + // Step 4: Receive session token + authResp, err := endpoint.Receive() + if err != nil { + log.Fatalf("receive auth response: %v", err) + } + if authResp.Type != core.TypeAuthResponse { + log.Fatalf("[AUTH] expected TypeAuthResponse, got %s", authResp.Type) + } + + type authResult struct { + Status string `json:"status"` + SessionToken string `json:"session_token"` + Capabilities []string `json:"capabilities"` + Error string `json:"error,omitempty"` + } + var authRes authResult + if err := json.Unmarshal(authResp.Payload, &authRes); err != nil { + log.Fatalf("unmarshal auth response: %v", err) + } + + sessionToken := authRes.SessionToken + if len(sessionToken) >= 16 { + log.Printf("[3] Authenticated! token=%s...%s caps=%v", + sessionToken[:8], sessionToken[len(sessionToken)-8:], authRes.Capabilities) + } else { + log.Printf("[3] Authenticated! token=%s caps=%v", sessionToken, authRes.Capabilities) + } + + // Step 5: Send HMAC-protected intent + log.Println("[4] Sending intent: move_left (confidence=0.91)") + intentPayload, err := codec.Marshal(core.IntentEvent{ + Intent: "move_left", + Confidence: 0.91, + SessionID: sessionToken, + }) + if err != nil { + log.Fatalf("marshal intent: %v", err) + } + + if err := endpoint.Send(core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeIntent, + Source: serviceID, + Timestamp: time.Now().UTC(), + SessionID: sessionToken, + RequestID: "req-1", + Priority: core.PriorityP0, + Capability: "ui:control", + Payload: intentPayload, + }); err != nil { + log.Fatalf("send intent: %v", err) + } + + // Step 6: Receive ack + ackMsg, err := endpoint.Receive() + if err != nil { + log.Fatalf("receive ack: %v", err) + } + + var ack core.AckEvent + if err := json.Unmarshal(ackMsg.Payload, &ack); err != nil { + log.Fatalf("unmarshal ack: %v", err) + } + + log.Printf("[5] Received ACK: request_id=%s status=%s", ack.RequestID, ack.Status) + + // Step 7: Send heartbeat + log.Println("[6] Sending heartbeat...") + hbPayload, err := codec.Marshal(core.HeartbeatEvent{ + Service: serviceID, + Status: "ready", + }) + if err != nil { + log.Fatalf("marshal heartbeat: %v", err) + } + + if err := endpoint.Send(core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeHeartbeat, + Source: serviceID, + Timestamp: time.Now().UTC(), + SessionID: sessionToken, + RequestID: "hb-1", + Priority: core.PriorityP2, + Payload: hbPayload, + }); err != nil { + log.Fatalf("send heartbeat: %v", err) + } + + time.Sleep(100 * time.Millisecond) + + endpoint.Close() + fmt.Println() + fmt.Println("=== EIPC Demo Complete (Hardened) ===") + fmt.Println("End-to-end flow demonstrated:") + fmt.Println(" 1. Client connected to server (TLS if configured)") + fmt.Println(" 2. Server sent challenge nonce") + fmt.Println(" 3. Client proved secret via HMAC-SHA256 response") + fmt.Println(" 4. Server validated & issued session token") + fmt.Println(" 5. Client sent HMAC-protected intents") + fmt.Println(" 6. Server enforced capability (ui:control)") + fmt.Println(" 7. Audit events recorded") + fmt.Println(" 8. Acks returned to client") +} diff --git a/cmd/eipc-server/main.go b/cmd/eipc-server/main.go index e15b025..8f374e6 100644 --- a/cmd/eipc-server/main.go +++ b/cmd/eipc-server/main.go @@ -1,673 +1,673 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package main - -import ( - "crypto/hmac" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "fmt" - "log" - "os" - "os/signal" - "time" - - "github.com/embeddedos-org/eipc/config" - "github.com/embeddedos-org/eipc/core" - "github.com/embeddedos-org/eipc/protocol" - "github.com/embeddedos-org/eipc/security/auth" - "github.com/embeddedos-org/eipc/security/capability" - "github.com/embeddedos-org/eipc/services/audit" - "github.com/embeddedos-org/eipc/services/health" - "github.com/embeddedos-org/eipc/services/registry" - "github.com/embeddedos-org/eipc/transport" - "github.com/embeddedos-org/eipc/transport/tcp" -) - -func main() { - addr := config.LoadListenAddr() - if len(os.Args) > 1 { - addr = os.Args[1] - } - - sharedSecret, err := config.LoadHMACKey() - if err != nil { - log.Fatalf("[CONFIG] %v", err) - } - - sessionTTL := config.LoadSessionTTL() - maxConns := config.LoadMaxConnections() - - authenticator := auth.NewAuthenticator(sharedSecret, map[string][]string{ - "nia.min": {"ui:control", "device:read"}, - "nia.framework": {"ui:control", "device:read", "device:write"}, - "ail.min.agent": {"ui:control"}, - "ail.framework": {"ui:control", "device:read", "device:write", "system:restricted"}, - "ebot.client": {"ai:chat"}, - }) - authenticator.SetSessionTTL(sessionTTL) - - capChecker := capability.NewChecker(map[string][]string{ - "ui:control": {"ui.cursor.move", "ui.click", "ui.scroll"}, - "device:read": {"device.sensor.read", "device.status"}, - "device:write": {"device.actuator.write"}, - "system:restricted": {"system.reboot", "system.update"}, - "ai:chat": {"ai.chat.send", "ai.complete.send"}, - }) - - auditLogger, err := audit.NewFileLogger("") - if err != nil { - log.Fatalf("audit logger: %v", err) - } - defer auditLogger.Close() - - healthSvc := health.NewService(5*time.Second, 15*time.Second) - - reg := registry.NewRegistry() - if err := reg.Register(registry.ServiceInfo{ - ServiceID: "eipc-server", - Capabilities: []string{"ui:control", "device:read", "device:write", "ai:chat"}, - Versions: []uint16{1}, - MessageTypes: []core.MessageType{ - core.TypeIntent, core.TypeAck, core.TypeHeartbeat, core.TypeAudit, - core.TypeChat, core.TypeComplete, core.TypeAuth, core.TypeChallenge, core.TypeAuthResponse, - }, - Priority: core.PriorityP0, - }); err != nil { - log.Printf("[REGISTRY] failed to register eipc-server: %v", err) - } - if err := reg.Register(registry.ServiceInfo{ - ServiceID: "ebot.client", - Capabilities: []string{"ai:chat"}, - Versions: []uint16{1}, - MessageTypes: []core.MessageType{core.TypeChat, core.TypeComplete, core.TypeAck, core.TypeAuth, core.TypeAuthResponse}, - Priority: core.PriorityP1, - }); err != nil { - log.Printf("[REGISTRY] failed to register ebot.client: %v", err) - } - - router := core.NewRouter() - - router.Handle(core.TypeIntent, func(msg core.Message) (*core.Message, error) { - var intent core.IntentEvent - codec := protocol.DefaultCodec() - if err := codec.Unmarshal(msg.Payload, &intent); err != nil { - return nil, fmt.Errorf("unmarshal intent: %w", err) - } - - log.Printf("[INTENT] from=%s intent=%s confidence=%.2f session=%s", - msg.Source, intent.Intent, intent.Confidence, intent.SessionID) - - if err := capChecker.Check([]string{msg.Capability}, "ui.cursor.move"); err != nil { - log.Printf("[POLICY] DENIED: %v", err) - if err := auditLogger.Log(audit.Entry{ - RequestID: msg.RequestID, - Source: msg.Source, - Target: "eipc-server", - Action: intent.Intent, - Decision: "denied", - Result: err.Error(), - }); err != nil { - log.Printf("[AUDIT] failed: %v", err) - } - return nil, err - } - - log.Printf("[POLICY] ALLOWED: capability=%s action=%s", msg.Capability, intent.Intent) - - if err := auditLogger.Log(audit.Entry{ - RequestID: msg.RequestID, - Source: msg.Source, - Target: "eipc-server", - Action: intent.Intent, - Decision: "allowed", - Result: "success", - }); err != nil { - log.Printf("[AUDIT] failed: %v", err) - } - - ackPayload, err := codec.Marshal(core.AckEvent{ - RequestID: msg.RequestID, - Status: "ok", - }) - if err != nil { - return nil, fmt.Errorf("marshal ack: %w", err) - } - - ack := core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeAck, - Source: "eipc-server", - Timestamp: time.Now().UTC(), - SessionID: msg.SessionID, - RequestID: msg.RequestID, - Priority: core.PriorityP0, - Payload: ackPayload, - } - return &ack, nil - }) - - router.Handle(core.TypeHeartbeat, func(msg core.Message) (*core.Message, error) { - var hb core.HeartbeatEvent - codec := protocol.DefaultCodec() - if err := codec.Unmarshal(msg.Payload, &hb); err != nil { - return nil, err - } - healthSvc.RecordHeartbeat(hb.Service, hb.Status) - log.Printf("[HEARTBEAT] service=%s status=%s", hb.Service, hb.Status) - return nil, nil - }) - - router.Handle(core.TypeChat, func(msg core.Message) (*core.Message, error) { - var chatReq core.ChatRequestEvent - codec := protocol.DefaultCodec() - if err := codec.Unmarshal(msg.Payload, &chatReq); err != nil { - return nil, fmt.Errorf("unmarshal chat request: %w", err) - } - - if err := capChecker.Check([]string{msg.Capability}, "ai.chat.send"); err != nil { - log.Printf("[POLICY] DENIED chat: %v", err) - if err := auditLogger.Log(audit.Entry{ - RequestID: msg.RequestID, - Source: msg.Source, - Target: "eipc-server", - Action: "ai.chat.send", - Decision: "denied", - Result: err.Error(), - }); err != nil { - log.Printf("[AUDIT] failed: %v", err) - } - return nil, err - } - - log.Printf("[CHAT] from=%s session=%s prompt=%q", - msg.Source, chatReq.SessionID, chatReq.UserPrompt) - - if err := auditLogger.Log(audit.Entry{ - RequestID: msg.RequestID, - Source: msg.Source, - Target: "eai", - Action: "ai.chat.send", - Decision: "allowed", - Result: "forwarded", - }); err != nil { - log.Printf("[AUDIT] failed: %v", err) - } - - // TODO: Forward to EAI agent loop. For now, echo acknowledgment. - chatResp := core.ChatResponseEvent{ - SessionID: chatReq.SessionID, - Response: fmt.Sprintf("[EIPC] Chat received: %s", chatReq.UserPrompt), - Model: chatReq.Model, - TokensUsed: 0, - } - respPayload, err := codec.Marshal(chatResp) - if err != nil { - return nil, fmt.Errorf("marshal chat response: %w", err) - } - - return &core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeChat, - Source: "eipc-server", - Timestamp: time.Now().UTC(), - SessionID: msg.SessionID, - RequestID: msg.RequestID, - Priority: core.PriorityP1, - Payload: respPayload, - }, nil - }) - - router.Handle(core.TypeComplete, func(msg core.Message) (*core.Message, error) { - var completeReq core.CompleteRequestEvent - codec := protocol.DefaultCodec() - if err := codec.Unmarshal(msg.Payload, &completeReq); err != nil { - return nil, fmt.Errorf("unmarshal complete request: %w", err) - } - - if err := capChecker.Check([]string{msg.Capability}, "ai.complete.send"); err != nil { - log.Printf("[POLICY] DENIED complete: %v", err) - if err := auditLogger.Log(audit.Entry{ - RequestID: msg.RequestID, - Source: msg.Source, - Target: "eipc-server", - Action: "ai.complete.send", - Decision: "denied", - Result: err.Error(), - }); err != nil { - log.Printf("[AUDIT] failed: %v", err) - } - return nil, err - } - - log.Printf("[COMPLETE] from=%s session=%s prompt=%q", - msg.Source, completeReq.SessionID, completeReq.Prompt) - - if err := auditLogger.Log(audit.Entry{ - RequestID: msg.RequestID, - Source: msg.Source, - Target: "eai", - Action: "ai.complete.send", - Decision: "allowed", - Result: "forwarded", - }); err != nil { - log.Printf("[AUDIT] failed: %v", err) - } - - completeResp := core.CompleteResponseEvent{ - SessionID: completeReq.SessionID, - Completion: fmt.Sprintf("[EIPC] Completion received: %s", completeReq.Prompt), - Model: completeReq.Model, - TokensUsed: 0, - } - respPayload, err := codec.Marshal(completeResp) - if err != nil { - return nil, fmt.Errorf("marshal complete response: %w", err) - } - - return &core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeComplete, - Source: "eipc-server", - Timestamp: time.Now().UTC(), - SessionID: msg.SessionID, - RequestID: msg.RequestID, - Priority: core.PriorityP1, - Payload: respPayload, - }, nil - }) - - tcpTransport := tcp.New() - if err := tcpTransport.SetupTLSFromEnv(); err != nil { - log.Fatalf("TLS setup: %v", err) - } - if err := tcpTransport.Listen(addr); err != nil { - log.Fatalf("listen: %v", err) - } - defer tcpTransport.Close() - - tlsMode := "plaintext" - if config.TLSEnabled() { - tlsMode = "TLS" - } - log.Printf("EIPC server listening on %s [%s] (max_conns=%d, session_ttl=%s)", - tcpTransport.Addr(), tlsMode, maxConns, sessionTTL) - - // Connection limit semaphore - connSem := make(chan struct{}, maxConns) - - // Background session cleanup goroutine - go func() { - ticker := time.NewTicker(5 * time.Minute) - defer ticker.Stop() - for range ticker.C { - removed := authenticator.CleanupExpired() - if removed > 0 { - log.Printf("[SESSION] cleaned up %d expired sessions", removed) - } - } - }() - - go func() { - sigCh := make(chan os.Signal, 1) - signal.Notify(sigCh, os.Interrupt) - <-sigCh - log.Println("Shutting down...") - tcpTransport.Close() - os.Exit(0) - }() - - codec := protocol.DefaultCodec() - - for { - conn, err := tcpTransport.Accept() - if err != nil { - log.Printf("accept error: %v", err) - return - } - - select { - case connSem <- struct{}{}: - go func() { - defer func() { <-connSem }() - handleConnection(conn, authenticator, codec, sharedSecret, router, auditLogger, capChecker) - }() - default: - log.Printf("[CONN] rejected connection from %s: max connections (%d) reached", conn.RemoteAddr(), maxConns) - if err := auditLogger.Log(audit.Entry{ - Source: conn.RemoteAddr(), - Target: "eipc-server", - Action: "connect", - Decision: "denied", - Result: "connection limit exceeded", - }); err != nil { - log.Printf("[AUDIT] failed: %v", err) - } - conn.Close() - } - } -} - -func handleConnection( - conn transport.Connection, - authenticator *auth.Authenticator, - codec protocol.Codec, - hmacKey []byte, - router *core.Router, - auditLogger audit.Logger, - capChecker *capability.Checker, -) { - defer conn.Close() - log.Printf("[CONN] new connection from %s", conn.RemoteAddr()) - - endpoint := core.NewServerEndpoint(conn, codec, hmacKey) - - // Auth timeout: 10s - authDone := make(chan struct{}) - go func() { - select { - case <-authDone: - case <-time.After(10 * time.Second): - log.Printf("[AUTH] timeout waiting for auth from %s", conn.RemoteAddr()) - conn.Close() - } - }() - - // Step 1: Receive auth request - authMsg, err := endpoint.Receive() - if err != nil { - log.Printf("[AUTH] failed to receive auth message: %v", err) - close(authDone) - return - } - if authMsg.Type != core.TypeAuth { - log.Printf("[AUTH] expected TypeAuth, got %s", authMsg.Type) - close(authDone) - return - } - - type authRequest struct { - ServiceID string `json:"service_id"` - } - var authReq authRequest - if err := json.Unmarshal(authMsg.Payload, &authReq); err != nil { - log.Printf("[AUTH] bad auth payload: %v", err) - close(authDone) - return - } - - // Step 2: Create challenge (send nonce) - challenge, err := authenticator.CreateChallenge(authReq.ServiceID) - if err != nil { - log.Printf("[AUTH] REJECTED: %v", err) - if err := auditLogger.Log(audit.Entry{ - RequestID: authMsg.RequestID, - Source: authReq.ServiceID, - Target: "eipc-server", - Action: "authenticate", - Decision: "denied", - Result: err.Error(), - }); err != nil { - log.Printf("[AUDIT] failed: %v", err) - } - type authResponse struct { - Status string `json:"status"` - Error string `json:"error,omitempty"` - } - respPayload, err := codec.Marshal(authResponse{Status: "denied", Error: err.Error()}) - if err != nil { - log.Printf("[AUTH] failed to marshal auth response: %v", err) - } else { - if err := endpoint.Send(core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeAuthResponse, - Source: "eipc-server", - Timestamp: time.Now().UTC(), - RequestID: authMsg.RequestID, - Payload: respPayload, - }); err != nil { - log.Printf("[AUTH] failed to send auth response: %v", err) - } - } - close(authDone) - return - } - - type challengeMessage struct { - Status string `json:"status"` - Nonce string `json:"nonce"` - } - challengePayload, err := codec.Marshal(challengeMessage{ - Status: "challenge", - Nonce: hex.EncodeToString(challenge.Nonce), - }) - if err != nil { - log.Printf("[AUTH] failed to marshal challenge: %v", err) - close(authDone) - return - } - if err := endpoint.Send(core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeChallenge, - Source: "eipc-server", - Timestamp: time.Now().UTC(), - RequestID: authMsg.RequestID, - Payload: challengePayload, - }); err != nil { - log.Printf("[AUTH] failed to send challenge: %v", err) - close(authDone) - return - } - - // Step 3: Receive HMAC response - responseMsg, err := endpoint.Receive() - if err != nil { - log.Printf("[AUTH] failed to receive challenge response: %v", err) - close(authDone) - return - } - if responseMsg.Type != core.TypeAuthResponse { - log.Printf("[AUTH] expected TypeAuthResponse, got %s", responseMsg.Type) - close(authDone) - return - } - - type challengeResponse struct { - ServiceID string `json:"service_id"` - Response string `json:"response"` - } - var chalResp challengeResponse - if err := json.Unmarshal(responseMsg.Payload, &chalResp); err != nil { - log.Printf("[AUTH] bad challenge response: %v", err) - close(authDone) - return - } - - responseBytes, err := hex.DecodeString(chalResp.Response) - if err != nil { - log.Printf("[AUTH] bad response encoding: %v", err) - close(authDone) - return - } - - // Step 4: Verify response - peer, err := authenticator.VerifyResponse(authReq.ServiceID, responseBytes) - if err != nil { - log.Printf("[AUTH] REJECTED (challenge-response): %v", err) - if err := auditLogger.Log(audit.Entry{ - RequestID: authMsg.RequestID, - Source: authReq.ServiceID, - Target: "eipc-server", - Action: "authenticate", - Decision: "denied", - Result: "challenge-response failed", - }); err != nil { - log.Printf("[AUDIT] failed: %v", err) - } - type authResponse struct { - Status string `json:"status"` - Error string `json:"error,omitempty"` - } - respPayload, err := codec.Marshal(authResponse{Status: "denied", Error: err.Error()}) - if err != nil { - log.Printf("[AUTH] failed to marshal auth response: %v", err) - } else { - if err := endpoint.Send(core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeAuthResponse, - Source: "eipc-server", - Timestamp: time.Now().UTC(), - RequestID: authMsg.RequestID, - Payload: respPayload, - }); err != nil { - log.Printf("[AUTH] failed to send auth response: %v", err) - } - } - close(authDone) - return - } - - close(authDone) // Auth completed successfully - - log.Printf("[AUTH] ACCEPTED: service=%s token=%s...%s caps=%v", - peer.ServiceID, peer.SessionToken[:8], peer.SessionToken[len(peer.SessionToken)-8:], peer.Capabilities) - - if err := auditLogger.Log(audit.Entry{ - RequestID: authMsg.RequestID, - Source: peer.ServiceID, - Target: "eipc-server", - Action: "authenticate", - Decision: "allowed", - Result: "session created", - }); err != nil { - log.Printf("[AUDIT] failed: %v", err) - } - - // Set peer capabilities on the endpoint for validation - endpoint.SetPeerCapabilities(peer.Capabilities) - - type authResult struct { - Status string `json:"status"` - SessionToken string `json:"session_token"` - Capabilities []string `json:"capabilities"` - } - respPayload, err := codec.Marshal(authResult{ - Status: "ok", - SessionToken: peer.SessionToken, - Capabilities: peer.Capabilities, - }) - if err != nil { - log.Printf("[AUTH] failed to marshal auth result: %v", err) - return - } - if err := endpoint.Send(core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeAuthResponse, - Source: "eipc-server", - Timestamp: time.Now().UTC(), - RequestID: authMsg.RequestID, - Payload: respPayload, - }); err != nil { - log.Printf("[AUTH] failed to send auth response: %v", err) - return - } - - // Message loop with idle timeout and capability enforcement - for { - msg, err := endpoint.Receive() - if err != nil { - log.Printf("[CONN] connection closed: %v", err) - return - } - - // Check session TTL - if peer.IsExpired() { - log.Printf("[SESSION] expired for %s", peer.ServiceID) - if err := auditLogger.Log(audit.Entry{ - Source: peer.ServiceID, - Target: "eipc-server", - Action: "session_check", - Decision: "denied", - Result: "session expired", - }); err != nil { - log.Printf("[AUDIT] failed: %v", err) - } - return - } - - // Enforce capability binding - if err := endpoint.ValidateCapability(msg.Capability); err != nil { - log.Printf("[CAPABILITY] DENIED: %s tried %s", peer.ServiceID, msg.Capability) - if err := auditLogger.Log(audit.Entry{ - RequestID: msg.RequestID, - Source: peer.ServiceID, - Target: "eipc-server", - Action: msg.Capability, - Decision: "denied", - Result: "capability violation", - }); err != nil { - log.Printf("[AUDIT] failed: %v", err) - } - errPayload, err := codec.Marshal(core.AckEvent{ - RequestID: msg.RequestID, - Status: "error", - Error: err.Error(), - }) - if err != nil { - log.Printf("[CAPABILITY] failed to marshal error: %v", err) - } else { - if err := endpoint.Send(core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeAck, - Source: "eipc-server", - Timestamp: time.Now().UTC(), - SessionID: msg.SessionID, - RequestID: msg.RequestID, - Priority: core.PriorityP0, - Payload: errPayload, - }); err != nil { - log.Printf("[CAPABILITY] failed to send error: %v", err) - } - } - continue - } - - resp, err := router.Dispatch(msg) - if err != nil { - log.Printf("[DISPATCH] error: %v", err) - errPayload, err := codec.Marshal(core.AckEvent{ - RequestID: msg.RequestID, - Status: "error", - Error: err.Error(), - }) - if err != nil { - log.Printf("[DISPATCH] failed to marshal error: %v", err) - } else { - if err := endpoint.Send(core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeAck, - Source: "eipc-server", - Timestamp: time.Now().UTC(), - SessionID: msg.SessionID, - RequestID: msg.RequestID, - Priority: core.PriorityP0, - Payload: errPayload, - }); err != nil { - log.Printf("[DISPATCH] failed to send error: %v", err) - } - } - continue - } - - if resp != nil { - if err := endpoint.Send(*resp); err != nil { - log.Printf("[SEND] error: %v", err) - return - } - } - } -} - -// computeChallengeResponse computes HMAC-SHA256(secret, nonce) for client-side auth. +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package main + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "fmt" + "log" + "os" + "os/signal" + "time" + + "github.com/embeddedos-org/eipc/config" + "github.com/embeddedos-org/eipc/core" + "github.com/embeddedos-org/eipc/protocol" + "github.com/embeddedos-org/eipc/security/auth" + "github.com/embeddedos-org/eipc/security/capability" + "github.com/embeddedos-org/eipc/services/audit" + "github.com/embeddedos-org/eipc/services/health" + "github.com/embeddedos-org/eipc/services/registry" + "github.com/embeddedos-org/eipc/transport" + "github.com/embeddedos-org/eipc/transport/tcp" +) + +func main() { + addr := config.LoadListenAddr() + if len(os.Args) > 1 { + addr = os.Args[1] + } + + sharedSecret, err := config.LoadHMACKey() + if err != nil { + log.Fatalf("[CONFIG] %v", err) + } + + sessionTTL := config.LoadSessionTTL() + maxConns := config.LoadMaxConnections() + + authenticator := auth.NewAuthenticator(sharedSecret, map[string][]string{ + "nia.min": {"ui:control", "device:read"}, + "nia.framework": {"ui:control", "device:read", "device:write"}, + "ail.min.agent": {"ui:control"}, + "ail.framework": {"ui:control", "device:read", "device:write", "system:restricted"}, + "ebot.client": {"ai:chat"}, + }) + authenticator.SetSessionTTL(sessionTTL) + + capChecker := capability.NewChecker(map[string][]string{ + "ui:control": {"ui.cursor.move", "ui.click", "ui.scroll"}, + "device:read": {"device.sensor.read", "device.status"}, + "device:write": {"device.actuator.write"}, + "system:restricted": {"system.reboot", "system.update"}, + "ai:chat": {"ai.chat.send", "ai.complete.send"}, + }) + + auditLogger, err := audit.NewFileLogger("") + if err != nil { + log.Fatalf("audit logger: %v", err) + } + defer auditLogger.Close() + + healthSvc := health.NewService(5*time.Second, 15*time.Second) + + reg := registry.NewRegistry() + if err := reg.Register(registry.ServiceInfo{ + ServiceID: "eipc-server", + Capabilities: []string{"ui:control", "device:read", "device:write", "ai:chat"}, + Versions: []uint16{1}, + MessageTypes: []core.MessageType{ + core.TypeIntent, core.TypeAck, core.TypeHeartbeat, core.TypeAudit, + core.TypeChat, core.TypeComplete, core.TypeAuth, core.TypeChallenge, core.TypeAuthResponse, + }, + Priority: core.PriorityP0, + }); err != nil { + log.Printf("[REGISTRY] failed to register eipc-server: %v", err) + } + if err := reg.Register(registry.ServiceInfo{ + ServiceID: "ebot.client", + Capabilities: []string{"ai:chat"}, + Versions: []uint16{1}, + MessageTypes: []core.MessageType{core.TypeChat, core.TypeComplete, core.TypeAck, core.TypeAuth, core.TypeAuthResponse}, + Priority: core.PriorityP1, + }); err != nil { + log.Printf("[REGISTRY] failed to register ebot.client: %v", err) + } + + router := core.NewRouter() + + router.Handle(core.TypeIntent, func(msg core.Message) (*core.Message, error) { + var intent core.IntentEvent + codec := protocol.DefaultCodec() + if err := codec.Unmarshal(msg.Payload, &intent); err != nil { + return nil, fmt.Errorf("unmarshal intent: %w", err) + } + + log.Printf("[INTENT] from=%s intent=%s confidence=%.2f session=%s", + msg.Source, intent.Intent, intent.Confidence, intent.SessionID) + + if err := capChecker.Check([]string{msg.Capability}, "ui.cursor.move"); err != nil { + log.Printf("[POLICY] DENIED: %v", err) + if err := auditLogger.Log(audit.Entry{ + RequestID: msg.RequestID, + Source: msg.Source, + Target: "eipc-server", + Action: intent.Intent, + Decision: "denied", + Result: err.Error(), + }); err != nil { + log.Printf("[AUDIT] failed: %v", err) + } + return nil, err + } + + log.Printf("[POLICY] ALLOWED: capability=%s action=%s", msg.Capability, intent.Intent) + + if err := auditLogger.Log(audit.Entry{ + RequestID: msg.RequestID, + Source: msg.Source, + Target: "eipc-server", + Action: intent.Intent, + Decision: "allowed", + Result: "success", + }); err != nil { + log.Printf("[AUDIT] failed: %v", err) + } + + ackPayload, err := codec.Marshal(core.AckEvent{ + RequestID: msg.RequestID, + Status: "ok", + }) + if err != nil { + return nil, fmt.Errorf("marshal ack: %w", err) + } + + ack := core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeAck, + Source: "eipc-server", + Timestamp: time.Now().UTC(), + SessionID: msg.SessionID, + RequestID: msg.RequestID, + Priority: core.PriorityP0, + Payload: ackPayload, + } + return &ack, nil + }) + + router.Handle(core.TypeHeartbeat, func(msg core.Message) (*core.Message, error) { + var hb core.HeartbeatEvent + codec := protocol.DefaultCodec() + if err := codec.Unmarshal(msg.Payload, &hb); err != nil { + return nil, err + } + healthSvc.RecordHeartbeat(hb.Service, hb.Status) + log.Printf("[HEARTBEAT] service=%s status=%s", hb.Service, hb.Status) + return nil, nil + }) + + router.Handle(core.TypeChat, func(msg core.Message) (*core.Message, error) { + var chatReq core.ChatRequestEvent + codec := protocol.DefaultCodec() + if err := codec.Unmarshal(msg.Payload, &chatReq); err != nil { + return nil, fmt.Errorf("unmarshal chat request: %w", err) + } + + if err := capChecker.Check([]string{msg.Capability}, "ai.chat.send"); err != nil { + log.Printf("[POLICY] DENIED chat: %v", err) + if err := auditLogger.Log(audit.Entry{ + RequestID: msg.RequestID, + Source: msg.Source, + Target: "eipc-server", + Action: "ai.chat.send", + Decision: "denied", + Result: err.Error(), + }); err != nil { + log.Printf("[AUDIT] failed: %v", err) + } + return nil, err + } + + log.Printf("[CHAT] from=%s session=%s prompt=%q", + msg.Source, chatReq.SessionID, chatReq.UserPrompt) + + if err := auditLogger.Log(audit.Entry{ + RequestID: msg.RequestID, + Source: msg.Source, + Target: "eai", + Action: "ai.chat.send", + Decision: "allowed", + Result: "forwarded", + }); err != nil { + log.Printf("[AUDIT] failed: %v", err) + } + + // TODO: Forward to EAI agent loop. For now, echo acknowledgment. + chatResp := core.ChatResponseEvent{ + SessionID: chatReq.SessionID, + Response: fmt.Sprintf("[EIPC] Chat received: %s", chatReq.UserPrompt), + Model: chatReq.Model, + TokensUsed: 0, + } + respPayload, err := codec.Marshal(chatResp) + if err != nil { + return nil, fmt.Errorf("marshal chat response: %w", err) + } + + return &core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeChat, + Source: "eipc-server", + Timestamp: time.Now().UTC(), + SessionID: msg.SessionID, + RequestID: msg.RequestID, + Priority: core.PriorityP1, + Payload: respPayload, + }, nil + }) + + router.Handle(core.TypeComplete, func(msg core.Message) (*core.Message, error) { + var completeReq core.CompleteRequestEvent + codec := protocol.DefaultCodec() + if err := codec.Unmarshal(msg.Payload, &completeReq); err != nil { + return nil, fmt.Errorf("unmarshal complete request: %w", err) + } + + if err := capChecker.Check([]string{msg.Capability}, "ai.complete.send"); err != nil { + log.Printf("[POLICY] DENIED complete: %v", err) + if err := auditLogger.Log(audit.Entry{ + RequestID: msg.RequestID, + Source: msg.Source, + Target: "eipc-server", + Action: "ai.complete.send", + Decision: "denied", + Result: err.Error(), + }); err != nil { + log.Printf("[AUDIT] failed: %v", err) + } + return nil, err + } + + log.Printf("[COMPLETE] from=%s session=%s prompt=%q", + msg.Source, completeReq.SessionID, completeReq.Prompt) + + if err := auditLogger.Log(audit.Entry{ + RequestID: msg.RequestID, + Source: msg.Source, + Target: "eai", + Action: "ai.complete.send", + Decision: "allowed", + Result: "forwarded", + }); err != nil { + log.Printf("[AUDIT] failed: %v", err) + } + + completeResp := core.CompleteResponseEvent{ + SessionID: completeReq.SessionID, + Completion: fmt.Sprintf("[EIPC] Completion received: %s", completeReq.Prompt), + Model: completeReq.Model, + TokensUsed: 0, + } + respPayload, err := codec.Marshal(completeResp) + if err != nil { + return nil, fmt.Errorf("marshal complete response: %w", err) + } + + return &core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeComplete, + Source: "eipc-server", + Timestamp: time.Now().UTC(), + SessionID: msg.SessionID, + RequestID: msg.RequestID, + Priority: core.PriorityP1, + Payload: respPayload, + }, nil + }) + + tcpTransport := tcp.New() + if err := tcpTransport.SetupTLSFromEnv(); err != nil { + log.Fatalf("TLS setup: %v", err) + } + if err := tcpTransport.Listen(addr); err != nil { + log.Fatalf("listen: %v", err) + } + defer tcpTransport.Close() + + tlsMode := "plaintext" + if config.TLSEnabled() { + tlsMode = "TLS" + } + log.Printf("EIPC server listening on %s [%s] (max_conns=%d, session_ttl=%s)", + tcpTransport.Addr(), tlsMode, maxConns, sessionTTL) + + // Connection limit semaphore + connSem := make(chan struct{}, maxConns) + + // Background session cleanup goroutine + go func() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + for range ticker.C { + removed := authenticator.CleanupExpired() + if removed > 0 { + log.Printf("[SESSION] cleaned up %d expired sessions", removed) + } + } + }() + + go func() { + sigCh := make(chan os.Signal, 1) + signal.Notify(sigCh, os.Interrupt) + <-sigCh + log.Println("Shutting down...") + tcpTransport.Close() + os.Exit(0) + }() + + codec := protocol.DefaultCodec() + + for { + conn, err := tcpTransport.Accept() + if err != nil { + log.Printf("accept error: %v", err) + return + } + + select { + case connSem <- struct{}{}: + go func() { + defer func() { <-connSem }() + handleConnection(conn, authenticator, codec, sharedSecret, router, auditLogger, capChecker) + }() + default: + log.Printf("[CONN] rejected connection from %s: max connections (%d) reached", conn.RemoteAddr(), maxConns) + if err := auditLogger.Log(audit.Entry{ + Source: conn.RemoteAddr(), + Target: "eipc-server", + Action: "connect", + Decision: "denied", + Result: "connection limit exceeded", + }); err != nil { + log.Printf("[AUDIT] failed: %v", err) + } + conn.Close() + } + } +} + +func handleConnection( + conn transport.Connection, + authenticator *auth.Authenticator, + codec protocol.Codec, + hmacKey []byte, + router *core.Router, + auditLogger audit.Logger, + capChecker *capability.Checker, +) { + defer conn.Close() + log.Printf("[CONN] new connection from %s", conn.RemoteAddr()) + + endpoint := core.NewServerEndpoint(conn, codec, hmacKey) + + // Auth timeout: 10s + authDone := make(chan struct{}) + go func() { + select { + case <-authDone: + case <-time.After(10 * time.Second): + log.Printf("[AUTH] timeout waiting for auth from %s", conn.RemoteAddr()) + conn.Close() + } + }() + + // Step 1: Receive auth request + authMsg, err := endpoint.Receive() + if err != nil { + log.Printf("[AUTH] failed to receive auth message: %v", err) + close(authDone) + return + } + if authMsg.Type != core.TypeAuth { + log.Printf("[AUTH] expected TypeAuth, got %s", authMsg.Type) + close(authDone) + return + } + + type authRequest struct { + ServiceID string `json:"service_id"` + } + var authReq authRequest + if err := json.Unmarshal(authMsg.Payload, &authReq); err != nil { + log.Printf("[AUTH] bad auth payload: %v", err) + close(authDone) + return + } + + // Step 2: Create challenge (send nonce) + challenge, err := authenticator.CreateChallenge(authReq.ServiceID) + if err != nil { + log.Printf("[AUTH] REJECTED: %v", err) + if err := auditLogger.Log(audit.Entry{ + RequestID: authMsg.RequestID, + Source: authReq.ServiceID, + Target: "eipc-server", + Action: "authenticate", + Decision: "denied", + Result: err.Error(), + }); err != nil { + log.Printf("[AUDIT] failed: %v", err) + } + type authResponse struct { + Status string `json:"status"` + Error string `json:"error,omitempty"` + } + respPayload, err := codec.Marshal(authResponse{Status: "denied", Error: err.Error()}) + if err != nil { + log.Printf("[AUTH] failed to marshal auth response: %v", err) + } else { + if err := endpoint.Send(core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeAuthResponse, + Source: "eipc-server", + Timestamp: time.Now().UTC(), + RequestID: authMsg.RequestID, + Payload: respPayload, + }); err != nil { + log.Printf("[AUTH] failed to send auth response: %v", err) + } + } + close(authDone) + return + } + + type challengeMessage struct { + Status string `json:"status"` + Nonce string `json:"nonce"` + } + challengePayload, err := codec.Marshal(challengeMessage{ + Status: "challenge", + Nonce: hex.EncodeToString(challenge.Nonce), + }) + if err != nil { + log.Printf("[AUTH] failed to marshal challenge: %v", err) + close(authDone) + return + } + if err := endpoint.Send(core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeChallenge, + Source: "eipc-server", + Timestamp: time.Now().UTC(), + RequestID: authMsg.RequestID, + Payload: challengePayload, + }); err != nil { + log.Printf("[AUTH] failed to send challenge: %v", err) + close(authDone) + return + } + + // Step 3: Receive HMAC response + responseMsg, err := endpoint.Receive() + if err != nil { + log.Printf("[AUTH] failed to receive challenge response: %v", err) + close(authDone) + return + } + if responseMsg.Type != core.TypeAuthResponse { + log.Printf("[AUTH] expected TypeAuthResponse, got %s", responseMsg.Type) + close(authDone) + return + } + + type challengeResponse struct { + ServiceID string `json:"service_id"` + Response string `json:"response"` + } + var chalResp challengeResponse + if err := json.Unmarshal(responseMsg.Payload, &chalResp); err != nil { + log.Printf("[AUTH] bad challenge response: %v", err) + close(authDone) + return + } + + responseBytes, err := hex.DecodeString(chalResp.Response) + if err != nil { + log.Printf("[AUTH] bad response encoding: %v", err) + close(authDone) + return + } + + // Step 4: Verify response + peer, err := authenticator.VerifyResponse(authReq.ServiceID, responseBytes) + if err != nil { + log.Printf("[AUTH] REJECTED (challenge-response): %v", err) + if err := auditLogger.Log(audit.Entry{ + RequestID: authMsg.RequestID, + Source: authReq.ServiceID, + Target: "eipc-server", + Action: "authenticate", + Decision: "denied", + Result: "challenge-response failed", + }); err != nil { + log.Printf("[AUDIT] failed: %v", err) + } + type authResponse struct { + Status string `json:"status"` + Error string `json:"error,omitempty"` + } + respPayload, err := codec.Marshal(authResponse{Status: "denied", Error: err.Error()}) + if err != nil { + log.Printf("[AUTH] failed to marshal auth response: %v", err) + } else { + if err := endpoint.Send(core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeAuthResponse, + Source: "eipc-server", + Timestamp: time.Now().UTC(), + RequestID: authMsg.RequestID, + Payload: respPayload, + }); err != nil { + log.Printf("[AUTH] failed to send auth response: %v", err) + } + } + close(authDone) + return + } + + close(authDone) // Auth completed successfully + + log.Printf("[AUTH] ACCEPTED: service=%s token=%s...%s caps=%v", + peer.ServiceID, peer.SessionToken[:8], peer.SessionToken[len(peer.SessionToken)-8:], peer.Capabilities) + + if err := auditLogger.Log(audit.Entry{ + RequestID: authMsg.RequestID, + Source: peer.ServiceID, + Target: "eipc-server", + Action: "authenticate", + Decision: "allowed", + Result: "session created", + }); err != nil { + log.Printf("[AUDIT] failed: %v", err) + } + + // Set peer capabilities on the endpoint for validation + endpoint.SetPeerCapabilities(peer.Capabilities) + + type authResult struct { + Status string `json:"status"` + SessionToken string `json:"session_token"` + Capabilities []string `json:"capabilities"` + } + respPayload, err := codec.Marshal(authResult{ + Status: "ok", + SessionToken: peer.SessionToken, + Capabilities: peer.Capabilities, + }) + if err != nil { + log.Printf("[AUTH] failed to marshal auth result: %v", err) + return + } + if err := endpoint.Send(core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeAuthResponse, + Source: "eipc-server", + Timestamp: time.Now().UTC(), + RequestID: authMsg.RequestID, + Payload: respPayload, + }); err != nil { + log.Printf("[AUTH] failed to send auth response: %v", err) + return + } + + // Message loop with idle timeout and capability enforcement + for { + msg, err := endpoint.Receive() + if err != nil { + log.Printf("[CONN] connection closed: %v", err) + return + } + + // Check session TTL + if peer.IsExpired() { + log.Printf("[SESSION] expired for %s", peer.ServiceID) + if err := auditLogger.Log(audit.Entry{ + Source: peer.ServiceID, + Target: "eipc-server", + Action: "session_check", + Decision: "denied", + Result: "session expired", + }); err != nil { + log.Printf("[AUDIT] failed: %v", err) + } + return + } + + // Enforce capability binding + if err := endpoint.ValidateCapability(msg.Capability); err != nil { + log.Printf("[CAPABILITY] DENIED: %s tried %s", peer.ServiceID, msg.Capability) + if err := auditLogger.Log(audit.Entry{ + RequestID: msg.RequestID, + Source: peer.ServiceID, + Target: "eipc-server", + Action: msg.Capability, + Decision: "denied", + Result: "capability violation", + }); err != nil { + log.Printf("[AUDIT] failed: %v", err) + } + errPayload, err := codec.Marshal(core.AckEvent{ + RequestID: msg.RequestID, + Status: "error", + Error: err.Error(), + }) + if err != nil { + log.Printf("[CAPABILITY] failed to marshal error: %v", err) + } else { + if err := endpoint.Send(core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeAck, + Source: "eipc-server", + Timestamp: time.Now().UTC(), + SessionID: msg.SessionID, + RequestID: msg.RequestID, + Priority: core.PriorityP0, + Payload: errPayload, + }); err != nil { + log.Printf("[CAPABILITY] failed to send error: %v", err) + } + } + continue + } + + resp, err := router.Dispatch(msg) + if err != nil { + log.Printf("[DISPATCH] error: %v", err) + errPayload, err := codec.Marshal(core.AckEvent{ + RequestID: msg.RequestID, + Status: "error", + Error: err.Error(), + }) + if err != nil { + log.Printf("[DISPATCH] failed to marshal error: %v", err) + } else { + if err := endpoint.Send(core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeAck, + Source: "eipc-server", + Timestamp: time.Now().UTC(), + SessionID: msg.SessionID, + RequestID: msg.RequestID, + Priority: core.PriorityP0, + Payload: errPayload, + }); err != nil { + log.Printf("[DISPATCH] failed to send error: %v", err) + } + } + continue + } + + if resp != nil { + if err := endpoint.Send(*resp); err != nil { + log.Printf("[SEND] error: %v", err) + return + } + } + } +} + +// computeChallengeResponse computes HMAC-SHA256(secret, nonce) for client-side auth. diff --git a/config/config.go b/config/config.go index dc6f41e..db27bc3 100644 --- a/config/config.go +++ b/config/config.go @@ -1,67 +1,67 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package config - -import ( - "fmt" - "os" - "path/filepath" - "strconv" - "time" -) - -// LoadHMACKey loads the shared HMAC key from environment or file. -// Priority: EIPC_HMAC_KEY env var → EIPC_KEY_FILE env var → error. -func LoadHMACKey() ([]byte, error) { - if key := os.Getenv("EIPC_HMAC_KEY"); key != "" { - return []byte(key), nil - } - - if path := os.Getenv("EIPC_KEY_FILE"); path != "" { - path = filepath.Clean(path) - data, err := os.ReadFile(path) - if err != nil { - return nil, fmt.Errorf("read key file %q: %w", path, err) - } - return data, nil - } - - return nil, fmt.Errorf("EIPC_HMAC_KEY or EIPC_KEY_FILE must be set") -} - -// LoadSessionTTL reads the session TTL from EIPC_SESSION_TTL env var. -// Format is Go duration string (e.g. "1h", "30m"). Default: 1h. -func LoadSessionTTL() time.Duration { - if s := os.Getenv("EIPC_SESSION_TTL"); s != "" { - if d, err := time.ParseDuration(s); err == nil { - return d - } - } - return 1 * time.Hour -} - -// LoadMaxConnections reads max connections from EIPC_MAX_CONNECTIONS env var. -// Default: 64. -func LoadMaxConnections() int { - if s := os.Getenv("EIPC_MAX_CONNECTIONS"); s != "" { - if n, err := strconv.Atoi(s); err == nil && n > 0 { - return n - } - } - return 64 -} - -// LoadListenAddr reads the listen address from EIPC_LISTEN_ADDR env var. -// Default: 127.0.0.1:9090. -func LoadListenAddr() string { - if addr := os.Getenv("EIPC_LISTEN_ADDR"); addr != "" { - return addr - } - return "127.0.0.1:9090" -} - -// TLSEnabled returns true if TLS cert files are configured or auto-cert is enabled. -func TLSEnabled() bool { - return os.Getenv("EIPC_TLS_CERT") != "" || os.Getenv("EIPC_TLS_AUTO_CERT") == "true" -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package config + +import ( + "fmt" + "os" + "path/filepath" + "strconv" + "time" +) + +// LoadHMACKey loads the shared HMAC key from environment or file. +// Priority: EIPC_HMAC_KEY env var → EIPC_KEY_FILE env var → error. +func LoadHMACKey() ([]byte, error) { + if key := os.Getenv("EIPC_HMAC_KEY"); key != "" { + return []byte(key), nil + } + + if path := os.Getenv("EIPC_KEY_FILE"); path != "" { + path = filepath.Clean(path) + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("read key file %q: %w", path, err) + } + return data, nil + } + + return nil, fmt.Errorf("EIPC_HMAC_KEY or EIPC_KEY_FILE must be set") +} + +// LoadSessionTTL reads the session TTL from EIPC_SESSION_TTL env var. +// Format is Go duration string (e.g. "1h", "30m"). Default: 1h. +func LoadSessionTTL() time.Duration { + if s := os.Getenv("EIPC_SESSION_TTL"); s != "" { + if d, err := time.ParseDuration(s); err == nil { + return d + } + } + return 1 * time.Hour +} + +// LoadMaxConnections reads max connections from EIPC_MAX_CONNECTIONS env var. +// Default: 64. +func LoadMaxConnections() int { + if s := os.Getenv("EIPC_MAX_CONNECTIONS"); s != "" { + if n, err := strconv.Atoi(s); err == nil && n > 0 { + return n + } + } + return 64 +} + +// LoadListenAddr reads the listen address from EIPC_LISTEN_ADDR env var. +// Default: 127.0.0.1:9090. +func LoadListenAddr() string { + if addr := os.Getenv("EIPC_LISTEN_ADDR"); addr != "" { + return addr + } + return "127.0.0.1:9090" +} + +// TLSEnabled returns true if TLS cert files are configured or auto-cert is enabled. +func TLSEnabled() bool { + return os.Getenv("EIPC_TLS_CERT") != "" || os.Getenv("EIPC_TLS_AUTO_CERT") == "true" +} diff --git a/config/config_test.go b/config/config_test.go index 1890f1b..d89fc03 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -1,162 +1,162 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package config - -import ( - "os" - "path/filepath" - "testing" - "time" -) - -func TestLoadHMACKey_EnvVar(t *testing.T) { - os.Setenv("EIPC_HMAC_KEY", "test-secret-key") - defer os.Unsetenv("EIPC_HMAC_KEY") - os.Unsetenv("EIPC_KEY_FILE") - - key, err := LoadHMACKey() - if err != nil { - t.Fatalf("LoadHMACKey: %v", err) - } - if string(key) != "test-secret-key" { - t.Errorf("expected 'test-secret-key', got %q", string(key)) - } -} - -func TestLoadHMACKey_File(t *testing.T) { - os.Unsetenv("EIPC_HMAC_KEY") - - tmpFile := filepath.Join(t.TempDir(), "hmac.key") - _ = os.WriteFile(tmpFile, []byte("file-based-key"), 0600) - - os.Setenv("EIPC_KEY_FILE", tmpFile) - defer os.Unsetenv("EIPC_KEY_FILE") - - key, err := LoadHMACKey() - if err != nil { - t.Fatalf("LoadHMACKey from file: %v", err) - } - if string(key) != "file-based-key" { - t.Errorf("expected 'file-based-key', got %q", string(key)) - } -} - -func TestLoadHMACKey_Missing(t *testing.T) { - os.Unsetenv("EIPC_HMAC_KEY") - os.Unsetenv("EIPC_KEY_FILE") - - _, err := LoadHMACKey() - if err == nil { - t.Fatal("expected error when no key source configured") - } -} - -func TestLoadHMACKey_BadFile(t *testing.T) { - os.Unsetenv("EIPC_HMAC_KEY") - os.Setenv("EIPC_KEY_FILE", "/nonexistent/path/key.dat") - defer os.Unsetenv("EIPC_KEY_FILE") - - _, err := LoadHMACKey() - if err == nil { - t.Fatal("expected error for nonexistent key file") - } -} - -func TestLoadSessionTTL_Default(t *testing.T) { - os.Unsetenv("EIPC_SESSION_TTL") - ttl := LoadSessionTTL() - if ttl != 1*time.Hour { - t.Errorf("expected default 1h, got %v", ttl) - } -} - -func TestLoadSessionTTL_Custom(t *testing.T) { - os.Setenv("EIPC_SESSION_TTL", "30m") - defer os.Unsetenv("EIPC_SESSION_TTL") - - ttl := LoadSessionTTL() - if ttl != 30*time.Minute { - t.Errorf("expected 30m, got %v", ttl) - } -} - -func TestLoadSessionTTL_Invalid(t *testing.T) { - os.Setenv("EIPC_SESSION_TTL", "not-a-duration") - defer os.Unsetenv("EIPC_SESSION_TTL") - - ttl := LoadSessionTTL() - if ttl != 1*time.Hour { - t.Errorf("expected default 1h for invalid input, got %v", ttl) - } -} - -func TestLoadMaxConnections_Default(t *testing.T) { - os.Unsetenv("EIPC_MAX_CONNECTIONS") - max := LoadMaxConnections() - if max != 64 { - t.Errorf("expected default 64, got %d", max) - } -} - -func TestLoadMaxConnections_Custom(t *testing.T) { - os.Setenv("EIPC_MAX_CONNECTIONS", "128") - defer os.Unsetenv("EIPC_MAX_CONNECTIONS") - - max := LoadMaxConnections() - if max != 128 { - t.Errorf("expected 128, got %d", max) - } -} - -func TestLoadMaxConnections_Invalid(t *testing.T) { - os.Setenv("EIPC_MAX_CONNECTIONS", "abc") - defer os.Unsetenv("EIPC_MAX_CONNECTIONS") - - max := LoadMaxConnections() - if max != 64 { - t.Errorf("expected default 64 for invalid input, got %d", max) - } -} - -func TestLoadListenAddr_Default(t *testing.T) { - os.Unsetenv("EIPC_LISTEN_ADDR") - addr := LoadListenAddr() - if addr != "127.0.0.1:9090" { - t.Errorf("expected default '127.0.0.1:9090', got %q", addr) - } -} - -func TestLoadListenAddr_Custom(t *testing.T) { - os.Setenv("EIPC_LISTEN_ADDR", "0.0.0.0:8080") - defer os.Unsetenv("EIPC_LISTEN_ADDR") - - addr := LoadListenAddr() - if addr != "0.0.0.0:8080" { - t.Errorf("expected '0.0.0.0:8080', got %q", addr) - } -} - -func TestTLSEnabled(t *testing.T) { - os.Unsetenv("EIPC_TLS_CERT") - os.Unsetenv("EIPC_TLS_AUTO_CERT") - if TLSEnabled() { - t.Error("TLS should be disabled by default") - } - - os.Setenv("EIPC_TLS_CERT", "/path/to/cert.pem") - defer os.Unsetenv("EIPC_TLS_CERT") - if !TLSEnabled() { - t.Error("TLS should be enabled with EIPC_TLS_CERT set") - } -} - -func TestTLSEnabled_AutoCert(t *testing.T) { - os.Unsetenv("EIPC_TLS_CERT") - os.Setenv("EIPC_TLS_AUTO_CERT", "true") - defer os.Unsetenv("EIPC_TLS_AUTO_CERT") - - if !TLSEnabled() { - t.Error("TLS should be enabled with EIPC_TLS_AUTO_CERT=true") - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package config + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestLoadHMACKey_EnvVar(t *testing.T) { + os.Setenv("EIPC_HMAC_KEY", "test-secret-key") + defer os.Unsetenv("EIPC_HMAC_KEY") + os.Unsetenv("EIPC_KEY_FILE") + + key, err := LoadHMACKey() + if err != nil { + t.Fatalf("LoadHMACKey: %v", err) + } + if string(key) != "test-secret-key" { + t.Errorf("expected 'test-secret-key', got %q", string(key)) + } +} + +func TestLoadHMACKey_File(t *testing.T) { + os.Unsetenv("EIPC_HMAC_KEY") + + tmpFile := filepath.Join(t.TempDir(), "hmac.key") + _ = os.WriteFile(tmpFile, []byte("file-based-key"), 0600) + + os.Setenv("EIPC_KEY_FILE", tmpFile) + defer os.Unsetenv("EIPC_KEY_FILE") + + key, err := LoadHMACKey() + if err != nil { + t.Fatalf("LoadHMACKey from file: %v", err) + } + if string(key) != "file-based-key" { + t.Errorf("expected 'file-based-key', got %q", string(key)) + } +} + +func TestLoadHMACKey_Missing(t *testing.T) { + os.Unsetenv("EIPC_HMAC_KEY") + os.Unsetenv("EIPC_KEY_FILE") + + _, err := LoadHMACKey() + if err == nil { + t.Fatal("expected error when no key source configured") + } +} + +func TestLoadHMACKey_BadFile(t *testing.T) { + os.Unsetenv("EIPC_HMAC_KEY") + os.Setenv("EIPC_KEY_FILE", "/nonexistent/path/key.dat") + defer os.Unsetenv("EIPC_KEY_FILE") + + _, err := LoadHMACKey() + if err == nil { + t.Fatal("expected error for nonexistent key file") + } +} + +func TestLoadSessionTTL_Default(t *testing.T) { + os.Unsetenv("EIPC_SESSION_TTL") + ttl := LoadSessionTTL() + if ttl != 1*time.Hour { + t.Errorf("expected default 1h, got %v", ttl) + } +} + +func TestLoadSessionTTL_Custom(t *testing.T) { + os.Setenv("EIPC_SESSION_TTL", "30m") + defer os.Unsetenv("EIPC_SESSION_TTL") + + ttl := LoadSessionTTL() + if ttl != 30*time.Minute { + t.Errorf("expected 30m, got %v", ttl) + } +} + +func TestLoadSessionTTL_Invalid(t *testing.T) { + os.Setenv("EIPC_SESSION_TTL", "not-a-duration") + defer os.Unsetenv("EIPC_SESSION_TTL") + + ttl := LoadSessionTTL() + if ttl != 1*time.Hour { + t.Errorf("expected default 1h for invalid input, got %v", ttl) + } +} + +func TestLoadMaxConnections_Default(t *testing.T) { + os.Unsetenv("EIPC_MAX_CONNECTIONS") + max := LoadMaxConnections() + if max != 64 { + t.Errorf("expected default 64, got %d", max) + } +} + +func TestLoadMaxConnections_Custom(t *testing.T) { + os.Setenv("EIPC_MAX_CONNECTIONS", "128") + defer os.Unsetenv("EIPC_MAX_CONNECTIONS") + + max := LoadMaxConnections() + if max != 128 { + t.Errorf("expected 128, got %d", max) + } +} + +func TestLoadMaxConnections_Invalid(t *testing.T) { + os.Setenv("EIPC_MAX_CONNECTIONS", "abc") + defer os.Unsetenv("EIPC_MAX_CONNECTIONS") + + max := LoadMaxConnections() + if max != 64 { + t.Errorf("expected default 64 for invalid input, got %d", max) + } +} + +func TestLoadListenAddr_Default(t *testing.T) { + os.Unsetenv("EIPC_LISTEN_ADDR") + addr := LoadListenAddr() + if addr != "127.0.0.1:9090" { + t.Errorf("expected default '127.0.0.1:9090', got %q", addr) + } +} + +func TestLoadListenAddr_Custom(t *testing.T) { + os.Setenv("EIPC_LISTEN_ADDR", "0.0.0.0:8080") + defer os.Unsetenv("EIPC_LISTEN_ADDR") + + addr := LoadListenAddr() + if addr != "0.0.0.0:8080" { + t.Errorf("expected '0.0.0.0:8080', got %q", addr) + } +} + +func TestTLSEnabled(t *testing.T) { + os.Unsetenv("EIPC_TLS_CERT") + os.Unsetenv("EIPC_TLS_AUTO_CERT") + if TLSEnabled() { + t.Error("TLS should be disabled by default") + } + + os.Setenv("EIPC_TLS_CERT", "/path/to/cert.pem") + defer os.Unsetenv("EIPC_TLS_CERT") + if !TLSEnabled() { + t.Error("TLS should be enabled with EIPC_TLS_CERT set") + } +} + +func TestTLSEnabled_AutoCert(t *testing.T) { + os.Unsetenv("EIPC_TLS_CERT") + os.Setenv("EIPC_TLS_AUTO_CERT", "true") + defer os.Unsetenv("EIPC_TLS_AUTO_CERT") + + if !TLSEnabled() { + t.Error("TLS should be enabled with EIPC_TLS_AUTO_CERT=true") + } +} diff --git a/core/benchmark_test.go b/core/benchmark_test.go index 23b01e7..a6f9adf 100644 --- a/core/benchmark_test.go +++ b/core/benchmark_test.go @@ -1,68 +1,68 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package core - -import ( - "testing" - "time" -) - -func BenchmarkNewMessage(b *testing.B) { - payload := []byte(`{"intent":"move_left","confidence":0.91}`) - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = NewMessage(TypeIntent, "eni.min", payload) - } -} - -func BenchmarkRouterDispatch(b *testing.B) { - router := NewRouter() - router.Handle(TypeIntent, func(msg Message) (*Message, error) { - return nil, nil - }) - - msg := Message{ - Version: ProtocolVersion, - Type: TypeIntent, - Source: "bench", - Timestamp: time.Now().UTC(), - Priority: PriorityP0, - Payload: []byte(`{"intent":"test"}`), - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - _ = router.Dispatch(msg) - } -} - -func BenchmarkRouterDispatchBatch(b *testing.B) { - router := NewRouter() - router.Handle(TypeIntent, func(msg Message) (*Message, error) { - return nil, nil - }) - router.Handle(TypeHeartbeat, func(msg Message) (*Message, error) { - return nil, nil - }) - - msgs := []Message{ - {Type: TypeIntent, Priority: PriorityP2, Payload: []byte(`{}`)}, - {Type: TypeHeartbeat, Priority: PriorityP0, Payload: []byte(`{}`)}, - {Type: TypeIntent, Priority: PriorityP1, Payload: []byte(`{}`)}, - {Type: TypeHeartbeat, Priority: PriorityP3, Payload: []byte(`{}`)}, - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - router.DispatchBatch(msgs) - } -} - -func BenchmarkMsgTypeToByte(b *testing.B) { - for i := 0; i < b.N; i++ { - MsgTypeToByte(TypeIntent) - MsgTypeToByte(TypeChat) - MsgTypeToByte(TypeAck) - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package core + +import ( + "testing" + "time" +) + +func BenchmarkNewMessage(b *testing.B) { + payload := []byte(`{"intent":"move_left","confidence":0.91}`) + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = NewMessage(TypeIntent, "eni.min", payload) + } +} + +func BenchmarkRouterDispatch(b *testing.B) { + router := NewRouter() + router.Handle(TypeIntent, func(msg Message) (*Message, error) { + return nil, nil + }) + + msg := Message{ + Version: ProtocolVersion, + Type: TypeIntent, + Source: "bench", + Timestamp: time.Now().UTC(), + Priority: PriorityP0, + Payload: []byte(`{"intent":"test"}`), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + _ = router.Dispatch(msg) + } +} + +func BenchmarkRouterDispatchBatch(b *testing.B) { + router := NewRouter() + router.Handle(TypeIntent, func(msg Message) (*Message, error) { + return nil, nil + }) + router.Handle(TypeHeartbeat, func(msg Message) (*Message, error) { + return nil, nil + }) + + msgs := []Message{ + {Type: TypeIntent, Priority: PriorityP2, Payload: []byte(`{}`)}, + {Type: TypeHeartbeat, Priority: PriorityP0, Payload: []byte(`{}`)}, + {Type: TypeIntent, Priority: PriorityP1, Payload: []byte(`{}`)}, + {Type: TypeHeartbeat, Priority: PriorityP3, Payload: []byte(`{}`)}, + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + router.DispatchBatch(msgs) + } +} + +func BenchmarkMsgTypeToByte(b *testing.B) { + for i := 0; i < b.N; i++ { + MsgTypeToByte(TypeIntent) + MsgTypeToByte(TypeChat) + MsgTypeToByte(TypeAck) + } +} diff --git a/core/endpoint.go b/core/endpoint.go index e3edaa1..b219e82 100644 --- a/core/endpoint.go +++ b/core/endpoint.go @@ -1,306 +1,306 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package core - -import ( - "fmt" - "sync" - "sync/atomic" - "time" - - "github.com/embeddedos-org/eipc/protocol" - "github.com/embeddedos-org/eipc/security/integrity" - "github.com/embeddedos-org/eipc/security/replay" - "github.com/embeddedos-org/eipc/transport" -) - -// Endpoint is the Go API surface for sending/receiving EIPC messages. -type Endpoint interface { - Send(msg Message) error - Receive() (Message, error) - Close() error -} - -// ClientEndpoint connects to an EIPC server and exchanges messages. -type ClientEndpoint struct { - conn transport.Connection - codec protocol.Codec - hmacKey []byte - sessionID string - sequence atomic.Uint64 -} - -// NewClientEndpoint creates a client endpoint over an existing connection. -func NewClientEndpoint(conn transport.Connection, codec protocol.Codec, hmacKey []byte, sessionID string) *ClientEndpoint { - return &ClientEndpoint{ - conn: conn, - codec: codec, - hmacKey: hmacKey, - sessionID: sessionID, - } -} - -func (e *ClientEndpoint) Send(msg Message) error { - seq := e.sequence.Add(1) - - hdr := protocol.Header{ - ServiceID: msg.Source, - SessionID: msg.SessionID, - RequestID: msg.RequestID, - Sequence: seq, - Timestamp: msg.Timestamp.Format(time.RFC3339Nano), - Priority: uint8(msg.Priority), - Capability: msg.Capability, - PayloadFormat: uint8(PayloadJSON), - } - hdrBytes, err := e.codec.Marshal(hdr) - if err != nil { - return fmt.Errorf("marshal header: %w", err) - } - - frame := &protocol.Frame{ - Version: ProtocolVersion, - MsgType: MsgTypeToByte(msg.Type), - Flags: protocol.FlagHMAC, - Header: hdrBytes, - Payload: msg.Payload, - } - - frame.MAC = integrity.Sign(e.hmacKey, frame.SignableBytes()) - - return e.conn.Send(frame) -} - -func (e *ClientEndpoint) Receive() (Message, error) { - frame, err := e.conn.Receive() - if err != nil { - return Message{}, err - } - - if frame.Flags&protocol.FlagHMAC != 0 { - if !integrity.Verify(e.hmacKey, frame.SignableBytes(), frame.MAC) { - return Message{}, ErrIntegrity - } - } - - var hdr protocol.Header - if err := e.codec.Unmarshal(frame.Header, &hdr); err != nil { - return Message{}, fmt.Errorf("unmarshal header: %w", err) - } - - var ts time.Time - if t, err := time.Parse(time.RFC3339Nano, hdr.Timestamp); err == nil { - ts = t - } - - return Message{ - Version: frame.Version, - Type: msgTypeFromByte(frame.MsgType), - Source: hdr.ServiceID, - Timestamp: ts, - SessionID: hdr.SessionID, - RequestID: hdr.RequestID, - Priority: Priority(hdr.Priority), - Capability: hdr.Capability, - Payload: frame.Payload, - }, nil -} - -func (e *ClientEndpoint) Close() error { - return e.conn.Close() -} - -// ServerEndpoint handles a single server-side connection. -type ServerEndpoint struct { - conn transport.Connection - codec protocol.Codec - hmacKey []byte - replay *replay.Tracker - sendMu sync.Mutex - seq atomic.Uint64 - peerCapabilities []string -} - -// SetPeerCapabilities sets the authenticated peer's capability list. -func (e *ServerEndpoint) SetPeerCapabilities(caps []string) { - e.sendMu.Lock() - defer e.sendMu.Unlock() - cpy := make([]string, len(caps)) - copy(cpy, caps) - e.peerCapabilities = cpy -} - -// ValidateCapability checks if msgCap is in the peer's granted capabilities. -func (e *ServerEndpoint) ValidateCapability(msgCap string) error { - if msgCap == "" { - return nil - } - e.sendMu.Lock() - caps := e.peerCapabilities - e.sendMu.Unlock() - for _, c := range caps { - if c == msgCap { - return nil - } - } - return fmt.Errorf("%w: peer lacks capability %q", ErrCapability, msgCap) -} - -// NewServerEndpoint wraps a server-side connection. -func NewServerEndpoint(conn transport.Connection, codec protocol.Codec, hmacKey []byte) *ServerEndpoint { - return &ServerEndpoint{ - conn: conn, - codec: codec, - hmacKey: hmacKey, - replay: replay.NewTracker(0), - } -} - -func (e *ServerEndpoint) Send(msg Message) error { - e.sendMu.Lock() - defer e.sendMu.Unlock() - - seq := e.seq.Add(1) - - hdr := protocol.Header{ - ServiceID: msg.Source, - SessionID: msg.SessionID, - RequestID: msg.RequestID, - Sequence: seq, - Timestamp: msg.Timestamp.Format(time.RFC3339Nano), - Priority: uint8(msg.Priority), - Capability: msg.Capability, - PayloadFormat: uint8(PayloadJSON), - } - hdrBytes, err := e.codec.Marshal(hdr) - if err != nil { - return fmt.Errorf("marshal header: %w", err) - } - - frame := &protocol.Frame{ - Version: ProtocolVersion, - MsgType: MsgTypeToByte(msg.Type), - Flags: protocol.FlagHMAC, - Header: hdrBytes, - Payload: msg.Payload, - } - - frame.MAC = integrity.Sign(e.hmacKey, frame.SignableBytes()) - - return e.conn.Send(frame) -} - -func (e *ServerEndpoint) Receive() (Message, error) { - frame, err := e.conn.Receive() - if err != nil { - return Message{}, err - } - - if frame.Flags&protocol.FlagHMAC != 0 { - if !integrity.Verify(e.hmacKey, frame.SignableBytes(), frame.MAC) { - return Message{}, ErrIntegrity - } - } - - var hdr protocol.Header - if err := e.codec.Unmarshal(frame.Header, &hdr); err != nil { - return Message{}, fmt.Errorf("unmarshal header: %w", err) - } - - if err := e.replay.Check(hdr.Sequence); err != nil { - return Message{}, err - } - - var ts time.Time - if t, err := time.Parse(time.RFC3339Nano, hdr.Timestamp); err == nil { - ts = t - } - - return Message{ - Version: frame.Version, - Type: msgTypeFromByte(frame.MsgType), - Source: hdr.ServiceID, - Timestamp: ts, - SessionID: hdr.SessionID, - RequestID: hdr.RequestID, - Priority: Priority(hdr.Priority), - Capability: hdr.Capability, - Payload: frame.Payload, - }, nil -} - -func (e *ServerEndpoint) Close() error { - return e.conn.Close() -} - -// RemoteAddr returns the remote address of the underlying connection. -func (e *ServerEndpoint) RemoteAddr() string { - return e.conn.RemoteAddr() -} - -func msgTypeFromByte(b uint8) MessageType { - switch b { - case 'i': - return TypeIntent - case 'f': - return TypeFeatures - case 't': - return TypeToolRequest - case 'a': - return TypeAck - case 'A': - return TypeAuth - case 'H': - return TypeChallenge - case 'R': - return TypeAuthResponse - case 'p': - return TypePolicyResult - case 'h': - return TypeHeartbeat - case 'u': - return TypeAudit - case 'c': - return TypeChat - case 'C': - return TypeComplete - default: - return MessageType(string(rune(b))) - } -} - -// MsgTypeToByte converts a MessageType to its wire byte representation. -func MsgTypeToByte(mt MessageType) uint8 { - switch mt { - case TypeIntent: - return 'i' - case TypeFeatures: - return 'f' - case TypeToolRequest: - return 't' - case TypeAck: - return 'a' - case TypeAuth: - return 'A' - case TypeChallenge: - return 'H' - case TypeAuthResponse: - return 'R' - case TypePolicyResult: - return 'p' - case TypeHeartbeat: - return 'h' - case TypeAudit: - return 'u' - case TypeChat: - return 'c' - case TypeComplete: - return 'C' - default: - if len(mt) > 0 { - return mt[0] - } - return 0 - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package core + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/embeddedos-org/eipc/protocol" + "github.com/embeddedos-org/eipc/security/integrity" + "github.com/embeddedos-org/eipc/security/replay" + "github.com/embeddedos-org/eipc/transport" +) + +// Endpoint is the Go API surface for sending/receiving EIPC messages. +type Endpoint interface { + Send(msg Message) error + Receive() (Message, error) + Close() error +} + +// ClientEndpoint connects to an EIPC server and exchanges messages. +type ClientEndpoint struct { + conn transport.Connection + codec protocol.Codec + hmacKey []byte + sessionID string + sequence atomic.Uint64 +} + +// NewClientEndpoint creates a client endpoint over an existing connection. +func NewClientEndpoint(conn transport.Connection, codec protocol.Codec, hmacKey []byte, sessionID string) *ClientEndpoint { + return &ClientEndpoint{ + conn: conn, + codec: codec, + hmacKey: hmacKey, + sessionID: sessionID, + } +} + +func (e *ClientEndpoint) Send(msg Message) error { + seq := e.sequence.Add(1) + + hdr := protocol.Header{ + ServiceID: msg.Source, + SessionID: msg.SessionID, + RequestID: msg.RequestID, + Sequence: seq, + Timestamp: msg.Timestamp.Format(time.RFC3339Nano), + Priority: uint8(msg.Priority), + Capability: msg.Capability, + PayloadFormat: uint8(PayloadJSON), + } + hdrBytes, err := e.codec.Marshal(hdr) + if err != nil { + return fmt.Errorf("marshal header: %w", err) + } + + frame := &protocol.Frame{ + Version: ProtocolVersion, + MsgType: MsgTypeToByte(msg.Type), + Flags: protocol.FlagHMAC, + Header: hdrBytes, + Payload: msg.Payload, + } + + frame.MAC = integrity.Sign(e.hmacKey, frame.SignableBytes()) + + return e.conn.Send(frame) +} + +func (e *ClientEndpoint) Receive() (Message, error) { + frame, err := e.conn.Receive() + if err != nil { + return Message{}, err + } + + if frame.Flags&protocol.FlagHMAC != 0 { + if !integrity.Verify(e.hmacKey, frame.SignableBytes(), frame.MAC) { + return Message{}, ErrIntegrity + } + } + + var hdr protocol.Header + if err := e.codec.Unmarshal(frame.Header, &hdr); err != nil { + return Message{}, fmt.Errorf("unmarshal header: %w", err) + } + + var ts time.Time + if t, err := time.Parse(time.RFC3339Nano, hdr.Timestamp); err == nil { + ts = t + } + + return Message{ + Version: frame.Version, + Type: msgTypeFromByte(frame.MsgType), + Source: hdr.ServiceID, + Timestamp: ts, + SessionID: hdr.SessionID, + RequestID: hdr.RequestID, + Priority: Priority(hdr.Priority), + Capability: hdr.Capability, + Payload: frame.Payload, + }, nil +} + +func (e *ClientEndpoint) Close() error { + return e.conn.Close() +} + +// ServerEndpoint handles a single server-side connection. +type ServerEndpoint struct { + conn transport.Connection + codec protocol.Codec + hmacKey []byte + replay *replay.Tracker + sendMu sync.Mutex + seq atomic.Uint64 + peerCapabilities []string +} + +// SetPeerCapabilities sets the authenticated peer's capability list. +func (e *ServerEndpoint) SetPeerCapabilities(caps []string) { + e.sendMu.Lock() + defer e.sendMu.Unlock() + cpy := make([]string, len(caps)) + copy(cpy, caps) + e.peerCapabilities = cpy +} + +// ValidateCapability checks if msgCap is in the peer's granted capabilities. +func (e *ServerEndpoint) ValidateCapability(msgCap string) error { + if msgCap == "" { + return nil + } + e.sendMu.Lock() + caps := e.peerCapabilities + e.sendMu.Unlock() + for _, c := range caps { + if c == msgCap { + return nil + } + } + return fmt.Errorf("%w: peer lacks capability %q", ErrCapability, msgCap) +} + +// NewServerEndpoint wraps a server-side connection. +func NewServerEndpoint(conn transport.Connection, codec protocol.Codec, hmacKey []byte) *ServerEndpoint { + return &ServerEndpoint{ + conn: conn, + codec: codec, + hmacKey: hmacKey, + replay: replay.NewTracker(0), + } +} + +func (e *ServerEndpoint) Send(msg Message) error { + e.sendMu.Lock() + defer e.sendMu.Unlock() + + seq := e.seq.Add(1) + + hdr := protocol.Header{ + ServiceID: msg.Source, + SessionID: msg.SessionID, + RequestID: msg.RequestID, + Sequence: seq, + Timestamp: msg.Timestamp.Format(time.RFC3339Nano), + Priority: uint8(msg.Priority), + Capability: msg.Capability, + PayloadFormat: uint8(PayloadJSON), + } + hdrBytes, err := e.codec.Marshal(hdr) + if err != nil { + return fmt.Errorf("marshal header: %w", err) + } + + frame := &protocol.Frame{ + Version: ProtocolVersion, + MsgType: MsgTypeToByte(msg.Type), + Flags: protocol.FlagHMAC, + Header: hdrBytes, + Payload: msg.Payload, + } + + frame.MAC = integrity.Sign(e.hmacKey, frame.SignableBytes()) + + return e.conn.Send(frame) +} + +func (e *ServerEndpoint) Receive() (Message, error) { + frame, err := e.conn.Receive() + if err != nil { + return Message{}, err + } + + if frame.Flags&protocol.FlagHMAC != 0 { + if !integrity.Verify(e.hmacKey, frame.SignableBytes(), frame.MAC) { + return Message{}, ErrIntegrity + } + } + + var hdr protocol.Header + if err := e.codec.Unmarshal(frame.Header, &hdr); err != nil { + return Message{}, fmt.Errorf("unmarshal header: %w", err) + } + + if err := e.replay.Check(hdr.Sequence); err != nil { + return Message{}, err + } + + var ts time.Time + if t, err := time.Parse(time.RFC3339Nano, hdr.Timestamp); err == nil { + ts = t + } + + return Message{ + Version: frame.Version, + Type: msgTypeFromByte(frame.MsgType), + Source: hdr.ServiceID, + Timestamp: ts, + SessionID: hdr.SessionID, + RequestID: hdr.RequestID, + Priority: Priority(hdr.Priority), + Capability: hdr.Capability, + Payload: frame.Payload, + }, nil +} + +func (e *ServerEndpoint) Close() error { + return e.conn.Close() +} + +// RemoteAddr returns the remote address of the underlying connection. +func (e *ServerEndpoint) RemoteAddr() string { + return e.conn.RemoteAddr() +} + +func msgTypeFromByte(b uint8) MessageType { + switch b { + case 'i': + return TypeIntent + case 'f': + return TypeFeatures + case 't': + return TypeToolRequest + case 'a': + return TypeAck + case 'A': + return TypeAuth + case 'H': + return TypeChallenge + case 'R': + return TypeAuthResponse + case 'p': + return TypePolicyResult + case 'h': + return TypeHeartbeat + case 'u': + return TypeAudit + case 'c': + return TypeChat + case 'C': + return TypeComplete + default: + return MessageType(string(rune(b))) + } +} + +// MsgTypeToByte converts a MessageType to its wire byte representation. +func MsgTypeToByte(mt MessageType) uint8 { + switch mt { + case TypeIntent: + return 'i' + case TypeFeatures: + return 'f' + case TypeToolRequest: + return 't' + case TypeAck: + return 'a' + case TypeAuth: + return 'A' + case TypeChallenge: + return 'H' + case TypeAuthResponse: + return 'R' + case TypePolicyResult: + return 'p' + case TypeHeartbeat: + return 'h' + case TypeAudit: + return 'u' + case TypeChat: + return 'c' + case TypeComplete: + return 'C' + default: + if len(mt) > 0 { + return mt[0] + } + return 0 + } +} diff --git a/core/lifecycle.go b/core/lifecycle.go index 68a7213..4fa6dcd 100644 --- a/core/lifecycle.go +++ b/core/lifecycle.go @@ -1,121 +1,121 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package core - -import ( - "fmt" - "math" - "sync" - "time" -) - -// ReconnectPolicy configures automatic reconnection behavior. -type ReconnectPolicy struct { - MaxRetries int - InitialBackoff time.Duration - MaxBackoff time.Duration - BackoffFactor float64 -} - -// DefaultReconnectPolicy returns a sensible default reconnect policy. -func DefaultReconnectPolicy() ReconnectPolicy { - return ReconnectPolicy{ - MaxRetries: 10, - InitialBackoff: 100 * time.Millisecond, - MaxBackoff: 30 * time.Second, - BackoffFactor: 2.0, - } -} - -// Backoff calculates the backoff duration for the given attempt number. -func (p ReconnectPolicy) Backoff(attempt int) time.Duration { - if attempt <= 0 { - return p.InitialBackoff - } - backoff := float64(p.InitialBackoff) * math.Pow(p.BackoffFactor, float64(attempt)) - if backoff > float64(p.MaxBackoff) { - backoff = float64(p.MaxBackoff) - } - return time.Duration(backoff) -} - -// HeartbeatConfig configures periodic heartbeat sending. -type HeartbeatConfig struct { - Interval time.Duration - ServiceID string -} - -// HeartbeatSender sends periodic heartbeat messages over an endpoint. -type HeartbeatSender struct { - endpoint Endpoint - config HeartbeatConfig - stopCh chan struct{} - stopped bool - mu sync.Mutex -} - -// NewHeartbeatSender creates a heartbeat sender for the given endpoint. -func NewHeartbeatSender(endpoint Endpoint, config HeartbeatConfig) *HeartbeatSender { - if config.Interval == 0 { - config.Interval = 5 * time.Second - } - return &HeartbeatSender{ - endpoint: endpoint, - config: config, - stopCh: make(chan struct{}), - } -} - -// Start begins sending heartbeats in a background goroutine. -func (h *HeartbeatSender) Start() { - go func() { - ticker := time.NewTicker(h.config.Interval) - defer ticker.Stop() - - for { - select { - case <-h.stopCh: - return - case <-ticker.C: - msg := Message{ - Version: ProtocolVersion, - Type: TypeHeartbeat, - Source: h.config.ServiceID, - Timestamp: time.Now().UTC(), - Priority: PriorityP3, - Payload: []byte(fmt.Sprintf(`{"service":"%s","status":"alive"}`, h.config.ServiceID)), - } - if err := h.endpoint.Send(msg); err != nil { - fmt.Printf("heartbeat send failed: %v\n", err) - } - } - } - }() -} - -// Stop halts heartbeat sending. -func (h *HeartbeatSender) Stop() { - h.mu.Lock() - defer h.mu.Unlock() - if !h.stopped { - close(h.stopCh) - h.stopped = true - } -} - -// GracefulShutdown closes an endpoint after draining in-flight messages. -// It waits up to timeout before forcing closure. -func GracefulShutdown(endpoint Endpoint, timeout time.Duration) error { - done := make(chan error, 1) - go func() { - done <- endpoint.Close() - }() - - select { - case err := <-done: - return err - case <-time.After(timeout): - return fmt.Errorf("shutdown timed out after %v", timeout) - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package core + +import ( + "fmt" + "math" + "sync" + "time" +) + +// ReconnectPolicy configures automatic reconnection behavior. +type ReconnectPolicy struct { + MaxRetries int + InitialBackoff time.Duration + MaxBackoff time.Duration + BackoffFactor float64 +} + +// DefaultReconnectPolicy returns a sensible default reconnect policy. +func DefaultReconnectPolicy() ReconnectPolicy { + return ReconnectPolicy{ + MaxRetries: 10, + InitialBackoff: 100 * time.Millisecond, + MaxBackoff: 30 * time.Second, + BackoffFactor: 2.0, + } +} + +// Backoff calculates the backoff duration for the given attempt number. +func (p ReconnectPolicy) Backoff(attempt int) time.Duration { + if attempt <= 0 { + return p.InitialBackoff + } + backoff := float64(p.InitialBackoff) * math.Pow(p.BackoffFactor, float64(attempt)) + if backoff > float64(p.MaxBackoff) { + backoff = float64(p.MaxBackoff) + } + return time.Duration(backoff) +} + +// HeartbeatConfig configures periodic heartbeat sending. +type HeartbeatConfig struct { + Interval time.Duration + ServiceID string +} + +// HeartbeatSender sends periodic heartbeat messages over an endpoint. +type HeartbeatSender struct { + endpoint Endpoint + config HeartbeatConfig + stopCh chan struct{} + stopped bool + mu sync.Mutex +} + +// NewHeartbeatSender creates a heartbeat sender for the given endpoint. +func NewHeartbeatSender(endpoint Endpoint, config HeartbeatConfig) *HeartbeatSender { + if config.Interval == 0 { + config.Interval = 5 * time.Second + } + return &HeartbeatSender{ + endpoint: endpoint, + config: config, + stopCh: make(chan struct{}), + } +} + +// Start begins sending heartbeats in a background goroutine. +func (h *HeartbeatSender) Start() { + go func() { + ticker := time.NewTicker(h.config.Interval) + defer ticker.Stop() + + for { + select { + case <-h.stopCh: + return + case <-ticker.C: + msg := Message{ + Version: ProtocolVersion, + Type: TypeHeartbeat, + Source: h.config.ServiceID, + Timestamp: time.Now().UTC(), + Priority: PriorityP3, + Payload: []byte(fmt.Sprintf(`{"service":"%s","status":"alive"}`, h.config.ServiceID)), + } + if err := h.endpoint.Send(msg); err != nil { + fmt.Printf("heartbeat send failed: %v\n", err) + } + } + } + }() +} + +// Stop halts heartbeat sending. +func (h *HeartbeatSender) Stop() { + h.mu.Lock() + defer h.mu.Unlock() + if !h.stopped { + close(h.stopCh) + h.stopped = true + } +} + +// GracefulShutdown closes an endpoint after draining in-flight messages. +// It waits up to timeout before forcing closure. +func GracefulShutdown(endpoint Endpoint, timeout time.Duration) error { + done := make(chan error, 1) + go func() { + done <- endpoint.Close() + }() + + select { + case err := <-done: + return err + case <-time.After(timeout): + return fmt.Errorf("shutdown timed out after %v", timeout) + } +} diff --git a/core/router_test.go b/core/router_test.go index bc067f9..6e5d66f 100644 --- a/core/router_test.go +++ b/core/router_test.go @@ -1,105 +1,105 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package core - -import ( - "testing" - "time" -) - -func TestRouterHandle(t *testing.T) { - r := NewRouter() - - called := false - r.Handle(TypeIntent, func(msg Message) (*Message, error) { - called = true - resp := NewMessage(TypeAck, "router", []byte("handled")) - return &resp, nil - }) - - msg := NewMessage(TypeIntent, "test", []byte("data")) - resp, err := r.Dispatch(msg) - if err != nil { - t.Fatalf("Dispatch failed: %v", err) - } - if !called { - t.Error("handler was not called") - } - if resp == nil { - t.Fatal("expected response, got nil") - } - if resp.Type != TypeAck { - t.Errorf("expected TypeAck, got %v", resp.Type) - } -} - -func TestRouterDispatchUnregistered(t *testing.T) { - r := NewRouter() - - msg := NewMessage(TypeIntent, "test", nil) - resp, err := r.Dispatch(msg) - if err != nil { - t.Fatalf("Dispatch should not error for unregistered type: %v", err) - } - if resp != nil { - t.Error("expected nil response for unregistered type") - } -} - -func TestRouterBatchPriorityOrder(t *testing.T) { - r := NewRouter() - - var order []Priority - r.Handle(TypeIntent, func(msg Message) (*Message, error) { - order = append(order, msg.Priority) - return nil, nil - }) - - messages := []Message{ - {Type: TypeIntent, Priority: PriorityP2, Source: "low", Timestamp: time.Now()}, - {Type: TypeIntent, Priority: PriorityP0, Source: "critical", Timestamp: time.Now()}, - {Type: TypeIntent, Priority: PriorityP1, Source: "normal", Timestamp: time.Now()}, - } - - r.DispatchBatch(messages) - - if len(order) != 3 { - t.Fatalf("expected 3 dispatches, got %d", len(order)) - } - if order[0] != PriorityP0 { - t.Errorf("first dispatch should be P0, got P%d", order[0]) - } - if order[1] != PriorityP1 { - t.Errorf("second dispatch should be P1, got P%d", order[1]) - } - if order[2] != PriorityP2 { - t.Errorf("third dispatch should be P2, got P%d", order[2]) - } -} - -func TestRouterMultipleHandlers(t *testing.T) { - r := NewRouter() - - intentCalled := false - ackCalled := false - - r.Handle(TypeIntent, func(msg Message) (*Message, error) { - intentCalled = true - return nil, nil - }) - r.Handle(TypeAck, func(msg Message) (*Message, error) { - ackCalled = true - return nil, nil - }) - - _ = r.Dispatch(NewMessage(TypeIntent, "test", nil)) - _ = r.Dispatch(NewMessage(TypeAck, "test", nil)) - - if !intentCalled { - t.Error("intent handler not called") - } - if !ackCalled { - t.Error("ack handler not called") - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package core + +import ( + "testing" + "time" +) + +func TestRouterHandle(t *testing.T) { + r := NewRouter() + + called := false + r.Handle(TypeIntent, func(msg Message) (*Message, error) { + called = true + resp := NewMessage(TypeAck, "router", []byte("handled")) + return &resp, nil + }) + + msg := NewMessage(TypeIntent, "test", []byte("data")) + resp, err := r.Dispatch(msg) + if err != nil { + t.Fatalf("Dispatch failed: %v", err) + } + if !called { + t.Error("handler was not called") + } + if resp == nil { + t.Fatal("expected response, got nil") + } + if resp.Type != TypeAck { + t.Errorf("expected TypeAck, got %v", resp.Type) + } +} + +func TestRouterDispatchUnregistered(t *testing.T) { + r := NewRouter() + + msg := NewMessage(TypeIntent, "test", nil) + resp, err := r.Dispatch(msg) + if err != nil { + t.Fatalf("Dispatch should not error for unregistered type: %v", err) + } + if resp != nil { + t.Error("expected nil response for unregistered type") + } +} + +func TestRouterBatchPriorityOrder(t *testing.T) { + r := NewRouter() + + var order []Priority + r.Handle(TypeIntent, func(msg Message) (*Message, error) { + order = append(order, msg.Priority) + return nil, nil + }) + + messages := []Message{ + {Type: TypeIntent, Priority: PriorityP2, Source: "low", Timestamp: time.Now()}, + {Type: TypeIntent, Priority: PriorityP0, Source: "critical", Timestamp: time.Now()}, + {Type: TypeIntent, Priority: PriorityP1, Source: "normal", Timestamp: time.Now()}, + } + + r.DispatchBatch(messages) + + if len(order) != 3 { + t.Fatalf("expected 3 dispatches, got %d", len(order)) + } + if order[0] != PriorityP0 { + t.Errorf("first dispatch should be P0, got P%d", order[0]) + } + if order[1] != PriorityP1 { + t.Errorf("second dispatch should be P1, got P%d", order[1]) + } + if order[2] != PriorityP2 { + t.Errorf("third dispatch should be P2, got P%d", order[2]) + } +} + +func TestRouterMultipleHandlers(t *testing.T) { + r := NewRouter() + + intentCalled := false + ackCalled := false + + r.Handle(TypeIntent, func(msg Message) (*Message, error) { + intentCalled = true + return nil, nil + }) + r.Handle(TypeAck, func(msg Message) (*Message, error) { + ackCalled = true + return nil, nil + }) + + _ = r.Dispatch(NewMessage(TypeIntent, "test", nil)) + _ = r.Dispatch(NewMessage(TypeAck, "test", nil)) + + if !intentCalled { + t.Error("intent handler not called") + } + if !ackCalled { + t.Error("ack handler not called") + } +} diff --git a/core/types.go b/core/types.go index 5371b17..df204b6 100644 --- a/core/types.go +++ b/core/types.go @@ -1,43 +1,43 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package core - -// MessageType identifies the kind of EIPC message. -type MessageType string - -const ( - TypeIntent MessageType = "intent" - TypeFeatures MessageType = "features" - TypeToolRequest MessageType = "tool_request" - TypeAuth MessageType = "auth" - TypeChallenge MessageType = "challenge" - TypeAuthResponse MessageType = "auth_response" - TypeAck MessageType = "ack" - TypePolicyResult MessageType = "policy_result" - TypeHeartbeat MessageType = "heartbeat" - TypeAudit MessageType = "audit" - TypeChat MessageType = "chat" - TypeComplete MessageType = "complete" -) - -// Priority defines message urgency lanes. -type Priority uint8 - -const ( - PriorityP0 Priority = 0 // Control-critical - PriorityP1 Priority = 1 // Interactive - PriorityP2 Priority = 2 // Telemetry - PriorityP3 Priority = 3 // Debug / audit bulk -) - -// PayloadFormat identifies serialization encoding. -type PayloadFormat uint8 - -const ( - PayloadJSON PayloadFormat = 0 - PayloadMsgPack PayloadFormat = 1 -) - -// ProtocolVersion is the current EIPC protocol version. -const ProtocolVersion uint16 = 1 +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package core + +// MessageType identifies the kind of EIPC message. +type MessageType string + +const ( + TypeIntent MessageType = "intent" + TypeFeatures MessageType = "features" + TypeToolRequest MessageType = "tool_request" + TypeAuth MessageType = "auth" + TypeChallenge MessageType = "challenge" + TypeAuthResponse MessageType = "auth_response" + TypeAck MessageType = "ack" + TypePolicyResult MessageType = "policy_result" + TypeHeartbeat MessageType = "heartbeat" + TypeAudit MessageType = "audit" + TypeChat MessageType = "chat" + TypeComplete MessageType = "complete" +) + +// Priority defines message urgency lanes. +type Priority uint8 + +const ( + PriorityP0 Priority = 0 // Control-critical + PriorityP1 Priority = 1 // Interactive + PriorityP2 Priority = 2 // Telemetry + PriorityP3 Priority = 3 // Debug / audit bulk +) + +// PayloadFormat identifies serialization encoding. +type PayloadFormat uint8 + +const ( + PayloadJSON PayloadFormat = 0 + PayloadMsgPack PayloadFormat = 1 +) + +// ProtocolVersion is the current EIPC protocol version. +const ProtocolVersion uint16 = 1 diff --git a/protocol/benchmark_test.go b/protocol/benchmark_test.go index 5984f68..fe316e7 100644 --- a/protocol/benchmark_test.go +++ b/protocol/benchmark_test.go @@ -1,62 +1,62 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package protocol - -import ( - "bytes" - "testing" -) - -func BenchmarkFrameEncode(b *testing.B) { - frame := &Frame{ - Version: ProtocolVersion, - MsgType: 'i', - Flags: FlagHMAC, - Header: []byte(`{"service_id":"eni.min","session_id":"sess-1","request_id":"req-1","sequence":1,"timestamp":"2026-01-01T00:00:00Z","priority":0}`), - Payload: []byte(`{"intent":"move_left","confidence":0.91}`), - MAC: make([]byte, MACSize), - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - var buf bytes.Buffer - _ = frame.Encode(&buf) - } -} - -func BenchmarkFrameDecode(b *testing.B) { - frame := &Frame{ - Version: ProtocolVersion, - MsgType: 'i', - Flags: FlagHMAC, - Header: []byte(`{"service_id":"eni.min","session_id":"sess-1","request_id":"req-1","sequence":1,"timestamp":"2026-01-01T00:00:00Z","priority":0}`), - Payload: []byte(`{"intent":"move_left","confidence":0.91}`), - MAC: make([]byte, MACSize), - } - - var buf bytes.Buffer - _ = frame.Encode(&buf) - encoded := buf.Bytes() - - b.ResetTimer() - for i := 0; i < b.N; i++ { - r := bytes.NewReader(encoded) - Decode(r) - } -} - -func BenchmarkSignableBytes(b *testing.B) { - frame := &Frame{ - Version: ProtocolVersion, - MsgType: 'i', - Flags: FlagHMAC, - Header: []byte(`{"service_id":"eni.min"}`), - Payload: []byte(`{"intent":"move"}`), - } - - b.ResetTimer() - for i := 0; i < b.N; i++ { - frame.SignableBytes() - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package protocol + +import ( + "bytes" + "testing" +) + +func BenchmarkFrameEncode(b *testing.B) { + frame := &Frame{ + Version: ProtocolVersion, + MsgType: 'i', + Flags: FlagHMAC, + Header: []byte(`{"service_id":"eni.min","session_id":"sess-1","request_id":"req-1","sequence":1,"timestamp":"2026-01-01T00:00:00Z","priority":0}`), + Payload: []byte(`{"intent":"move_left","confidence":0.91}`), + MAC: make([]byte, MACSize), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var buf bytes.Buffer + _ = frame.Encode(&buf) + } +} + +func BenchmarkFrameDecode(b *testing.B) { + frame := &Frame{ + Version: ProtocolVersion, + MsgType: 'i', + Flags: FlagHMAC, + Header: []byte(`{"service_id":"eni.min","session_id":"sess-1","request_id":"req-1","sequence":1,"timestamp":"2026-01-01T00:00:00Z","priority":0}`), + Payload: []byte(`{"intent":"move_left","confidence":0.91}`), + MAC: make([]byte, MACSize), + } + + var buf bytes.Buffer + _ = frame.Encode(&buf) + encoded := buf.Bytes() + + b.ResetTimer() + for i := 0; i < b.N; i++ { + r := bytes.NewReader(encoded) + Decode(r) + } +} + +func BenchmarkSignableBytes(b *testing.B) { + frame := &Frame{ + Version: ProtocolVersion, + MsgType: 'i', + Flags: FlagHMAC, + Header: []byte(`{"service_id":"eni.min"}`), + Payload: []byte(`{"intent":"move"}`), + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + frame.SignableBytes() + } +} diff --git a/protocol/frame.go b/protocol/frame.go index 70f79cb..63de50e 100644 --- a/protocol/frame.go +++ b/protocol/frame.go @@ -1,157 +1,157 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package protocol - -import ( - "encoding/binary" - "errors" - "fmt" - "io" -) - -// Magic bytes: "EIPC" = 0x45495043 -const ( - MagicBytes uint32 = 0x45495043 - MaxFrameSize uint32 = 1 << 20 // 1 MB - MACSize = 32 // HMAC-SHA256 - ProtocolVersion uint16 = 1 -) - -var ( - ErrBadMagic = errors.New("eipc: invalid magic bytes") - ErrBadVersion = errors.New("eipc: unsupported protocol version") - ErrFrameTooLarge = errors.New("eipc: frame exceeds maximum size") -) - -// Frame flags -const ( - FlagHMAC uint8 = 1 << 0 // Frame carries an HMAC - FlagCompress uint8 = 1 << 1 // Payload is compressed (future) - FlagEncrypted uint8 = 1 << 2 // Payload is encrypted with AES-GCM -) - -// Frame is the on-the-wire representation of an EIPC message. -// -// Wire format: -// -// [magic:4][version:2][msg_type:1][flags:1][header_len:4][payload_len:4][header][payload][mac:32?] -type Frame struct { - Version uint16 - MsgType uint8 - Flags uint8 - Header []byte - Payload []byte - MAC []byte // Present only when FlagHMAC is set -} - -// FrameFixedSize is the byte count of the fixed-width preamble. -const FrameFixedSize = 4 + 2 + 1 + 1 + 4 + 4 // 16 bytes - -// Encode writes the frame to w in the EIPC wire format. -func (f *Frame) Encode(w io.Writer) error { - if uint64(len(f.Header))+uint64(len(f.Payload)) > uint64(MaxFrameSize) { - return ErrFrameTooLarge - } - - buf := make([]byte, FrameFixedSize) - binary.BigEndian.PutUint32(buf[0:4], MagicBytes) - binary.BigEndian.PutUint16(buf[4:6], f.Version) - buf[6] = f.MsgType - buf[7] = f.Flags - binary.BigEndian.PutUint32(buf[8:12], uint32(len(f.Header))) - binary.BigEndian.PutUint32(buf[12:16], uint32(len(f.Payload))) - - if _, err := w.Write(buf); err != nil { - return fmt.Errorf("write preamble: %w", err) - } - if _, err := w.Write(f.Header); err != nil { - return fmt.Errorf("write header: %w", err) - } - if _, err := w.Write(f.Payload); err != nil { - return fmt.Errorf("write payload: %w", err) - } - if f.Flags&FlagHMAC != 0 && len(f.MAC) == MACSize { - if _, err := w.Write(f.MAC); err != nil { - return fmt.Errorf("write mac: %w", err) - } - } - return nil -} - -// Decode reads a frame from r in the EIPC wire format. -func Decode(r io.Reader) (*Frame, error) { - preamble := make([]byte, FrameFixedSize) - if _, err := io.ReadFull(r, preamble); err != nil { - return nil, fmt.Errorf("read preamble: %w", err) - } - - magic := binary.BigEndian.Uint32(preamble[0:4]) - if magic != MagicBytes { - return nil, ErrBadMagic - } - - version := binary.BigEndian.Uint16(preamble[4:6]) - if version != ProtocolVersion { - return nil, ErrBadVersion - } - - f := &Frame{ - Version: version, - MsgType: preamble[6], - Flags: preamble[7], - } - - headerLen := binary.BigEndian.Uint32(preamble[8:12]) - payloadLen := binary.BigEndian.Uint32(preamble[12:16]) - - if uint64(headerLen)+uint64(payloadLen) > uint64(MaxFrameSize) { - return nil, ErrFrameTooLarge - } - - if headerLen > 0 { - f.Header = make([]byte, headerLen) - if _, err := io.ReadFull(r, f.Header); err != nil { - return nil, fmt.Errorf("read header: %w", err) - } - } - - if payloadLen > 0 { - f.Payload = make([]byte, payloadLen) - if _, err := io.ReadFull(r, f.Payload); err != nil { - return nil, fmt.Errorf("read payload: %w", err) - } - } - - if f.Flags&FlagHMAC != 0 { - f.MAC = make([]byte, MACSize) - if _, err := io.ReadFull(r, f.MAC); err != nil { - return nil, fmt.Errorf("read mac: %w", err) - } - } - - return f, nil -} - -// SignableBytes returns the portion of the frame that is covered by the MAC -// (everything except the MAC itself). -func (f *Frame) SignableBytes() []byte { - if uint64(len(f.Header))+uint64(len(f.Payload)) > uint64(MaxFrameSize) { - return nil - } - size := FrameFixedSize + len(f.Header) + len(f.Payload) - buf := make([]byte, 0, size) - - preamble := make([]byte, FrameFixedSize) - binary.BigEndian.PutUint32(preamble[0:4], MagicBytes) - binary.BigEndian.PutUint16(preamble[4:6], f.Version) - preamble[6] = f.MsgType - preamble[7] = f.Flags - binary.BigEndian.PutUint32(preamble[8:12], uint32(len(f.Header))) - binary.BigEndian.PutUint32(preamble[12:16], uint32(len(f.Payload))) - - buf = append(buf, preamble...) - buf = append(buf, f.Header...) - buf = append(buf, f.Payload...) - return buf -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package protocol + +import ( + "encoding/binary" + "errors" + "fmt" + "io" +) + +// Magic bytes: "EIPC" = 0x45495043 +const ( + MagicBytes uint32 = 0x45495043 + MaxFrameSize uint32 = 1 << 20 // 1 MB + MACSize = 32 // HMAC-SHA256 + ProtocolVersion uint16 = 1 +) + +var ( + ErrBadMagic = errors.New("eipc: invalid magic bytes") + ErrBadVersion = errors.New("eipc: unsupported protocol version") + ErrFrameTooLarge = errors.New("eipc: frame exceeds maximum size") +) + +// Frame flags +const ( + FlagHMAC uint8 = 1 << 0 // Frame carries an HMAC + FlagCompress uint8 = 1 << 1 // Payload is compressed (future) + FlagEncrypted uint8 = 1 << 2 // Payload is encrypted with AES-GCM +) + +// Frame is the on-the-wire representation of an EIPC message. +// +// Wire format: +// +// [magic:4][version:2][msg_type:1][flags:1][header_len:4][payload_len:4][header][payload][mac:32?] +type Frame struct { + Version uint16 + MsgType uint8 + Flags uint8 + Header []byte + Payload []byte + MAC []byte // Present only when FlagHMAC is set +} + +// FrameFixedSize is the byte count of the fixed-width preamble. +const FrameFixedSize = 4 + 2 + 1 + 1 + 4 + 4 // 16 bytes + +// Encode writes the frame to w in the EIPC wire format. +func (f *Frame) Encode(w io.Writer) error { + if uint64(len(f.Header))+uint64(len(f.Payload)) > uint64(MaxFrameSize) { + return ErrFrameTooLarge + } + + buf := make([]byte, FrameFixedSize) + binary.BigEndian.PutUint32(buf[0:4], MagicBytes) + binary.BigEndian.PutUint16(buf[4:6], f.Version) + buf[6] = f.MsgType + buf[7] = f.Flags + binary.BigEndian.PutUint32(buf[8:12], uint32(len(f.Header))) + binary.BigEndian.PutUint32(buf[12:16], uint32(len(f.Payload))) + + if _, err := w.Write(buf); err != nil { + return fmt.Errorf("write preamble: %w", err) + } + if _, err := w.Write(f.Header); err != nil { + return fmt.Errorf("write header: %w", err) + } + if _, err := w.Write(f.Payload); err != nil { + return fmt.Errorf("write payload: %w", err) + } + if f.Flags&FlagHMAC != 0 && len(f.MAC) == MACSize { + if _, err := w.Write(f.MAC); err != nil { + return fmt.Errorf("write mac: %w", err) + } + } + return nil +} + +// Decode reads a frame from r in the EIPC wire format. +func Decode(r io.Reader) (*Frame, error) { + preamble := make([]byte, FrameFixedSize) + if _, err := io.ReadFull(r, preamble); err != nil { + return nil, fmt.Errorf("read preamble: %w", err) + } + + magic := binary.BigEndian.Uint32(preamble[0:4]) + if magic != MagicBytes { + return nil, ErrBadMagic + } + + version := binary.BigEndian.Uint16(preamble[4:6]) + if version != ProtocolVersion { + return nil, ErrBadVersion + } + + f := &Frame{ + Version: version, + MsgType: preamble[6], + Flags: preamble[7], + } + + headerLen := binary.BigEndian.Uint32(preamble[8:12]) + payloadLen := binary.BigEndian.Uint32(preamble[12:16]) + + if uint64(headerLen)+uint64(payloadLen) > uint64(MaxFrameSize) { + return nil, ErrFrameTooLarge + } + + if headerLen > 0 { + f.Header = make([]byte, headerLen) + if _, err := io.ReadFull(r, f.Header); err != nil { + return nil, fmt.Errorf("read header: %w", err) + } + } + + if payloadLen > 0 { + f.Payload = make([]byte, payloadLen) + if _, err := io.ReadFull(r, f.Payload); err != nil { + return nil, fmt.Errorf("read payload: %w", err) + } + } + + if f.Flags&FlagHMAC != 0 { + f.MAC = make([]byte, MACSize) + if _, err := io.ReadFull(r, f.MAC); err != nil { + return nil, fmt.Errorf("read mac: %w", err) + } + } + + return f, nil +} + +// SignableBytes returns the portion of the frame that is covered by the MAC +// (everything except the MAC itself). +func (f *Frame) SignableBytes() []byte { + if uint64(len(f.Header))+uint64(len(f.Payload)) > uint64(MaxFrameSize) { + return nil + } + size := FrameFixedSize + len(f.Header) + len(f.Payload) + buf := make([]byte, 0, size) + + preamble := make([]byte, FrameFixedSize) + binary.BigEndian.PutUint32(preamble[0:4], MagicBytes) + binary.BigEndian.PutUint16(preamble[4:6], f.Version) + preamble[6] = f.MsgType + preamble[7] = f.Flags + binary.BigEndian.PutUint32(preamble[8:12], uint32(len(f.Header))) + binary.BigEndian.PutUint32(preamble[12:16], uint32(len(f.Payload))) + + buf = append(buf, preamble...) + buf = append(buf, f.Header...) + buf = append(buf, f.Payload...) + return buf +} diff --git a/protocol/fuzz_test.go b/protocol/fuzz_test.go index 582930a..6079fec 100644 --- a/protocol/fuzz_test.go +++ b/protocol/fuzz_test.go @@ -1,59 +1,59 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package protocol - -import ( - "bytes" - "encoding/binary" - "testing" -) - -func FuzzFrameDecode(f *testing.F) { - // Seed with a valid frame - var validFrame bytes.Buffer - frame := &Frame{ - Version: ProtocolVersion, - MsgType: 'i', - Flags: 0, - Header: []byte(`{"service_id":"test"}`), - Payload: []byte(`{"intent":"move"}`), - } - _ = frame.Encode(&validFrame) - f.Add(validFrame.Bytes()) - - // Seed with minimal valid preamble - preamble := make([]byte, FrameFixedSize) - binary.BigEndian.PutUint32(preamble[0:4], MagicBytes) - binary.BigEndian.PutUint16(preamble[4:6], ProtocolVersion) - preamble[6] = 'a' - preamble[7] = 0 - binary.BigEndian.PutUint32(preamble[8:12], 0) - binary.BigEndian.PutUint32(preamble[12:16], 0) - f.Add(preamble) - - // Seed with empty data - f.Add([]byte{}) - - // Seed with garbage - f.Add([]byte{0xff, 0xfe, 0xfd, 0xfc, 0x00, 0x01}) - - f.Fuzz(func(t *testing.T, data []byte) { - r := bytes.NewReader(data) - frame, err := Decode(r) - if err != nil { - return // Expected for random input - } - - // If decode succeeded, verify the frame is reasonable - if frame.Version != ProtocolVersion { - t.Errorf("decoded frame with unexpected version %d", frame.Version) - } - - // Re-encode and verify round-trip - var buf bytes.Buffer - if err := frame.Encode(&buf); err != nil { - t.Errorf("failed to re-encode decoded frame: %v", err) - } - }) -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package protocol + +import ( + "bytes" + "encoding/binary" + "testing" +) + +func FuzzFrameDecode(f *testing.F) { + // Seed with a valid frame + var validFrame bytes.Buffer + frame := &Frame{ + Version: ProtocolVersion, + MsgType: 'i', + Flags: 0, + Header: []byte(`{"service_id":"test"}`), + Payload: []byte(`{"intent":"move"}`), + } + _ = frame.Encode(&validFrame) + f.Add(validFrame.Bytes()) + + // Seed with minimal valid preamble + preamble := make([]byte, FrameFixedSize) + binary.BigEndian.PutUint32(preamble[0:4], MagicBytes) + binary.BigEndian.PutUint16(preamble[4:6], ProtocolVersion) + preamble[6] = 'a' + preamble[7] = 0 + binary.BigEndian.PutUint32(preamble[8:12], 0) + binary.BigEndian.PutUint32(preamble[12:16], 0) + f.Add(preamble) + + // Seed with empty data + f.Add([]byte{}) + + // Seed with garbage + f.Add([]byte{0xff, 0xfe, 0xfd, 0xfc, 0x00, 0x01}) + + f.Fuzz(func(t *testing.T, data []byte) { + r := bytes.NewReader(data) + frame, err := Decode(r) + if err != nil { + return // Expected for random input + } + + // If decode succeeded, verify the frame is reasonable + if frame.Version != ProtocolVersion { + t.Errorf("decoded frame with unexpected version %d", frame.Version) + } + + // Re-encode and verify round-trip + var buf bytes.Buffer + if err := frame.Encode(&buf); err != nil { + t.Errorf("failed to re-encode decoded frame: %v", err) + } + }) +} diff --git a/sdk/c/src/eipc_client.c b/sdk/c/src/eipc_client.c index 2aef119..0f6ca46 100644 --- a/sdk/c/src/eipc_client.c +++ b/sdk/c/src/eipc_client.c @@ -1,312 +1,312 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project -// ISO/IEC 25000 | ISO/IEC/IEEE 15288:2023 - -/* - * EIPC High-level Client API - */ - -#include "eipc.h" - -#include -#include -#include -#include - -/* --------------- utility implementations --------------- */ - -void eipc_timestamp_now(char *buf, size_t buf_size) { - if (!buf || buf_size == 0) return; - snprintf(buf, buf_size, "%llu", (unsigned long long)time(NULL)); -} - -void eipc_generate_request_id(char *buf, size_t buf_size) { - if (!buf || buf_size == 0) return; - snprintf(buf, buf_size, "req-%08lx-%04x", - (unsigned long)time(NULL), - (unsigned)(rand() & 0xFFFF)); -} - -/* --------------- helpers --------------- */ - -static void fill_header(eipc_header_t *h, - const char *service_id, - uint64_t sequence) { - memset(h, 0, sizeof(*h)); - strncpy(h->service_id, service_id, sizeof(h->service_id) - 1); - h->service_id[sizeof(h->service_id) - 1] = '\0'; - eipc_generate_request_id(h->request_id, sizeof(h->request_id)); - h->sequence = sequence; - eipc_timestamp_now(h->timestamp, sizeof(h->timestamp)); - h->priority = EIPC_PRIORITY_P1; - h->payload_format = EIPC_PAYLOAD_JSON; -} - -static eipc_status_t build_and_send(eipc_client_t *c, - uint8_t msg_type, - const eipc_header_t *hdr, - const uint8_t *payload, - size_t payload_len) { - eipc_frame_t frame; - eipc_status_t rc; - char header_json[EIPC_MAX_HEADER]; - size_t signable_len; - - memset(&frame, 0, sizeof(frame)); - - frame.version = EIPC_PROTOCOL_VER; - frame.msg_type = msg_type; - frame.flags = EIPC_FLAG_HMAC; - - size_t hdr_json_len = 0; - rc = eipc_header_to_json(hdr, header_json, sizeof(header_json), &hdr_json_len); - if (rc != EIPC_OK) return rc; - memcpy(frame.header, header_json, hdr_json_len); - frame.header_len = (uint32_t)hdr_json_len; - - if (payload && payload_len > 0) { - if (payload_len > sizeof(frame.payload)) - return EIPC_ERR_FRAME_TOO_LARGE; - memcpy(frame.payload, payload, payload_len); - frame.payload_len = (uint32_t)payload_len; - } - - { - uint8_t signable[EIPC_MAX_FRAME]; - signable_len = eipc_frame_signable_bytes(&frame, signable, sizeof(signable)); - if (signable_len == 0) - return EIPC_ERR_INTEGRITY; - - eipc_hmac_sign(c->hmac_key, c->hmac_key_len, - signable, signable_len, - frame.mac); - } - - return eipc_transport_send_frame(c->sock, &frame); -} - -/* --------------- public API --------------- */ - -eipc_status_t eipc_client_init(eipc_client_t *c, const char *service_id) { - if (!c || !service_id) return EIPC_ERR_INVALID; - - memset(c, 0, sizeof(*c)); - c->sock = EIPC_INVALID_SOCKET; - strncpy(c->service_id, service_id, sizeof(c->service_id) - 1); - c->service_id[sizeof(c->service_id) - 1] = '\0'; - - return EIPC_OK; -} - -eipc_status_t eipc_client_connect(eipc_client_t *c, - const char *address, - const char *hmac_key) { - eipc_status_t rc; - size_t key_len; - - if (!c || !address || !hmac_key) return EIPC_ERR_INVALID; - - key_len = strlen(hmac_key); - if (key_len > sizeof(c->hmac_key)) return EIPC_ERR_INVALID; - memcpy(c->hmac_key, hmac_key, key_len); - c->hmac_key_len = (uint32_t)key_len; - - rc = eipc_transport_connect(&c->sock, address); - if (rc != EIPC_OK) return rc; - - c->connected = true; - c->sequence = 0; - - return EIPC_OK; -} - -eipc_status_t eipc_client_send_intent(eipc_client_t *c, - const char *intent, - float confidence) { - eipc_header_t hdr; - eipc_intent_event_t ev; - char payload_json[EIPC_MAX_PAYLOAD]; - eipc_status_t rc; - - if (!c || !c->connected || !intent) return EIPC_ERR_INVALID; - - c->sequence++; - - fill_header(&hdr, c->service_id, c->sequence); - - memset(&ev, 0, sizeof(ev)); - strncpy(ev.intent, intent, sizeof(ev.intent) - 1); - ev.intent[sizeof(ev.intent) - 1] = '\0'; - ev.confidence = confidence; - - size_t payload_written = 0; - rc = eipc_intent_to_json(&ev, payload_json, sizeof(payload_json), &payload_written); - if (rc != EIPC_OK) return rc; - - return build_and_send(c, EIPC_MSG_INTENT, &hdr, - (const uint8_t *)payload_json, strlen(payload_json)); -} - -eipc_status_t eipc_client_send_tool_request(eipc_client_t *c, - const char *tool, - const eipc_kv_t *args, - int arg_count) { - eipc_header_t hdr; - eipc_tool_request_t req; - char payload_json[EIPC_MAX_PAYLOAD]; - eipc_status_t rc; - - if (!c || !c->connected || !tool) return EIPC_ERR_INVALID; - - c->sequence++; - - fill_header(&hdr, c->service_id, c->sequence); - - memset(&req, 0, sizeof(req)); - strncpy(req.tool, tool, sizeof(req.tool) - 1); - req.tool[sizeof(req.tool) - 1] = '\0'; - if (args && arg_count > 0) { - int n = arg_count; - if (n > EIPC_MAX_ARGS) n = EIPC_MAX_ARGS; - memcpy(req.args, args, sizeof(eipc_kv_t) * (size_t)n); - req.arg_count = n; - } - - size_t payload_written = 0; - rc = eipc_tool_request_to_json(&req, payload_json, sizeof(payload_json), &payload_written); - if (rc != EIPC_OK) return rc; - - return build_and_send(c, EIPC_MSG_TOOL_REQUEST, &hdr, - (const uint8_t *)payload_json, strlen(payload_json)); -} - -eipc_status_t eipc_client_send_heartbeat(eipc_client_t *c) { - eipc_header_t hdr; - eipc_heartbeat_event_t hb; - char payload_json[EIPC_MAX_PAYLOAD]; - eipc_status_t rc; - - if (!c || !c->connected) return EIPC_ERR_INVALID; - - c->sequence++; - - fill_header(&hdr, c->service_id, c->sequence); - - memset(&hb, 0, sizeof(hb)); - strncpy(hb.service, c->service_id, sizeof(hb.service) - 1); - hb.service[sizeof(hb.service) - 1] = '\0'; - strncpy(hb.status, "alive", sizeof(hb.status) - 1); - hb.status[sizeof(hb.status) - 1] = '\0'; - - size_t payload_written = 0; - rc = eipc_heartbeat_to_json(&hb, payload_json, sizeof(payload_json), &payload_written); - if (rc != EIPC_OK) return rc; - - return build_and_send(c, EIPC_MSG_HEARTBEAT, &hdr, - (const uint8_t *)payload_json, strlen(payload_json)); -} - -eipc_status_t eipc_client_send_chat(eipc_client_t *c, - const eipc_chat_request_t *req) { - eipc_header_t hdr; - char payload_json[EIPC_MAX_PAYLOAD]; - eipc_status_t rc; - - if (!c || !c->connected || !req) return EIPC_ERR_INVALID; - - c->sequence++; - fill_header(&hdr, c->service_id, c->sequence); - strncpy(hdr.capability, "ai:chat", sizeof(hdr.capability) - 1); - hdr.capability[sizeof(hdr.capability) - 1] = '\0'; - - size_t chat_written = 0; - rc = eipc_chat_request_to_json(req, payload_json, sizeof(payload_json), &chat_written); - if (rc != EIPC_OK) return rc; - - return build_and_send(c, EIPC_MSG_CHAT, &hdr, - (const uint8_t *)payload_json, strlen(payload_json)); -} - -eipc_status_t eipc_client_send_complete(eipc_client_t *c, - const char *prompt, - const char *session_id) { - eipc_header_t hdr; - char payload_json[EIPC_MAX_PAYLOAD]; - int n; - - if (!c || !c->connected || !prompt) return EIPC_ERR_INVALID; - - c->sequence++; - fill_header(&hdr, c->service_id, c->sequence); - strncpy(hdr.capability, "ai:chat", sizeof(hdr.capability) - 1); - hdr.capability[sizeof(hdr.capability) - 1] = '\0'; - - n = snprintf(payload_json, sizeof(payload_json), - "{\"session_id\":\"%s\",\"prompt\":\"%s\",\"model\":\"\",\"max_tokens\":0}", - session_id ? session_id : "", prompt); - if (n < 0 || (size_t)n >= sizeof(payload_json)) return EIPC_ERR_FRAME_TOO_LARGE; - - return build_and_send(c, EIPC_MSG_COMPLETE, &hdr, - (const uint8_t *)payload_json, (size_t)n); -} - -eipc_status_t eipc_client_receive(eipc_client_t *c, eipc_message_t *msg) { - eipc_frame_t frame; - eipc_header_t hdr; - eipc_status_t rc; - - if (!c || !c->connected || !msg) return EIPC_ERR_INVALID; - - memset(&frame, 0, sizeof(frame)); - - rc = eipc_transport_recv_frame(c->sock, &frame); - if (rc != EIPC_OK) return rc; - - if (frame.flags & EIPC_FLAG_HMAC) { - uint8_t signable[EIPC_MAX_FRAME]; - size_t signable_len = eipc_frame_signable_bytes(&frame, signable, sizeof(signable)); - if (signable_len == 0) - return EIPC_ERR_INTEGRITY; - - if (!eipc_hmac_verify(c->hmac_key, c->hmac_key_len, - signable, signable_len, frame.mac)) - return EIPC_ERR_AUTH; - } - - memset(msg, 0, sizeof(*msg)); - msg->msg_type = frame.msg_type; - msg->version = frame.version; - - rc = eipc_header_from_json((const char *)frame.header, frame.header_len, &hdr); - if (rc != EIPC_OK) return rc; - - strncpy(msg->source, hdr.service_id, sizeof(msg->source) - 1); - msg->source[sizeof(msg->source) - 1] = '\0'; - strncpy(msg->session_id, hdr.session_id, sizeof(msg->session_id) - 1); - msg->session_id[sizeof(msg->session_id) - 1] = '\0'; - strncpy(msg->request_id, hdr.request_id, sizeof(msg->request_id) - 1); - msg->request_id[sizeof(msg->request_id) - 1] = '\0'; - msg->priority = hdr.priority; - strncpy(msg->capability, hdr.capability, sizeof(msg->capability) - 1); - msg->capability[sizeof(msg->capability) - 1] = '\0'; - - if (frame.payload_len > 0) { - if (frame.payload_len > sizeof(msg->payload)) - return EIPC_ERR_FRAME_TOO_LARGE; - memcpy(msg->payload, frame.payload, frame.payload_len); - msg->payload_len = frame.payload_len; - } - - return EIPC_OK; -} - -void eipc_client_close(eipc_client_t *c) { - if (!c) return; - - if (c->sock != EIPC_INVALID_SOCKET) { - eipc_transport_close(c->sock); - c->sock = EIPC_INVALID_SOCKET; - } - c->connected = false; - c->sequence = 0; -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project +// ISO/IEC 25000 | ISO/IEC/IEEE 15288:2023 + +/* + * EIPC High-level Client API + */ + +#include "eipc.h" + +#include +#include +#include +#include + +/* --------------- utility implementations --------------- */ + +void eipc_timestamp_now(char *buf, size_t buf_size) { + if (!buf || buf_size == 0) return; + snprintf(buf, buf_size, "%llu", (unsigned long long)time(NULL)); +} + +void eipc_generate_request_id(char *buf, size_t buf_size) { + if (!buf || buf_size == 0) return; + snprintf(buf, buf_size, "req-%08lx-%04x", + (unsigned long)time(NULL), + (unsigned)(rand() & 0xFFFF)); +} + +/* --------------- helpers --------------- */ + +static void fill_header(eipc_header_t *h, + const char *service_id, + uint64_t sequence) { + memset(h, 0, sizeof(*h)); + strncpy(h->service_id, service_id, sizeof(h->service_id) - 1); + h->service_id[sizeof(h->service_id) - 1] = '\0'; + eipc_generate_request_id(h->request_id, sizeof(h->request_id)); + h->sequence = sequence; + eipc_timestamp_now(h->timestamp, sizeof(h->timestamp)); + h->priority = EIPC_PRIORITY_P1; + h->payload_format = EIPC_PAYLOAD_JSON; +} + +static eipc_status_t build_and_send(eipc_client_t *c, + uint8_t msg_type, + const eipc_header_t *hdr, + const uint8_t *payload, + size_t payload_len) { + eipc_frame_t frame; + eipc_status_t rc; + char header_json[EIPC_MAX_HEADER]; + size_t signable_len; + + memset(&frame, 0, sizeof(frame)); + + frame.version = EIPC_PROTOCOL_VER; + frame.msg_type = msg_type; + frame.flags = EIPC_FLAG_HMAC; + + size_t hdr_json_len = 0; + rc = eipc_header_to_json(hdr, header_json, sizeof(header_json), &hdr_json_len); + if (rc != EIPC_OK) return rc; + memcpy(frame.header, header_json, hdr_json_len); + frame.header_len = (uint32_t)hdr_json_len; + + if (payload && payload_len > 0) { + if (payload_len > sizeof(frame.payload)) + return EIPC_ERR_FRAME_TOO_LARGE; + memcpy(frame.payload, payload, payload_len); + frame.payload_len = (uint32_t)payload_len; + } + + { + uint8_t signable[EIPC_MAX_FRAME]; + signable_len = eipc_frame_signable_bytes(&frame, signable, sizeof(signable)); + if (signable_len == 0) + return EIPC_ERR_INTEGRITY; + + eipc_hmac_sign(c->hmac_key, c->hmac_key_len, + signable, signable_len, + frame.mac); + } + + return eipc_transport_send_frame(c->sock, &frame); +} + +/* --------------- public API --------------- */ + +eipc_status_t eipc_client_init(eipc_client_t *c, const char *service_id) { + if (!c || !service_id) return EIPC_ERR_INVALID; + + memset(c, 0, sizeof(*c)); + c->sock = EIPC_INVALID_SOCKET; + strncpy(c->service_id, service_id, sizeof(c->service_id) - 1); + c->service_id[sizeof(c->service_id) - 1] = '\0'; + + return EIPC_OK; +} + +eipc_status_t eipc_client_connect(eipc_client_t *c, + const char *address, + const char *hmac_key) { + eipc_status_t rc; + size_t key_len; + + if (!c || !address || !hmac_key) return EIPC_ERR_INVALID; + + key_len = strlen(hmac_key); + if (key_len > sizeof(c->hmac_key)) return EIPC_ERR_INVALID; + memcpy(c->hmac_key, hmac_key, key_len); + c->hmac_key_len = (uint32_t)key_len; + + rc = eipc_transport_connect(&c->sock, address); + if (rc != EIPC_OK) return rc; + + c->connected = true; + c->sequence = 0; + + return EIPC_OK; +} + +eipc_status_t eipc_client_send_intent(eipc_client_t *c, + const char *intent, + float confidence) { + eipc_header_t hdr; + eipc_intent_event_t ev; + char payload_json[EIPC_MAX_PAYLOAD]; + eipc_status_t rc; + + if (!c || !c->connected || !intent) return EIPC_ERR_INVALID; + + c->sequence++; + + fill_header(&hdr, c->service_id, c->sequence); + + memset(&ev, 0, sizeof(ev)); + strncpy(ev.intent, intent, sizeof(ev.intent) - 1); + ev.intent[sizeof(ev.intent) - 1] = '\0'; + ev.confidence = confidence; + + size_t payload_written = 0; + rc = eipc_intent_to_json(&ev, payload_json, sizeof(payload_json), &payload_written); + if (rc != EIPC_OK) return rc; + + return build_and_send(c, EIPC_MSG_INTENT, &hdr, + (const uint8_t *)payload_json, strlen(payload_json)); +} + +eipc_status_t eipc_client_send_tool_request(eipc_client_t *c, + const char *tool, + const eipc_kv_t *args, + int arg_count) { + eipc_header_t hdr; + eipc_tool_request_t req; + char payload_json[EIPC_MAX_PAYLOAD]; + eipc_status_t rc; + + if (!c || !c->connected || !tool) return EIPC_ERR_INVALID; + + c->sequence++; + + fill_header(&hdr, c->service_id, c->sequence); + + memset(&req, 0, sizeof(req)); + strncpy(req.tool, tool, sizeof(req.tool) - 1); + req.tool[sizeof(req.tool) - 1] = '\0'; + if (args && arg_count > 0) { + int n = arg_count; + if (n > EIPC_MAX_ARGS) n = EIPC_MAX_ARGS; + memcpy(req.args, args, sizeof(eipc_kv_t) * (size_t)n); + req.arg_count = n; + } + + size_t payload_written = 0; + rc = eipc_tool_request_to_json(&req, payload_json, sizeof(payload_json), &payload_written); + if (rc != EIPC_OK) return rc; + + return build_and_send(c, EIPC_MSG_TOOL_REQUEST, &hdr, + (const uint8_t *)payload_json, strlen(payload_json)); +} + +eipc_status_t eipc_client_send_heartbeat(eipc_client_t *c) { + eipc_header_t hdr; + eipc_heartbeat_event_t hb; + char payload_json[EIPC_MAX_PAYLOAD]; + eipc_status_t rc; + + if (!c || !c->connected) return EIPC_ERR_INVALID; + + c->sequence++; + + fill_header(&hdr, c->service_id, c->sequence); + + memset(&hb, 0, sizeof(hb)); + strncpy(hb.service, c->service_id, sizeof(hb.service) - 1); + hb.service[sizeof(hb.service) - 1] = '\0'; + strncpy(hb.status, "alive", sizeof(hb.status) - 1); + hb.status[sizeof(hb.status) - 1] = '\0'; + + size_t payload_written = 0; + rc = eipc_heartbeat_to_json(&hb, payload_json, sizeof(payload_json), &payload_written); + if (rc != EIPC_OK) return rc; + + return build_and_send(c, EIPC_MSG_HEARTBEAT, &hdr, + (const uint8_t *)payload_json, strlen(payload_json)); +} + +eipc_status_t eipc_client_send_chat(eipc_client_t *c, + const eipc_chat_request_t *req) { + eipc_header_t hdr; + char payload_json[EIPC_MAX_PAYLOAD]; + eipc_status_t rc; + + if (!c || !c->connected || !req) return EIPC_ERR_INVALID; + + c->sequence++; + fill_header(&hdr, c->service_id, c->sequence); + strncpy(hdr.capability, "ai:chat", sizeof(hdr.capability) - 1); + hdr.capability[sizeof(hdr.capability) - 1] = '\0'; + + size_t chat_written = 0; + rc = eipc_chat_request_to_json(req, payload_json, sizeof(payload_json), &chat_written); + if (rc != EIPC_OK) return rc; + + return build_and_send(c, EIPC_MSG_CHAT, &hdr, + (const uint8_t *)payload_json, strlen(payload_json)); +} + +eipc_status_t eipc_client_send_complete(eipc_client_t *c, + const char *prompt, + const char *session_id) { + eipc_header_t hdr; + char payload_json[EIPC_MAX_PAYLOAD]; + int n; + + if (!c || !c->connected || !prompt) return EIPC_ERR_INVALID; + + c->sequence++; + fill_header(&hdr, c->service_id, c->sequence); + strncpy(hdr.capability, "ai:chat", sizeof(hdr.capability) - 1); + hdr.capability[sizeof(hdr.capability) - 1] = '\0'; + + n = snprintf(payload_json, sizeof(payload_json), + "{\"session_id\":\"%s\",\"prompt\":\"%s\",\"model\":\"\",\"max_tokens\":0}", + session_id ? session_id : "", prompt); + if (n < 0 || (size_t)n >= sizeof(payload_json)) return EIPC_ERR_FRAME_TOO_LARGE; + + return build_and_send(c, EIPC_MSG_COMPLETE, &hdr, + (const uint8_t *)payload_json, (size_t)n); +} + +eipc_status_t eipc_client_receive(eipc_client_t *c, eipc_message_t *msg) { + eipc_frame_t frame; + eipc_header_t hdr; + eipc_status_t rc; + + if (!c || !c->connected || !msg) return EIPC_ERR_INVALID; + + memset(&frame, 0, sizeof(frame)); + + rc = eipc_transport_recv_frame(c->sock, &frame); + if (rc != EIPC_OK) return rc; + + if (frame.flags & EIPC_FLAG_HMAC) { + uint8_t signable[EIPC_MAX_FRAME]; + size_t signable_len = eipc_frame_signable_bytes(&frame, signable, sizeof(signable)); + if (signable_len == 0) + return EIPC_ERR_INTEGRITY; + + if (!eipc_hmac_verify(c->hmac_key, c->hmac_key_len, + signable, signable_len, frame.mac)) + return EIPC_ERR_AUTH; + } + + memset(msg, 0, sizeof(*msg)); + msg->msg_type = frame.msg_type; + msg->version = frame.version; + + rc = eipc_header_from_json((const char *)frame.header, frame.header_len, &hdr); + if (rc != EIPC_OK) return rc; + + strncpy(msg->source, hdr.service_id, sizeof(msg->source) - 1); + msg->source[sizeof(msg->source) - 1] = '\0'; + strncpy(msg->session_id, hdr.session_id, sizeof(msg->session_id) - 1); + msg->session_id[sizeof(msg->session_id) - 1] = '\0'; + strncpy(msg->request_id, hdr.request_id, sizeof(msg->request_id) - 1); + msg->request_id[sizeof(msg->request_id) - 1] = '\0'; + msg->priority = hdr.priority; + strncpy(msg->capability, hdr.capability, sizeof(msg->capability) - 1); + msg->capability[sizeof(msg->capability) - 1] = '\0'; + + if (frame.payload_len > 0) { + if (frame.payload_len > sizeof(msg->payload)) + return EIPC_ERR_FRAME_TOO_LARGE; + memcpy(msg->payload, frame.payload, frame.payload_len); + msg->payload_len = frame.payload_len; + } + + return EIPC_OK; +} + +void eipc_client_close(eipc_client_t *c) { + if (!c) return; + + if (c->sock != EIPC_INVALID_SOCKET) { + eipc_transport_close(c->sock); + c->sock = EIPC_INVALID_SOCKET; + } + c->connected = false; + c->sequence = 0; +} diff --git a/security/auth/identity.go b/security/auth/identity.go index a9da571..d7a4786 100644 --- a/security/auth/identity.go +++ b/security/auth/identity.go @@ -1,226 +1,226 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package auth - -import ( - "crypto/hmac" - "crypto/rand" - "crypto/sha256" - "encoding/hex" - "fmt" - "sync" - "time" - - "github.com/embeddedos-org/eipc/core" -) - -// PeerIdentity represents an authenticated EIPC peer. -type PeerIdentity struct { - ServiceID string `json:"service_id"` - Capabilities []string `json:"capabilities"` - SessionToken string `json:"session_token"` - CreatedAt time.Time `json:"created_at"` - SessionTTL time.Duration `json:"-"` -} - -// IsExpired returns true if the session has exceeded its TTL. -func (p *PeerIdentity) IsExpired() bool { - if p.SessionTTL <= 0 { - return false - } - return time.Since(p.CreatedAt) > p.SessionTTL -} - -// Challenge represents a pending challenge-response authentication. -type Challenge struct { - ServiceID string - Nonce []byte - CreatedAt time.Time -} - -// Authenticator validates peer credentials and issues session tokens. -type Authenticator struct { - mu sync.RWMutex - sharedSecret []byte - knownServices map[string][]string // service_id → allowed capabilities - activeSessions map[string]*PeerIdentity - pendingChallenges map[string]*Challenge // serviceID → pending challenge - sessionTTL time.Duration -} - -// NewAuthenticator creates an authenticator with the given shared secret. -// knownServices maps service IDs to their allowed capability sets. -func NewAuthenticator(sharedSecret []byte, knownServices map[string][]string) *Authenticator { - known := make(map[string][]string, len(knownServices)) - for k, v := range knownServices { - caps := make([]string, len(v)) - copy(caps, v) - known[k] = caps - } - return &Authenticator{ - sharedSecret: sharedSecret, - knownServices: known, - activeSessions: make(map[string]*PeerIdentity), - pendingChallenges: make(map[string]*Challenge), - sessionTTL: 1 * time.Hour, - } -} - -// SetSessionTTL configures the default session TTL for new sessions. -func (a *Authenticator) SetSessionTTL(ttl time.Duration) { - a.mu.Lock() - defer a.mu.Unlock() - a.sessionTTL = ttl -} - -// CreateChallenge generates a 32-byte nonce challenge for the given service ID. -// Returns error if the service is unknown. -func (a *Authenticator) CreateChallenge(serviceID string) (*Challenge, error) { - a.mu.Lock() - defer a.mu.Unlock() - - if _, ok := a.knownServices[serviceID]; !ok { - return nil, fmt.Errorf("%w: unknown service %q", core.ErrAuth, serviceID) - } - - nonce := make([]byte, 32) - if _, err := rand.Read(nonce); err != nil { - return nil, fmt.Errorf("generate nonce: %w", err) - } - - challenge := &Challenge{ - ServiceID: serviceID, - Nonce: nonce, - CreatedAt: time.Now().UTC(), - } - a.pendingChallenges[serviceID] = challenge - return challenge, nil -} - -// VerifyResponse verifies the HMAC-SHA256 response to a challenge. -// On success, creates and returns a PeerIdentity with session token. -func (a *Authenticator) VerifyResponse(serviceID string, response []byte) (*PeerIdentity, error) { - a.mu.Lock() - defer a.mu.Unlock() - - challenge, ok := a.pendingChallenges[serviceID] - if !ok { - return nil, fmt.Errorf("%w: no pending challenge for %q", core.ErrAuth, serviceID) - } - delete(a.pendingChallenges, serviceID) - - // Compute expected response: HMAC-SHA256(sharedSecret, nonce) - mac := hmac.New(sha256.New, a.sharedSecret) - mac.Write(challenge.Nonce) - expected := mac.Sum(nil) - - if !hmac.Equal(expected, response) { - return nil, fmt.Errorf("%w: challenge-response verification failed for %q", core.ErrAuth, serviceID) - } - - capsSrc := a.knownServices[serviceID] - caps := make([]string, len(capsSrc)) - copy(caps, capsSrc) - - token, err := generateToken() - if err != nil { - return nil, fmt.Errorf("generate token: %w", err) - } - - peer := &PeerIdentity{ - ServiceID: serviceID, - Capabilities: caps, - SessionToken: token, - CreatedAt: time.Now().UTC(), - SessionTTL: a.sessionTTL, - } - a.activeSessions[token] = peer - return peer, nil -} - -// Authenticate validates a peer's service ID and returns a PeerIdentity -// with a fresh session token. (Legacy simple auth, kept for backward compat.) -func (a *Authenticator) Authenticate(serviceID string) (*PeerIdentity, error) { - a.mu.Lock() - defer a.mu.Unlock() - - capsSrc, ok := a.knownServices[serviceID] - if !ok { - return nil, fmt.Errorf("%w: unknown service %q", core.ErrAuth, serviceID) - } - - caps := make([]string, len(capsSrc)) - copy(caps, capsSrc) - - token, err := generateToken() - if err != nil { - return nil, fmt.Errorf("generate token: %w", err) - } - - peer := &PeerIdentity{ - ServiceID: serviceID, - Capabilities: caps, - SessionToken: token, - CreatedAt: time.Now().UTC(), - SessionTTL: a.sessionTTL, - } - a.activeSessions[token] = peer - return peer, nil -} - -// ValidateSession checks whether a session token is valid and not expired. -func (a *Authenticator) ValidateSession(token string) (*PeerIdentity, error) { - a.mu.RLock() - defer a.mu.RUnlock() - - peer, ok := a.activeSessions[token] - if !ok { - return nil, fmt.Errorf("%w: invalid session token", core.ErrAuth) - } - if peer.IsExpired() { - return nil, fmt.Errorf("%w: session expired", core.ErrAuth) - } - return peer, nil -} - -// SharedSecret returns the shared HMAC key for message signing. -func (a *Authenticator) SharedSecret() []byte { - return a.sharedSecret -} - -// RevokeSession removes a session. -func (a *Authenticator) RevokeSession(token string) { - a.mu.Lock() - defer a.mu.Unlock() - delete(a.activeSessions, token) -} - -// CleanupExpired removes all expired sessions. Returns count removed. -func (a *Authenticator) CleanupExpired() int { - a.mu.Lock() - defer a.mu.Unlock() - removed := 0 - for token, peer := range a.activeSessions { - if peer.IsExpired() { - delete(a.activeSessions, token) - removed++ - } - } - return removed -} - -// ActiveSessionCount returns the number of active sessions. -func (a *Authenticator) ActiveSessionCount() int { - a.mu.RLock() - defer a.mu.RUnlock() - return len(a.activeSessions) -} - -func generateToken() (string, error) { - b := make([]byte, 32) - if _, err := rand.Read(b); err != nil { - return "", err - } - return hex.EncodeToString(b), nil -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package auth + +import ( + "crypto/hmac" + "crypto/rand" + "crypto/sha256" + "encoding/hex" + "fmt" + "sync" + "time" + + "github.com/embeddedos-org/eipc/core" +) + +// PeerIdentity represents an authenticated EIPC peer. +type PeerIdentity struct { + ServiceID string `json:"service_id"` + Capabilities []string `json:"capabilities"` + SessionToken string `json:"session_token"` + CreatedAt time.Time `json:"created_at"` + SessionTTL time.Duration `json:"-"` +} + +// IsExpired returns true if the session has exceeded its TTL. +func (p *PeerIdentity) IsExpired() bool { + if p.SessionTTL <= 0 { + return false + } + return time.Since(p.CreatedAt) > p.SessionTTL +} + +// Challenge represents a pending challenge-response authentication. +type Challenge struct { + ServiceID string + Nonce []byte + CreatedAt time.Time +} + +// Authenticator validates peer credentials and issues session tokens. +type Authenticator struct { + mu sync.RWMutex + sharedSecret []byte + knownServices map[string][]string // service_id → allowed capabilities + activeSessions map[string]*PeerIdentity + pendingChallenges map[string]*Challenge // serviceID → pending challenge + sessionTTL time.Duration +} + +// NewAuthenticator creates an authenticator with the given shared secret. +// knownServices maps service IDs to their allowed capability sets. +func NewAuthenticator(sharedSecret []byte, knownServices map[string][]string) *Authenticator { + known := make(map[string][]string, len(knownServices)) + for k, v := range knownServices { + caps := make([]string, len(v)) + copy(caps, v) + known[k] = caps + } + return &Authenticator{ + sharedSecret: sharedSecret, + knownServices: known, + activeSessions: make(map[string]*PeerIdentity), + pendingChallenges: make(map[string]*Challenge), + sessionTTL: 1 * time.Hour, + } +} + +// SetSessionTTL configures the default session TTL for new sessions. +func (a *Authenticator) SetSessionTTL(ttl time.Duration) { + a.mu.Lock() + defer a.mu.Unlock() + a.sessionTTL = ttl +} + +// CreateChallenge generates a 32-byte nonce challenge for the given service ID. +// Returns error if the service is unknown. +func (a *Authenticator) CreateChallenge(serviceID string) (*Challenge, error) { + a.mu.Lock() + defer a.mu.Unlock() + + if _, ok := a.knownServices[serviceID]; !ok { + return nil, fmt.Errorf("%w: unknown service %q", core.ErrAuth, serviceID) + } + + nonce := make([]byte, 32) + if _, err := rand.Read(nonce); err != nil { + return nil, fmt.Errorf("generate nonce: %w", err) + } + + challenge := &Challenge{ + ServiceID: serviceID, + Nonce: nonce, + CreatedAt: time.Now().UTC(), + } + a.pendingChallenges[serviceID] = challenge + return challenge, nil +} + +// VerifyResponse verifies the HMAC-SHA256 response to a challenge. +// On success, creates and returns a PeerIdentity with session token. +func (a *Authenticator) VerifyResponse(serviceID string, response []byte) (*PeerIdentity, error) { + a.mu.Lock() + defer a.mu.Unlock() + + challenge, ok := a.pendingChallenges[serviceID] + if !ok { + return nil, fmt.Errorf("%w: no pending challenge for %q", core.ErrAuth, serviceID) + } + delete(a.pendingChallenges, serviceID) + + // Compute expected response: HMAC-SHA256(sharedSecret, nonce) + mac := hmac.New(sha256.New, a.sharedSecret) + mac.Write(challenge.Nonce) + expected := mac.Sum(nil) + + if !hmac.Equal(expected, response) { + return nil, fmt.Errorf("%w: challenge-response verification failed for %q", core.ErrAuth, serviceID) + } + + capsSrc := a.knownServices[serviceID] + caps := make([]string, len(capsSrc)) + copy(caps, capsSrc) + + token, err := generateToken() + if err != nil { + return nil, fmt.Errorf("generate token: %w", err) + } + + peer := &PeerIdentity{ + ServiceID: serviceID, + Capabilities: caps, + SessionToken: token, + CreatedAt: time.Now().UTC(), + SessionTTL: a.sessionTTL, + } + a.activeSessions[token] = peer + return peer, nil +} + +// Authenticate validates a peer's service ID and returns a PeerIdentity +// with a fresh session token. (Legacy simple auth, kept for backward compat.) +func (a *Authenticator) Authenticate(serviceID string) (*PeerIdentity, error) { + a.mu.Lock() + defer a.mu.Unlock() + + capsSrc, ok := a.knownServices[serviceID] + if !ok { + return nil, fmt.Errorf("%w: unknown service %q", core.ErrAuth, serviceID) + } + + caps := make([]string, len(capsSrc)) + copy(caps, capsSrc) + + token, err := generateToken() + if err != nil { + return nil, fmt.Errorf("generate token: %w", err) + } + + peer := &PeerIdentity{ + ServiceID: serviceID, + Capabilities: caps, + SessionToken: token, + CreatedAt: time.Now().UTC(), + SessionTTL: a.sessionTTL, + } + a.activeSessions[token] = peer + return peer, nil +} + +// ValidateSession checks whether a session token is valid and not expired. +func (a *Authenticator) ValidateSession(token string) (*PeerIdentity, error) { + a.mu.RLock() + defer a.mu.RUnlock() + + peer, ok := a.activeSessions[token] + if !ok { + return nil, fmt.Errorf("%w: invalid session token", core.ErrAuth) + } + if peer.IsExpired() { + return nil, fmt.Errorf("%w: session expired", core.ErrAuth) + } + return peer, nil +} + +// SharedSecret returns the shared HMAC key for message signing. +func (a *Authenticator) SharedSecret() []byte { + return a.sharedSecret +} + +// RevokeSession removes a session. +func (a *Authenticator) RevokeSession(token string) { + a.mu.Lock() + defer a.mu.Unlock() + delete(a.activeSessions, token) +} + +// CleanupExpired removes all expired sessions. Returns count removed. +func (a *Authenticator) CleanupExpired() int { + a.mu.Lock() + defer a.mu.Unlock() + removed := 0 + for token, peer := range a.activeSessions { + if peer.IsExpired() { + delete(a.activeSessions, token) + removed++ + } + } + return removed +} + +// ActiveSessionCount returns the number of active sessions. +func (a *Authenticator) ActiveSessionCount() int { + a.mu.RLock() + defer a.mu.RUnlock() + return len(a.activeSessions) +} + +func generateToken() (string, error) { + b := make([]byte, 32) + if _, err := rand.Read(b); err != nil { + return "", err + } + return hex.EncodeToString(b), nil +} diff --git a/security/encryption/aes_test.go b/security/encryption/aes_test.go index 7a613f1..a53f2e1 100644 --- a/security/encryption/aes_test.go +++ b/security/encryption/aes_test.go @@ -1,168 +1,168 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package encryption - -import ( - "bytes" - "crypto/rand" - "io" - "testing" -) - -func TestEncryptDecrypt(t *testing.T) { - key := make([]byte, KeySize) - if _, err := io.ReadFull(rand.Reader, key); err != nil { - t.Fatalf("rand.Read: %v", err) - } - - plaintext := []byte("hello EIPC encryption") - - ciphertext, err := Encrypt(key, plaintext) - if err != nil { - t.Fatalf("Encrypt: %v", err) - } - - if bytes.Equal(ciphertext[NonceSize:], plaintext) { - t.Error("ciphertext should not equal plaintext") - } - - decrypted, err := Decrypt(key, ciphertext) - if err != nil { - t.Fatalf("Decrypt: %v", err) - } - - if !bytes.Equal(decrypted, plaintext) { - t.Errorf("decrypted %q != plaintext %q", decrypted, plaintext) - } -} - -func TestEncrypt_WrongKeySize(t *testing.T) { - _, err := Encrypt([]byte("short"), []byte("data")) - if err == nil { - t.Fatal("expected error for wrong key size") - } -} - -func TestDecrypt_WrongKey(t *testing.T) { - key := make([]byte, KeySize) - if _, err := io.ReadFull(rand.Reader, key); err != nil { - t.Fatalf("rand.Read: %v", err) - } - - ciphertext, err := Encrypt(key, []byte("secret data")) - if err != nil { - t.Fatalf("Encrypt: %v", err) - } - - wrongKey := make([]byte, KeySize) - if _, err := io.ReadFull(rand.Reader, wrongKey); err != nil { - t.Fatalf("rand.Read: %v", err) - } - - _, err = Decrypt(wrongKey, ciphertext) - if err == nil { - t.Fatal("expected error for wrong key") - } -} - -func TestDecrypt_TamperedCiphertext(t *testing.T) { - key := make([]byte, KeySize) - if _, err := io.ReadFull(rand.Reader, key); err != nil { - t.Fatalf("rand.Read: %v", err) - } - - ciphertext, err := Encrypt(key, []byte("secret data")) - if err != nil { - t.Fatalf("Encrypt: %v", err) - } - - // Tamper with the ciphertext (not the nonce) - if len(ciphertext) > NonceSize+1 { - ciphertext[NonceSize+1] ^= 0xff - } - - _, err = Decrypt(key, ciphertext) - if err == nil { - t.Fatal("expected error for tampered ciphertext") - } -} - -func TestDecrypt_TooShort(t *testing.T) { - key := make([]byte, KeySize) - if _, err := io.ReadFull(rand.Reader, key); err != nil { - t.Fatalf("rand.Read: %v", err) - } - - _, err := Decrypt(key, []byte{1, 2, 3}) - if err == nil { - t.Fatal("expected error for ciphertext too short") - } -} - -func TestEncryptDecrypt_EmptyPlaintext(t *testing.T) { - key := make([]byte, KeySize) - if _, err := io.ReadFull(rand.Reader, key); err != nil { - t.Fatalf("rand.Read: %v", err) - } - - ciphertext, err := Encrypt(key, []byte{}) - if err != nil { - t.Fatalf("Encrypt empty: %v", err) - } - - decrypted, err := Decrypt(key, ciphertext) - if err != nil { - t.Fatalf("Decrypt empty: %v", err) - } - - if len(decrypted) != 0 { - t.Errorf("expected empty plaintext, got %d bytes", len(decrypted)) - } -} - -func TestEncryptDecrypt_LargePayload(t *testing.T) { - key := make([]byte, KeySize) - if _, err := io.ReadFull(rand.Reader, key); err != nil { - t.Fatalf("rand.Read: %v", err) - } - - plaintext := make([]byte, 1<<16) // 64KB - if _, err := io.ReadFull(rand.Reader, plaintext); err != nil { - t.Fatalf("rand.Read: %v", err) - } - - ciphertext, err := Encrypt(key, plaintext) - if err != nil { - t.Fatalf("Encrypt large: %v", err) - } - - decrypted, err := Decrypt(key, ciphertext) - if err != nil { - t.Fatalf("Decrypt large: %v", err) - } - - if !bytes.Equal(decrypted, plaintext) { - t.Error("large payload round-trip failed") - } -} - -func TestEncrypt_UniqueNonces(t *testing.T) { - key := make([]byte, KeySize) - if _, err := io.ReadFull(rand.Reader, key); err != nil { - t.Fatalf("rand.Read: %v", err) - } - - ct1, err := Encrypt(key, []byte("same data")) - if err != nil { - t.Fatalf("Encrypt 1: %v", err) - } - ct2, err := Encrypt(key, []byte("same data")) - if err != nil { - t.Fatalf("Encrypt 2: %v", err) - } - - if bytes.Equal(ct1, ct2) { - t.Error("two encryptions of same data should produce different ciphertexts (different nonces)") - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package encryption + +import ( + "bytes" + "crypto/rand" + "io" + "testing" +) + +func TestEncryptDecrypt(t *testing.T) { + key := make([]byte, KeySize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatalf("rand.Read: %v", err) + } + + plaintext := []byte("hello EIPC encryption") + + ciphertext, err := Encrypt(key, plaintext) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + + if bytes.Equal(ciphertext[NonceSize:], plaintext) { + t.Error("ciphertext should not equal plaintext") + } + + decrypted, err := Decrypt(key, ciphertext) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + + if !bytes.Equal(decrypted, plaintext) { + t.Errorf("decrypted %q != plaintext %q", decrypted, plaintext) + } +} + +func TestEncrypt_WrongKeySize(t *testing.T) { + _, err := Encrypt([]byte("short"), []byte("data")) + if err == nil { + t.Fatal("expected error for wrong key size") + } +} + +func TestDecrypt_WrongKey(t *testing.T) { + key := make([]byte, KeySize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatalf("rand.Read: %v", err) + } + + ciphertext, err := Encrypt(key, []byte("secret data")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + + wrongKey := make([]byte, KeySize) + if _, err := io.ReadFull(rand.Reader, wrongKey); err != nil { + t.Fatalf("rand.Read: %v", err) + } + + _, err = Decrypt(wrongKey, ciphertext) + if err == nil { + t.Fatal("expected error for wrong key") + } +} + +func TestDecrypt_TamperedCiphertext(t *testing.T) { + key := make([]byte, KeySize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatalf("rand.Read: %v", err) + } + + ciphertext, err := Encrypt(key, []byte("secret data")) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + + // Tamper with the ciphertext (not the nonce) + if len(ciphertext) > NonceSize+1 { + ciphertext[NonceSize+1] ^= 0xff + } + + _, err = Decrypt(key, ciphertext) + if err == nil { + t.Fatal("expected error for tampered ciphertext") + } +} + +func TestDecrypt_TooShort(t *testing.T) { + key := make([]byte, KeySize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatalf("rand.Read: %v", err) + } + + _, err := Decrypt(key, []byte{1, 2, 3}) + if err == nil { + t.Fatal("expected error for ciphertext too short") + } +} + +func TestEncryptDecrypt_EmptyPlaintext(t *testing.T) { + key := make([]byte, KeySize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatalf("rand.Read: %v", err) + } + + ciphertext, err := Encrypt(key, []byte{}) + if err != nil { + t.Fatalf("Encrypt empty: %v", err) + } + + decrypted, err := Decrypt(key, ciphertext) + if err != nil { + t.Fatalf("Decrypt empty: %v", err) + } + + if len(decrypted) != 0 { + t.Errorf("expected empty plaintext, got %d bytes", len(decrypted)) + } +} + +func TestEncryptDecrypt_LargePayload(t *testing.T) { + key := make([]byte, KeySize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatalf("rand.Read: %v", err) + } + + plaintext := make([]byte, 1<<16) // 64KB + if _, err := io.ReadFull(rand.Reader, plaintext); err != nil { + t.Fatalf("rand.Read: %v", err) + } + + ciphertext, err := Encrypt(key, plaintext) + if err != nil { + t.Fatalf("Encrypt large: %v", err) + } + + decrypted, err := Decrypt(key, ciphertext) + if err != nil { + t.Fatalf("Decrypt large: %v", err) + } + + if !bytes.Equal(decrypted, plaintext) { + t.Error("large payload round-trip failed") + } +} + +func TestEncrypt_UniqueNonces(t *testing.T) { + key := make([]byte, KeySize) + if _, err := io.ReadFull(rand.Reader, key); err != nil { + t.Fatalf("rand.Read: %v", err) + } + + ct1, err := Encrypt(key, []byte("same data")) + if err != nil { + t.Fatalf("Encrypt 1: %v", err) + } + ct2, err := Encrypt(key, []byte("same data")) + if err != nil { + t.Fatalf("Encrypt 2: %v", err) + } + + if bytes.Equal(ct1, ct2) { + t.Error("two encryptions of same data should produce different ciphertexts (different nonces)") + } +} diff --git a/security/keyring/keyring.go b/security/keyring/keyring.go index 44456ee..8744087 100644 --- a/security/keyring/keyring.go +++ b/security/keyring/keyring.go @@ -1,164 +1,164 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package keyring - -import ( - "crypto/rand" - "encoding/hex" - "fmt" - "sync" - "time" -) - -type KeyEntry struct { - ID string - Key []byte - CreatedAt time.Time - ExpiresAt time.Time - Revoked bool -} - -type Keyring struct { - mu sync.RWMutex - keys map[string]*KeyEntry -} - -func New() *Keyring { - return &Keyring{keys: make(map[string]*KeyEntry)} -} - -func (kr *Keyring) Generate(id string, keyLen int, ttl time.Duration) (*KeyEntry, error) { - if id == "" { - return nil, fmt.Errorf("key id is required") - } - if keyLen <= 0 { - keyLen = 32 - } - raw := make([]byte, keyLen) - if _, err := rand.Read(raw); err != nil { - return nil, fmt.Errorf("generate random key: %w", err) - } - now := time.Now().UTC() - entry := &KeyEntry{ID: id, Key: raw, CreatedAt: now} - if ttl > 0 { - entry.ExpiresAt = now.Add(ttl) - } - kr.mu.Lock() - defer kr.mu.Unlock() - kr.keys[id] = entry - return entry, nil -} - -func (kr *Keyring) Store(id string, key []byte, ttl time.Duration) error { - if id == "" || len(key) == 0 { - return fmt.Errorf("id and key are required") - } - now := time.Now().UTC() - entry := &KeyEntry{ID: id, Key: make([]byte, len(key)), CreatedAt: now} - copy(entry.Key, key) - if ttl > 0 { - entry.ExpiresAt = now.Add(ttl) - } - kr.mu.Lock() - defer kr.mu.Unlock() - kr.keys[id] = entry - return nil -} - -func (kr *Keyring) Lookup(id string) (*KeyEntry, error) { - kr.mu.RLock() - defer kr.mu.RUnlock() - entry, ok := kr.keys[id] - if !ok { - return nil, fmt.Errorf("key %q not found", id) - } - if entry.Revoked { - return nil, fmt.Errorf("key %q has been revoked", id) - } - if !entry.ExpiresAt.IsZero() && time.Now().After(entry.ExpiresAt) { - return nil, fmt.Errorf("key %q has expired", id) - } - return entry, nil -} - -func (kr *Keyring) Revoke(id string) error { - kr.mu.Lock() - defer kr.mu.Unlock() - entry, ok := kr.keys[id] - if !ok { - return fmt.Errorf("key %q not found", id) - } - entry.Revoked = true - return nil -} - -func (kr *Keyring) Delete(id string) { - kr.mu.Lock() - defer kr.mu.Unlock() - delete(kr.keys, id) -} - -func (kr *Keyring) Rotate(id string, keyLen int, ttl time.Duration) (*KeyEntry, error) { - oldID := id + ".prev." + hex.EncodeToString(randomBytes(4)) - kr.mu.Lock() - if old, ok := kr.keys[id]; ok { - old.Revoked = true - kr.keys[oldID] = old - delete(kr.keys, id) - } - - if keyLen <= 0 { - keyLen = 32 - } - raw := make([]byte, keyLen) - if _, err := rand.Read(raw); err != nil { - kr.mu.Unlock() - return nil, fmt.Errorf("generate random key: %w", err) - } - now := time.Now().UTC() - entry := &KeyEntry{ID: id, Key: raw, CreatedAt: now} - if ttl > 0 { - entry.ExpiresAt = now.Add(ttl) - } - kr.keys[id] = entry - kr.mu.Unlock() - return entry, nil -} - -func (kr *Keyring) ListActive() []KeyEntry { - kr.mu.RLock() - defer kr.mu.RUnlock() - now := time.Now() - var result []KeyEntry - for _, e := range kr.keys { - if e.Revoked { - continue - } - if !e.ExpiresAt.IsZero() && now.After(e.ExpiresAt) { - continue - } - result = append(result, *e) - } - return result -} - -func (kr *Keyring) Cleanup() int { - kr.mu.Lock() - defer kr.mu.Unlock() - now := time.Now() - removed := 0 - for id, e := range kr.keys { - if e.Revoked || (!e.ExpiresAt.IsZero() && now.After(e.ExpiresAt)) { - delete(kr.keys, id) - removed++ - } - } - return removed -} - -func randomBytes(n int) []byte { - b := make([]byte, n) - _ = rand.Read(b) - return b -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package keyring + +import ( + "crypto/rand" + "encoding/hex" + "fmt" + "sync" + "time" +) + +type KeyEntry struct { + ID string + Key []byte + CreatedAt time.Time + ExpiresAt time.Time + Revoked bool +} + +type Keyring struct { + mu sync.RWMutex + keys map[string]*KeyEntry +} + +func New() *Keyring { + return &Keyring{keys: make(map[string]*KeyEntry)} +} + +func (kr *Keyring) Generate(id string, keyLen int, ttl time.Duration) (*KeyEntry, error) { + if id == "" { + return nil, fmt.Errorf("key id is required") + } + if keyLen <= 0 { + keyLen = 32 + } + raw := make([]byte, keyLen) + if _, err := rand.Read(raw); err != nil { + return nil, fmt.Errorf("generate random key: %w", err) + } + now := time.Now().UTC() + entry := &KeyEntry{ID: id, Key: raw, CreatedAt: now} + if ttl > 0 { + entry.ExpiresAt = now.Add(ttl) + } + kr.mu.Lock() + defer kr.mu.Unlock() + kr.keys[id] = entry + return entry, nil +} + +func (kr *Keyring) Store(id string, key []byte, ttl time.Duration) error { + if id == "" || len(key) == 0 { + return fmt.Errorf("id and key are required") + } + now := time.Now().UTC() + entry := &KeyEntry{ID: id, Key: make([]byte, len(key)), CreatedAt: now} + copy(entry.Key, key) + if ttl > 0 { + entry.ExpiresAt = now.Add(ttl) + } + kr.mu.Lock() + defer kr.mu.Unlock() + kr.keys[id] = entry + return nil +} + +func (kr *Keyring) Lookup(id string) (*KeyEntry, error) { + kr.mu.RLock() + defer kr.mu.RUnlock() + entry, ok := kr.keys[id] + if !ok { + return nil, fmt.Errorf("key %q not found", id) + } + if entry.Revoked { + return nil, fmt.Errorf("key %q has been revoked", id) + } + if !entry.ExpiresAt.IsZero() && time.Now().After(entry.ExpiresAt) { + return nil, fmt.Errorf("key %q has expired", id) + } + return entry, nil +} + +func (kr *Keyring) Revoke(id string) error { + kr.mu.Lock() + defer kr.mu.Unlock() + entry, ok := kr.keys[id] + if !ok { + return fmt.Errorf("key %q not found", id) + } + entry.Revoked = true + return nil +} + +func (kr *Keyring) Delete(id string) { + kr.mu.Lock() + defer kr.mu.Unlock() + delete(kr.keys, id) +} + +func (kr *Keyring) Rotate(id string, keyLen int, ttl time.Duration) (*KeyEntry, error) { + oldID := id + ".prev." + hex.EncodeToString(randomBytes(4)) + kr.mu.Lock() + if old, ok := kr.keys[id]; ok { + old.Revoked = true + kr.keys[oldID] = old + delete(kr.keys, id) + } + + if keyLen <= 0 { + keyLen = 32 + } + raw := make([]byte, keyLen) + if _, err := rand.Read(raw); err != nil { + kr.mu.Unlock() + return nil, fmt.Errorf("generate random key: %w", err) + } + now := time.Now().UTC() + entry := &KeyEntry{ID: id, Key: raw, CreatedAt: now} + if ttl > 0 { + entry.ExpiresAt = now.Add(ttl) + } + kr.keys[id] = entry + kr.mu.Unlock() + return entry, nil +} + +func (kr *Keyring) ListActive() []KeyEntry { + kr.mu.RLock() + defer kr.mu.RUnlock() + now := time.Now() + var result []KeyEntry + for _, e := range kr.keys { + if e.Revoked { + continue + } + if !e.ExpiresAt.IsZero() && now.After(e.ExpiresAt) { + continue + } + result = append(result, *e) + } + return result +} + +func (kr *Keyring) Cleanup() int { + kr.mu.Lock() + defer kr.mu.Unlock() + now := time.Now() + removed := 0 + for id, e := range kr.keys { + if e.Revoked || (!e.ExpiresAt.IsZero() && now.After(e.ExpiresAt)) { + delete(kr.keys, id) + removed++ + } + } + return removed +} + +func randomBytes(n int) []byte { + b := make([]byte, n) + _ = rand.Read(b) + return b +} diff --git a/security/keyring/keyring_test.go b/security/keyring/keyring_test.go index c941360..182b1d7 100644 --- a/security/keyring/keyring_test.go +++ b/security/keyring/keyring_test.go @@ -1,146 +1,146 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package keyring - -import ( - "testing" - "time" -) - -func TestGenerateAndLookup(t *testing.T) { - kr := New() - entry, err := kr.Generate("hmac-key-1", 32, 0) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - if len(entry.Key) != 32 { - t.Errorf("expected 32-byte key, got %d", len(entry.Key)) - } - found, err := kr.Lookup("hmac-key-1") - if err != nil { - t.Fatalf("Lookup failed: %v", err) - } - if found.ID != entry.ID { - t.Error("lookup returned wrong entry") - } -} - -func TestLookupNotFound(t *testing.T) { - kr := New() - _, err := kr.Lookup("nonexistent") - if err == nil { - t.Error("expected error for nonexistent key") - } -} - -func TestRevoke(t *testing.T) { - kr := New() - if _, err := kr.Generate("key1", 16, 0); err != nil { - t.Fatalf("Generate failed: %v", err) - } - if err := kr.Revoke("key1"); err != nil { - t.Fatalf("Revoke failed: %v", err) - } - _, err := kr.Lookup("key1") - if err == nil { - t.Error("expected error for revoked key") - } -} - -func TestExpiry(t *testing.T) { - kr := New() - if _, err := kr.Generate("short", 16, 1*time.Millisecond); err != nil { - t.Fatalf("Generate failed: %v", err) - } - time.Sleep(5 * time.Millisecond) - _, err := kr.Lookup("short") - if err == nil { - t.Error("expected error for expired key") - } -} - -func TestStore(t *testing.T) { - kr := New() - key := []byte("my-secret-key-32-bytes-00000000") - if err := kr.Store("ext", key, 0); err != nil { - t.Fatalf("Store failed: %v", err) - } - found, err := kr.Lookup("ext") - if err != nil { - t.Fatalf("Lookup failed: %v", err) - } - if string(found.Key) != string(key) { - t.Error("stored key mismatch") - } -} - -func TestRotate(t *testing.T) { - kr := New() - if _, err := kr.Generate("rot", 32, 0); err != nil { - t.Fatalf("Generate failed: %v", err) - } - newEntry, err := kr.Rotate("rot", 32, 0) - if err != nil { - t.Fatalf("Rotate failed: %v", err) - } - found, err := kr.Lookup("rot") - if err != nil { - t.Fatalf("Lookup failed: %v", err) - } - if string(found.Key) != string(newEntry.Key) { - t.Error("rotated key mismatch") - } -} - -func TestListActive(t *testing.T) { - kr := New() - if _, err := kr.Generate("a1", 16, 0); err != nil { - t.Fatalf("Generate failed: %v", err) - } - if _, err := kr.Generate("a2", 16, 0); err != nil { - t.Fatalf("Generate failed: %v", err) - } - if _, err := kr.Generate("r1", 16, 0); err != nil { - t.Fatalf("Generate failed: %v", err) - } - if err := kr.Revoke("r1"); err != nil { - t.Fatalf("Revoke failed: %v", err) - } - if len(kr.ListActive()) != 2 { - t.Errorf("expected 2 active, got %d", len(kr.ListActive())) - } -} - -func TestCleanup(t *testing.T) { - kr := New() - if _, err := kr.Generate("keep", 16, 0); err != nil { - t.Fatalf("Generate failed: %v", err) - } - if _, err := kr.Generate("exp", 16, 1*time.Millisecond); err != nil { - t.Fatalf("Generate failed: %v", err) - } - if _, err := kr.Generate("rev", 16, 0); err != nil { - t.Fatalf("Generate failed: %v", err) - } - if err := kr.Revoke("rev"); err != nil { - t.Fatalf("Revoke failed: %v", err) - } - time.Sleep(5 * time.Millisecond) - removed := kr.Cleanup() - if removed != 2 { - t.Errorf("expected 2 removed, got %d", removed) - } -} - -func TestDelete(t *testing.T) { - kr := New() - if _, err := kr.Generate("del", 16, 0); err != nil { - t.Fatalf("Generate failed: %v", err) - } - kr.Delete("del") - _, err := kr.Lookup("del") - if err == nil { - t.Error("expected error for deleted key") - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package keyring + +import ( + "testing" + "time" +) + +func TestGenerateAndLookup(t *testing.T) { + kr := New() + entry, err := kr.Generate("hmac-key-1", 32, 0) + if err != nil { + t.Fatalf("Generate failed: %v", err) + } + if len(entry.Key) != 32 { + t.Errorf("expected 32-byte key, got %d", len(entry.Key)) + } + found, err := kr.Lookup("hmac-key-1") + if err != nil { + t.Fatalf("Lookup failed: %v", err) + } + if found.ID != entry.ID { + t.Error("lookup returned wrong entry") + } +} + +func TestLookupNotFound(t *testing.T) { + kr := New() + _, err := kr.Lookup("nonexistent") + if err == nil { + t.Error("expected error for nonexistent key") + } +} + +func TestRevoke(t *testing.T) { + kr := New() + if _, err := kr.Generate("key1", 16, 0); err != nil { + t.Fatalf("Generate failed: %v", err) + } + if err := kr.Revoke("key1"); err != nil { + t.Fatalf("Revoke failed: %v", err) + } + _, err := kr.Lookup("key1") + if err == nil { + t.Error("expected error for revoked key") + } +} + +func TestExpiry(t *testing.T) { + kr := New() + if _, err := kr.Generate("short", 16, 1*time.Millisecond); err != nil { + t.Fatalf("Generate failed: %v", err) + } + time.Sleep(5 * time.Millisecond) + _, err := kr.Lookup("short") + if err == nil { + t.Error("expected error for expired key") + } +} + +func TestStore(t *testing.T) { + kr := New() + key := []byte("my-secret-key-32-bytes-00000000") + if err := kr.Store("ext", key, 0); err != nil { + t.Fatalf("Store failed: %v", err) + } + found, err := kr.Lookup("ext") + if err != nil { + t.Fatalf("Lookup failed: %v", err) + } + if string(found.Key) != string(key) { + t.Error("stored key mismatch") + } +} + +func TestRotate(t *testing.T) { + kr := New() + if _, err := kr.Generate("rot", 32, 0); err != nil { + t.Fatalf("Generate failed: %v", err) + } + newEntry, err := kr.Rotate("rot", 32, 0) + if err != nil { + t.Fatalf("Rotate failed: %v", err) + } + found, err := kr.Lookup("rot") + if err != nil { + t.Fatalf("Lookup failed: %v", err) + } + if string(found.Key) != string(newEntry.Key) { + t.Error("rotated key mismatch") + } +} + +func TestListActive(t *testing.T) { + kr := New() + if _, err := kr.Generate("a1", 16, 0); err != nil { + t.Fatalf("Generate failed: %v", err) + } + if _, err := kr.Generate("a2", 16, 0); err != nil { + t.Fatalf("Generate failed: %v", err) + } + if _, err := kr.Generate("r1", 16, 0); err != nil { + t.Fatalf("Generate failed: %v", err) + } + if err := kr.Revoke("r1"); err != nil { + t.Fatalf("Revoke failed: %v", err) + } + if len(kr.ListActive()) != 2 { + t.Errorf("expected 2 active, got %d", len(kr.ListActive())) + } +} + +func TestCleanup(t *testing.T) { + kr := New() + if _, err := kr.Generate("keep", 16, 0); err != nil { + t.Fatalf("Generate failed: %v", err) + } + if _, err := kr.Generate("exp", 16, 1*time.Millisecond); err != nil { + t.Fatalf("Generate failed: %v", err) + } + if _, err := kr.Generate("rev", 16, 0); err != nil { + t.Fatalf("Generate failed: %v", err) + } + if err := kr.Revoke("rev"); err != nil { + t.Fatalf("Revoke failed: %v", err) + } + time.Sleep(5 * time.Millisecond) + removed := kr.Cleanup() + if removed != 2 { + t.Errorf("expected 2 removed, got %d", removed) + } +} + +func TestDelete(t *testing.T) { + kr := New() + if _, err := kr.Generate("del", 16, 0); err != nil { + t.Fatalf("Generate failed: %v", err) + } + kr.Delete("del") + _, err := kr.Lookup("del") + if err == nil { + t.Error("expected error for deleted key") + } +} diff --git a/services/audit/audit.go b/services/audit/audit.go index e469884..9535451 100644 --- a/services/audit/audit.go +++ b/services/audit/audit.go @@ -1,84 +1,84 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package audit - -import ( - "encoding/json" - "fmt" - "io" - "os" - "path/filepath" - "sync" - "time" -) - -// Logger defines the audit logging interface. -type Logger interface { - Log(entry Entry) error - Close() error -} - -// Entry is a single audit log record. -type Entry struct { - Timestamp string `json:"timestamp"` - RequestID string `json:"request_id"` - Source string `json:"source"` - Target string `json:"target"` - Action string `json:"action"` - Decision string `json:"decision"` - Result string `json:"result"` -} - -// FileLogger writes audit entries as JSON lines to a file. -type FileLogger struct { - mu sync.Mutex - writer io.WriteCloser -} - -// NewFileLogger creates an audit logger that writes to the given file path. -// If path is empty, it writes to stdout. -func NewFileLogger(path string) (*FileLogger, error) { - var w io.WriteCloser - if path == "" { - w = os.Stdout - } else { - path = filepath.Clean(path) - f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) - if err != nil { - return nil, fmt.Errorf("open audit log: %w", err) - } - w = f - } - return &FileLogger{writer: w}, nil -} - -// Log writes an audit entry as a JSON line. -func (l *FileLogger) Log(entry Entry) error { - if entry.Timestamp == "" { - entry.Timestamp = time.Now().UTC().Format(time.RFC3339Nano) - } - - data, err := json.Marshal(entry) - if err != nil { - return fmt.Errorf("marshal audit entry: %w", err) - } - - l.mu.Lock() - defer l.mu.Unlock() - - if _, err := l.writer.Write(append(data, '\n')); err != nil { - return fmt.Errorf("write audit entry: %w", err) - } - return nil -} - -// Close closes the underlying writer. -func (l *FileLogger) Close() error { - l.mu.Lock() - defer l.mu.Unlock() - if l.writer != os.Stdout { - return l.writer.Close() - } - return nil -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package audit + +import ( + "encoding/json" + "fmt" + "io" + "os" + "path/filepath" + "sync" + "time" +) + +// Logger defines the audit logging interface. +type Logger interface { + Log(entry Entry) error + Close() error +} + +// Entry is a single audit log record. +type Entry struct { + Timestamp string `json:"timestamp"` + RequestID string `json:"request_id"` + Source string `json:"source"` + Target string `json:"target"` + Action string `json:"action"` + Decision string `json:"decision"` + Result string `json:"result"` +} + +// FileLogger writes audit entries as JSON lines to a file. +type FileLogger struct { + mu sync.Mutex + writer io.WriteCloser +} + +// NewFileLogger creates an audit logger that writes to the given file path. +// If path is empty, it writes to stdout. +func NewFileLogger(path string) (*FileLogger, error) { + var w io.WriteCloser + if path == "" { + w = os.Stdout + } else { + path = filepath.Clean(path) + f, err := os.OpenFile(path, os.O_CREATE|os.O_APPEND|os.O_WRONLY, 0600) + if err != nil { + return nil, fmt.Errorf("open audit log: %w", err) + } + w = f + } + return &FileLogger{writer: w}, nil +} + +// Log writes an audit entry as a JSON line. +func (l *FileLogger) Log(entry Entry) error { + if entry.Timestamp == "" { + entry.Timestamp = time.Now().UTC().Format(time.RFC3339Nano) + } + + data, err := json.Marshal(entry) + if err != nil { + return fmt.Errorf("marshal audit entry: %w", err) + } + + l.mu.Lock() + defer l.mu.Unlock() + + if _, err := l.writer.Write(append(data, '\n')); err != nil { + return fmt.Errorf("write audit entry: %w", err) + } + return nil +} + +// Close closes the underlying writer. +func (l *FileLogger) Close() error { + l.mu.Lock() + defer l.mu.Unlock() + if l.writer != os.Stdout { + return l.writer.Close() + } + return nil +} diff --git a/services/audit/audit_test.go b/services/audit/audit_test.go index c814a45..756b80e 100644 --- a/services/audit/audit_test.go +++ b/services/audit/audit_test.go @@ -1,142 +1,142 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package audit - -import ( - "encoding/json" - "os" - "path/filepath" - "strings" - "testing" -) - -func TestFileLogger_Stdout(t *testing.T) { - logger, err := NewFileLogger("") - if err != nil { - t.Fatalf("NewFileLogger stdout: %v", err) - } - defer logger.Close() - - err = logger.Log(Entry{ - RequestID: "req-1", - Source: "test", - Target: "server", - Action: "test.action", - Decision: "allowed", - Result: "ok", - }) - if err != nil { - t.Errorf("Log to stdout: %v", err) - } -} - -func TestFileLogger_File(t *testing.T) { - tmpFile := filepath.Join(t.TempDir(), "audit.jsonl") - logger, err := NewFileLogger(tmpFile) - if err != nil { - t.Fatalf("NewFileLogger file: %v", err) - } - - err = logger.Log(Entry{ - RequestID: "req-1", - Source: "test.client", - Target: "eipc-server", - Action: "authenticate", - Decision: "allowed", - Result: "session created", - }) - if err != nil { - t.Fatalf("Log: %v", err) - } - - err = logger.Log(Entry{ - RequestID: "req-2", - Source: "test.client", - Target: "eipc-server", - Action: "ui.cursor.move", - Decision: "denied", - Result: "capability violation", - }) - if err != nil { - t.Fatalf("Log second entry: %v", err) - } - - logger.Close() - - data, err := os.ReadFile(tmpFile) - if err != nil { - t.Fatalf("read audit file: %v", err) - } - - lines := strings.Split(strings.TrimSpace(string(data)), "\n") - if len(lines) != 2 { - t.Fatalf("expected 2 lines, got %d", len(lines)) - } - - var entry Entry - if err := json.Unmarshal([]byte(lines[0]), &entry); err != nil { - t.Fatalf("unmarshal first line: %v", err) - } - if entry.RequestID != "req-1" { - t.Errorf("expected request_id 'req-1', got %q", entry.RequestID) - } - if entry.Timestamp == "" { - t.Error("expected auto-filled timestamp") - } -} - -func TestFileLogger_TimestampAutoFill(t *testing.T) { - tmpFile := filepath.Join(t.TempDir(), "audit-ts.jsonl") - logger, err := NewFileLogger(tmpFile) - if err != nil { - t.Fatalf("NewFileLogger: %v", err) - } - - _ = logger.Log(Entry{Source: "test", Action: "check"}) - logger.Close() - - data, err := os.ReadFile(tmpFile) - if err != nil { - t.Fatalf("read audit file: %v", err) - } - var entry Entry - if err := json.Unmarshal([]byte(strings.TrimSpace(string(data))), &entry); err != nil { - t.Fatalf("unmarshal: %v", err) - } - if entry.Timestamp == "" { - t.Error("timestamp should be auto-filled when empty") - } -} - -func TestFileLogger_PresetTimestamp(t *testing.T) { - tmpFile := filepath.Join(t.TempDir(), "audit-preset.jsonl") - logger, err := NewFileLogger(tmpFile) - if err != nil { - t.Fatalf("NewFileLogger: %v", err) - } - - customTS := "2026-01-01T00:00:00Z" - _ = logger.Log(Entry{Timestamp: customTS, Source: "test", Action: "check"}) - logger.Close() - - data, err := os.ReadFile(tmpFile) - if err != nil { - t.Fatalf("read audit file: %v", err) - } - var entry Entry - if err := json.Unmarshal([]byte(strings.TrimSpace(string(data))), &entry); err != nil { - t.Fatalf("unmarshal: %v", err) - } - if entry.Timestamp != customTS { - t.Errorf("expected preset timestamp %q, got %q", customTS, entry.Timestamp) - } -} - -func TestFileLogger_Close_Stdout(t *testing.T) { - logger, _ := NewFileLogger("") - err := logger.Close() - if err != nil { - t.Errorf("closing stdout logger should not error: %v", err) - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package audit + +import ( + "encoding/json" + "os" + "path/filepath" + "strings" + "testing" +) + +func TestFileLogger_Stdout(t *testing.T) { + logger, err := NewFileLogger("") + if err != nil { + t.Fatalf("NewFileLogger stdout: %v", err) + } + defer logger.Close() + + err = logger.Log(Entry{ + RequestID: "req-1", + Source: "test", + Target: "server", + Action: "test.action", + Decision: "allowed", + Result: "ok", + }) + if err != nil { + t.Errorf("Log to stdout: %v", err) + } +} + +func TestFileLogger_File(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "audit.jsonl") + logger, err := NewFileLogger(tmpFile) + if err != nil { + t.Fatalf("NewFileLogger file: %v", err) + } + + err = logger.Log(Entry{ + RequestID: "req-1", + Source: "test.client", + Target: "eipc-server", + Action: "authenticate", + Decision: "allowed", + Result: "session created", + }) + if err != nil { + t.Fatalf("Log: %v", err) + } + + err = logger.Log(Entry{ + RequestID: "req-2", + Source: "test.client", + Target: "eipc-server", + Action: "ui.cursor.move", + Decision: "denied", + Result: "capability violation", + }) + if err != nil { + t.Fatalf("Log second entry: %v", err) + } + + logger.Close() + + data, err := os.ReadFile(tmpFile) + if err != nil { + t.Fatalf("read audit file: %v", err) + } + + lines := strings.Split(strings.TrimSpace(string(data)), "\n") + if len(lines) != 2 { + t.Fatalf("expected 2 lines, got %d", len(lines)) + } + + var entry Entry + if err := json.Unmarshal([]byte(lines[0]), &entry); err != nil { + t.Fatalf("unmarshal first line: %v", err) + } + if entry.RequestID != "req-1" { + t.Errorf("expected request_id 'req-1', got %q", entry.RequestID) + } + if entry.Timestamp == "" { + t.Error("expected auto-filled timestamp") + } +} + +func TestFileLogger_TimestampAutoFill(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "audit-ts.jsonl") + logger, err := NewFileLogger(tmpFile) + if err != nil { + t.Fatalf("NewFileLogger: %v", err) + } + + _ = logger.Log(Entry{Source: "test", Action: "check"}) + logger.Close() + + data, err := os.ReadFile(tmpFile) + if err != nil { + t.Fatalf("read audit file: %v", err) + } + var entry Entry + if err := json.Unmarshal([]byte(strings.TrimSpace(string(data))), &entry); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if entry.Timestamp == "" { + t.Error("timestamp should be auto-filled when empty") + } +} + +func TestFileLogger_PresetTimestamp(t *testing.T) { + tmpFile := filepath.Join(t.TempDir(), "audit-preset.jsonl") + logger, err := NewFileLogger(tmpFile) + if err != nil { + t.Fatalf("NewFileLogger: %v", err) + } + + customTS := "2026-01-01T00:00:00Z" + _ = logger.Log(Entry{Timestamp: customTS, Source: "test", Action: "check"}) + logger.Close() + + data, err := os.ReadFile(tmpFile) + if err != nil { + t.Fatalf("read audit file: %v", err) + } + var entry Entry + if err := json.Unmarshal([]byte(strings.TrimSpace(string(data))), &entry); err != nil { + t.Fatalf("unmarshal: %v", err) + } + if entry.Timestamp != customTS { + t.Errorf("expected preset timestamp %q, got %q", customTS, entry.Timestamp) + } +} + +func TestFileLogger_Close_Stdout(t *testing.T) { + logger, _ := NewFileLogger("") + err := logger.Close() + if err != nil { + t.Errorf("closing stdout logger should not error: %v", err) + } +} diff --git a/services/broker/broker.go b/services/broker/broker.go index 6150cf9..7127aaf 100644 --- a/services/broker/broker.go +++ b/services/broker/broker.go @@ -1,161 +1,161 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package broker - -import ( - "fmt" - "sync" - - "github.com/embeddedos-org/eipc/core" - "github.com/embeddedos-org/eipc/services/audit" - "github.com/embeddedos-org/eipc/services/registry" -) - -type Subscriber struct { - ServiceID string - Capabilities []string - Endpoint core.Endpoint - Priority core.Priority -} - -type Broker struct { - mu sync.RWMutex - subscribers map[string]*Subscriber - routes map[core.MessageType][]string - registry *registry.Registry - audit audit.Logger - router *core.Router - running bool -} - -func NewBroker(reg *registry.Registry, auditLogger audit.Logger) *Broker { - return &Broker{ - subscribers: make(map[string]*Subscriber), - routes: make(map[core.MessageType][]string), - registry: reg, - audit: auditLogger, - router: core.NewRouter(), - } -} - -func (b *Broker) Subscribe(sub *Subscriber) error { - if sub == nil || sub.ServiceID == "" || sub.Endpoint == nil { - return fmt.Errorf("invalid subscriber: service_id and endpoint required") - } - b.mu.Lock() - defer b.mu.Unlock() - b.subscribers[sub.ServiceID] = sub - return nil -} - -func (b *Broker) Unsubscribe(serviceID string) { - b.mu.Lock() - defer b.mu.Unlock() - if sub, ok := b.subscribers[serviceID]; ok { - sub.Endpoint.Close() - delete(b.subscribers, serviceID) - } -} - -func (b *Broker) AddRoute(msgType core.MessageType, targets ...string) { - b.mu.Lock() - defer b.mu.Unlock() - existing := b.routes[msgType] - seen := make(map[string]bool) - for _, t := range existing { - seen[t] = true - } - for _, t := range targets { - if !seen[t] { - existing = append(existing, t) - seen[t] = true - } - } - b.routes[msgType] = existing -} - -func (b *Broker) RemoveRoute(msgType core.MessageType, target string) { - b.mu.Lock() - defer b.mu.Unlock() - targets := b.routes[msgType] - for i, t := range targets { - if t == target { - b.routes[msgType] = append(targets[:i], targets[i+1:]...) - return - } - } -} - -func (b *Broker) Route(msg core.Message) []RouteResult { - b.mu.RLock() - targets := b.routes[msg.Type] - subs := make([]*Subscriber, 0, len(targets)) - for _, t := range targets { - if sub, ok := b.subscribers[t]; ok { - subs = append(subs, sub) - } - } - b.mu.RUnlock() - if len(subs) == 0 { - return nil - } - sortByPriority(subs) - results := make([]RouteResult, len(subs)) - for i, sub := range subs { - err := sub.Endpoint.Send(msg) - results[i] = RouteResult{ServiceID: sub.ServiceID, Err: err} - if b.audit != nil { - decision := "delivered" - result := "ok" - if err != nil { - decision = "failed" - result = err.Error() - } - _ = b.audit.Log(audit.Entry{ - RequestID: msg.RequestID, Source: msg.Source, - Target: sub.ServiceID, Action: string(msg.Type), - Decision: decision, Result: result, - }) - } - } - return results -} - -func (b *Broker) Fanout(msg core.Message) []RouteResult { - b.mu.RLock() - subs := make([]*Subscriber, 0, len(b.subscribers)) - for _, sub := range b.subscribers { - subs = append(subs, sub) - } - b.mu.RUnlock() - results := make([]RouteResult, len(subs)) - for i, sub := range subs { - err := sub.Endpoint.Send(msg) - results[i] = RouteResult{ServiceID: sub.ServiceID, Err: err} - } - return results -} - -func (b *Broker) Subscribers() []string { - b.mu.RLock() - defer b.mu.RUnlock() - ids := make([]string, 0, len(b.subscribers)) - for id := range b.subscribers { - ids = append(ids, id) - } - return ids -} - -type RouteResult struct { - ServiceID string - Err error -} - -func sortByPriority(subs []*Subscriber) { - for i := 1; i < len(subs); i++ { - for j := i; j > 0 && subs[j].Priority < subs[j-1].Priority; j-- { - subs[j], subs[j-1] = subs[j-1], subs[j] - } - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package broker + +import ( + "fmt" + "sync" + + "github.com/embeddedos-org/eipc/core" + "github.com/embeddedos-org/eipc/services/audit" + "github.com/embeddedos-org/eipc/services/registry" +) + +type Subscriber struct { + ServiceID string + Capabilities []string + Endpoint core.Endpoint + Priority core.Priority +} + +type Broker struct { + mu sync.RWMutex + subscribers map[string]*Subscriber + routes map[core.MessageType][]string + registry *registry.Registry + audit audit.Logger + router *core.Router + running bool +} + +func NewBroker(reg *registry.Registry, auditLogger audit.Logger) *Broker { + return &Broker{ + subscribers: make(map[string]*Subscriber), + routes: make(map[core.MessageType][]string), + registry: reg, + audit: auditLogger, + router: core.NewRouter(), + } +} + +func (b *Broker) Subscribe(sub *Subscriber) error { + if sub == nil || sub.ServiceID == "" || sub.Endpoint == nil { + return fmt.Errorf("invalid subscriber: service_id and endpoint required") + } + b.mu.Lock() + defer b.mu.Unlock() + b.subscribers[sub.ServiceID] = sub + return nil +} + +func (b *Broker) Unsubscribe(serviceID string) { + b.mu.Lock() + defer b.mu.Unlock() + if sub, ok := b.subscribers[serviceID]; ok { + sub.Endpoint.Close() + delete(b.subscribers, serviceID) + } +} + +func (b *Broker) AddRoute(msgType core.MessageType, targets ...string) { + b.mu.Lock() + defer b.mu.Unlock() + existing := b.routes[msgType] + seen := make(map[string]bool) + for _, t := range existing { + seen[t] = true + } + for _, t := range targets { + if !seen[t] { + existing = append(existing, t) + seen[t] = true + } + } + b.routes[msgType] = existing +} + +func (b *Broker) RemoveRoute(msgType core.MessageType, target string) { + b.mu.Lock() + defer b.mu.Unlock() + targets := b.routes[msgType] + for i, t := range targets { + if t == target { + b.routes[msgType] = append(targets[:i], targets[i+1:]...) + return + } + } +} + +func (b *Broker) Route(msg core.Message) []RouteResult { + b.mu.RLock() + targets := b.routes[msg.Type] + subs := make([]*Subscriber, 0, len(targets)) + for _, t := range targets { + if sub, ok := b.subscribers[t]; ok { + subs = append(subs, sub) + } + } + b.mu.RUnlock() + if len(subs) == 0 { + return nil + } + sortByPriority(subs) + results := make([]RouteResult, len(subs)) + for i, sub := range subs { + err := sub.Endpoint.Send(msg) + results[i] = RouteResult{ServiceID: sub.ServiceID, Err: err} + if b.audit != nil { + decision := "delivered" + result := "ok" + if err != nil { + decision = "failed" + result = err.Error() + } + _ = b.audit.Log(audit.Entry{ + RequestID: msg.RequestID, Source: msg.Source, + Target: sub.ServiceID, Action: string(msg.Type), + Decision: decision, Result: result, + }) + } + } + return results +} + +func (b *Broker) Fanout(msg core.Message) []RouteResult { + b.mu.RLock() + subs := make([]*Subscriber, 0, len(b.subscribers)) + for _, sub := range b.subscribers { + subs = append(subs, sub) + } + b.mu.RUnlock() + results := make([]RouteResult, len(subs)) + for i, sub := range subs { + err := sub.Endpoint.Send(msg) + results[i] = RouteResult{ServiceID: sub.ServiceID, Err: err} + } + return results +} + +func (b *Broker) Subscribers() []string { + b.mu.RLock() + defer b.mu.RUnlock() + ids := make([]string, 0, len(b.subscribers)) + for id := range b.subscribers { + ids = append(ids, id) + } + return ids +} + +type RouteResult struct { + ServiceID string + Err error +} + +func sortByPriority(subs []*Subscriber) { + for i := 1; i < len(subs); i++ { + for j := i; j > 0 && subs[j].Priority < subs[j-1].Priority; j-- { + subs[j], subs[j-1] = subs[j-1], subs[j] + } + } +} diff --git a/services/broker/broker_test.go b/services/broker/broker_test.go index 9bcac04..f9d97b1 100644 --- a/services/broker/broker_test.go +++ b/services/broker/broker_test.go @@ -1,73 +1,73 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package broker - -import ( - "testing" - - "github.com/embeddedos-org/eipc/core" -) - -type mockEndpoint struct { - sent []core.Message -} - -func (m *mockEndpoint) Send(msg core.Message) error { - m.sent = append(m.sent, msg) - return nil -} - -func (m *mockEndpoint) Receive() (core.Message, error) { - return core.Message{}, nil -} - -func (m *mockEndpoint) Close() error { - return nil -} - -func TestBrokerSubscribeAndRoute(t *testing.T) { - b := NewBroker(nil, nil) - ep1 := &mockEndpoint{} - ep2 := &mockEndpoint{} - _ = b.Subscribe(&Subscriber{ServiceID: "eai.agent", Endpoint: ep1, Priority: core.PriorityP1}) - _ = b.Subscribe(&Subscriber{ServiceID: "eai.monitor", Endpoint: ep2, Priority: core.PriorityP2}) - b.AddRoute(core.TypeIntent, "eai.agent") - msg := core.NewMessage(core.TypeIntent, "eni.min", []byte(`{"intent":"move_left"}`)) - results := b.Route(msg) - if len(results) != 1 { - t.Fatalf("expected 1 result, got %d", len(results)) - } - if results[0].ServiceID != "eai.agent" { - t.Errorf("expected eai.agent, got %s", results[0].ServiceID) - } - if len(ep1.sent) != 1 { - t.Errorf("expected 1 msg to ep1, got %d", len(ep1.sent)) - } -} - -func TestBrokerFanout(t *testing.T) { - b := NewBroker(nil, nil) - ep1 := &mockEndpoint{} - ep2 := &mockEndpoint{} - _ = b.Subscribe(&Subscriber{ServiceID: "svc1", Endpoint: ep1}) - b.Subscribe(&Subscriber{ServiceID: "svc2", Endpoint: ep2}) - msg := core.NewMessage(core.TypeHeartbeat, "system", []byte(`{}`)) - results := b.Fanout(msg) - if len(results) != 2 { - t.Fatalf("expected 2 results, got %d", len(results)) - } -} - -func TestBrokerUnsubscribe(t *testing.T) { - b := NewBroker(nil, nil) - ep := &mockEndpoint{} - b.Subscribe(&Subscriber{ServiceID: "svc1", Endpoint: ep}) - if len(b.Subscribers()) != 1 { - t.Fatal("expected 1 subscriber") - } - b.Unsubscribe("svc1") - if len(b.Subscribers()) != 0 { - t.Fatal("expected 0 subscribers") - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package broker + +import ( + "testing" + + "github.com/embeddedos-org/eipc/core" +) + +type mockEndpoint struct { + sent []core.Message +} + +func (m *mockEndpoint) Send(msg core.Message) error { + m.sent = append(m.sent, msg) + return nil +} + +func (m *mockEndpoint) Receive() (core.Message, error) { + return core.Message{}, nil +} + +func (m *mockEndpoint) Close() error { + return nil +} + +func TestBrokerSubscribeAndRoute(t *testing.T) { + b := NewBroker(nil, nil) + ep1 := &mockEndpoint{} + ep2 := &mockEndpoint{} + _ = b.Subscribe(&Subscriber{ServiceID: "eai.agent", Endpoint: ep1, Priority: core.PriorityP1}) + _ = b.Subscribe(&Subscriber{ServiceID: "eai.monitor", Endpoint: ep2, Priority: core.PriorityP2}) + b.AddRoute(core.TypeIntent, "eai.agent") + msg := core.NewMessage(core.TypeIntent, "eni.min", []byte(`{"intent":"move_left"}`)) + results := b.Route(msg) + if len(results) != 1 { + t.Fatalf("expected 1 result, got %d", len(results)) + } + if results[0].ServiceID != "eai.agent" { + t.Errorf("expected eai.agent, got %s", results[0].ServiceID) + } + if len(ep1.sent) != 1 { + t.Errorf("expected 1 msg to ep1, got %d", len(ep1.sent)) + } +} + +func TestBrokerFanout(t *testing.T) { + b := NewBroker(nil, nil) + ep1 := &mockEndpoint{} + ep2 := &mockEndpoint{} + _ = b.Subscribe(&Subscriber{ServiceID: "svc1", Endpoint: ep1}) + b.Subscribe(&Subscriber{ServiceID: "svc2", Endpoint: ep2}) + msg := core.NewMessage(core.TypeHeartbeat, "system", []byte(`{}`)) + results := b.Fanout(msg) + if len(results) != 2 { + t.Fatalf("expected 2 results, got %d", len(results)) + } +} + +func TestBrokerUnsubscribe(t *testing.T) { + b := NewBroker(nil, nil) + ep := &mockEndpoint{} + b.Subscribe(&Subscriber{ServiceID: "svc1", Endpoint: ep}) + if len(b.Subscribers()) != 1 { + t.Fatal("expected 1 subscriber") + } + b.Unsubscribe("svc1") + if len(b.Subscribers()) != 0 { + t.Fatal("expected 0 subscribers") + } +} diff --git a/services/policy/policy.go b/services/policy/policy.go index 419d62b..9cf41cc 100644 --- a/services/policy/policy.go +++ b/services/policy/policy.go @@ -1,231 +1,231 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package policy - -import ( - "fmt" - "sync" - "time" - - "github.com/embeddedos-org/eipc/core" - "github.com/embeddedos-org/eipc/services/audit" -) - -// ActionClass categorizes actions by risk level. -type ActionClass int - -const ( - ActionSafe ActionClass = iota // UI control, read-only - ActionControlled // Device writes - ActionRestricted // System operations -) - -// Verdict is the policy decision for an action. -type Verdict int - -const ( - VerdictAllow Verdict = iota - VerdictDeny - VerdictConfirm // Requires operator approval -) - -// Rule defines a single policy rule. -type Rule struct { - Action string - Class ActionClass - Verdict Verdict - Capability string - Description string -} - -// Request represents an action request to be evaluated. -type Request struct { - Source string - Action string - Capability string - RequestID string -} - -// Result is the outcome of a policy evaluation. -type Result struct { - Allowed bool - Verdict Verdict - Reason string - RequestID string - Timestamp time.Time -} - -// Engine evaluates policy rules for incoming EIPC requests. -// It implements the EIPC security model: identity check → capability check → -// policy check → tool execution → audit log. -type Engine struct { - mu sync.RWMutex - rules map[string]*Rule - defaultDeny bool - audit audit.Logger -} - -// NewEngine creates a policy engine. If defaultDeny is true, -// any action without an explicit rule is denied. -func NewEngine(defaultDeny bool, auditLogger audit.Logger) *Engine { - return &Engine{ - rules: make(map[string]*Rule), - defaultDeny: defaultDeny, - audit: auditLogger, - } -} - -// AddRule registers a policy rule for the given action. -func (e *Engine) AddRule(rule Rule) error { - if rule.Action == "" { - return fmt.Errorf("rule action is required") - } - - e.mu.Lock() - defer e.mu.Unlock() - - e.rules[rule.Action] = &rule - return nil -} - -// RemoveRule removes a policy rule. -func (e *Engine) RemoveRule(action string) { - e.mu.Lock() - defer e.mu.Unlock() - delete(e.rules, action) -} - -// Evaluate checks whether a request is allowed by policy. -func (e *Engine) Evaluate(req Request) Result { - e.mu.RLock() - rule, exists := e.rules[req.Action] - defaultDeny := e.defaultDeny - e.mu.RUnlock() - - result := Result{ - RequestID: req.RequestID, - Timestamp: time.Now().UTC(), - } - - if !exists { - if defaultDeny { - result.Allowed = false - result.Verdict = VerdictDeny - result.Reason = fmt.Sprintf("no rule for action %q, default deny", req.Action) - } else { - result.Allowed = true - result.Verdict = VerdictAllow - result.Reason = "no rule, default allow" - } - } else { - switch rule.Verdict { - case VerdictAllow: - if rule.Capability != "" && rule.Capability != req.Capability { - result.Allowed = false - result.Verdict = VerdictDeny - result.Reason = fmt.Sprintf("capability mismatch: need %q, have %q", - rule.Capability, req.Capability) - } else { - result.Allowed = true - result.Verdict = VerdictAllow - result.Reason = "allowed by rule" - } - case VerdictDeny: - result.Allowed = false - result.Verdict = VerdictDeny - result.Reason = fmt.Sprintf("denied by rule: %s", rule.Description) - case VerdictConfirm: - result.Allowed = false - result.Verdict = VerdictConfirm - result.Reason = "requires operator confirmation" - } - } - - if e.audit != nil { - decision := "allow" - if !result.Allowed { - decision = "deny" - } - _ = e.audit.Log(audit.Entry{ - RequestID: req.RequestID, - Source: req.Source, - Action: req.Action, - Decision: decision, - Result: result.Reason, - }) - } - - return result -} - -// EvaluateMessage is a convenience method that evaluates a core.Message. -func (e *Engine) EvaluateMessage(msg core.Message) Result { - return e.Evaluate(Request{ - Source: msg.Source, - Action: string(msg.Type), - Capability: msg.Capability, - RequestID: msg.RequestID, - }) -} - -// ListRules returns all registered rules. -func (e *Engine) ListRules() []Rule { - e.mu.RLock() - defer e.mu.RUnlock() - - rules := make([]Rule, 0, len(e.rules)) - for _, r := range e.rules { - rules = append(rules, *r) - } - return rules -} - -// SetDefaultDeny changes the default deny policy. -func (e *Engine) SetDefaultDeny(deny bool) { - e.mu.Lock() - defer e.mu.Unlock() - e.defaultDeny = deny -} - -// LoadSafeDefaults adds permissive rules for safe action classes -// and restrictive rules for controlled/restricted classes. -func (e *Engine) LoadSafeDefaults() { - safeActions := []string{ - "ui.cursor.move", "ui.scroll", "ui.select", "ui.focus", - "sensor:read", "device:read", "status:read", - } - for _, a := range safeActions { - _ = e.AddRule(Rule{ - Action: a, - Class: ActionSafe, - Verdict: VerdictAllow, - }) - } - - controlledActions := []string{ - "device:write", "actuator:write", "motor:move", - "iot:publish", "config:write", - } - for _, a := range controlledActions { - _ = e.AddRule(Rule{ - Action: a, - Class: ActionControlled, - Verdict: VerdictAllow, - Capability: "device:write", - }) - } - - restrictedActions := []string{ - "system:reboot", "system:shutdown", "firmware:update", - "security:modify", "policy:modify", - } - for _, a := range restrictedActions { - _ = e.AddRule(Rule{ - Action: a, - Class: ActionRestricted, - Verdict: VerdictConfirm, - Description: "restricted system operation", - }) - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package policy + +import ( + "fmt" + "sync" + "time" + + "github.com/embeddedos-org/eipc/core" + "github.com/embeddedos-org/eipc/services/audit" +) + +// ActionClass categorizes actions by risk level. +type ActionClass int + +const ( + ActionSafe ActionClass = iota // UI control, read-only + ActionControlled // Device writes + ActionRestricted // System operations +) + +// Verdict is the policy decision for an action. +type Verdict int + +const ( + VerdictAllow Verdict = iota + VerdictDeny + VerdictConfirm // Requires operator approval +) + +// Rule defines a single policy rule. +type Rule struct { + Action string + Class ActionClass + Verdict Verdict + Capability string + Description string +} + +// Request represents an action request to be evaluated. +type Request struct { + Source string + Action string + Capability string + RequestID string +} + +// Result is the outcome of a policy evaluation. +type Result struct { + Allowed bool + Verdict Verdict + Reason string + RequestID string + Timestamp time.Time +} + +// Engine evaluates policy rules for incoming EIPC requests. +// It implements the EIPC security model: identity check → capability check → +// policy check → tool execution → audit log. +type Engine struct { + mu sync.RWMutex + rules map[string]*Rule + defaultDeny bool + audit audit.Logger +} + +// NewEngine creates a policy engine. If defaultDeny is true, +// any action without an explicit rule is denied. +func NewEngine(defaultDeny bool, auditLogger audit.Logger) *Engine { + return &Engine{ + rules: make(map[string]*Rule), + defaultDeny: defaultDeny, + audit: auditLogger, + } +} + +// AddRule registers a policy rule for the given action. +func (e *Engine) AddRule(rule Rule) error { + if rule.Action == "" { + return fmt.Errorf("rule action is required") + } + + e.mu.Lock() + defer e.mu.Unlock() + + e.rules[rule.Action] = &rule + return nil +} + +// RemoveRule removes a policy rule. +func (e *Engine) RemoveRule(action string) { + e.mu.Lock() + defer e.mu.Unlock() + delete(e.rules, action) +} + +// Evaluate checks whether a request is allowed by policy. +func (e *Engine) Evaluate(req Request) Result { + e.mu.RLock() + rule, exists := e.rules[req.Action] + defaultDeny := e.defaultDeny + e.mu.RUnlock() + + result := Result{ + RequestID: req.RequestID, + Timestamp: time.Now().UTC(), + } + + if !exists { + if defaultDeny { + result.Allowed = false + result.Verdict = VerdictDeny + result.Reason = fmt.Sprintf("no rule for action %q, default deny", req.Action) + } else { + result.Allowed = true + result.Verdict = VerdictAllow + result.Reason = "no rule, default allow" + } + } else { + switch rule.Verdict { + case VerdictAllow: + if rule.Capability != "" && rule.Capability != req.Capability { + result.Allowed = false + result.Verdict = VerdictDeny + result.Reason = fmt.Sprintf("capability mismatch: need %q, have %q", + rule.Capability, req.Capability) + } else { + result.Allowed = true + result.Verdict = VerdictAllow + result.Reason = "allowed by rule" + } + case VerdictDeny: + result.Allowed = false + result.Verdict = VerdictDeny + result.Reason = fmt.Sprintf("denied by rule: %s", rule.Description) + case VerdictConfirm: + result.Allowed = false + result.Verdict = VerdictConfirm + result.Reason = "requires operator confirmation" + } + } + + if e.audit != nil { + decision := "allow" + if !result.Allowed { + decision = "deny" + } + _ = e.audit.Log(audit.Entry{ + RequestID: req.RequestID, + Source: req.Source, + Action: req.Action, + Decision: decision, + Result: result.Reason, + }) + } + + return result +} + +// EvaluateMessage is a convenience method that evaluates a core.Message. +func (e *Engine) EvaluateMessage(msg core.Message) Result { + return e.Evaluate(Request{ + Source: msg.Source, + Action: string(msg.Type), + Capability: msg.Capability, + RequestID: msg.RequestID, + }) +} + +// ListRules returns all registered rules. +func (e *Engine) ListRules() []Rule { + e.mu.RLock() + defer e.mu.RUnlock() + + rules := make([]Rule, 0, len(e.rules)) + for _, r := range e.rules { + rules = append(rules, *r) + } + return rules +} + +// SetDefaultDeny changes the default deny policy. +func (e *Engine) SetDefaultDeny(deny bool) { + e.mu.Lock() + defer e.mu.Unlock() + e.defaultDeny = deny +} + +// LoadSafeDefaults adds permissive rules for safe action classes +// and restrictive rules for controlled/restricted classes. +func (e *Engine) LoadSafeDefaults() { + safeActions := []string{ + "ui.cursor.move", "ui.scroll", "ui.select", "ui.focus", + "sensor:read", "device:read", "status:read", + } + for _, a := range safeActions { + _ = e.AddRule(Rule{ + Action: a, + Class: ActionSafe, + Verdict: VerdictAllow, + }) + } + + controlledActions := []string{ + "device:write", "actuator:write", "motor:move", + "iot:publish", "config:write", + } + for _, a := range controlledActions { + _ = e.AddRule(Rule{ + Action: a, + Class: ActionControlled, + Verdict: VerdictAllow, + Capability: "device:write", + }) + } + + restrictedActions := []string{ + "system:reboot", "system:shutdown", "firmware:update", + "security:modify", "policy:modify", + } + for _, a := range restrictedActions { + _ = e.AddRule(Rule{ + Action: a, + Class: ActionRestricted, + Verdict: VerdictConfirm, + Description: "restricted system operation", + }) + } +} diff --git a/services/registry/registry_test.go b/services/registry/registry_test.go index b8fa04a..3b19ff7 100644 --- a/services/registry/registry_test.go +++ b/services/registry/registry_test.go @@ -1,127 +1,127 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package registry - -import ( - "testing" - - "github.com/embeddedos-org/eipc/core" -) - -func TestRegister(t *testing.T) { - reg := NewRegistry() - err := reg.Register(ServiceInfo{ - ServiceID: "eni.min", - Capabilities: []string{"ui:control"}, - Versions: []uint16{1}, - }) - if err != nil { - t.Fatalf("Register: %v", err) - } -} - -func TestRegister_EmptyID(t *testing.T) { - reg := NewRegistry() - err := reg.Register(ServiceInfo{}) - if err == nil { - t.Fatal("expected error for empty service_id") - } -} - -func TestLookup(t *testing.T) { - reg := NewRegistry() - _ = reg.Register(ServiceInfo{ - ServiceID: "eni.min", - Capabilities: []string{"ui:control"}, - }) - - info, err := reg.Lookup("eni.min") - if err != nil { - t.Fatalf("Lookup: %v", err) - } - if info.ServiceID != "eni.min" { - t.Errorf("expected 'eni.min', got %q", info.ServiceID) - } -} - -func TestLookup_NotFound(t *testing.T) { - reg := NewRegistry() - _, err := reg.Lookup("nonexistent") - if err == nil { - t.Fatal("expected error for nonexistent service") - } -} - -func TestDeregister(t *testing.T) { - reg := NewRegistry() - _ = reg.Register(ServiceInfo{ServiceID: "eni.min"}) - reg.Deregister("eni.min") - - _, err := reg.Lookup("eni.min") - if err == nil { - t.Fatal("expected error after deregistration") - } -} - -func TestList(t *testing.T) { - reg := NewRegistry() - _ = reg.Register(ServiceInfo{ServiceID: "eni.min"}) - reg.Register(ServiceInfo{ServiceID: "eai.agent"}) - - list := reg.List() - if len(list) != 2 { - t.Fatalf("expected 2 services, got %d", len(list)) - } -} - -func TestFindByCapability(t *testing.T) { - reg := NewRegistry() - reg.Register(ServiceInfo{ServiceID: "eni.min", Capabilities: []string{"ui:control"}}) - reg.Register(ServiceInfo{ServiceID: "eai.agent", Capabilities: []string{"ai:chat"}}) - reg.Register(ServiceInfo{ServiceID: "tool.svc", Capabilities: []string{"ui:control", "device:read"}}) - - results := reg.FindByCapability("ui:control") - if len(results) != 2 { - t.Fatalf("expected 2 services with ui:control, got %d", len(results)) - } -} - -func TestFindByCapability_None(t *testing.T) { - reg := NewRegistry() - reg.Register(ServiceInfo{ServiceID: "eni.min", Capabilities: []string{"ui:control"}}) - - results := reg.FindByCapability("nonexistent") - if len(results) != 0 { - t.Fatalf("expected 0 results, got %d", len(results)) - } -} - -func TestRegister_WithMessageTypes(t *testing.T) { - reg := NewRegistry() - reg.Register(ServiceInfo{ - ServiceID: "eni.min", - Capabilities: []string{"ui:control"}, - MessageTypes: []core.MessageType{core.TypeIntent, core.TypeHeartbeat}, - Priority: core.PriorityP0, - }) - - info, _ := reg.Lookup("eni.min") - if len(info.MessageTypes) != 2 { - t.Errorf("expected 2 message types, got %d", len(info.MessageTypes)) - } - if info.Priority != core.PriorityP0 { - t.Errorf("expected P0 priority, got %d", info.Priority) - } -} - -func TestRegister_Update(t *testing.T) { - reg := NewRegistry() - reg.Register(ServiceInfo{ServiceID: "eni.min", Capabilities: []string{"ui:control"}}) - reg.Register(ServiceInfo{ServiceID: "eni.min", Capabilities: []string{"ui:control", "device:read"}}) - - info, _ := reg.Lookup("eni.min") - if len(info.Capabilities) != 2 { - t.Errorf("expected updated capabilities with 2 entries, got %d", len(info.Capabilities)) - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package registry + +import ( + "testing" + + "github.com/embeddedos-org/eipc/core" +) + +func TestRegister(t *testing.T) { + reg := NewRegistry() + err := reg.Register(ServiceInfo{ + ServiceID: "eni.min", + Capabilities: []string{"ui:control"}, + Versions: []uint16{1}, + }) + if err != nil { + t.Fatalf("Register: %v", err) + } +} + +func TestRegister_EmptyID(t *testing.T) { + reg := NewRegistry() + err := reg.Register(ServiceInfo{}) + if err == nil { + t.Fatal("expected error for empty service_id") + } +} + +func TestLookup(t *testing.T) { + reg := NewRegistry() + _ = reg.Register(ServiceInfo{ + ServiceID: "eni.min", + Capabilities: []string{"ui:control"}, + }) + + info, err := reg.Lookup("eni.min") + if err != nil { + t.Fatalf("Lookup: %v", err) + } + if info.ServiceID != "eni.min" { + t.Errorf("expected 'eni.min', got %q", info.ServiceID) + } +} + +func TestLookup_NotFound(t *testing.T) { + reg := NewRegistry() + _, err := reg.Lookup("nonexistent") + if err == nil { + t.Fatal("expected error for nonexistent service") + } +} + +func TestDeregister(t *testing.T) { + reg := NewRegistry() + _ = reg.Register(ServiceInfo{ServiceID: "eni.min"}) + reg.Deregister("eni.min") + + _, err := reg.Lookup("eni.min") + if err == nil { + t.Fatal("expected error after deregistration") + } +} + +func TestList(t *testing.T) { + reg := NewRegistry() + _ = reg.Register(ServiceInfo{ServiceID: "eni.min"}) + reg.Register(ServiceInfo{ServiceID: "eai.agent"}) + + list := reg.List() + if len(list) != 2 { + t.Fatalf("expected 2 services, got %d", len(list)) + } +} + +func TestFindByCapability(t *testing.T) { + reg := NewRegistry() + reg.Register(ServiceInfo{ServiceID: "eni.min", Capabilities: []string{"ui:control"}}) + reg.Register(ServiceInfo{ServiceID: "eai.agent", Capabilities: []string{"ai:chat"}}) + reg.Register(ServiceInfo{ServiceID: "tool.svc", Capabilities: []string{"ui:control", "device:read"}}) + + results := reg.FindByCapability("ui:control") + if len(results) != 2 { + t.Fatalf("expected 2 services with ui:control, got %d", len(results)) + } +} + +func TestFindByCapability_None(t *testing.T) { + reg := NewRegistry() + reg.Register(ServiceInfo{ServiceID: "eni.min", Capabilities: []string{"ui:control"}}) + + results := reg.FindByCapability("nonexistent") + if len(results) != 0 { + t.Fatalf("expected 0 results, got %d", len(results)) + } +} + +func TestRegister_WithMessageTypes(t *testing.T) { + reg := NewRegistry() + reg.Register(ServiceInfo{ + ServiceID: "eni.min", + Capabilities: []string{"ui:control"}, + MessageTypes: []core.MessageType{core.TypeIntent, core.TypeHeartbeat}, + Priority: core.PriorityP0, + }) + + info, _ := reg.Lookup("eni.min") + if len(info.MessageTypes) != 2 { + t.Errorf("expected 2 message types, got %d", len(info.MessageTypes)) + } + if info.Priority != core.PriorityP0 { + t.Errorf("expected P0 priority, got %d", info.Priority) + } +} + +func TestRegister_Update(t *testing.T) { + reg := NewRegistry() + reg.Register(ServiceInfo{ServiceID: "eni.min", Capabilities: []string{"ui:control"}}) + reg.Register(ServiceInfo{ServiceID: "eni.min", Capabilities: []string{"ui:control", "device:read"}}) + + info, _ := reg.Lookup("eni.min") + if len(info.Capabilities) != 2 { + t.Errorf("expected updated capabilities with 2 entries, got %d", len(info.Capabilities)) + } +} diff --git a/tests/integration_test.go b/tests/integration_test.go index 8d2e9cc..83e8f61 100644 --- a/tests/integration_test.go +++ b/tests/integration_test.go @@ -1,319 +1,319 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package tests - -import ( - "crypto/hmac" - "crypto/sha256" - "encoding/hex" - "encoding/json" - "os" - "testing" - "time" - - "github.com/embeddedos-org/eipc/core" - "github.com/embeddedos-org/eipc/protocol" - "github.com/embeddedos-org/eipc/security/auth" - "github.com/embeddedos-org/eipc/security/capability" - "github.com/embeddedos-org/eipc/transport/tcp" -) - -const testSecret = "test-secret-key-32-bytes-long!!" - -func init() { - os.Setenv("EIPC_HMAC_KEY", testSecret) -} - -func TestChallengeResponseAuth(t *testing.T) { - authenticator := auth.NewAuthenticator([]byte(testSecret), map[string][]string{ - "test.client": {"ui:control"}, - }) - - challenge, err := authenticator.CreateChallenge("test.client") - if err != nil { - t.Fatalf("CreateChallenge: %v", err) - } - if len(challenge.Nonce) != 32 { - t.Fatalf("expected 32-byte nonce, got %d", len(challenge.Nonce)) - } - - mac := hmac.New(sha256.New, []byte(testSecret)) - mac.Write(challenge.Nonce) - response := mac.Sum(nil) - - peer, err := authenticator.VerifyResponse("test.client", response) - if err != nil { - t.Fatalf("VerifyResponse: %v", err) - } - if peer.ServiceID != "test.client" { - t.Errorf("expected service_id 'test.client', got %q", peer.ServiceID) - } - if len(peer.Capabilities) != 1 || peer.Capabilities[0] != "ui:control" { - t.Errorf("unexpected capabilities: %v", peer.Capabilities) - } - if peer.SessionToken == "" { - t.Error("expected non-empty session token") - } -} - -func TestChallengeResponseAuth_WrongSecret(t *testing.T) { - authenticator := auth.NewAuthenticator([]byte(testSecret), map[string][]string{ - "test.client": {"ui:control"}, - }) - - challenge, err := authenticator.CreateChallenge("test.client") - if err != nil { - t.Fatalf("CreateChallenge: %v", err) - } - - mac := hmac.New(sha256.New, []byte("wrong-secret-key-definitely-bad!")) - mac.Write(challenge.Nonce) - wrongResponse := mac.Sum(nil) - - _, err = authenticator.VerifyResponse("test.client", wrongResponse) - if err == nil { - t.Fatal("expected error for wrong secret, got nil") - } -} - -func TestChallengeResponseAuth_UnknownService(t *testing.T) { - authenticator := auth.NewAuthenticator([]byte(testSecret), map[string][]string{ - "test.client": {"ui:control"}, - }) - - _, err := authenticator.CreateChallenge("unknown.service") - if err == nil { - t.Fatal("expected error for unknown service") - } -} - -func TestCapabilityEnforcement(t *testing.T) { - checker := capability.NewChecker(map[string][]string{ - "ui:control": {"ui.cursor.move", "ui.click"}, - "ai:chat": {"ai.chat.send", "ai.complete.send"}, - }) - - if err := checker.Check([]string{"ui:control"}, "ui.cursor.move"); err != nil { - t.Errorf("expected ui:control to permit ui.cursor.move: %v", err) - } - - if err := checker.Check([]string{"ai:chat"}, "ai.chat.send"); err != nil { - t.Errorf("expected ai:chat to permit ai.chat.send: %v", err) - } - - if err := checker.Check([]string{"ui:control"}, "ai.chat.send"); err == nil { - t.Error("expected ui:control NOT to permit ai.chat.send") - } - - if err := checker.Check([]string{"ai:chat"}, "ui.cursor.move"); err == nil { - t.Error("expected ai:chat NOT to permit ui.cursor.move") - } -} - -func TestSessionTTL(t *testing.T) { - authenticator := auth.NewAuthenticator([]byte(testSecret), map[string][]string{ - "test.client": {"ui:control"}, - }) - authenticator.SetSessionTTL(100 * time.Millisecond) - - peer, err := authenticator.Authenticate("test.client") - if err != nil { - t.Fatalf("Authenticate: %v", err) - } - - if peer.IsExpired() { - t.Error("peer should not be expired immediately") - } - - time.Sleep(150 * time.Millisecond) - - if !peer.IsExpired() { - t.Error("peer should be expired after TTL") - } - - removed := authenticator.CleanupExpired() - if removed != 1 { - t.Errorf("expected 1 expired session cleaned, got %d", removed) - } -} - -func TestChatCompleteMessageTypes(t *testing.T) { - if core.TypeChat != "chat" { - t.Errorf("expected TypeChat='chat', got %q", core.TypeChat) - } - if core.TypeComplete != "complete" { - t.Errorf("expected TypeComplete='complete', got %q", core.TypeComplete) - } - - if core.MsgTypeToByte(core.TypeChat) != 'c' { - t.Errorf("expected chat wire byte 'c', got %c", core.MsgTypeToByte(core.TypeChat)) - } - if core.MsgTypeToByte(core.TypeComplete) != 'C' { - t.Errorf("expected complete wire byte 'C', got %c", core.MsgTypeToByte(core.TypeComplete)) - } -} - -func TestChatRequestEventSerialization(t *testing.T) { - codec := protocol.DefaultCodec() - req := core.ChatRequestEvent{ - SessionID: "sess-123", - UserPrompt: "Hello EIPC", - Model: "llama3", - MaxTokens: 512, - } - - data, err := codec.Marshal(req) - if err != nil { - t.Fatalf("marshal: %v", err) - } - - var parsed core.ChatRequestEvent - if err := codec.Unmarshal(data, &parsed); err != nil { - t.Fatalf("unmarshal: %v", err) - } - - if parsed.SessionID != "sess-123" || parsed.UserPrompt != "Hello EIPC" { - t.Errorf("roundtrip mismatch: %+v", parsed) - } -} - -func TestEbot_EIPC_EAI_ChatFlow(t *testing.T) { - secret := []byte(testSecret) - codec := protocol.DefaultCodec() - - authenticator := auth.NewAuthenticator(secret, map[string][]string{ - "ebot.client": {"ai:chat"}, - }) - authenticator.SetSessionTTL(1 * time.Hour) - - tcpTransport := tcp.New() - if err := tcpTransport.Listen("127.0.0.1:0"); err != nil { - t.Fatalf("listen: %v", err) - } - defer tcpTransport.Close() - addr := tcpTransport.Addr() - - serverDone := make(chan struct{}) - go func() { - defer close(serverDone) - conn, err := tcpTransport.Accept() - if err != nil { - return - } - defer conn.Close() - - ep := core.NewServerEndpoint(conn, codec, secret) - - authMsg, _ := ep.Receive() - var authReq struct { - ServiceID string `json:"service_id"` - } - _ = json.Unmarshal(authMsg.Payload, &authReq) - - challenge, _ := authenticator.CreateChallenge(authReq.ServiceID) - chalPayload, _ := codec.Marshal(map[string]string{ - "status": "challenge", - "nonce": hex.EncodeToString(challenge.Nonce), - }) - _ = ep.Send(core.Message{Version: 1, Type: core.TypeAck, Source: "server", - Timestamp: time.Now().UTC(), Payload: chalPayload}) - - respMsg, _ := ep.Receive() - var chalResp struct { - Response string `json:"response"` - } - _ = json.Unmarshal(respMsg.Payload, &chalResp) - respBytes, _ := hex.DecodeString(chalResp.Response) - - peer, _ := authenticator.VerifyResponse(authReq.ServiceID, respBytes) - ep.SetPeerCapabilities(peer.Capabilities) - - resultPayload, _ := codec.Marshal(map[string]interface{}{ - "status": "ok", - "session_token": peer.SessionToken, - "capabilities": peer.Capabilities, - }) - _ = ep.Send(core.Message{Version: 1, Type: core.TypeAck, Source: "server", - Timestamp: time.Now().UTC(), Payload: resultPayload}) - - chatMsg, _ := ep.Receive() - var chatReq core.ChatRequestEvent - _ = codec.Unmarshal(chatMsg.Payload, &chatReq) - - chatResp := core.ChatResponseEvent{ - SessionID: chatReq.SessionID, - Response: "Echo: " + chatReq.UserPrompt, - TokensUsed: 5, - } - respPayload, _ := codec.Marshal(chatResp) - _ = ep.Send(core.Message{Version: 1, Type: core.TypeChat, Source: "server", - Timestamp: time.Now().UTC(), Payload: respPayload}) - }() - - clientTransport := tcp.New() - conn, err := clientTransport.Dial(addr) - if err != nil { - t.Fatalf("dial: %v", err) - } - defer conn.Close() - - ep := core.NewClientEndpoint(conn, codec, secret, "") - - authPayload, _ := codec.Marshal(map[string]string{"service_id": "ebot.client"}) - ep.Send(core.Message{Version: 1, Type: core.TypeAck, Source: "ebot.client", - Timestamp: time.Now().UTC(), RequestID: "auth-1", Payload: authPayload}) - - chalMsg, _ := ep.Receive() - var chalData struct { - Nonce string `json:"nonce"` - } - _ = json.Unmarshal(chalMsg.Payload, &chalData) - nonceBytes, _ := hex.DecodeString(chalData.Nonce) - - mac := hmac.New(sha256.New, secret) - mac.Write(nonceBytes) - response := mac.Sum(nil) - - chalRespPayload, _ := codec.Marshal(map[string]string{ - "service_id": "ebot.client", - "response": hex.EncodeToString(response), - }) - ep.Send(core.Message{Version: 1, Type: core.TypeAck, Source: "ebot.client", - Timestamp: time.Now().UTC(), RequestID: "auth-2", Payload: chalRespPayload}) - - authResult, _ := ep.Receive() - var result struct { - Status string `json:"status"` - } - json.Unmarshal(authResult.Payload, &result) - if result.Status != "ok" { - t.Fatalf("auth failed: %s", result.Status) - } - - chatReq := core.ChatRequestEvent{ - SessionID: "ebot-test", - UserPrompt: "Hello from ebot", - Model: "test", - } - chatPayload, _ := codec.Marshal(chatReq) - ep.Send(core.Message{Version: 1, Type: core.TypeChat, Source: "ebot.client", - Timestamp: time.Now().UTC(), RequestID: "chat-1", Capability: "ai:chat", - Payload: chatPayload}) - - chatResp, err := ep.Receive() - if err != nil { - t.Fatalf("receive chat response: %v", err) - } - - var resp core.ChatResponseEvent - if err := codec.Unmarshal(chatResp.Payload, &resp); err != nil { - t.Fatalf("unmarshal chat response: %v", err) - } - - if resp.Response != "Echo: Hello from ebot" { - t.Errorf("unexpected response: %q", resp.Response) - } - - <-serverDone -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package tests + +import ( + "crypto/hmac" + "crypto/sha256" + "encoding/hex" + "encoding/json" + "os" + "testing" + "time" + + "github.com/embeddedos-org/eipc/core" + "github.com/embeddedos-org/eipc/protocol" + "github.com/embeddedos-org/eipc/security/auth" + "github.com/embeddedos-org/eipc/security/capability" + "github.com/embeddedos-org/eipc/transport/tcp" +) + +const testSecret = "test-secret-key-32-bytes-long!!" + +func init() { + os.Setenv("EIPC_HMAC_KEY", testSecret) +} + +func TestChallengeResponseAuth(t *testing.T) { + authenticator := auth.NewAuthenticator([]byte(testSecret), map[string][]string{ + "test.client": {"ui:control"}, + }) + + challenge, err := authenticator.CreateChallenge("test.client") + if err != nil { + t.Fatalf("CreateChallenge: %v", err) + } + if len(challenge.Nonce) != 32 { + t.Fatalf("expected 32-byte nonce, got %d", len(challenge.Nonce)) + } + + mac := hmac.New(sha256.New, []byte(testSecret)) + mac.Write(challenge.Nonce) + response := mac.Sum(nil) + + peer, err := authenticator.VerifyResponse("test.client", response) + if err != nil { + t.Fatalf("VerifyResponse: %v", err) + } + if peer.ServiceID != "test.client" { + t.Errorf("expected service_id 'test.client', got %q", peer.ServiceID) + } + if len(peer.Capabilities) != 1 || peer.Capabilities[0] != "ui:control" { + t.Errorf("unexpected capabilities: %v", peer.Capabilities) + } + if peer.SessionToken == "" { + t.Error("expected non-empty session token") + } +} + +func TestChallengeResponseAuth_WrongSecret(t *testing.T) { + authenticator := auth.NewAuthenticator([]byte(testSecret), map[string][]string{ + "test.client": {"ui:control"}, + }) + + challenge, err := authenticator.CreateChallenge("test.client") + if err != nil { + t.Fatalf("CreateChallenge: %v", err) + } + + mac := hmac.New(sha256.New, []byte("wrong-secret-key-definitely-bad!")) + mac.Write(challenge.Nonce) + wrongResponse := mac.Sum(nil) + + _, err = authenticator.VerifyResponse("test.client", wrongResponse) + if err == nil { + t.Fatal("expected error for wrong secret, got nil") + } +} + +func TestChallengeResponseAuth_UnknownService(t *testing.T) { + authenticator := auth.NewAuthenticator([]byte(testSecret), map[string][]string{ + "test.client": {"ui:control"}, + }) + + _, err := authenticator.CreateChallenge("unknown.service") + if err == nil { + t.Fatal("expected error for unknown service") + } +} + +func TestCapabilityEnforcement(t *testing.T) { + checker := capability.NewChecker(map[string][]string{ + "ui:control": {"ui.cursor.move", "ui.click"}, + "ai:chat": {"ai.chat.send", "ai.complete.send"}, + }) + + if err := checker.Check([]string{"ui:control"}, "ui.cursor.move"); err != nil { + t.Errorf("expected ui:control to permit ui.cursor.move: %v", err) + } + + if err := checker.Check([]string{"ai:chat"}, "ai.chat.send"); err != nil { + t.Errorf("expected ai:chat to permit ai.chat.send: %v", err) + } + + if err := checker.Check([]string{"ui:control"}, "ai.chat.send"); err == nil { + t.Error("expected ui:control NOT to permit ai.chat.send") + } + + if err := checker.Check([]string{"ai:chat"}, "ui.cursor.move"); err == nil { + t.Error("expected ai:chat NOT to permit ui.cursor.move") + } +} + +func TestSessionTTL(t *testing.T) { + authenticator := auth.NewAuthenticator([]byte(testSecret), map[string][]string{ + "test.client": {"ui:control"}, + }) + authenticator.SetSessionTTL(100 * time.Millisecond) + + peer, err := authenticator.Authenticate("test.client") + if err != nil { + t.Fatalf("Authenticate: %v", err) + } + + if peer.IsExpired() { + t.Error("peer should not be expired immediately") + } + + time.Sleep(150 * time.Millisecond) + + if !peer.IsExpired() { + t.Error("peer should be expired after TTL") + } + + removed := authenticator.CleanupExpired() + if removed != 1 { + t.Errorf("expected 1 expired session cleaned, got %d", removed) + } +} + +func TestChatCompleteMessageTypes(t *testing.T) { + if core.TypeChat != "chat" { + t.Errorf("expected TypeChat='chat', got %q", core.TypeChat) + } + if core.TypeComplete != "complete" { + t.Errorf("expected TypeComplete='complete', got %q", core.TypeComplete) + } + + if core.MsgTypeToByte(core.TypeChat) != 'c' { + t.Errorf("expected chat wire byte 'c', got %c", core.MsgTypeToByte(core.TypeChat)) + } + if core.MsgTypeToByte(core.TypeComplete) != 'C' { + t.Errorf("expected complete wire byte 'C', got %c", core.MsgTypeToByte(core.TypeComplete)) + } +} + +func TestChatRequestEventSerialization(t *testing.T) { + codec := protocol.DefaultCodec() + req := core.ChatRequestEvent{ + SessionID: "sess-123", + UserPrompt: "Hello EIPC", + Model: "llama3", + MaxTokens: 512, + } + + data, err := codec.Marshal(req) + if err != nil { + t.Fatalf("marshal: %v", err) + } + + var parsed core.ChatRequestEvent + if err := codec.Unmarshal(data, &parsed); err != nil { + t.Fatalf("unmarshal: %v", err) + } + + if parsed.SessionID != "sess-123" || parsed.UserPrompt != "Hello EIPC" { + t.Errorf("roundtrip mismatch: %+v", parsed) + } +} + +func TestEbot_EIPC_EAI_ChatFlow(t *testing.T) { + secret := []byte(testSecret) + codec := protocol.DefaultCodec() + + authenticator := auth.NewAuthenticator(secret, map[string][]string{ + "ebot.client": {"ai:chat"}, + }) + authenticator.SetSessionTTL(1 * time.Hour) + + tcpTransport := tcp.New() + if err := tcpTransport.Listen("127.0.0.1:0"); err != nil { + t.Fatalf("listen: %v", err) + } + defer tcpTransport.Close() + addr := tcpTransport.Addr() + + serverDone := make(chan struct{}) + go func() { + defer close(serverDone) + conn, err := tcpTransport.Accept() + if err != nil { + return + } + defer conn.Close() + + ep := core.NewServerEndpoint(conn, codec, secret) + + authMsg, _ := ep.Receive() + var authReq struct { + ServiceID string `json:"service_id"` + } + _ = json.Unmarshal(authMsg.Payload, &authReq) + + challenge, _ := authenticator.CreateChallenge(authReq.ServiceID) + chalPayload, _ := codec.Marshal(map[string]string{ + "status": "challenge", + "nonce": hex.EncodeToString(challenge.Nonce), + }) + _ = ep.Send(core.Message{Version: 1, Type: core.TypeAck, Source: "server", + Timestamp: time.Now().UTC(), Payload: chalPayload}) + + respMsg, _ := ep.Receive() + var chalResp struct { + Response string `json:"response"` + } + _ = json.Unmarshal(respMsg.Payload, &chalResp) + respBytes, _ := hex.DecodeString(chalResp.Response) + + peer, _ := authenticator.VerifyResponse(authReq.ServiceID, respBytes) + ep.SetPeerCapabilities(peer.Capabilities) + + resultPayload, _ := codec.Marshal(map[string]interface{}{ + "status": "ok", + "session_token": peer.SessionToken, + "capabilities": peer.Capabilities, + }) + _ = ep.Send(core.Message{Version: 1, Type: core.TypeAck, Source: "server", + Timestamp: time.Now().UTC(), Payload: resultPayload}) + + chatMsg, _ := ep.Receive() + var chatReq core.ChatRequestEvent + _ = codec.Unmarshal(chatMsg.Payload, &chatReq) + + chatResp := core.ChatResponseEvent{ + SessionID: chatReq.SessionID, + Response: "Echo: " + chatReq.UserPrompt, + TokensUsed: 5, + } + respPayload, _ := codec.Marshal(chatResp) + _ = ep.Send(core.Message{Version: 1, Type: core.TypeChat, Source: "server", + Timestamp: time.Now().UTC(), Payload: respPayload}) + }() + + clientTransport := tcp.New() + conn, err := clientTransport.Dial(addr) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer conn.Close() + + ep := core.NewClientEndpoint(conn, codec, secret, "") + + authPayload, _ := codec.Marshal(map[string]string{"service_id": "ebot.client"}) + ep.Send(core.Message{Version: 1, Type: core.TypeAck, Source: "ebot.client", + Timestamp: time.Now().UTC(), RequestID: "auth-1", Payload: authPayload}) + + chalMsg, _ := ep.Receive() + var chalData struct { + Nonce string `json:"nonce"` + } + _ = json.Unmarshal(chalMsg.Payload, &chalData) + nonceBytes, _ := hex.DecodeString(chalData.Nonce) + + mac := hmac.New(sha256.New, secret) + mac.Write(nonceBytes) + response := mac.Sum(nil) + + chalRespPayload, _ := codec.Marshal(map[string]string{ + "service_id": "ebot.client", + "response": hex.EncodeToString(response), + }) + ep.Send(core.Message{Version: 1, Type: core.TypeAck, Source: "ebot.client", + Timestamp: time.Now().UTC(), RequestID: "auth-2", Payload: chalRespPayload}) + + authResult, _ := ep.Receive() + var result struct { + Status string `json:"status"` + } + json.Unmarshal(authResult.Payload, &result) + if result.Status != "ok" { + t.Fatalf("auth failed: %s", result.Status) + } + + chatReq := core.ChatRequestEvent{ + SessionID: "ebot-test", + UserPrompt: "Hello from ebot", + Model: "test", + } + chatPayload, _ := codec.Marshal(chatReq) + ep.Send(core.Message{Version: 1, Type: core.TypeChat, Source: "ebot.client", + Timestamp: time.Now().UTC(), RequestID: "chat-1", Capability: "ai:chat", + Payload: chatPayload}) + + chatResp, err := ep.Receive() + if err != nil { + t.Fatalf("receive chat response: %v", err) + } + + var resp core.ChatResponseEvent + if err := codec.Unmarshal(chatResp.Payload, &resp); err != nil { + t.Fatalf("unmarshal chat response: %v", err) + } + + if resp.Response != "Echo: Hello from ebot" { + t.Errorf("unexpected response: %q", resp.Response) + } + + <-serverDone +} diff --git a/tests/stress_test.go b/tests/stress_test.go index b60dd7e..cd26bc7 100644 --- a/tests/stress_test.go +++ b/tests/stress_test.go @@ -1,235 +1,235 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package tests - -import ( - "crypto/rand" - "sync" - "testing" - "time" - - "github.com/embeddedos-org/eipc/core" - "github.com/embeddedos-org/eipc/protocol" - "github.com/embeddedos-org/eipc/transport/tcp" -) - -func TestStress_LargePayload(t *testing.T) { - secret := []byte(testSecret) - codec := protocol.DefaultCodec() - - transport := tcp.New() - if err := transport.Listen("127.0.0.1:0"); err != nil { - t.Fatalf("listen: %v", err) - } - defer transport.Close() - addr := transport.Addr() - - payload := make([]byte, 512*1024) // 512KB - _ = rand.Read(payload) - - done := make(chan struct{}) - go func() { - defer close(done) - conn, err := transport.Accept() - if err != nil { - return - } - defer conn.Close() - - ep := core.NewServerEndpoint(conn, codec, secret) - msg, err := ep.Receive() - if err != nil { - t.Errorf("server receive: %v", err) - return - } - if len(msg.Payload) != len(payload) { - t.Errorf("expected payload %d bytes, got %d", len(payload), len(msg.Payload)) - } - }() - - clientTransport := tcp.New() - conn, err := clientTransport.Dial(addr) - if err != nil { - t.Fatalf("dial: %v", err) - } - defer conn.Close() - - ep := core.NewClientEndpoint(conn, codec, secret, "") - err = ep.Send(core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeIntent, - Source: "stress.client", - Timestamp: time.Now().UTC(), - RequestID: "large-1", - Priority: core.PriorityP0, - Payload: payload, - }) - if err != nil { - t.Fatalf("send large payload: %v", err) - } - - <-done -} - -func TestStress_ConcurrentClients(t *testing.T) { - secret := []byte(testSecret) - codec := protocol.DefaultCodec() - numClients := 10 - - serverTransport := tcp.New() - if err := serverTransport.Listen("127.0.0.1:0"); err != nil { - t.Fatalf("listen: %v", err) - } - defer serverTransport.Close() - addr := serverTransport.Addr() - - var serverWg sync.WaitGroup - serverWg.Add(numClients) - - go func() { - for i := 0; i < numClients; i++ { - conn, err := serverTransport.Accept() - if err != nil { - return - } - go func() { - defer serverWg.Done() - defer conn.Close() - ep := core.NewServerEndpoint(conn, codec, secret) - msg, err := ep.Receive() - if err != nil { - t.Errorf("server receive: %v", err) - return - } - - resp := core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeAck, - Source: "server", - Timestamp: time.Now().UTC(), - RequestID: msg.RequestID, - Priority: core.PriorityP0, - Payload: []byte(`{"status":"ok"}`), - } - ep.Send(resp) - }() - } - }() - - var clientWg sync.WaitGroup - errors := make(chan error, numClients) - - for i := 0; i < numClients; i++ { - clientWg.Add(1) - go func(id int) { - defer clientWg.Done() - - ct := tcp.New() - conn, err := ct.Dial(addr) - if err != nil { - errors <- err - return - } - defer conn.Close() - - ep := core.NewClientEndpoint(conn, codec, secret, "") - err = ep.Send(core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeIntent, - Source: "client", - Timestamp: time.Now().UTC(), - RequestID: "concurrent", - Priority: core.PriorityP1, - Payload: []byte(`{"intent":"test"}`), - }) - if err != nil { - errors <- err - return - } - - _, err = ep.Receive() - if err != nil { - errors <- err - } - }(i) - } - - clientWg.Wait() - close(errors) - - for err := range errors { - t.Errorf("client error: %v", err) - } - - serverWg.Wait() -} - -func TestStress_MessageOrdering(t *testing.T) { - secret := []byte(testSecret) - codec := protocol.DefaultCodec() - numMessages := 50 - - serverTransport := tcp.New() - if err := serverTransport.Listen("127.0.0.1:0"); err != nil { - t.Fatalf("listen: %v", err) - } - defer serverTransport.Close() - addr := serverTransport.Addr() - - received := make(chan string, numMessages) - done := make(chan struct{}) - - go func() { - defer close(done) - conn, err := serverTransport.Accept() - if err != nil { - return - } - defer conn.Close() - - ep := core.NewServerEndpoint(conn, codec, secret) - for i := 0; i < numMessages; i++ { - msg, err := ep.Receive() - if err != nil { - t.Errorf("receive %d: %v", i, err) - return - } - received <- msg.RequestID - } - }() - - ct := tcp.New() - conn, err := ct.Dial(addr) - if err != nil { - t.Fatalf("dial: %v", err) - } - defer conn.Close() - - ep := core.NewClientEndpoint(conn, codec, secret, "") - for i := 0; i < numMessages; i++ { - err := ep.Send(core.Message{ - Version: core.ProtocolVersion, - Type: core.TypeIntent, - Source: "ordering.client", - Timestamp: time.Now().UTC(), - RequestID: string(rune('A' + i%26)), - Priority: core.PriorityP0, - Payload: []byte(`{}`), - }) - if err != nil { - t.Fatalf("send %d: %v", i, err) - } - } - - <-done - close(received) - - count := 0 - for range received { - count++ - } - if count != numMessages { - t.Errorf("expected %d messages, received %d", numMessages, count) - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package tests + +import ( + "crypto/rand" + "sync" + "testing" + "time" + + "github.com/embeddedos-org/eipc/core" + "github.com/embeddedos-org/eipc/protocol" + "github.com/embeddedos-org/eipc/transport/tcp" +) + +func TestStress_LargePayload(t *testing.T) { + secret := []byte(testSecret) + codec := protocol.DefaultCodec() + + transport := tcp.New() + if err := transport.Listen("127.0.0.1:0"); err != nil { + t.Fatalf("listen: %v", err) + } + defer transport.Close() + addr := transport.Addr() + + payload := make([]byte, 512*1024) // 512KB + _ = rand.Read(payload) + + done := make(chan struct{}) + go func() { + defer close(done) + conn, err := transport.Accept() + if err != nil { + return + } + defer conn.Close() + + ep := core.NewServerEndpoint(conn, codec, secret) + msg, err := ep.Receive() + if err != nil { + t.Errorf("server receive: %v", err) + return + } + if len(msg.Payload) != len(payload) { + t.Errorf("expected payload %d bytes, got %d", len(payload), len(msg.Payload)) + } + }() + + clientTransport := tcp.New() + conn, err := clientTransport.Dial(addr) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer conn.Close() + + ep := core.NewClientEndpoint(conn, codec, secret, "") + err = ep.Send(core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeIntent, + Source: "stress.client", + Timestamp: time.Now().UTC(), + RequestID: "large-1", + Priority: core.PriorityP0, + Payload: payload, + }) + if err != nil { + t.Fatalf("send large payload: %v", err) + } + + <-done +} + +func TestStress_ConcurrentClients(t *testing.T) { + secret := []byte(testSecret) + codec := protocol.DefaultCodec() + numClients := 10 + + serverTransport := tcp.New() + if err := serverTransport.Listen("127.0.0.1:0"); err != nil { + t.Fatalf("listen: %v", err) + } + defer serverTransport.Close() + addr := serverTransport.Addr() + + var serverWg sync.WaitGroup + serverWg.Add(numClients) + + go func() { + for i := 0; i < numClients; i++ { + conn, err := serverTransport.Accept() + if err != nil { + return + } + go func() { + defer serverWg.Done() + defer conn.Close() + ep := core.NewServerEndpoint(conn, codec, secret) + msg, err := ep.Receive() + if err != nil { + t.Errorf("server receive: %v", err) + return + } + + resp := core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeAck, + Source: "server", + Timestamp: time.Now().UTC(), + RequestID: msg.RequestID, + Priority: core.PriorityP0, + Payload: []byte(`{"status":"ok"}`), + } + ep.Send(resp) + }() + } + }() + + var clientWg sync.WaitGroup + errors := make(chan error, numClients) + + for i := 0; i < numClients; i++ { + clientWg.Add(1) + go func(id int) { + defer clientWg.Done() + + ct := tcp.New() + conn, err := ct.Dial(addr) + if err != nil { + errors <- err + return + } + defer conn.Close() + + ep := core.NewClientEndpoint(conn, codec, secret, "") + err = ep.Send(core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeIntent, + Source: "client", + Timestamp: time.Now().UTC(), + RequestID: "concurrent", + Priority: core.PriorityP1, + Payload: []byte(`{"intent":"test"}`), + }) + if err != nil { + errors <- err + return + } + + _, err = ep.Receive() + if err != nil { + errors <- err + } + }(i) + } + + clientWg.Wait() + close(errors) + + for err := range errors { + t.Errorf("client error: %v", err) + } + + serverWg.Wait() +} + +func TestStress_MessageOrdering(t *testing.T) { + secret := []byte(testSecret) + codec := protocol.DefaultCodec() + numMessages := 50 + + serverTransport := tcp.New() + if err := serverTransport.Listen("127.0.0.1:0"); err != nil { + t.Fatalf("listen: %v", err) + } + defer serverTransport.Close() + addr := serverTransport.Addr() + + received := make(chan string, numMessages) + done := make(chan struct{}) + + go func() { + defer close(done) + conn, err := serverTransport.Accept() + if err != nil { + return + } + defer conn.Close() + + ep := core.NewServerEndpoint(conn, codec, secret) + for i := 0; i < numMessages; i++ { + msg, err := ep.Receive() + if err != nil { + t.Errorf("receive %d: %v", i, err) + return + } + received <- msg.RequestID + } + }() + + ct := tcp.New() + conn, err := ct.Dial(addr) + if err != nil { + t.Fatalf("dial: %v", err) + } + defer conn.Close() + + ep := core.NewClientEndpoint(conn, codec, secret, "") + for i := 0; i < numMessages; i++ { + err := ep.Send(core.Message{ + Version: core.ProtocolVersion, + Type: core.TypeIntent, + Source: "ordering.client", + Timestamp: time.Now().UTC(), + RequestID: string(rune('A' + i%26)), + Priority: core.PriorityP0, + Payload: []byte(`{}`), + }) + if err != nil { + t.Fatalf("send %d: %v", i, err) + } + } + + <-done + close(received) + + count := 0 + for range received { + count++ + } + if count != numMessages { + t.Errorf("expected %d messages, received %d", numMessages, count) + } +} diff --git a/transport/shm/shm.go b/transport/shm/shm.go index e3022e0..4770b42 100644 --- a/transport/shm/shm.go +++ b/transport/shm/shm.go @@ -1,208 +1,208 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package shm - -import ( - "fmt" - "runtime" - "sync" - "sync/atomic" - - "github.com/embeddedos-org/eipc/protocol" - "github.com/embeddedos-org/eipc/transport" -) - -// Config holds shared memory transport configuration. -type Config struct { - Name string // Shared memory region name - BufferSize int // Ring buffer size in bytes (default 64KB) - SlotCount int // Number of message slots (default 256) -} - -const ( - DefaultBufferSize = 64 * 1024 - DefaultSlotCount = 256 - MaxSlotSize = 8192 -) - -// RingBuffer is an in-process shared memory ring buffer for -// high-rate feature streams between ENI and EAI. -// -// This is a process-local implementation suitable for threads/goroutines. -// For true cross-process shared memory, the platform-specific mmap -// implementation should be used (see shm_linux.go, shm_windows.go). -type RingBuffer struct { - mu sync.Mutex - name string - buf []byte - slots int - slotSize int - head atomic.Uint64 - tail atomic.Uint64 - hmacKey []byte // When set, frames are HMAC-signed on write and verified on read -} - -// SetHMACKey enables HMAC-SHA256 signing/verification for SHM frames. -func (rb *RingBuffer) SetHMACKey(key []byte) { - rb.hmacKey = make([]byte, len(key)) - copy(rb.hmacKey, key) -} - -// NewRingBuffer creates a new shared-memory ring buffer. -func NewRingBuffer(cfg Config) *RingBuffer { - if cfg.BufferSize <= 0 { - cfg.BufferSize = DefaultBufferSize - } - if cfg.SlotCount <= 0 { - cfg.SlotCount = DefaultSlotCount - } - slotSize := cfg.BufferSize / cfg.SlotCount - if slotSize > MaxSlotSize { - slotSize = MaxSlotSize - } - - return &RingBuffer{ - name: cfg.Name, - buf: make([]byte, cfg.SlotCount*slotSize), - slots: cfg.SlotCount, - slotSize: slotSize, - } -} - -// Write places a frame into the next available slot. -// Returns ErrBackpressure if the buffer is full. -func (rb *RingBuffer) Write(frame *protocol.Frame) error { - data := frame.SignableBytes() - if len(data) > rb.slotSize-2 { - return fmt.Errorf("frame too large for slot (%d > %d)", len(data), rb.slotSize-2) - } - - rb.mu.Lock() - defer rb.mu.Unlock() - - head := rb.head.Load() - tail := rb.tail.Load() - - if head-tail >= uint64(rb.slots) { - return fmt.Errorf("ring buffer full (backpressure)") - } - - idx := int(head % uint64(rb.slots)) - offset := idx * rb.slotSize - - rb.buf[offset] = byte(len(data) >> 8) - rb.buf[offset+1] = byte(len(data) & 0xFF) - copy(rb.buf[offset+2:], data) - - rb.head.Add(1) - return nil -} - -// Read retrieves the next frame from the buffer. -// Returns nil if the buffer is empty. -func (rb *RingBuffer) Read() (*protocol.Frame, error) { - rb.mu.Lock() - defer rb.mu.Unlock() - - head := rb.head.Load() - tail := rb.tail.Load() - - if tail >= head { - return nil, nil - } - - idx := int(tail % uint64(rb.slots)) - offset := idx * rb.slotSize - - length := int(rb.buf[offset])<<8 | int(rb.buf[offset+1]) - if length > rb.slotSize-2 { - return nil, fmt.Errorf("invalid slot length (%d)", length) - } - data := make([]byte, length) - copy(data, rb.buf[offset+2:offset+2+length]) - - rb.tail.Add(1) - - if length < protocol.FrameFixedSize { - return nil, fmt.Errorf("slot data too short (%d bytes)", length) - } - - frame := &protocol.Frame{ - Version: uint16(data[4])<<8 | uint16(data[5]), - MsgType: data[6], - Flags: data[7], - } - - headerLen := int(data[8])<<24 | int(data[9])<<16 | int(data[10])<<8 | int(data[11]) - payloadLen := int(data[12])<<24 | int(data[13])<<16 | int(data[14])<<8 | int(data[15]) - - pos := protocol.FrameFixedSize - if headerLen > 0 && pos+headerLen <= length { - frame.Header = make([]byte, headerLen) - copy(frame.Header, data[pos:pos+headerLen]) - pos += headerLen - } - if payloadLen > 0 && pos+payloadLen <= length { - frame.Payload = make([]byte, payloadLen) - copy(frame.Payload, data[pos:pos+payloadLen]) - } - - return frame, nil -} - -// Len returns the number of unread messages in the buffer. -func (rb *RingBuffer) Len() int { - return int(rb.head.Load() - rb.tail.Load()) -} - -// Name returns the shared memory region name. -func (rb *RingBuffer) Name() string { - return rb.name -} - -// Connection wraps a RingBuffer pair (tx/rx) as a transport.Connection. -type Connection struct { - tx *RingBuffer - rx *RingBuffer - remote string - codec protocol.Codec -} - -// NewConnection creates a shared-memory connection using two ring buffers. -func NewConnection(tx, rx *RingBuffer, remote string) *Connection { - return &Connection{ - tx: tx, - rx: rx, - remote: remote, - codec: protocol.DefaultCodec(), - } -} - -func (c *Connection) Send(frame *protocol.Frame) error { - return c.tx.Write(frame) -} - -func (c *Connection) Receive() (*protocol.Frame, error) { - for { - frame, err := c.rx.Read() - if err != nil { - return nil, err - } - if frame != nil { - return frame, nil - } - runtime.Gosched() - } -} - -func (c *Connection) Close() error { - return nil -} - -func (c *Connection) RemoteAddr() string { - return c.remote -} - -// Ensure Connection satisfies transport.Connection. -var _ transport.Connection = (*Connection)(nil) +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package shm + +import ( + "fmt" + "runtime" + "sync" + "sync/atomic" + + "github.com/embeddedos-org/eipc/protocol" + "github.com/embeddedos-org/eipc/transport" +) + +// Config holds shared memory transport configuration. +type Config struct { + Name string // Shared memory region name + BufferSize int // Ring buffer size in bytes (default 64KB) + SlotCount int // Number of message slots (default 256) +} + +const ( + DefaultBufferSize = 64 * 1024 + DefaultSlotCount = 256 + MaxSlotSize = 8192 +) + +// RingBuffer is an in-process shared memory ring buffer for +// high-rate feature streams between ENI and EAI. +// +// This is a process-local implementation suitable for threads/goroutines. +// For true cross-process shared memory, the platform-specific mmap +// implementation should be used (see shm_linux.go, shm_windows.go). +type RingBuffer struct { + mu sync.Mutex + name string + buf []byte + slots int + slotSize int + head atomic.Uint64 + tail atomic.Uint64 + hmacKey []byte // When set, frames are HMAC-signed on write and verified on read +} + +// SetHMACKey enables HMAC-SHA256 signing/verification for SHM frames. +func (rb *RingBuffer) SetHMACKey(key []byte) { + rb.hmacKey = make([]byte, len(key)) + copy(rb.hmacKey, key) +} + +// NewRingBuffer creates a new shared-memory ring buffer. +func NewRingBuffer(cfg Config) *RingBuffer { + if cfg.BufferSize <= 0 { + cfg.BufferSize = DefaultBufferSize + } + if cfg.SlotCount <= 0 { + cfg.SlotCount = DefaultSlotCount + } + slotSize := cfg.BufferSize / cfg.SlotCount + if slotSize > MaxSlotSize { + slotSize = MaxSlotSize + } + + return &RingBuffer{ + name: cfg.Name, + buf: make([]byte, cfg.SlotCount*slotSize), + slots: cfg.SlotCount, + slotSize: slotSize, + } +} + +// Write places a frame into the next available slot. +// Returns ErrBackpressure if the buffer is full. +func (rb *RingBuffer) Write(frame *protocol.Frame) error { + data := frame.SignableBytes() + if len(data) > rb.slotSize-2 { + return fmt.Errorf("frame too large for slot (%d > %d)", len(data), rb.slotSize-2) + } + + rb.mu.Lock() + defer rb.mu.Unlock() + + head := rb.head.Load() + tail := rb.tail.Load() + + if head-tail >= uint64(rb.slots) { + return fmt.Errorf("ring buffer full (backpressure)") + } + + idx := int(head % uint64(rb.slots)) + offset := idx * rb.slotSize + + rb.buf[offset] = byte(len(data) >> 8) + rb.buf[offset+1] = byte(len(data) & 0xFF) + copy(rb.buf[offset+2:], data) + + rb.head.Add(1) + return nil +} + +// Read retrieves the next frame from the buffer. +// Returns nil if the buffer is empty. +func (rb *RingBuffer) Read() (*protocol.Frame, error) { + rb.mu.Lock() + defer rb.mu.Unlock() + + head := rb.head.Load() + tail := rb.tail.Load() + + if tail >= head { + return nil, nil + } + + idx := int(tail % uint64(rb.slots)) + offset := idx * rb.slotSize + + length := int(rb.buf[offset])<<8 | int(rb.buf[offset+1]) + if length > rb.slotSize-2 { + return nil, fmt.Errorf("invalid slot length (%d)", length) + } + data := make([]byte, length) + copy(data, rb.buf[offset+2:offset+2+length]) + + rb.tail.Add(1) + + if length < protocol.FrameFixedSize { + return nil, fmt.Errorf("slot data too short (%d bytes)", length) + } + + frame := &protocol.Frame{ + Version: uint16(data[4])<<8 | uint16(data[5]), + MsgType: data[6], + Flags: data[7], + } + + headerLen := int(data[8])<<24 | int(data[9])<<16 | int(data[10])<<8 | int(data[11]) + payloadLen := int(data[12])<<24 | int(data[13])<<16 | int(data[14])<<8 | int(data[15]) + + pos := protocol.FrameFixedSize + if headerLen > 0 && pos+headerLen <= length { + frame.Header = make([]byte, headerLen) + copy(frame.Header, data[pos:pos+headerLen]) + pos += headerLen + } + if payloadLen > 0 && pos+payloadLen <= length { + frame.Payload = make([]byte, payloadLen) + copy(frame.Payload, data[pos:pos+payloadLen]) + } + + return frame, nil +} + +// Len returns the number of unread messages in the buffer. +func (rb *RingBuffer) Len() int { + return int(rb.head.Load() - rb.tail.Load()) +} + +// Name returns the shared memory region name. +func (rb *RingBuffer) Name() string { + return rb.name +} + +// Connection wraps a RingBuffer pair (tx/rx) as a transport.Connection. +type Connection struct { + tx *RingBuffer + rx *RingBuffer + remote string + codec protocol.Codec +} + +// NewConnection creates a shared-memory connection using two ring buffers. +func NewConnection(tx, rx *RingBuffer, remote string) *Connection { + return &Connection{ + tx: tx, + rx: rx, + remote: remote, + codec: protocol.DefaultCodec(), + } +} + +func (c *Connection) Send(frame *protocol.Frame) error { + return c.tx.Write(frame) +} + +func (c *Connection) Receive() (*protocol.Frame, error) { + for { + frame, err := c.rx.Read() + if err != nil { + return nil, err + } + if frame != nil { + return frame, nil + } + runtime.Gosched() + } +} + +func (c *Connection) Close() error { + return nil +} + +func (c *Connection) RemoteAddr() string { + return c.remote +} + +// Ensure Connection satisfies transport.Connection. +var _ transport.Connection = (*Connection)(nil) diff --git a/transport/shm/shm_test.go b/transport/shm/shm_test.go index 5322ba2..727ad84 100644 --- a/transport/shm/shm_test.go +++ b/transport/shm/shm_test.go @@ -1,94 +1,94 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package shm - -import ( - "testing" - - "github.com/embeddedos-org/eipc/protocol" -) - -func TestRingBufferWriteRead(t *testing.T) { - rb := NewRingBuffer(Config{Name: "test", BufferSize: 8192, SlotCount: 8}) - - frame := &protocol.Frame{ - Version: 1, - MsgType: 'i', - Flags: 0, - Header: []byte(`{"service_id":"test"}`), - Payload: []byte(`{"intent":"move_left"}`), - } - - if err := rb.Write(frame); err != nil { - t.Fatalf("Write failed: %v", err) - } - - if rb.Len() != 1 { - t.Errorf("expected Len()=1, got %d", rb.Len()) - } - - readFrame, err := rb.Read() - if err != nil { - t.Fatalf("Read failed: %v", err) - } - if readFrame == nil { - t.Fatal("Read returned nil frame") - } - if readFrame.MsgType != 'i' { - t.Errorf("expected msg_type 'i', got %c", readFrame.MsgType) - } -} - -func TestRingBufferEmpty(t *testing.T) { - rb := NewRingBuffer(Config{Name: "empty", BufferSize: 4096, SlotCount: 4}) - - frame, err := rb.Read() - if err != nil { - t.Fatalf("Read from empty buffer should not error: %v", err) - } - if frame != nil { - t.Error("Read from empty buffer should return nil") - } -} - -func TestRingBufferFull(t *testing.T) { - rb := NewRingBuffer(Config{Name: "full", BufferSize: 4096, SlotCount: 2}) - - frame := &protocol.Frame{ - Version: 1, - MsgType: 'h', - Header: []byte(`{}`), - Payload: []byte(`{}`), - } - - _ = rb.Write(frame) - _ = rb.Write(frame) - - err := rb.Write(frame) - if err == nil { - t.Error("expected backpressure error when buffer is full") - } -} - -func TestConnectionInterface(t *testing.T) { - txBuf := NewRingBuffer(Config{Name: "tx", BufferSize: 8192, SlotCount: 8}) - rxBuf := NewRingBuffer(Config{Name: "rx", BufferSize: 8192, SlotCount: 8}) - - conn := NewConnection(txBuf, rxBuf, "shm://test") - - if conn.RemoteAddr() != "shm://test" { - t.Errorf("expected remote addr shm://test, got %s", conn.RemoteAddr()) - } - - if err := conn.Close(); err != nil { - t.Errorf("Close should not error: %v", err) - } -} - -func TestRingBufferName(t *testing.T) { - rb := NewRingBuffer(Config{Name: "my-region"}) - if rb.Name() != "my-region" { - t.Errorf("expected name my-region, got %s", rb.Name()) - } -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package shm + +import ( + "testing" + + "github.com/embeddedos-org/eipc/protocol" +) + +func TestRingBufferWriteRead(t *testing.T) { + rb := NewRingBuffer(Config{Name: "test", BufferSize: 8192, SlotCount: 8}) + + frame := &protocol.Frame{ + Version: 1, + MsgType: 'i', + Flags: 0, + Header: []byte(`{"service_id":"test"}`), + Payload: []byte(`{"intent":"move_left"}`), + } + + if err := rb.Write(frame); err != nil { + t.Fatalf("Write failed: %v", err) + } + + if rb.Len() != 1 { + t.Errorf("expected Len()=1, got %d", rb.Len()) + } + + readFrame, err := rb.Read() + if err != nil { + t.Fatalf("Read failed: %v", err) + } + if readFrame == nil { + t.Fatal("Read returned nil frame") + } + if readFrame.MsgType != 'i' { + t.Errorf("expected msg_type 'i', got %c", readFrame.MsgType) + } +} + +func TestRingBufferEmpty(t *testing.T) { + rb := NewRingBuffer(Config{Name: "empty", BufferSize: 4096, SlotCount: 4}) + + frame, err := rb.Read() + if err != nil { + t.Fatalf("Read from empty buffer should not error: %v", err) + } + if frame != nil { + t.Error("Read from empty buffer should return nil") + } +} + +func TestRingBufferFull(t *testing.T) { + rb := NewRingBuffer(Config{Name: "full", BufferSize: 4096, SlotCount: 2}) + + frame := &protocol.Frame{ + Version: 1, + MsgType: 'h', + Header: []byte(`{}`), + Payload: []byte(`{}`), + } + + _ = rb.Write(frame) + _ = rb.Write(frame) + + err := rb.Write(frame) + if err == nil { + t.Error("expected backpressure error when buffer is full") + } +} + +func TestConnectionInterface(t *testing.T) { + txBuf := NewRingBuffer(Config{Name: "tx", BufferSize: 8192, SlotCount: 8}) + rxBuf := NewRingBuffer(Config{Name: "rx", BufferSize: 8192, SlotCount: 8}) + + conn := NewConnection(txBuf, rxBuf, "shm://test") + + if conn.RemoteAddr() != "shm://test" { + t.Errorf("expected remote addr shm://test, got %s", conn.RemoteAddr()) + } + + if err := conn.Close(); err != nil { + t.Errorf("Close should not error: %v", err) + } +} + +func TestRingBufferName(t *testing.T) { + rb := NewRingBuffer(Config{Name: "my-region"}) + if rb.Name() != "my-region" { + t.Errorf("expected name my-region, got %s", rb.Name()) + } +} diff --git a/transport/tcp/tcp.go b/transport/tcp/tcp.go index 6d8e6cc..d235c59 100644 --- a/transport/tcp/tcp.go +++ b/transport/tcp/tcp.go @@ -1,154 +1,154 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package tcp - -import ( - "crypto/tls" - "fmt" - "net" - "os" - "sync" - "time" - - "github.com/embeddedos-org/eipc/transport" -) - -// Transport implements the EIPC transport interface over TCP. -// Works on Linux, Windows, and macOS. Supports optional TLS/mTLS. -type Transport struct { - mu sync.Mutex - listener net.Listener - tlsConfig *tls.Config -} - -// New creates a new TCP transport. -func New() *Transport { - return &Transport{} -} - -// WithTLS configures TLS from cert/key/CA files. -func (t *Transport) WithTLS(certFile, keyFile, caFile string) error { - cfg, err := LoadTLSConfig(certFile, keyFile, caFile) - if err != nil { - return err - } - t.tlsConfig = cfg - return nil -} - -// WithTLSConfig sets a pre-built TLS config. -func (t *Transport) WithTLSConfig(cfg *tls.Config) { - t.tlsConfig = cfg -} - -// Listen starts a TCP listener on the given address (e.g. "127.0.0.1:9090"). -// If TLS is configured, wraps listener with tls.NewListener. -func (t *Transport) Listen(address string) error { - t.mu.Lock() - defer t.mu.Unlock() - - ln, err := net.Listen("tcp", address) - if err != nil { - return fmt.Errorf("tcp listen: %w", err) - } - - if t.tlsConfig != nil { - ln = tls.NewListener(ln, t.tlsConfig) - } - - t.listener = ln - return nil -} - -// Dial connects to a remote TCP address and returns a Connection. -// If TLS is configured, uses tls.Dial; also enables TCP keepalive. -func (t *Transport) Dial(address string) (transport.Connection, error) { - if t.tlsConfig != nil { - clientTLS := t.tlsConfig.Clone() - conn, err := tls.Dial("tcp", address, clientTLS) - if err != nil { - return nil, fmt.Errorf("tcp tls dial: %w", err) - } - return transport.NewConnWrapper(conn), nil - } - - conn, err := net.Dial("tcp", address) - if err != nil { - return nil, fmt.Errorf("tcp dial: %w", err) - } - if tc, ok := conn.(*net.TCPConn); ok { - _ = tc.SetKeepAlive(true) - _ = tc.SetKeepAlivePeriod(30 * time.Second) - } - return transport.NewConnWrapper(conn), nil -} - -// Accept waits for and returns the next inbound TCP connection. -// Enables keepalive on accepted TCP connections. -func (t *Transport) Accept() (transport.Connection, error) { - t.mu.Lock() - ln := t.listener - t.mu.Unlock() - - if ln == nil { - return nil, fmt.Errorf("tcp: not listening") - } - - conn, err := ln.Accept() - if err != nil { - return nil, fmt.Errorf("tcp accept: %w", err) - } - - if tc, ok := conn.(*net.TCPConn); ok { - _ = tc.SetKeepAlive(true) - _ = tc.SetKeepAlivePeriod(30 * time.Second) - } - - return transport.NewConnWrapper(conn), nil -} - -// Close shuts down the TCP listener. -func (t *Transport) Close() error { - t.mu.Lock() - defer t.mu.Unlock() - - if t.listener != nil { - return t.listener.Close() - } - return nil -} - -// Addr returns the listener's address. Returns "" if not listening. -func (t *Transport) Addr() string { - t.mu.Lock() - defer t.mu.Unlock() - if t.listener != nil { - return t.listener.Addr().String() - } - return "" -} - -// SetupTLSFromEnv configures TLS from environment variables. -// EIPC_TLS_CERT, EIPC_TLS_KEY for cert/key files. -// EIPC_TLS_CA for mTLS CA file. -// EIPC_TLS_AUTO_CERT=true for auto-generated self-signed cert. -func (t *Transport) SetupTLSFromEnv() error { - certFile := os.Getenv("EIPC_TLS_CERT") - keyFile := os.Getenv("EIPC_TLS_KEY") - - if certFile != "" && keyFile != "" { - return t.WithTLS(certFile, keyFile, os.Getenv("EIPC_TLS_CA")) - } - - if os.Getenv("EIPC_TLS_AUTO_CERT") == "true" { - cfg, err := AutoTLSConfig() - if err != nil { - return err - } - t.tlsConfig = cfg - return nil - } - - return nil -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package tcp + +import ( + "crypto/tls" + "fmt" + "net" + "os" + "sync" + "time" + + "github.com/embeddedos-org/eipc/transport" +) + +// Transport implements the EIPC transport interface over TCP. +// Works on Linux, Windows, and macOS. Supports optional TLS/mTLS. +type Transport struct { + mu sync.Mutex + listener net.Listener + tlsConfig *tls.Config +} + +// New creates a new TCP transport. +func New() *Transport { + return &Transport{} +} + +// WithTLS configures TLS from cert/key/CA files. +func (t *Transport) WithTLS(certFile, keyFile, caFile string) error { + cfg, err := LoadTLSConfig(certFile, keyFile, caFile) + if err != nil { + return err + } + t.tlsConfig = cfg + return nil +} + +// WithTLSConfig sets a pre-built TLS config. +func (t *Transport) WithTLSConfig(cfg *tls.Config) { + t.tlsConfig = cfg +} + +// Listen starts a TCP listener on the given address (e.g. "127.0.0.1:9090"). +// If TLS is configured, wraps listener with tls.NewListener. +func (t *Transport) Listen(address string) error { + t.mu.Lock() + defer t.mu.Unlock() + + ln, err := net.Listen("tcp", address) + if err != nil { + return fmt.Errorf("tcp listen: %w", err) + } + + if t.tlsConfig != nil { + ln = tls.NewListener(ln, t.tlsConfig) + } + + t.listener = ln + return nil +} + +// Dial connects to a remote TCP address and returns a Connection. +// If TLS is configured, uses tls.Dial; also enables TCP keepalive. +func (t *Transport) Dial(address string) (transport.Connection, error) { + if t.tlsConfig != nil { + clientTLS := t.tlsConfig.Clone() + conn, err := tls.Dial("tcp", address, clientTLS) + if err != nil { + return nil, fmt.Errorf("tcp tls dial: %w", err) + } + return transport.NewConnWrapper(conn), nil + } + + conn, err := net.Dial("tcp", address) + if err != nil { + return nil, fmt.Errorf("tcp dial: %w", err) + } + if tc, ok := conn.(*net.TCPConn); ok { + _ = tc.SetKeepAlive(true) + _ = tc.SetKeepAlivePeriod(30 * time.Second) + } + return transport.NewConnWrapper(conn), nil +} + +// Accept waits for and returns the next inbound TCP connection. +// Enables keepalive on accepted TCP connections. +func (t *Transport) Accept() (transport.Connection, error) { + t.mu.Lock() + ln := t.listener + t.mu.Unlock() + + if ln == nil { + return nil, fmt.Errorf("tcp: not listening") + } + + conn, err := ln.Accept() + if err != nil { + return nil, fmt.Errorf("tcp accept: %w", err) + } + + if tc, ok := conn.(*net.TCPConn); ok { + _ = tc.SetKeepAlive(true) + _ = tc.SetKeepAlivePeriod(30 * time.Second) + } + + return transport.NewConnWrapper(conn), nil +} + +// Close shuts down the TCP listener. +func (t *Transport) Close() error { + t.mu.Lock() + defer t.mu.Unlock() + + if t.listener != nil { + return t.listener.Close() + } + return nil +} + +// Addr returns the listener's address. Returns "" if not listening. +func (t *Transport) Addr() string { + t.mu.Lock() + defer t.mu.Unlock() + if t.listener != nil { + return t.listener.Addr().String() + } + return "" +} + +// SetupTLSFromEnv configures TLS from environment variables. +// EIPC_TLS_CERT, EIPC_TLS_KEY for cert/key files. +// EIPC_TLS_CA for mTLS CA file. +// EIPC_TLS_AUTO_CERT=true for auto-generated self-signed cert. +func (t *Transport) SetupTLSFromEnv() error { + certFile := os.Getenv("EIPC_TLS_CERT") + keyFile := os.Getenv("EIPC_TLS_KEY") + + if certFile != "" && keyFile != "" { + return t.WithTLS(certFile, keyFile, os.Getenv("EIPC_TLS_CA")) + } + + if os.Getenv("EIPC_TLS_AUTO_CERT") == "true" { + cfg, err := AutoTLSConfig() + if err != nil { + return err + } + t.tlsConfig = cfg + return nil + } + + return nil +} diff --git a/transport/tcp/tls_config.go b/transport/tcp/tls_config.go index dce3e21..946a530 100644 --- a/transport/tcp/tls_config.go +++ b/transport/tcp/tls_config.go @@ -1,108 +1,108 @@ -// SPDX-License-Identifier: MIT -// Copyright (c) 2026 EoS Project - -package tcp - -import ( - "crypto/ecdsa" - "crypto/elliptic" - "crypto/rand" - "crypto/tls" - "crypto/x509" - "crypto/x509/pkix" - "encoding/pem" - "fmt" - "math/big" - "os" - "path/filepath" - "time" -) - -// GenerateSelfSignedCert creates a P-256 ECDSA self-signed certificate -// and returns the TLS certificate suitable for tls.Config. -func GenerateSelfSignedCert() (tls.Certificate, error) { - key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - if err != nil { - return tls.Certificate{}, fmt.Errorf("generate key: %w", err) - } - - serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) - if err != nil { - return tls.Certificate{}, fmt.Errorf("generate serial: %w", err) - } - - template := x509.Certificate{ - SerialNumber: serial, - Subject: pkix.Name{ - Organization: []string{"EoS EIPC"}, - CommonName: "eipc-server", - }, - NotBefore: time.Now().Add(-1 * time.Hour), - NotAfter: time.Now().Add(365 * 24 * time.Hour), - KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, - ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, - BasicConstraintsValid: true, - DNSNames: []string{"localhost"}, - } - - certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) - if err != nil { - return tls.Certificate{}, fmt.Errorf("create certificate: %w", err) - } - - certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) - - keyDER, err := x509.MarshalECPrivateKey(key) - if err != nil { - return tls.Certificate{}, fmt.Errorf("marshal key: %w", err) - } - keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) - - return tls.X509KeyPair(certPEM, keyPEM) -} - -// LoadTLSConfig builds a *tls.Config from cert/key/CA file paths. -// If caFile is provided, mTLS (RequireAndVerifyClientCert) is enabled. -func LoadTLSConfig(certFile, keyFile, caFile string) (*tls.Config, error) { - certFile = filepath.Clean(certFile) - keyFile = filepath.Clean(keyFile) - - cert, err := tls.LoadX509KeyPair(certFile, keyFile) - if err != nil { - return nil, fmt.Errorf("load key pair: %w", err) - } - - cfg := &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, - } - - if caFile != "" { - caFile = filepath.Clean(caFile) - caPEM, err := os.ReadFile(caFile) - if err != nil { - return nil, fmt.Errorf("read CA file: %w", err) - } - pool := x509.NewCertPool() - if !pool.AppendCertsFromPEM(caPEM) { - return nil, fmt.Errorf("failed to parse CA certificate") - } - cfg.ClientCAs = pool - cfg.ClientAuth = tls.RequireAndVerifyClientCert - } - - return cfg, nil -} - -// AutoTLSConfig generates a self-signed cert for development use. -// Controlled by EIPC_TLS_AUTO_CERT env var. -func AutoTLSConfig() (*tls.Config, error) { - cert, err := GenerateSelfSignedCert() - if err != nil { - return nil, err - } - return &tls.Config{ - Certificates: []tls.Certificate{cert}, - MinVersion: tls.VersionTLS12, - }, nil -} +// SPDX-License-Identifier: MIT +// Copyright (c) 2026 EoS Project + +package tcp + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "os" + "path/filepath" + "time" +) + +// GenerateSelfSignedCert creates a P-256 ECDSA self-signed certificate +// and returns the TLS certificate suitable for tls.Config. +func GenerateSelfSignedCert() (tls.Certificate, error) { + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return tls.Certificate{}, fmt.Errorf("generate key: %w", err) + } + + serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128)) + if err != nil { + return tls.Certificate{}, fmt.Errorf("generate serial: %w", err) + } + + template := x509.Certificate{ + SerialNumber: serial, + Subject: pkix.Name{ + Organization: []string{"EoS EIPC"}, + CommonName: "eipc-server", + }, + NotBefore: time.Now().Add(-1 * time.Hour), + NotAfter: time.Now().Add(365 * 24 * time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth, x509.ExtKeyUsageClientAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"localhost"}, + } + + certDER, err := x509.CreateCertificate(rand.Reader, &template, &template, &key.PublicKey, key) + if err != nil { + return tls.Certificate{}, fmt.Errorf("create certificate: %w", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: certDER}) + + keyDER, err := x509.MarshalECPrivateKey(key) + if err != nil { + return tls.Certificate{}, fmt.Errorf("marshal key: %w", err) + } + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: keyDER}) + + return tls.X509KeyPair(certPEM, keyPEM) +} + +// LoadTLSConfig builds a *tls.Config from cert/key/CA file paths. +// If caFile is provided, mTLS (RequireAndVerifyClientCert) is enabled. +func LoadTLSConfig(certFile, keyFile, caFile string) (*tls.Config, error) { + certFile = filepath.Clean(certFile) + keyFile = filepath.Clean(keyFile) + + cert, err := tls.LoadX509KeyPair(certFile, keyFile) + if err != nil { + return nil, fmt.Errorf("load key pair: %w", err) + } + + cfg := &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + } + + if caFile != "" { + caFile = filepath.Clean(caFile) + caPEM, err := os.ReadFile(caFile) + if err != nil { + return nil, fmt.Errorf("read CA file: %w", err) + } + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(caPEM) { + return nil, fmt.Errorf("failed to parse CA certificate") + } + cfg.ClientCAs = pool + cfg.ClientAuth = tls.RequireAndVerifyClientCert + } + + return cfg, nil +} + +// AutoTLSConfig generates a self-signed cert for development use. +// Controlled by EIPC_TLS_AUTO_CERT env var. +func AutoTLSConfig() (*tls.Config, error) { + cert, err := GenerateSelfSignedCert() + if err != nil { + return nil, err + } + return &tls.Config{ + Certificates: []tls.Certificate{cert}, + MinVersion: tls.VersionTLS12, + }, nil +}