diff --git a/ai_api.go b/ai_api.go index 78cef557..cce6b1ae 100644 --- a/ai_api.go +++ b/ai_api.go @@ -145,18 +145,23 @@ func (e *Enforcer) buildExplainContext(rvals []interface{}, result bool, matched // callAIAPI calls the configured AI API to get an explanation. func (e *Enforcer) callAIAPI(explainContext string) (string, error) { + return e.callAIAPIWithSystemPrompt(explainContext, "You are an expert in access control and authorization systems. "+ + "Explain why an authorization request was allowed or denied based on the "+ + "provided access control model, policies, and enforcement result. "+ + "Be clear, concise, and educational.") +} + +// callAIAPIWithSystemPrompt calls the configured AI API with a custom system prompt. +func (e *Enforcer) callAIAPIWithSystemPrompt(userContent, systemPrompt string) (string, error) { // Prepare the request messages := []aiMessage{ { - Role: "system", - Content: "You are an expert in access control and authorization systems. " + - "Explain why an authorization request was allowed or denied based on the " + - "provided access control model, policies, and enforcement result. " + - "Be clear, concise, and educational.", + Role: "system", + Content: systemPrompt, }, { Role: "user", - Content: fmt.Sprintf("Please explain the following authorization decision:\n\n%s", explainContext), + Content: userContent, }, } @@ -219,3 +224,51 @@ func (e *Enforcer) callAIAPI(explainContext string) (string, error) { return chatResp.Choices[0].Message.Content, nil } + +// evaluateAIPolicy evaluates an AI policy by calling the configured LLM API. +// It returns true if the AI policy allows the request, false otherwise. +func (e *Enforcer) evaluateAIPolicy(policyPrompt string, rvals []interface{}) (bool, error) { + if e.aiConfig.Endpoint == "" { + return false, errors.New("AI config not set, use SetAIConfig first") + } + + // Build context for AI + var sb strings.Builder + sb.WriteString("Authorization Request:\n") + if len(rvals) > 0 { + sb.WriteString(fmt.Sprintf("Subject: %v\n", rvals[0])) + } + if len(rvals) > 1 { + sb.WriteString(fmt.Sprintf("Object: %v\n", rvals[1])) + } + if len(rvals) > 2 { + sb.WriteString(fmt.Sprintf("Action: %v\n", rvals[2])) + } + + sb.WriteString(fmt.Sprintf("\nAI Policy Rule: %s\n", policyPrompt)) + sb.WriteString("\nQuestion: Does this request satisfy the AI policy rule? Answer with 'ALLOW' if yes, 'DENY' if no.") + + // Call AI API + systemPrompt := "You are an AI security policy evaluator. " + + "Your task is to determine if an authorization request satisfies the given AI policy rule. " + + "Respond with ONLY the word 'ALLOW' or 'DENY' based on your evaluation." + + response, err := e.callAIAPIWithSystemPrompt(sb.String(), systemPrompt) + if err != nil { + return false, fmt.Errorf("failed to evaluate AI policy: %w", err) + } + + // Parse response + response = strings.TrimSpace(strings.ToUpper(response)) + // More robust parsing: check if response starts with ALLOW or DENY + // to avoid false positives like "I cannot ALLOW this" + if strings.HasPrefix(response, "ALLOW") { + return true, nil + } + if strings.HasPrefix(response, "DENY") { + return false, nil + } + + // If response doesn't clearly start with ALLOW or DENY, deny by default for safety + return false, nil +} diff --git a/ai_policy_api.go b/ai_policy_api.go new file mode 100644 index 00000000..2396ff2e --- /dev/null +++ b/ai_policy_api.go @@ -0,0 +1,123 @@ +// Copyright 2026 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +// GetAIPolicy gets all the AI policy rules in the policy. +func (e *Enforcer) GetAIPolicy() ([][]string, error) { + return e.GetNamedAIPolicy("a") +} + +// GetFilteredAIPolicy gets all the AI policy rules in the policy, field filters can be specified. +func (e *Enforcer) GetFilteredAIPolicy(fieldIndex int, fieldValues ...string) ([][]string, error) { + return e.GetFilteredNamedAIPolicy("a", fieldIndex, fieldValues...) +} + +// GetNamedAIPolicy gets all the AI policy rules in the named policy. +func (e *Enforcer) GetNamedAIPolicy(ptype string) ([][]string, error) { + return e.model.GetPolicy("a", ptype) +} + +// GetFilteredNamedAIPolicy gets all the AI policy rules in the named policy, field filters can be specified. +func (e *Enforcer) GetFilteredNamedAIPolicy(ptype string, fieldIndex int, fieldValues ...string) ([][]string, error) { + return e.model.GetFilteredPolicy("a", ptype, fieldIndex, fieldValues...) +} + +// HasAIPolicy determines whether an AI policy rule exists. +func (e *Enforcer) HasAIPolicy(params ...string) (bool, error) { + return e.HasNamedAIPolicy("a", params...) +} + +// HasNamedAIPolicy determines whether a named AI policy rule exists. +func (e *Enforcer) HasNamedAIPolicy(ptype string, params ...string) (bool, error) { + return e.model.HasPolicy("a", ptype, params) +} + +// AddAIPolicy adds an AI policy rule to the current policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *Enforcer) AddAIPolicy(params ...string) (bool, error) { + return e.AddNamedAIPolicy("a", params...) +} + +// AddAIPolicies adds AI policy rules to the current policy. +// If the rule already exists, the function returns false for the corresponding rule and the rule will not be added. +// Otherwise the function returns true for the corresponding rule by adding the new rule. +func (e *Enforcer) AddAIPolicies(rules [][]string) (bool, error) { + return e.AddNamedAIPolicies("a", rules) +} + +// AddNamedAIPolicy adds an AI policy rule to the current named policy. +// If the rule already exists, the function returns false and the rule will not be added. +// Otherwise the function returns true by adding the new rule. +func (e *Enforcer) AddNamedAIPolicy(ptype string, params ...string) (bool, error) { + return e.addPolicy("a", ptype, params) +} + +// AddNamedAIPolicies adds AI policy rules to the current named policy. +// If the rule already exists, the function returns false for the corresponding policy rule and the rule will not be added. +// Otherwise the function returns true for the corresponding policy rule by adding the new rule. +func (e *Enforcer) AddNamedAIPolicies(ptype string, rules [][]string) (bool, error) { + return e.addPolicies("a", ptype, rules, false) +} + +// RemoveAIPolicy removes an AI policy rule from the current policy. +func (e *Enforcer) RemoveAIPolicy(params ...string) (bool, error) { + return e.RemoveNamedAIPolicy("a", params...) +} + +// RemoveAIPolicies removes AI policy rules from the current policy. +func (e *Enforcer) RemoveAIPolicies(rules [][]string) (bool, error) { + return e.RemoveNamedAIPolicies("a", rules) +} + +// RemoveFilteredAIPolicy removes an AI policy rule from the current policy, field filters can be specified. +func (e *Enforcer) RemoveFilteredAIPolicy(fieldIndex int, fieldValues ...string) (bool, error) { + return e.RemoveFilteredNamedAIPolicy("a", fieldIndex, fieldValues...) +} + +// RemoveNamedAIPolicy removes an AI policy rule from the current named policy. +func (e *Enforcer) RemoveNamedAIPolicy(ptype string, params ...string) (bool, error) { + return e.removePolicy("a", ptype, params) +} + +// RemoveNamedAIPolicies removes AI policy rules from the current named policy. +func (e *Enforcer) RemoveNamedAIPolicies(ptype string, rules [][]string) (bool, error) { + return e.removePolicies("a", ptype, rules) +} + +// RemoveFilteredNamedAIPolicy removes an AI policy rule from the current named policy, field filters can be specified. +func (e *Enforcer) RemoveFilteredNamedAIPolicy(ptype string, fieldIndex int, fieldValues ...string) (bool, error) { + return e.removeFilteredPolicy("a", ptype, fieldIndex, fieldValues) +} + +// UpdateAIPolicy updates an AI policy rule from the current policy. +func (e *Enforcer) UpdateAIPolicy(oldPolicy []string, newPolicy []string) (bool, error) { + return e.UpdateNamedAIPolicy("a", oldPolicy, newPolicy) +} + +// UpdateAIPolicies updates AI policy rules from the current policy. +func (e *Enforcer) UpdateAIPolicies(oldPolicies [][]string, newPolicies [][]string) (bool, error) { + return e.UpdateNamedAIPolicies("a", oldPolicies, newPolicies) +} + +// UpdateNamedAIPolicy updates an AI policy rule from the current named policy. +func (e *Enforcer) UpdateNamedAIPolicy(ptype string, oldPolicy []string, newPolicy []string) (bool, error) { + return e.updatePolicy("a", ptype, oldPolicy, newPolicy) +} + +// UpdateNamedAIPolicies updates AI policy rules from the current named policy. +func (e *Enforcer) UpdateNamedAIPolicies(ptype string, oldPolicies [][]string, newPolicies [][]string) (bool, error) { + return e.updatePolicies("a", ptype, oldPolicies, newPolicies) +} diff --git a/ai_policy_api_test.go b/ai_policy_api_test.go new file mode 100644 index 00000000..44d75c88 --- /dev/null +++ b/ai_policy_api_test.go @@ -0,0 +1,437 @@ +// Copyright 2026 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +// TestAIPolicyManagementAPI tests the management APIs for AI policies. +func TestAIPolicyManagementAPI(t *testing.T) { + e, err := NewEnforcer("examples/ai_policy_model.conf", "examples/ai_policy.csv") + if err != nil { + t.Fatal(err) + } + + // Test GetAIPolicy + policies, err := e.GetAIPolicy() + if err != nil { + t.Fatalf("GetAIPolicy failed: %v", err) + } + if len(policies) != 2 { + t.Errorf("Expected 2 AI policies, got %d", len(policies)) + } + + // Test HasAIPolicy + has, err := e.HasAIPolicy("allow US residential IPs to read data1") + if err != nil { + t.Fatalf("HasAIPolicy failed: %v", err) + } + if !has { + t.Error("Expected AI policy to exist") + } + + // Test AddAIPolicy + added, err := e.AddAIPolicy("deny requests with suspicious patterns") + if err != nil { + t.Fatalf("AddAIPolicy failed: %v", err) + } + if !added { + t.Error("Expected AI policy to be added") + } + + // Verify the policy was added + policies, err = e.GetAIPolicy() + if err != nil { + t.Fatalf("GetAIPolicy failed: %v", err) + } + if len(policies) != 3 { + t.Errorf("Expected 3 AI policies after adding, got %d", len(policies)) + } + + // Test AddAIPolicy with duplicate (should not add) + added, err = e.AddAIPolicy("deny requests with suspicious patterns") + if err != nil { + t.Fatalf("AddAIPolicy failed: %v", err) + } + if added { + t.Error("Expected duplicate AI policy not to be added") + } + + // Test RemoveAIPolicy + removed, err := e.RemoveAIPolicy("deny requests with suspicious patterns") + if err != nil { + t.Fatalf("RemoveAIPolicy failed: %v", err) + } + if !removed { + t.Error("Expected AI policy to be removed") + } + + // Verify the policy was removed + policies, err = e.GetAIPolicy() + if err != nil { + t.Fatalf("GetAIPolicy failed: %v", err) + } + if len(policies) != 2 { + t.Errorf("Expected 2 AI policies after removing, got %d", len(policies)) + } +} + +// TestAIPolicyBulkOperations tests bulk operations for AI policies. +func TestAIPolicyBulkOperations(t *testing.T) { + e, err := NewEnforcer("examples/ai_policy_model.conf") + if err != nil { + t.Fatal(err) + } + + // Test AddAIPolicies + rules := [][]string{ + {"allow authenticated users to read public data"}, + {"deny anonymous users from writing sensitive data"}, + {"allow admin users all access"}, + } + + added, err := e.AddAIPolicies(rules) + if err != nil { + t.Fatalf("AddAIPolicies failed: %v", err) + } + if !added { + t.Error("Expected AI policies to be added") + } + + // Verify the policies were added + policies, err := e.GetAIPolicy() + if err != nil { + t.Fatalf("GetAIPolicy failed: %v", err) + } + if len(policies) != 3 { + t.Errorf("Expected 3 AI policies, got %d", len(policies)) + } + + // Test RemoveAIPolicies + removeRules := [][]string{ + {"allow authenticated users to read public data"}, + {"deny anonymous users from writing sensitive data"}, + } + + removed, err := e.RemoveAIPolicies(removeRules) + if err != nil { + t.Fatalf("RemoveAIPolicies failed: %v", err) + } + if !removed { + t.Error("Expected AI policies to be removed") + } + + // Verify the policies were removed + policies, err = e.GetAIPolicy() + if err != nil { + t.Fatalf("GetAIPolicy failed: %v", err) + } + if len(policies) != 1 { + t.Errorf("Expected 1 AI policy after removing, got %d", len(policies)) + } +} + +// TestAIPolicyUpdate tests updating AI policies. +func TestAIPolicyUpdate(t *testing.T) { + e, err := NewEnforcer("examples/ai_policy_model.conf") + if err != nil { + t.Fatal(err) + } + + // Add a policy first + _, err = e.AddAIPolicy("allow read access to public data") + if err != nil { + t.Fatalf("AddAIPolicy failed: %v", err) + } + + // Update the policy + updated, err := e.UpdateAIPolicy( + []string{"allow read access to public data"}, + []string{"allow read and write access to public data"}, + ) + if err != nil { + t.Fatalf("UpdateAIPolicy failed: %v", err) + } + if !updated { + t.Error("Expected AI policy to be updated") + } + + // Verify the policy was updated + has, err := e.HasAIPolicy("allow read and write access to public data") + if err != nil { + t.Fatalf("HasAIPolicy failed: %v", err) + } + if !has { + t.Error("Expected updated AI policy to exist") + } + + // Verify old policy doesn't exist + has, err = e.HasAIPolicy("allow read access to public data") + if err != nil { + t.Fatalf("HasAIPolicy failed: %v", err) + } + if has { + t.Error("Expected old AI policy not to exist") + } +} + +// TestAIPolicyEnforcement tests AI policy enforcement with a mock LLM API. +func TestAIPolicyEnforcement(t *testing.T) { + // Create a mock server that simulates LLM API + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // Parse request + var req aiChatRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + t.Errorf("Failed to decode request: %v", err) + } + + // Determine response based on request content + userMessage := req.Messages[1].Content + var responseContent string + + if strings.Contains(userMessage, "192.168.2.1") && strings.Contains(userMessage, "data1") && strings.Contains(userMessage, "read") { + if strings.Contains(userMessage, "allow US residential IPs to read data1") { + responseContent = "ALLOW" + } else { + responseContent = "DENY" + } + } else if strings.Contains(userMessage, "credential") || strings.Contains(userMessage, "secret") { + responseContent = "DENY" + } else { + responseContent = "DENY" + } + + resp := aiChatResponse{ + Choices: []struct { + Message aiMessage `json:"message"` + }{ + { + Message: aiMessage{ + Role: "assistant", + Content: responseContent, + }, + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer mockServer.Close() + + // Create enforcer with AI policy + e, err := NewEnforcer("examples/ai_policy_model.conf") + if err != nil { + t.Fatal(err) + } + + // Set AI config + e.SetAIConfig(AIConfig{ + Endpoint: mockServer.URL, + APIKey: "test-api-key", + Model: "gpt-3.5-turbo", + Timeout: 5 * time.Second, + }) + + // Add an AI policy + _, err = e.AddAIPolicy("allow US residential IPs to read data1") + if err != nil { + t.Fatalf("AddAIPolicy failed: %v", err) + } + + // Test enforcement - should be allowed by AI policy + allowed, err := e.Enforce("192.168.2.1", "data1", "read") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + if !allowed { + t.Error("Expected request to be allowed by AI policy") + } +} + +// TestAIPolicyWithoutAIConfig tests that enforcement works when AI config is not set. +func TestAIPolicyWithoutAIConfig(t *testing.T) { + e, err := NewEnforcer("examples/ai_policy_model.conf") + if err != nil { + t.Fatal(err) + } + + // Add a traditional policy first (using IP addresses since the model uses ipMatch) + _, err = e.AddPolicy("192.168.1.0/24", "data1", "read") + if err != nil { + t.Fatalf("AddPolicy failed: %v", err) + } + + // Add an AI policy without setting AI config + _, err = e.AddAIPolicy("allow all requests") + if err != nil { + t.Fatalf("AddAIPolicy failed: %v", err) + } + + // Test enforcement - should fall through to traditional policies since AI evaluation fails + // 192.168.1.5 has permission to read data1 from the policy we just added + allowed, err := e.Enforce("192.168.1.5", "data1", "read") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + // Without AI config, AI policy evaluation will fail and fall through to traditional policies + if !allowed { + t.Error("Expected request to be allowed by traditional policy when AI config is not set") + } + + // Test a request that should be denied + allowed, err = e.Enforce("192.168.1.5", "data2", "write") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + if allowed { + t.Error("Expected request to be denied when no matching policy exists") + } +} + +// TestAIPolicyWithTraditionalPolicies tests AI policies working alongside traditional policies. +func TestAIPolicyWithTraditionalPolicies(t *testing.T) { + // Create a mock server + mockServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := aiChatResponse{ + Choices: []struct { + Message aiMessage `json:"message"` + }{ + { + Message: aiMessage{ + Role: "assistant", + Content: "ALLOW", + }, + }, + }, + } + + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusOK) + json.NewEncoder(w).Encode(resp) + })) + defer mockServer.Close() + + e, err := NewEnforcer("examples/ai_policy_model.conf") + if err != nil { + t.Fatal(err) + } + + // Set AI config + e.SetAIConfig(AIConfig{ + Endpoint: mockServer.URL, + APIKey: "test-api-key", + Model: "gpt-3.5-turbo", + Timeout: 5 * time.Second, + }) + + // Add both traditional and AI policies + _, err = e.AddPolicy("alice", "data1", "read") + if err != nil { + t.Fatalf("AddPolicy failed: %v", err) + } + + _, err = e.AddAIPolicy("allow all authenticated users") + if err != nil { + t.Fatalf("AddAIPolicy failed: %v", err) + } + + // Test enforcement - AI policy is checked first + allowed, err := e.Enforce("bob", "data2", "write") + if err != nil { + t.Fatalf("Enforce failed: %v", err) + } + if !allowed { + t.Error("Expected request to be allowed by AI policy") + } +} + +// TestGetFilteredAIPolicy tests filtering AI policies. +func TestGetFilteredAIPolicy(t *testing.T) { + e, err := NewEnforcer("examples/ai_policy_model.conf") + if err != nil { + t.Fatal(err) + } + + // Add multiple AI policies + rules := [][]string{ + {"allow read access"}, + {"allow write access"}, + {"deny delete access"}, + } + + _, err = e.AddAIPolicies(rules) + if err != nil { + t.Fatalf("AddAIPolicies failed: %v", err) + } + + // Test filtering + filtered, err := e.GetFilteredAIPolicy(0, "allow read access") + if err != nil { + t.Fatalf("GetFilteredAIPolicy failed: %v", err) + } + if len(filtered) != 1 { + t.Errorf("Expected 1 filtered AI policy, got %d", len(filtered)) + } + if filtered[0][0] != "allow read access" { + t.Errorf("Expected 'allow read access', got %s", filtered[0][0]) + } +} + +// TestRemoveFilteredAIPolicy tests removing filtered AI policies. +func TestRemoveFilteredAIPolicy(t *testing.T) { + e, err := NewEnforcer("examples/ai_policy_model.conf") + if err != nil { + t.Fatal(err) + } + + // Add multiple AI policies + rules := [][]string{ + {"allow read access to public data"}, + {"allow read access to private data"}, + {"deny write access"}, + } + + _, err = e.AddAIPolicies(rules) + if err != nil { + t.Fatalf("AddAIPolicies failed: %v", err) + } + + // Remove policies that start with "allow read" + // Note: This removes based on exact match at the specified field index + removed, err := e.RemoveFilteredAIPolicy(0, "allow read access to public data") + if err != nil { + t.Fatalf("RemoveFilteredAIPolicy failed: %v", err) + } + if !removed { + t.Error("Expected AI policies to be removed") + } + + // Verify + policies, err := e.GetAIPolicy() + if err != nil { + t.Fatalf("GetAIPolicy failed: %v", err) + } + if len(policies) != 2 { + t.Errorf("Expected 2 AI policies after removal, got %d", len(policies)) + } +} diff --git a/enforcer.go b/enforcer.go index a6bf1740..8cc00654 100644 --- a/enforcer.go +++ b/enforcer.go @@ -20,6 +20,7 @@ import ( "runtime/debug" "strings" "sync" + "time" "github.com/casbin/casbin/v3/detector" "github.com/casbin/casbin/v3/effector" @@ -679,6 +680,58 @@ func (e *Enforcer) invalidateMatcherMap() { e.matcherMap = sync.Map{} } +// checkAIPolicies evaluates AI policies and returns true if any policy allows the request. +func (e *Enforcer) checkAIPolicies(rvals []interface{}) (bool, error) { + aType := "a" + + // Check if AI policies exist + if _, ok := e.model["a"]; !ok { + return false, nil + } + + aPolicies, ok := e.model["a"][aType] + if !ok || len(aPolicies.Policy) == 0 { + return false, nil + } + + // Evaluate AI policies + for _, aPolicy := range aPolicies.Policy { + if len(aPolicy) == 0 { + continue + } + + // The AI policy prompt is the first (and typically only) field + policyPrompt := aPolicy[0] + allowed, err := e.evaluateAIPolicy(policyPrompt, rvals) + if err != nil { + // If AI evaluation fails, log the error and continue with other AI policies + e.logAIPolicyError(err) + continue + } + + if allowed { + return true, nil + } + } + + return false, nil +} + +// logAIPolicyError logs AI policy evaluation errors. +func (e *Enforcer) logAIPolicyError(err error) { + if e.logger == nil { + return + } + + logEntry := &log.LogEntry{ + EventType: "ai_policy_evaluation_error", + Error: err, + StartTime: time.Now(), + EndTime: time.Now(), + } + _ = e.logger.OnAfterEvent(logEntry) +} + // enforce use a custom matcher to decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (matcher, sub, obj, act), use model matcher by default when matcher is "". func (e *Enforcer) enforce(matcher string, explains *[]string, rvals ...interface{}) (ok bool, err error) { //nolint:funlen,cyclop,gocyclo // TODO: reduce function complexity logEntry := e.onLogBeforeEventInEnforce(rvals) @@ -799,6 +852,15 @@ func (e *Enforcer) enforce(matcher string, explains *[]string, rvals ...interfac var effect effector.Effect var explainIndex int + // Check AI policies first if they exist + aiPolicyAllowed, err := e.checkAIPolicies(rvals) + if err != nil { + return false, err + } + if aiPolicyAllowed { + return true, nil + } + if policyLen := len(e.model["p"][pType].Policy); policyLen != 0 && strings.Contains(expString, pType+"_") { //nolint:nestif // TODO: reduce function complexity policyEffects = make([]effector.Effect, policyLen) matcherResults = make([]float64, policyLen) diff --git a/examples/ai_policy.csv b/examples/ai_policy.csv new file mode 100644 index 00000000..da2c13fe --- /dev/null +++ b/examples/ai_policy.csv @@ -0,0 +1,4 @@ +p, 192.168.2.0/24, data1, read +p, 10.0.0.0/16, data2, write +a, "allow US residential IPs to read data1" +a, "allow global cloud IPs to write data2" diff --git a/examples/ai_policy_model.conf b/examples/ai_policy_model.conf new file mode 100644 index 00000000..61fa693c --- /dev/null +++ b/examples/ai_policy_model.conf @@ -0,0 +1,14 @@ +[request_definition] +r = sub, obj, act + +[policy_definition] +p = sub, obj, act + +[ai_policy_definition] +a = prompt + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = ipMatch(r.sub, p.sub) && r.obj == p.obj && r.act == p.act diff --git a/model/model.go b/model/model.go index b541e1b8..826464b8 100644 --- a/model/model.go +++ b/model/model.go @@ -44,6 +44,7 @@ var sectionNameMap = map[string]string{ "e": "policy_effect", "m": "matchers", "c": "constraint_definition", + "a": "ai_policy_definition", } // Minimal required sections for a model to be valid. @@ -78,7 +79,7 @@ func (model Model) AddDef(sec string, key string, value string) bool { ast.PolicyMap = make(map[string]int) ast.FieldIndexMap = make(map[string]int) - if sec == "r" || sec == "p" { + if sec == "r" || sec == "p" || sec == "a" { ast.Tokens = strings.Split(ast.Value, ",") for i := range ast.Tokens { ast.Tokens[i] = key + "_" + strings.TrimSpace(ast.Tokens[i]) diff --git a/model/policy.go b/model/policy.go index e55bf410..6843c262 100644 --- a/model/policy.go +++ b/model/policy.go @@ -105,6 +105,11 @@ func (model Model) ClearPolicy() { ast.Policy = nil ast.PolicyMap = map[string]int{} } + + for _, ast := range model["a"] { + ast.Policy = nil + ast.PolicyMap = map[string]int{} + } } // GetPolicy gets all rules in a policy.