Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 110 additions & 0 deletions ai_policy_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
// Copyright 2024 The casbin Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package casbin

import (
"testing"

fileadapter "github.com/casbin/casbin/v3/persist/file-adapter"
"github.com/casbin/casbin/v3/util"
)

func TestAIPolicyLoad(t *testing.T) {
e, err := NewEnforcer("examples/ai_policy_model.conf", "examples/ai_policy.csv")
if err != nil {
t.Fatal(err)
}

// Test that regular policies are loaded
policies, err := e.GetPolicy()
if err != nil {
t.Fatal(err)
}

expectedPolicies := [][]string{
{"alice", "data1", "read", "09:00", "18:00"},
{"bob", "data2", "write", "13:00", "16:00"},
}

if !util.Array2DEquals(expectedPolicies, policies) {
t.Errorf("Policies = %v, want %v", policies, expectedPolicies)
}

// Test that grouping policies are loaded
groupingPolicies, err := e.GetGroupingPolicy()
if err != nil {
t.Fatal(err)
}

expectedGrouping := [][]string{
{"cathy", "alice"},
}

if !util.Array2DEquals(expectedGrouping, groupingPolicies) {
t.Errorf("Grouping policies = %v, want %v", groupingPolicies, expectedGrouping)
}

// Test that AI policies are loaded
aiPolicies, err := e.model.GetPolicy("a", "ai")
if err != nil {
t.Fatal(err)
}

expectedAI := [][]string{
{`if the request object contains anything like credential/secret leak, then deny`},
}

if !util.Array2DEquals(expectedAI, aiPolicies) {
t.Errorf("AI policies = %v, want %v", aiPolicies, expectedAI)
}
}

func TestAIPolicySave(t *testing.T) {
// Create a temporary file for testing
tmpFile := t.TempDir() + "/ai_policy_test.csv"

e, err := NewEnforcer("examples/ai_policy_model.conf", "examples/ai_policy.csv")
if err != nil {
t.Fatal(err)
}

// Update adapter to save to temp file
e.SetAdapter(fileadapter.NewAdapter(tmpFile))
// Save to the temporary file
err = e.SavePolicy()
if err != nil {
t.Fatal(err)
}

// Load from the saved file
e2, err := NewEnforcer("examples/ai_policy_model.conf", tmpFile)
if err != nil {
t.Fatal(err)
}

// Verify AI policies are preserved
aiPolicies, err := e2.model.GetPolicy("a", "ai")
if err != nil {
t.Fatal(err)
}

expectedAI := [][]string{
{`if the request object contains anything like credential/secret leak, then deny`},
}

if !util.Array2DEquals(expectedAI, aiPolicies) {
t.Errorf("AI policies after save/load = %v, want %v", aiPolicies, expectedAI)
}
}
Binary file added demo_ai_policy
Binary file not shown.
6 changes: 6 additions & 0 deletions examples/ai_policy.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
p, alice, data1, read, 09:00, 18:00
p, bob, data2, write, 13:00, 16:00

g, cathy, alice

ai, "if the request object contains anything like credential/secret leak, then deny"
17 changes: 17 additions & 0 deletions examples/ai_policy_model.conf
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[request_definition]
r = sub, obj, act, time

[policy_definition]
p = sub, obj, act, time_start, time_end

[role_definition]
g = _, _

[ai_definition]
ai = rule

[policy_effect]
e = some(where (p.eft == allow))

[matchers]
m = g(r.sub, p.sub) && r.obj == p.obj && r.act == p.act && r.time >= p.time_start && r.time <= p.time_end
4 changes: 2 additions & 2 deletions examples/basic_policy.csv
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
p, alice, data1, read
p, bob, data2, write
p,alice,data1,read
p,bob,data2,write
10 changes: 5 additions & 5 deletions examples/rbac_policy.csv
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
p, alice, data1, read
p, bob, data2, write
p, data2_admin, data2, read
p, data2_admin, data2, write
g, alice, data2_admin
p,alice,data1,read
p,bob,data2,write
p,data2_admin,data2,read
p,data2_admin,data2,write
g,alice,data2_admin
12 changes: 6 additions & 6 deletions examples/rbac_with_domains_policy.csv
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
p, admin, domain1, data1, read
p, admin, domain1, data1, write
p, admin, domain2, data2, read
p, admin, domain2, data2, write
g, alice, admin, domain1
g, bob, admin, domain2
p,admin,domain1,data1,read
p,admin,domain1,data1,write
p,admin,domain2,data2,read
p,admin,domain2,data2,write
g,alice,admin,domain1
g,bob,admin,domain2
7 changes: 6 additions & 1 deletion model/model.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ var sectionNameMap = map[string]string{
"e": "policy_effect",
"m": "matchers",
"c": "constraint_definition",
"a": "ai_definition",
}

