Skip to content
164 changes: 151 additions & 13 deletions enforcer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"runtime/debug"
"strings"
"sync"
"time"

"github.com/casbin/casbin/v3/effector"
"github.com/casbin/casbin/v3/log"
Expand Down Expand Up @@ -54,7 +55,8 @@ type Enforcer struct {
autoNotifyDispatcher bool
acceptJsonRequest bool

logger log.Logger
logger log.Logger
subscribeCache map[log.EventType]bool
}

// EnforceContext is used as the first element of the parameter "rvals" in method "enforce".
Expand All @@ -80,7 +82,7 @@ func (e EnforceContext) GetCacheKey() string {
// a := mysqladapter.NewDBAdapter("mysql", "mysql_username:mysql_password@tcp(127.0.0.1:3306)/")
// e := casbin.NewEnforcer("path/to/basic_model.conf", a)
func NewEnforcer(params ...interface{}) (*Enforcer, error) {
e := &Enforcer{logger: &log.DefaultLogger{}}
e := &Enforcer{logger: log.NewDefaultLogger()}

parsedParamLen := 0
paramLen := len(params)
Expand Down Expand Up @@ -194,18 +196,46 @@ func (e *Enforcer) InitWithModelAndAdapter(m model.Model, adapter persist.Adapte
return nil
}

// SetLogger changes the current enforcer's logger.
// SetLogger sets the logger for the enforcer.
func (e *Enforcer) SetLogger(logger log.Logger) {
e.logger = logger
e.model.SetLogger(e.logger)
for k := range e.rmMap {
e.rmMap[k].SetLogger(e.logger)
e.updateSubscribeCache()
}

// updateSubscribeCache updates the subscription cache for quick event type lookup.
func (e *Enforcer) updateSubscribeCache() {
e.subscribeCache = make(map[log.EventType]bool)

if e.logger == nil {
return
}
for k := range e.condRmMap {
e.condRmMap[k].SetLogger(e.logger)

events := e.logger.Subscribe()
// Both nil and empty slice mean subscribe to all events.
if len(events) == 0 {
e.subscribeCache = nil
return
}

for _, event := range events {
e.subscribeCache[event] = true
}
}

// shouldLog checks if we should log this event type.
func (e *Enforcer) shouldLog(eventType log.EventType) bool {
if e.logger == nil || !e.logger.IsEnabled() {
return false
}

// nil cache means subscribe to all events.
if e.subscribeCache == nil {
return true
}

return e.subscribeCache[eventType]
}

func (e *Enforcer) initialize() {
e.rmMap = map[string]rbac.RoleManager{}
e.condRmMap = map[string]rbac.ConditionalRoleManager{}
Expand Down Expand Up @@ -327,14 +357,32 @@ func (e *Enforcer) ClearPolicy() {

// LoadPolicy reloads the policy from file/database.
func (e *Enforcer) LoadPolicy() error {
entry, handle, shouldLog := e.logEventStart(log.EventPolicyLoad)
if shouldLog {
entry.Operation = "load"
defer func() {
if shouldLog {
entry.RuleCount = e.GetPolicyCount()
}
e.logEventEnd(handle, entry, shouldLog)
}()
}

newModel, err := e.loadPolicyFromAdapter(e.model)
if err != nil {
if shouldLog {
entry.Error = err
}
return err
}
err = e.applyModifiedModel(newModel)
if err != nil {
if shouldLog {
entry.Error = err
}
return err
}

return nil
}

Expand Down Expand Up @@ -478,10 +526,24 @@ func (e *Enforcer) IsFiltered() bool {

// SavePolicy saves the current policy (usually after changed with Casbin API) back to file/database.
func (e *Enforcer) SavePolicy() error {
entry, handle, shouldLog := e.logEventStart(log.EventPolicySave)
if shouldLog {
entry.Operation = "save"
entry.RuleCount = e.GetPolicyCount()
defer e.logEventEnd(handle, entry, shouldLog)
}

if e.IsFiltered() {
return errors.New("cannot save a filtered policy")
err := errors.New("cannot save a filtered policy")
if shouldLog {
entry.Error = err
}
return err
}
if err := e.adapter.SavePolicy(e.model); err != nil {
if shouldLog {
entry.Error = err
}
return err
}
if e.watcher != nil {
Expand All @@ -491,8 +553,12 @@ func (e *Enforcer) SavePolicy() error {
} else {
err = e.watcher.Update()
}
if shouldLog {
entry.Error = err
}
return err
}

return nil
}

Expand Down Expand Up @@ -533,7 +599,7 @@ func (e *Enforcer) EnableEnforce(enable bool) {

// EnableLog changes whether Casbin will log messages to the Logger.
func (e *Enforcer) EnableLog(enable bool) {
e.logger.EnableLog(enable)
e.logger.Enable(enable)
}

// IsLogEnabled returns the current logger's enabled status.
Expand Down Expand Up @@ -610,6 +676,47 @@ func (e *Enforcer) invalidateMatcherMap() {

// enforce use a custom matcher to decides whether a "subject" can access a "object" with the operation "action", input parameters are usually: (matcher, sub, obj, act), use model matcher by default when matcher is "".
func (e *Enforcer) enforce(matcher string, explains *[]string, rvals ...interface{}) (ok bool, err error) { //nolint:funlen,cyclop,gocyclo // TODO: reduce function complexity
// Event logging setup
var entry *log.LogEntry
var handle *log.Handle
var logExplains [][]string
shouldLog := e.shouldLog(log.EventEnforce)

if shouldLog {
entry = &log.LogEntry{
Type: log.EventEnforce,
Timestamp: time.Now(),
Request: rvals,
Attributes: make(map[string]interface{}),
}

// Parse request parameters
if len(rvals) >= 1 {
entry.Subject = toString(rvals[0])
}
if len(rvals) >= 2 {
entry.Object = toString(rvals[1])
}
if len(rvals) >= 3 {
entry.Action = toString(rvals[2])
}
if len(rvals) >= 4 {
entry.Domain = toString(rvals[3])
}

handle = e.logger.OnBeforeEvent(entry)
}

defer func() {
if shouldLog && entry != nil && handle != nil {
entry.Duration = time.Since(handle.StartTime)
entry.Allowed = ok
entry.Matched = logExplains
entry.Error = err
e.logger.OnAfterEvent(handle, entry)
}
}()

defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic: %v\n%s", r, debug.Stack())
Expand Down Expand Up @@ -810,8 +917,6 @@ func (e *Enforcer) enforce(matcher string, explains *[]string, rvals ...interfac
}
}

var logExplains [][]string

if explains != nil {
if len(*explains) > 0 {
logExplains = append(logExplains, *explains)
Expand All @@ -828,7 +933,7 @@ func (e *Enforcer) enforce(matcher string, explains *[]string, rvals ...interfac
if effect == effector.Allow {
result = true
}
e.logger.LogEnforce(expString, rvals, result, logExplains)
// Note: LogEnforce was removed as enforcement is now logged via OnBeforeEvent/OnAfterEvent.

return result, nil
}
Expand Down Expand Up @@ -1012,3 +1117,36 @@ func generateEvalFunction(functions map[string]govaluate.ExpressionFunction, par
return expr.Eval(parameters)
}
}

// logEventStart initializes event logging for a given event type.
// Returns entry, handle, and shouldLog flag.
func (e *Enforcer) logEventStart(eventType log.EventType) (*log.LogEntry, *log.Handle, bool) {
shouldLog := e.shouldLog(eventType)
if !shouldLog {
return nil, nil, false
}

entry := &log.LogEntry{
Type: eventType,
Timestamp: time.Now(),
Attributes: make(map[string]interface{}),
}
handle := e.logger.OnBeforeEvent(entry)
return entry, handle, true
}

// logEventEnd finalizes event logging.
func (e *Enforcer) logEventEnd(handle *log.Handle, entry *log.LogEntry, shouldLog bool) {
if shouldLog && entry != nil && handle != nil {
entry.Duration = time.Since(handle.StartTime)
e.logger.OnAfterEvent(handle, entry)
}
}

// toString converts an interface{} to string for logging.
func toString(v interface{}) string {
if s, ok := v.(string); ok {
return s
}
return fmt.Sprintf("%v", v)
}
2 changes: 1 addition & 1 deletion enforcer_cached.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ func GetCacheKey(params ...interface{}) (string, bool) {
func (e *CachedEnforcer) ClearPolicy() {
if atomic.LoadInt32(&e.enableCache) != 0 {
if err := e.cache.Clear(); err != nil {
e.logger.LogError(err, "clear cache failed")
// Note: LogError was removed as the new Logger interface is event-based.
return
}
}
Expand Down
Loading
Loading