diff --git a/ai_api.go b/ai_api.go new file mode 100644 index 00000000..78cef557 --- /dev/null +++ b/ai_api.go @@ -0,0 +1,221 @@ +// Copyright 2026 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" +) + +// AIConfig contains configuration for AI API calls. +type AIConfig struct { + // Endpoint is the API endpoint (e.g., "https://api.openai.com/v1/chat/completions") + Endpoint string + // APIKey is the authentication key for the API + APIKey string + // Model is the model to use (e.g., "gpt-3.5-turbo", "gpt-4") + Model string + // Timeout for API requests (default: 30s) + Timeout time.Duration +} + +// aiMessage represents a message in the OpenAI chat format. +type aiMessage struct { + Role string `json:"role"` + Content string `json:"content"` +} + +// aiChatRequest represents the request to OpenAI chat completions API. +type aiChatRequest struct { + Model string `json:"model"` + Messages []aiMessage `json:"messages"` +} + +// aiChatResponse represents the response from OpenAI chat completions API. +type aiChatResponse struct { + Choices []struct { + Message aiMessage `json:"message"` + } `json:"choices"` + Error *struct { + Message string `json:"message"` + } `json:"error,omitempty"` +} + +// SetAIConfig sets the configuration for AI API calls. +func (e *Enforcer) SetAIConfig(config AIConfig) { + if config.Timeout == 0 { + config.Timeout = 30 * time.Second + } + e.aiConfig = config +} + +// Explain returns an AI-generated explanation of why Enforce returned a particular result. +// It calls the configured OpenAI-compatible API to generate a natural language explanation. +func (e *Enforcer) Explain(rvals ...interface{}) (string, error) { + if e.aiConfig.Endpoint == "" { + return "", errors.New("AI config not set, use SetAIConfig first") + } + + // Get enforcement result and matched rules + result, matchedRules, err := e.EnforceEx(rvals...) + if err != nil { + return "", fmt.Errorf("failed to enforce: %w", err) + } + + // Build context for AI + explainContext := e.buildExplainContext(rvals, result, matchedRules) + + // Call AI API + explanation, err := e.callAIAPI(explainContext) + if err != nil { + return "", fmt.Errorf("failed to get AI explanation: %w", err) + } + + return explanation, nil +} + +// buildExplainContext builds the context string for AI explanation. +func (e *Enforcer) buildExplainContext(rvals []interface{}, result bool, matchedRules []string) string { + var sb strings.Builder + + // Add request information + sb.WriteString("Authorization Request:\n") + sb.WriteString(fmt.Sprintf("Subject: %v\n", rvals[0])) + if len(rvals) > 1 { + sb.WriteString(fmt.Sprintf("Object: %v\n", rvals[1])) + } + if len(rvals) > 2 { + sb.WriteString(fmt.Sprintf("Action: %v\n", rvals[2])) + } + sb.WriteString(fmt.Sprintf("\nEnforcement Result: %v\n", result)) + + // Add matched rules + if len(matchedRules) > 0 { + sb.WriteString("\nMatched Policy Rules:\n") + for _, rule := range matchedRules { + sb.WriteString(fmt.Sprintf("- %s\n", rule)) + } + } else { + sb.WriteString("\nNo policy rules matched.\n") + } + + // Add model information + sb.WriteString("\nAccess Control Model:\n") + if m, ok := e.model["m"]; ok { + for key, ast := range m { + sb.WriteString(fmt.Sprintf("Matcher (%s): %s\n", key, ast.Value)) + } + } + if eff, ok := e.model["e"]; ok { + for key, ast := range eff { + sb.WriteString(fmt.Sprintf("Effect (%s): %s\n", key, ast.Value)) + } + } + + // Add all policies + policies, _ := e.GetPolicy() + if len(policies) > 0 { + sb.WriteString("\nAll Policy Rules:\n") + for _, policy := range policies { + sb.WriteString(fmt.Sprintf("- %s\n", strings.Join(policy, ", "))) + } + } + + return sb.String() +} + +// callAIAPI calls the configured AI API to get an explanation. +func (e *Enforcer) callAIAPI(explainContext string) (string, error) { + // Prepare the request + messages := []aiMessage{ + { + Role: "system", + Content: "You are an expert in access control and authorization systems. " + + "Explain why an authorization request was allowed or denied based on the " + + "provided access control model, policies, and enforcement result. " + + "Be clear, concise, and educational.", + }, + { + Role: "user", + Content: fmt.Sprintf("Please explain the following authorization decision:\n\n%s", explainContext), + }, + } + + reqBody := aiChatRequest{ + Model: e.aiConfig.Model, + Messages: messages, + } + + jsonData, err := json.Marshal(reqBody) + if err != nil { + return "", fmt.Errorf("failed to marshal request: %w", err) + } + + // Create HTTP request with context + reqCtx, cancel := context.WithTimeout(context.Background(), e.aiConfig.Timeout) + defer cancel() + + req, err := http.NewRequestWithContext(reqCtx, http.MethodPost, e.aiConfig.Endpoint, bytes.NewBuffer(jsonData)) + if err != nil { + return "", fmt.Errorf("failed to create request: %w", err) + } + + req.Header.Set("Content-Type", "application/json") + req.Header.Set("Authorization", "Bearer "+e.aiConfig.APIKey) + + // Execute request + client := &http.Client{} + + resp, err := client.Do(req) + if err != nil { + return "", fmt.Errorf("failed to execute request: %w", err) + } + defer resp.Body.Close() + + // Read response + body, err := io.ReadAll(resp.Body) + if err != nil { + return "", fmt.Errorf("failed to read response: %w", err) + } + + // Parse response + var chatResp aiChatResponse + if err := json.Unmarshal(body, &chatResp); err != nil { + return "", fmt.Errorf("failed to parse response: %w", err) + } + + // Check for API errors + if chatResp.Error != nil { + return "", fmt.Errorf("API error: %s", chatResp.Error.Message) + } + + if resp.StatusCode != http.StatusOK { + return "", fmt.Errorf("API returned status %d: %s", resp.StatusCode, string(body)) + } + + // Extract explanation + if len(chatResp.Choices) == 0 { + return "", errors.New("no response from AI") + } + + return chatResp.Choices[0].Message.Content, nil +} diff --git a/ai_api_test.go b/ai_api_test.go new file mode 100644 index 00000000..4ff0b4e9 --- /dev/null +++ b/ai_api_test.go @@ -0,0 +1,250 @@ +// Copyright 2026 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +// TestExplainWithoutConfig tests that Explain returns error when config is not set. +func TestExplainWithoutConfig(t *testing.T) { + e, err := NewEnforcer("examples/basic_model.conf", "examples/basic_policy.csv") + if err != nil { + t.Fatal(err) + } + + _, err = e.Explain("alice", "data1", "read") + if err == nil { + t.Error("Expected error when AI config is not set") + } + if !strings.Contains(err.Error(), "AI config not set") { + t.Errorf("Expected 'AI config not set' error, got: %v", err) + } +} + +// TestExplainWithMockAPI tests Explain with a mock OpenAI-compatible API. +func TestExplainWithMockAPI(t *testing.T) { + // Create a mock server that simulates OpenAI API + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Verify request + if r.Method != http.MethodPost { + t.Errorf("Expected POST request, got %s", r.Method) + } + if r.Header.Get("Content-Type") != "application/json" { + t.Errorf("Expected Content-Type: application/json, got %s", r.Header.Get("Content-Type")) + } + if !strings.HasPrefix(r.Header.Get("Authorization"), "Bearer ") { + t.Errorf("Expected Bearer token in Authorization header, got %s", r.Header.Get("Authorization")) + } + + // Parse request to verify structure + var req aiChatRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Errorf("Failed to decode request: %v", err) + } + + if req.Model != "gpt-3.5-turbo" { + t.Errorf("Expected model gpt-3.5-turbo, got %s", req.Model) + } + + if len(req.Messages) != 2 { + t.Errorf("Expected 2 messages, got %d", len(req.Messages)) + } + + // Send mock response + resp := aiChatResponse{ + Choices: []struct { + Message aiMessage `json:"message"` + }{ + { + Message: aiMessage{ + Role: "assistant", + Content: "The request was allowed because alice has read permission on data1 according to the policy rule.", + }, + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer mockServer.Close() + + // Create enforcer + e, err := NewEnforcer("examples/basic_model.conf", "examples/basic_policy.csv") + if err != nil { + t.Fatal(err) + } + + // Set AI config with mock server + e.SetAIConfig(AIConfig{ + Endpoint: mockServer.URL, + APIKey: "test-api-key", + Model: "gpt-3.5-turbo", + Timeout: 5 * time.Second, + }) + + // Test explanation for allowed request + explanation, err := e.Explain("alice", "data1", "read") + if err != nil { + t.Fatalf("Failed to get explanation: %v", err) + } + + if explanation == "" { + t.Error("Expected non-empty explanation") + } + + if !strings.Contains(explanation, "allowed") { + t.Errorf("Expected explanation to mention 'allowed', got: %s", explanation) + } +} + +// TestExplainDenied tests Explain for a denied request. +func TestExplainDenied(t *testing.T) { + // Create a mock server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := aiChatResponse{ + Choices: []struct { + Message aiMessage `json:"message"` + }{ + { + Message: aiMessage{ + Role: "assistant", + Content: "The request was denied because there is no policy rule that allows alice to write to data1.", + }, + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer mockServer.Close() + + // Create enforcer + e, err := NewEnforcer("examples/basic_model.conf", "examples/basic_policy.csv") + if err != nil { + t.Fatal(err) + } + + // Set AI config + e.SetAIConfig(AIConfig{ + Endpoint: mockServer.URL, + APIKey: "test-api-key", + Model: "gpt-3.5-turbo", + Timeout: 5 * time.Second, + }) + + // Test explanation for denied request + explanation, err := e.Explain("alice", "data1", "write") + if err != nil { + t.Fatalf("Failed to get explanation: %v", err) + } + + if explanation == "" { + t.Error("Expected non-empty explanation") + } + + if !strings.Contains(explanation, "denied") { + t.Errorf("Expected explanation to mention 'denied', got: %s", explanation) + } +} + +// TestExplainAPIError tests handling of API errors. +func TestExplainAPIError(t *testing.T) { + // Create a mock server that returns an error + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := aiChatResponse{ + Error: &struct { + Message string `json:"message"` + }{ + Message: "Invalid API key", + }, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(resp) + })) + defer mockServer.Close() + + // Create enforcer + e, err := NewEnforcer("examples/basic_model.conf", "examples/basic_policy.csv") + if err != nil { + t.Fatal(err) + } + + // Set AI config + e.SetAIConfig(AIConfig{ + Endpoint: mockServer.URL, + APIKey: "invalid-key", + Model: "gpt-3.5-turbo", + Timeout: 5 * time.Second, + }) + + // Test that API error is properly handled + _, err = e.Explain("alice", "data1", "read") + if err == nil { + t.Error("Expected error for API failure") + } + if !strings.Contains(err.Error(), "Invalid API key") { + t.Errorf("Expected API error message, got: %v", err) + } +} + +// TestBuildExplainContext tests the context building function. +func TestBuildExplainContext(t *testing.T) { + e, err := NewEnforcer("examples/basic_model.conf", "examples/basic_policy.csv") + if err != nil { + t.Fatal(err) + } + + // Test with matched rules + rvals := []interface{}{"alice", "data1", "read"} + result := true + matchedRules := []string{"alice, data1, read"} + + context := e.buildExplainContext(rvals, result, matchedRules) + + // Verify context contains expected elements + if !strings.Contains(context, "alice") { + t.Error("Context should contain subject 'alice'") + } + if !strings.Contains(context, "data1") { + t.Error("Context should contain object 'data1'") + } + if !strings.Contains(context, "read") { + t.Error("Context should contain action 'read'") + } + if !strings.Contains(context, "true") { + t.Error("Context should contain result 'true'") + } + if !strings.Contains(context, "alice, data1, read") { + t.Error("Context should contain matched rule") + } + + // Test with no matched rules + context2 := e.buildExplainContext(rvals, false, []string{}) + if !strings.Contains(context2, "No policy rules matched") { + t.Error("Context should indicate no matched rules") + } +} diff --git a/enforcer.go b/enforcer.go index ff8c5431..a6bf1740 100644 --- a/enforcer.go +++ b/enforcer.go @@ -56,6 +56,8 @@ type Enforcer struct { autoNotifyWatcher bool autoNotifyDispatcher bool acceptJsonRequest bool + + aiConfig AIConfig } // EnforceContext is used as the first element of the parameter "rvals" in method "enforce". diff --git a/enforcer_interface.go b/enforcer_interface.go index 73365318..94baf84e 100644 --- a/enforcer_interface.go +++ b/enforcer_interface.go @@ -41,6 +41,7 @@ type IEnforcer interface { GetRoleManager() rbac.RoleManager SetRoleManager(rm rbac.RoleManager) SetEffector(eft effector.Effector) + SetAIConfig(config AIConfig) ClearPolicy() LoadPolicy() error LoadFilteredPolicy(filter interface{}) error @@ -58,6 +59,7 @@ type IEnforcer interface { EnforceExWithMatcher(matcher string, rvals ...interface{}) (bool, []string, error) BatchEnforce(requests [][]interface{}) ([]bool, error) BatchEnforceWithMatcher(matcher string, requests [][]interface{}) ([]bool, error) + Explain(rvals ...interface{}) (string, error) /* RBAC API */ GetRolesForUser(name string, domain ...string) ([]string, error)