// Minimal required sections for a model to be valid.
Expand Down Expand Up @@ -78,7 +79,7 @@ func (model Model) AddDef(sec string, key string, value string) bool {
ast.PolicyMap = make(map[string]int)
ast.FieldIndexMap = make(map[string]int)

if sec == "r" || sec == "p" {
if sec == "r" || sec == "p" || sec == "a" {
ast.Tokens = strings.Split(ast.Value, ",")
for i := range ast.Tokens {
ast.Tokens[i] = key + "_" + strings.TrimSpace(ast.Tokens[i])
Expand Down Expand Up @@ -187,6 +188,10 @@ func (model Model) LoadModelFromText(text string) error {
func (model Model) loadModelFromConfig(cfg config.ConfigInterface) error {
for s := range sectionNameMap {
loadSection(model, cfg, s)
// Special handling for AI section to load "ai" key
if s == "a" {
loadAssertion(model, cfg, s, "ai")
}
}
ms := make([]string, 0)
for _, rs := range requiredSections {
Expand Down
9 changes: 7 additions & 2 deletions model/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,11 @@ func (model Model) ClearPolicy() {
ast.Policy = nil
ast.PolicyMap = map[string]int{}
}

for _, ast := range model["a"] {
ast.Policy = nil
ast.PolicyMap = map[string]int{}
}
}

// GetPolicy gets all rules in a policy.
Expand Down Expand Up @@ -148,11 +153,11 @@ func (model Model) HasPolicyEx(sec string, ptype string, rule []string) (bool, e
return false, err
}
switch sec {
case "p":
case "p", "a":
if len(rule) != len(assertion.Tokens) {
return false, fmt.Errorf(
"invalid policy rule size: expected %d, got %d, rule: %v",
len(model["p"][ptype].Tokens),
len(model[sec][ptype].Tokens),
len(rule),
rule)
}
Expand Down
31 changes: 24 additions & 7 deletions persist/file-adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@ package fileadapter
import (
"bufio"
"bytes"
"encoding/csv"
"errors"
"os"
"strings"

"github.com/casbin/casbin/v3/model"
"github.com/casbin/casbin/v3/persist"
"github.com/casbin/casbin/v3/util"
)

// Adapter is the file adapter for Casbin.
Expand Down Expand Up @@ -65,23 +65,40 @@ func (a *Adapter) SavePolicy(model model.Model) error {
}

var tmp bytes.Buffer
writer := csv.NewWriter(&tmp)

for ptype, ast := range model["p"] {
for _, rule := range ast.Policy {
tmp.WriteString(ptype + ", ")
tmp.WriteString(util.ArrayToString(rule))
tmp.WriteString("\n")
record := append([]string{ptype}, rule...)
if err := writer.Write(record); err != nil {
return err
}
}
}

for ptype, ast := range model["g"] {
for _, rule := range ast.Policy {
tmp.WriteString(ptype + ", ")
tmp.WriteString(util.ArrayToString(rule))
tmp.WriteString("\n")
record := append([]string{ptype}, rule...)
if err := writer.Write(record); err != nil {
return err
}
}
}

for ptype, ast := range model["a"] {
for _, rule := range ast.Policy {
record := append([]string{ptype}, rule...)
if err := writer.Write(record); err != nil {
return err
}
}
}

writer.Flush()
if err := writer.Error(); err != nil {
return err
}

return a.savePolicyFile(strings.TrimRight(tmp.String(), "\n"))
}

Expand Down
3 changes: 3 additions & 0 deletions persist/file-adapter/adapter_filtered.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ type Filter struct {
G3 []string
G4 []string
G5 []string
AI []string
}

// NewFilteredAdapter is the constructor for FilteredAdapter.
Expand Down Expand Up @@ -137,6 +138,8 @@ func filterLine(line string, filter *Filter) bool {
filterSlice = filter.G4
case "g5":
filterSlice = filter.G5
case "ai":
filterSlice = filter.AI
}
return filterWords(p, filterSlice)
}
Expand Down
33 changes: 26 additions & 7 deletions persist/string-adapter/adapter.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ package stringadapter

import (
"bytes"
"encoding/csv"
"errors"
"strings"

"github.com/casbin/casbin/v3/model"
"github.com/casbin/casbin/v3/persist"
"github.com/casbin/casbin/v3/util"
)

// Adapter is the string adapter for Casbin.
Expand Down Expand Up @@ -56,21 +56,40 @@ func (a *Adapter) LoadPolicy(model model.Model) error {
// SavePolicy saves all policy rules to the storage.
func (a *Adapter) SavePolicy(model model.Model) error {
var tmp bytes.Buffer
writer := csv.NewWriter(&tmp)

for ptype, ast := range model["p"] {
for _, rule := range ast.Policy {
tmp.WriteString(ptype + ", ")
tmp.WriteString(util.ArrayToString(rule))
tmp.WriteString("\n")
record := append([]string{ptype}, rule...)
if err := writer.Write(record); err != nil {
return err
}
}
}

for ptype, ast := range model["g"] {
for _, rule := range ast.Policy {
tmp.WriteString(ptype + ", ")
tmp.WriteString(util.ArrayToString(rule))
tmp.WriteString("\n")
record := append([]string{ptype}, rule...)
if err := writer.Write(record); err != nil {
return err
}
}
}

for ptype, ast := range model["a"] {
for _, rule := range ast.Policy {
record := append([]string{ptype}, rule...)
if err := writer.Write(record); err != nil {
return err
}
}
}

writer.Flush()
if err := writer.Error(); err != nil {
return err
}

a.Line = strings.TrimRight(tmp.String(), "\n")
return nil
}
Expand Down
Loading