diff --git a/ai_policy_test.go b/ai_policy_test.go new file mode 100644 index 000000000..3c3a8f050 --- /dev/null +++ b/ai_policy_test.go @@ -0,0 +1,110 @@ +// Copyright 2024 The casbin Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package casbin + +import ( + "testing" + + fileadapter "github.com/casbin/casbin/v3/persist/file-adapter" + "github.com/casbin/casbin/v3/util" +) + +func TestAIPolicyLoad(t *testing.T) { + e, err := NewEnforcer("examples/ai_policy_model.conf", "examples/ai_policy.csv") + if err != nil { + t.Fatal(err) + } + + // Test that regular policies are loaded + policies, err := e.GetPolicy() + if err != nil { + t.Fatal(err) + } + + expectedPolicies := [][]string{ + {"alice", "data1", "read", "09:00", "18:00"}, + {"bob", "data2", "write", "13:00", "16:00"}, + } + + if !util.Array2DEquals(expectedPolicies, policies) { + t.Errorf("Policies = %v, want %v", policies, expectedPolicies) + } + + // Test that grouping policies are loaded + groupingPolicies, err := e.GetGroupingPolicy() + if err != nil { + t.Fatal(err) + } + + expectedGrouping := [][]string{ + {"cathy", "alice"}, + } + + if !util.Array2DEquals(expectedGrouping, groupingPolicies) { + t.Errorf("Grouping policies = %v, want %v", groupingPolicies, expectedGrouping) + } + + // Test that AI policies are loaded + aiPolicies, err := e.model.GetPolicy("a", "ai") + if err != nil { + t.Fatal(err) + } + + expectedAI := [][]string{ + {`if the request object contains anything like credential/secret leak, then deny`}, + } + + if !util.Array2DEquals(expectedAI, aiPolicies) { + t.Errorf("AI policies = %v, want %v", aiPolicies, expectedAI) + } +} + +func TestAIPolicySave(t *testing.T) { + // Create a temporary file for testing + tmpFile := t.TempDir() + "/ai_policy_test.csv" + + e, err := NewEnforcer("examples/ai_policy_model.conf", "examples/ai_policy.csv") + if err != nil { + t.Fatal(err) + } + + // Update adapter to save to temp file + e.SetAdapter(fileadapter.NewAdapter(tmpFile)) + // Save to the temporary file + err = e.SavePolicy() + if err != nil { + t.Fatal(err) + } + + // Load from the saved file + e2, err := NewEnforcer("examples/ai_policy_model.conf", tmpFile) + if err != nil { + t.Fatal(err) + } + + // Verify AI policies are preserved + aiPolicies, err := e2.model.GetPolicy("a", "ai") + if err != nil { + t.Fatal(err) + } + + expectedAI := [][]string{ + {`if the request object contains anything like credential/secret leak, then deny`}, + } + + if !util.Array2DEquals(expectedAI, aiPolicies) { + t.Errorf("AI policies after save/load = %v, want %v", aiPolicies, expectedAI) + } +} diff --git a/demo_ai_policy b/demo_ai_policy new file mode 100755 index 000000000..0df5d3cc3 Binary files /dev/null and b/demo_ai_policy differ diff --git a/examples/ai_policy.csv b/examples/ai_policy.csv new file mode 100644 index 000000000..a070ce96f --- /dev/null +++ b/examples/ai_policy.csv @@ -0,0 +1,6 @@ +p, alice, data1, read, 09:00, 18:00 +p, bob, data2, write, 13:00, 16:00 + +g, cathy, alice + +ai, "if the request object contains anything like credential/secret leak, then deny" diff --git a/examples/ai_policy_model.conf b/examples/ai_policy_model.conf new file mode 100644 index 000000000..3fef53cb8 --- /dev/null +++ b/examples/ai_policy_model.conf @@ -0,0 +1,17 @@ +[request_definition] +r = sub, obj, act, time + +[policy_definition] +p = sub, obj, act, time_start, time_end + +[role_definition] +g = _, _ + +[ai_definition] +ai = rule + +[policy_effect] +e = some(where (p.eft == allow)) + +[matchers] +m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act && r.time >= p.time_start && r.time <= p.time_end diff --git a/examples/basic_policy.csv b/examples/basic_policy.csv index 57aaa9760..60da4b08c 100644 --- a/examples/basic_policy.csv +++ b/examples/basic_policy.csv @@ -1,2 +1,2 @@ -p, alice, data1, read -p, bob, data2, write \ No newline at end of file +p,alice,data1,read +p,bob,data2,write \ No newline at end of file diff --git a/examples/rbac_policy.csv b/examples/rbac_policy.csv index f93d6df81..0779945d4 100644 --- a/examples/rbac_policy.csv +++ b/examples/rbac_policy.csv @@ -1,5 +1,5 @@ -p, alice, data1, read -p, bob, data2, write -p, data2_admin, data2, read -p, data2_admin, data2, write -g, alice, data2_admin \ No newline at end of file +p,alice,data1,read +p,bob,data2,write +p,data2_admin,data2,read +p,data2_admin,data2,write +g,alice,data2_admin \ No newline at end of file diff --git a/examples/rbac_with_domains_policy.csv b/examples/rbac_with_domains_policy.csv index 8558d171c..3ccd99bb8 100644 --- a/examples/rbac_with_domains_policy.csv +++ b/examples/rbac_with_domains_policy.csv @@ -1,6 +1,6 @@ -p, admin, domain1, data1, read -p, admin, domain1, data1, write -p, admin, domain2, data2, read -p, admin, domain2, data2, write -g, alice, admin, domain1 -g, bob, admin, domain2 \ No newline at end of file +p,admin,domain1,data1,read +p,admin,domain1,data1,write +p,admin,domain2,data2,read +p,admin,domain2,data2,write +g,alice,admin,domain1 +g,bob,admin,domain2 \ No newline at end of file diff --git a/model/model.go b/model/model.go index b541e1b84..e3b100b75 100644 --- a/model/model.go +++ b/model/model.go @@ -44,6 +44,7 @@ var sectionNameMap = map[string]string{ "e": "policy_effect", "m": "matchers", "c": "constraint_definition", + "a": "ai_definition", } // Minimal required sections for a model to be valid. @@ -78,7 +79,7 @@ func (model Model) AddDef(sec string, key string, value string) bool { ast.PolicyMap = make(map[string]int) ast.FieldIndexMap = make(map[string]int) - if sec == "r" || sec == "p" { + if sec == "r" || sec == "p" || sec == "a" { ast.Tokens = strings.Split(ast.Value, ",") for i := range ast.Tokens { ast.Tokens[i] = key + "_" + strings.TrimSpace(ast.Tokens[i]) @@ -187,6 +188,10 @@ func (model Model) LoadModelFromText(text string) error { func (model Model) loadModelFromConfig(cfg config.ConfigInterface) error { for s := range sectionNameMap { loadSection(model, cfg, s) + // Special handling for AI section to load "ai" key + if s == "a" { + loadAssertion(model, cfg, s, "ai") + } } ms := make([]string, 0) for _, rs := range requiredSections { diff --git a/model/policy.go b/model/policy.go index e55bf4105..f3c9fe8c8 100644 --- a/model/policy.go +++ b/model/policy.go @@ -105,6 +105,11 @@ func (model Model) ClearPolicy() { ast.Policy = nil ast.PolicyMap = map[string]int{} } + + for _, ast := range model["a"] { + ast.Policy = nil + ast.PolicyMap = map[string]int{} + } } // GetPolicy gets all rules in a policy. @@ -148,11 +153,11 @@ func (model Model) HasPolicyEx(sec string, ptype string, rule []string) (bool, e return false, err } switch sec { - case "p": + case "p", "a": if len(rule) != len(assertion.Tokens) { return false, fmt.Errorf( "invalid policy rule size: expected %d, got %d, rule: %v", - len(model["p"][ptype].Tokens), + len(model[sec][ptype].Tokens), len(rule), rule) } diff --git a/persist/file-adapter/adapter.go b/persist/file-adapter/adapter.go index 454b2d13c..f37caf113 100644 --- a/persist/file-adapter/adapter.go +++ b/persist/file-adapter/adapter.go @@ -17,13 +17,13 @@ package fileadapter import ( "bufio" "bytes" + "encoding/csv" "errors" "os" "strings" "github.com/casbin/casbin/v3/model" "github.com/casbin/casbin/v3/persist" - "github.com/casbin/casbin/v3/util" ) // Adapter is the file adapter for Casbin. @@ -65,23 +65,40 @@ func (a *Adapter) SavePolicy(model model.Model) error { } var tmp bytes.Buffer + writer := csv.NewWriter(&tmp) for ptype, ast := range model["p"] { for _, rule := range ast.Policy { - tmp.WriteString(ptype + ", ") - tmp.WriteString(util.ArrayToString(rule)) - tmp.WriteString("\n") + record := append([]string{ptype}, rule...) + if err := writer.Write(record); err != nil { + return err + } } } for ptype, ast := range model["g"] { for _, rule := range ast.Policy { - tmp.WriteString(ptype + ", ") - tmp.WriteString(util.ArrayToString(rule)) - tmp.WriteString("\n") + record := append([]string{ptype}, rule...) + if err := writer.Write(record); err != nil { + return err + } } } + for ptype, ast := range model["a"] { + for _, rule := range ast.Policy { + record := append([]string{ptype}, rule...) + if err := writer.Write(record); err != nil { + return err + } + } + } + + writer.Flush() + if err := writer.Error(); err != nil { + return err + } + return a.savePolicyFile(strings.TrimRight(tmp.String(), "\n")) } diff --git a/persist/file-adapter/adapter_filtered.go b/persist/file-adapter/adapter_filtered.go index bf033f98b..edbaefa33 100644 --- a/persist/file-adapter/adapter_filtered.go +++ b/persist/file-adapter/adapter_filtered.go @@ -41,6 +41,7 @@ type Filter struct { G3 []string G4 []string G5 []string + AI []string } // NewFilteredAdapter is the constructor for FilteredAdapter. @@ -137,6 +138,8 @@ func filterLine(line string, filter *Filter) bool { filterSlice = filter.G4 case "g5": filterSlice = filter.G5 + case "ai": + filterSlice = filter.AI } return filterWords(p, filterSlice) } diff --git a/persist/string-adapter/adapter.go b/persist/string-adapter/adapter.go index c9655acbe..7006fd135 100644 --- a/persist/string-adapter/adapter.go +++ b/persist/string-adapter/adapter.go @@ -16,12 +16,12 @@ package stringadapter import ( "bytes" + "encoding/csv" "errors" "strings" "github.com/casbin/casbin/v3/model" "github.com/casbin/casbin/v3/persist" - "github.com/casbin/casbin/v3/util" ) // Adapter is the string adapter for Casbin. @@ -56,21 +56,40 @@ func (a *Adapter) LoadPolicy(model model.Model) error { // SavePolicy saves all policy rules to the storage. func (a *Adapter) SavePolicy(model model.Model) error { var tmp bytes.Buffer + writer := csv.NewWriter(&tmp) + for ptype, ast := range model["p"] { for _, rule := range ast.Policy { - tmp.WriteString(ptype + ", ") - tmp.WriteString(util.ArrayToString(rule)) - tmp.WriteString("\n") + record := append([]string{ptype}, rule...) + if err := writer.Write(record); err != nil { + return err + } } } for ptype, ast := range model["g"] { for _, rule := range ast.Policy { - tmp.WriteString(ptype + ", ") - tmp.WriteString(util.ArrayToString(rule)) - tmp.WriteString("\n") + record := append([]string{ptype}, rule...) + if err := writer.Write(record); err != nil { + return err + } + } + } + + for ptype, ast := range model["a"] { + for _, rule := range ast.Policy { + record := append([]string{ptype}, rule...) + if err := writer.Write(record); err != nil { + return err + } } } + + writer.Flush() + if err := writer.Error(); err != nil { + return err + } + a.Line = strings.TrimRight(tmp.String(), "\n") return nil }