diff --git a/cmd/hubauth-ext/main.go b/cmd/hubauth-ext/main.go index fce743d..834646d 100644 --- a/cmd/hubauth-ext/main.go +++ b/cmd/hubauth-ext/main.go @@ -90,7 +90,7 @@ func main() { } rootPubKey = biscuitKey.Public().Bytes() - accessTokenBuilder = token.NewBiscuitBuilder(kmsClient, audienceKeyNamer, biscuitKey) + accessTokenBuilder = token.NewBiscuitBuilder(kmsClient, datastore.New(dsClient), audienceKeyNamer, biscuitKey) default: log.Fatalf("invalid TOKEN_TYPE, must be one of: Bearer, Biscuit") } diff --git a/go.mod b/go.mod index 5dde64c..7cb5039 100644 --- a/go.mod +++ b/go.mod @@ -7,9 +7,10 @@ require ( cloud.google.com/go/datastore v1.3.0 contrib.go.opencensus.io/exporter/stackdriver v0.13.4 github.com/alecthomas/kong v0.2.12 + github.com/alecthomas/participle/v2 v2.0.0-alpha3.0.20201208114601-14bec2482095 github.com/aws/aws-sdk-go v1.36.7 // indirect github.com/census-instrumentation/opencensus-proto v0.3.0 // indirect - github.com/flynn/biscuit-go v0.0.0-20201009174859-e7eb59a90195 + github.com/flynn/biscuit-go v0.0.0-20201211135022-dbd2f8863bf4 github.com/golang/protobuf v1.4.3 github.com/googleapis/gax-go/v2 v2.0.5 github.com/jedib0t/go-pretty/v6 v6.0.5 diff --git a/go.sum b/go.sum index 287ce43..61cea50 100644 --- a/go.sum +++ b/go.sum @@ -54,7 +54,9 @@ github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03 github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo= github.com/alecthomas/kong v0.2.12 h1:X3kkCOXGUNzLmiu+nQtoxWqj4U2a39MpSJR3QdQXOwI= github.com/alecthomas/kong v0.2.12/go.mod h1:kQOmtJgV+Lb4aj+I2LEn40cbtawdWJ9Y8QLq+lElKxE= -github.com/alecthomas/participle v0.6.0/go.mod h1:HfdmEuwvr12HXQN44HPWXR0lHmVolVYe4dyL6lQ3duY= +github.com/alecthomas/participle/v2 v2.0.0-alpha3.0.20201208114601-14bec2482095 h1:DCGcCFtR/4YWEOoszqekJRdDoq41G+btPdOSWf5FoSo= +github.com/alecthomas/participle/v2 v2.0.0-alpha3.0.20201208114601-14bec2482095/go.mod h1:Z1zPLDbcGsVsBYsThKXY00i84575bN/nMczzIrU4rWU= +github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1 h1:GDQdwm/gAcJcLAKQQZGOJ4knlw+7rfEQQcmwTbt4p5E= github.com/alecthomas/repr v0.0.0-20181024024818-d37bc2a10ba1/go.mod h1:xTS7Pm1pD1mvyM075QCDSRqH6qRLXylzS24ZTpRiSzQ= github.com/aws/aws-sdk-go v1.23.20 h1:2CBuL21P0yKdZN5urf2NxKa1ha8fhnY+A3pBCHFeZoA= github.com/aws/aws-sdk-go v1.23.20/go.mod h1:KmX6BPdI08NWTb3/sm4ZGu5ShLoqVDhKgpiN924inxo= @@ -77,9 +79,11 @@ github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.m github.com/envoyproxy/go-control-plane v0.9.4/go.mod h1:6rpuAdCZL397s3pYoYcLgu1mIlRU8Am5FuJP05cCM98= github.com/envoyproxy/go-control-plane v0.9.7/go.mod h1:cwu0lG7PUMfa9snN8LXBig5ynNVH9qI8YYLbd1fK2po= github.com/envoyproxy/protoc-gen-validate v0.1.0/go.mod h1:iSmxcyjqTsJpI2R4NaDN7+kN2VEUnK/pcBlmesArF7c= -github.com/flynn/biscuit-go v0.0.0-20201015081742-15d7d351f345 h1:ME6bm5dwn9V2DUlfXJqeN121B5nM7rDFqLFOATALqYE= -github.com/flynn/biscuit-go v0.0.0-20201015081742-15d7d351f345/go.mod h1:Sj4oR2hNkrZH1cf3Cj5DPHc3Xq0o61GWeau6UkZR+3c= +github.com/flynn/biscuit-go v0.0.0-20201211135022-dbd2f8863bf4 h1:5TqasLkkptxZIP8TNawz76F+vMSf04Mab8/d8VdJWus= +github.com/flynn/biscuit-go v0.0.0-20201211135022-dbd2f8863bf4/go.mod h1:mY0paJD7nJ1hsxNzqHOKES2u+asFldh25X8WkveoMaw= github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= +github.com/go-gl/glfw v0.0.0-20190409004039-e6da0acd62b1/go.mod h1:vR7hzQXu2zJy9AVAgeJqvqgH9Q5CA+iKCZ2gyEVpxRU= +github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20191125211704-12ad95a8df72/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/go-gl/glfw/v3.3/glfw v0.0.0-20200222043503-6f7a984d4dc4/go.mod h1:tQ2UAYgL5IevRw8kRxooKSPJfGvJ9fJQFa0TUsXzTg8= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= @@ -291,6 +295,7 @@ golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJ golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200317015054-43a5402ce75a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20200625203802-6e8e738ad208/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9 h1:SQFwaSi55rU7vdNs9Yr0Z324VNlrF+0wMqRXT4St8ck= golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a h1:DcqTD9SDLc+1P/r1EmRBwnVsrOwW+kk2vWf9n+1sGhs= golang.org/x/sync v0.0.0-20201207232520-09787c993a3a/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= @@ -324,6 +329,7 @@ golang.org/x/sys v0.0.0-20200803210538-64077c9b5642/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20200905004654-be1d3432aa8f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3 h1:kzM6+9dur93BcC2kVlYl34cHU+TYZLanmpSJHVMmL64= golang.org/x/sys v0.0.0-20201201145000-ef89a241ccb3/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201214095126-aec9a390925b h1:tv7/y4pd+sR8bcNb2D6o7BNU6zjWm0VjQLac+w7fNNM= golang.org/x/sys v0.0.0-20201214095126-aec9a390925b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= @@ -386,6 +392,7 @@ golang.org/x/tools v0.0.0-20200904185747-39188db58858/go.mod h1:Cj7w3i3Rnn0Xh82u golang.org/x/tools v0.0.0-20200916150407-587cf2330ce8/go.mod h1:z6u4i615ZeAfBE4XtMziQW1fSVJXACjjbWkB/mvPzlU= golang.org/x/tools v0.0.0-20201110124207-079ba7bd75cd/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20201201161351-ac6f37ff4c2a/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= +golang.org/x/tools v0.0.0-20201208233053-a543418bbed2 h1:vEtypaVub6UvKkiXZ2xx9QIvp9TL7sI7xp7vdi2kezA= golang.org/x/tools v0.0.0-20201208233053-a543418bbed2/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= golang.org/x/tools v0.0.0-20201211185031-d93e913c1a58 h1:1Bs6RVeBFtLZ8Yi1Hk07DiOqzvwLD/4hln4iahvFlag= golang.org/x/tools v0.0.0-20201211185031-d93e913c1a58/go.mod h1:emZCQorbCU4vsT4fOWvOPXz4eW1wZW4PmDk9uLelYpA= @@ -457,6 +464,7 @@ google.golang.org/genproto v0.0.0-20200904004341-0bd0a958aa1d/go.mod h1:FWY/as6D google.golang.org/genproto v0.0.0-20200916143405-f6a2fa72f0c4/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201109203340-2640f1f9cdfb/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201201144952-b05cb90ed32e/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= +google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc h1:BgQmMjmd7K1zov8j8lYULHW0WnmBGUIMp6+VDwlGErc= google.golang.org/genproto v0.0.0-20201210142538-e3217bee35cc/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= google.golang.org/genproto v0.0.0-20201211151036-40ec1c210f7a h1:GnJAhasbD8HiT8DZMvsEx3QLVy/X0icq/MGr0MqRJ2M= google.golang.org/genproto v0.0.0-20201211151036-40ec1c210f7a/go.mod h1:FWY/as6DDZQgahTzZj3fqbO1CbirC29ZNUFHwi0/+no= diff --git a/pkg/cli/audiences.go b/pkg/cli/audiences.go index de43b0a..3dc11a5 100644 --- a/pkg/cli/audiences.go +++ b/pkg/cli/audiences.go @@ -5,26 +5,39 @@ import ( "encoding/base64" "encoding/pem" "fmt" + "io/ioutil" "net/url" "os" "strings" "github.com/flynn/hubauth/pkg/hubauth" + "github.com/flynn/hubauth/pkg/policy" "github.com/jedib0t/go-pretty/v6/table" "google.golang.org/genproto/googleapis/cloud/kms/v1" ) type audiencesCmd struct { - List audiencesListCmd `kong:"cmd,help='list audiences',default:'1'"` - Create audiencesCreateCmd `kong:"cmd,help='create audience'"` - UpdateType audienceUpdateTypeCmd `kong:"cmd,name='update-type',help='change audience type'"` - UpdateClientIDs audiencesUpdateClientsIDsCmd `kong:"cmd,name='update-client-ids',help='add or remove audience client IDs'"` - Delete audiencesDeleteCmd `kong:"cmd,help='delete audience and all its keys'"` + List audiencesListCmd `kong:"cmd,help='list audiences',default:'1'"` + Create audiencesCreateCmd `kong:"cmd,help='create audience'"` + UpdateType audienceUpdateTypeCmd `kong:"cmd,name='update-type',help='change audience type'"` + UpdateClientIDs audiencesUpdateClientsIDsCmd `kong:"cmd,name='update-client-ids',help='add or remove audience client IDs'"` + Delete audiencesDeleteCmd `kong:"cmd,help='delete audience and all its keys'"` + ListUserGroups audiencesListUserGroupsCmd `kong:"cmd,name='list-user-groups',help='list audience user groups'"` SetUserGroups audiencesSetUserGroupsCmd `kong:"cmd,name='set-user-groups',help='set audience auth user groups'"` UpdateUserGroups audiencesUpdateUserGroupsCmd `kong:"cmd,name='update-user-groups',help='modify audience user groups api user or groups'"` DeleteUserGroups audiencesDeleteUserGroupsCmd `kong:"cmd,name='delete-user-groups',help='delete audience auth user groups'"` - Key audiencesKeyCmd `kong:"cmd,help='get audience public key'"` + + Key audiencesKeyCmd `kong:"cmd,help='get audience public key'"` + + ListPolicies audiencesListPoliciesCmd `kong:"cmd,name='list-policies',help='list audience policies'"` + DumpPolicies audiencesDumpPoliciesCmd `kong:"cmd,name='dump-policies',help='dump audience policies'"` + SetPolicies audiencesSetPoliciesCmd `kong:"cmd,name='set-policies',help='set audience policies'"` + UpdatePolicy audiencesUpdatePolicyCmd `kong:"cmd,name='update-policy',help='modify audience policy content or groups'"` + DeletePolicy audiencesDeletePolicyCmd `kong:"cmd,name='delete-policy',help='delete audience policy'"` + + NewPolicy audiencesNewPolicyCmd `kong:"cmd,name='new-policy',help='print a new empty policy document on stdout'"` + ValidatePolicies audiencesValidatePoliciesCmd `kong:"cmd,name='validate-policies',help='validate a policy file'"` } type audiencesListCmd struct{} @@ -285,3 +298,264 @@ func (c *audiencesKeyCmd) Run(cfg *Config) error { fmt.Println(base64.URLEncoding.EncodeToString(b.Bytes)) return nil } + +type audiencesListPoliciesCmd struct { + AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` +} + +func (c *audiencesListPoliciesCmd) Run(cfg *Config) error { + audience, err := cfg.DB.GetAudience(context.Background(), c.AudienceURL) + if err != nil { + return err + } + + t := table.NewWriter() + t.SetOutputMirror(os.Stdout) + t.AppendHeader(table.Row{"Name", "Groups", "Description"}) + for _, p := range audience.Policies { + t.AppendRow(table.Row{p.Name, p.Groups, getFirstComment(p)}) + } + t.Render() + return nil +} + +// getFirstComment parse the policy content and returns the first policy +// comment line if it exists. On failure to parse the policy content, or when unset, an empty string is returned. +func getFirstComment(p *hubauth.BiscuitPolicy) string { + doc, err := policy.ParseDocumentPolicy(strings.NewReader(p.Content)) + if err != nil { + return "" + } + if len(doc.Comments) == 0 { + return "" + } + return string(*doc.Comments[0]) +} + +type audiencesSetPoliciesCmd struct { + AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` + Filepath string `kong:"required,name='filepath',help='policy file'"` + Groups []string `kong:"help='comma-separated group IDs'"` +} + +// Run parses Filepath for a list of policies, and creates or updates them on the audience identified by AudienceURL, +// forcing their groups to the provided Groups. +func (c *audiencesSetPoliciesCmd) Run(cfg *Config) error { + f, err := os.Open(c.Filepath) + if err != nil { + return err + } + + doc, err := policy.ParseNamed(f.Name(), f) + if err != nil { + return err + } + + muts := make([]*hubauth.AudienceMutation, len(doc.Policies)) + for i, p := range doc.Policies { + muts[i] = &hubauth.AudienceMutation{ + Op: hubauth.AudienceMutationSetPolicy, + Policy: hubauth.BiscuitPolicy{ + Name: *p.Name, + Content: policy.PrintPolicy(p), + Groups: c.Groups, + }, + } + } + + return cfg.DB.MutateAudience(context.Background(), c.AudienceURL, muts) +} + +type audiencesUpdatePolicyCmd struct { + AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` + PolicyName string `kong:"required,help='policy name'"` + Filepath string `kong:"name='filepath',help='replace policy content from a file'"` + AddGroups []string `kong:"name='add-groups',help='comma-separated group IDs to add'"` + DeleteGroups []string `kong:"name='delete-groups',help='comma-separated group IDs to delete'"` +} + +func (c *audiencesUpdatePolicyCmd) Run(cfg *Config) error { + var muts []*hubauth.AudiencePolicyMutation + for _, groupID := range c.AddGroups { + muts = append(muts, &hubauth.AudiencePolicyMutation{ + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: groupID, + }) + } + for _, groupID := range c.DeleteGroups { + muts = append(muts, &hubauth.AudiencePolicyMutation{ + Op: hubauth.AudiencePolicyMutationOpDeleteGroup, + Group: groupID, + }) + } + if c.Filepath != "" { + doc, err := parsePolicy(c.Filepath) + if err != nil { + return err + } + + var mutatedPolicy *policy.DocumentPolicy + for _, p := range doc.Policies { + if *p.Name == c.PolicyName { + mutatedPolicy = p + break + } + } + if mutatedPolicy == nil { + return fmt.Errorf("policy %q not found in file %q", c.PolicyName, c.Filepath) + } + + muts = append(muts, &hubauth.AudiencePolicyMutation{ + Op: hubauth.AudiencePolicyMutationOpSetContent, + Content: policy.PrintPolicy(mutatedPolicy), + }) + } + + return cfg.DB.MutateAudiencePolicy(context.Background(), c.AudienceURL, c.PolicyName, muts) +} + +type audiencesDeletePolicyCmd struct { + AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` + PolicyName string `kong:"required,help='policy name'"` +} + +func (c *audiencesDeletePolicyCmd) Run(cfg *Config) error { + mut := &hubauth.AudienceMutation{ + Op: hubauth.AudienceMutationDeletePolicy, + Policy: hubauth.BiscuitPolicy{ + Name: c.PolicyName, + }, + } + return cfg.DB.MutateAudience(context.Background(), c.AudienceURL, []*hubauth.AudienceMutation{mut}) +} + +type audiencesNewPolicyCmd struct { + Filepath string `kong:"name='filepath',short='f',help='optionnal filepath where to write the policy (default: stdout)'"` +} + +var policyTemplate string = `// this is a template policy +policy "dummy" { + rules { + // this is a dummy rule + *head($var1) + <- body1(#ambient, $name), + body2($value) + @ $name == "example" + } + + caveats {[ + // this is a dummy caveat + *head($var1) + <- body1(#ambient, $name), + body2($value) + @ $name == "example" + ]} +}` + +func (c *audiencesNewPolicyCmd) Run(cfg *Config) error { + d, err := policy.Parse(strings.NewReader(policyTemplate)) + if err != nil { + return err + } + + out, err := policy.Print(d) + if err != nil { + return err + } + + if c.Filepath != "" { + ioutil.WriteFile(c.Filepath, []byte(out), 0644) + fmt.Printf("written %s\n", c.Filepath) + return nil + } + + fmt.Print(out) + return nil +} + +type audiencesValidatePoliciesCmd struct { + Filepath string `kong:"required,name='filepath',short='f',help='a file containing policy definitions'"` +} + +func (c *audiencesValidatePoliciesCmd) Run(cfg *Config) error { + f, err := os.Open(c.Filepath) + if err != nil { + return err + } + + _, err = policy.ParseNamed(f.Name(), f) + if err != nil { + return err + } + + return nil +} + +type audiencesDumpPoliciesCmd struct { + AudienceURL string `kong:"required,name='audience-url',help='audience URL'"` + PolicyNames []string `kong:"name='policy-names',help='comma separated policy names to dump (default: all)'"` + Filepath string `kong:"name='filepath',short='f',help='optionnal filepath where to write the policies (default: stdout)'"` +} + +func (c *audiencesDumpPoliciesCmd) Run(cfg *Config) error { + aud, err := cfg.DB.GetAudience(context.Background(), c.AudienceURL) + if err != nil { + return err + } + + if len(aud.Policies) == 0 { + return fmt.Errorf("audience %s have no policy", c.AudienceURL) + } + + dumpPolicies := aud.Policies + if len(c.PolicyNames) > 0 { + dumpPolicies = make([]*hubauth.BiscuitPolicy, 0, len(c.PolicyNames)) + for _, p := range aud.Policies { + for _, name := range c.PolicyNames { + if name == p.Name { + dumpPolicies = append(dumpPolicies, p) + break + } + } + } + } + + aggContent := "" + for _, p := range dumpPolicies { + aggContent += p.Content + } + + doc, err := policy.Parse(strings.NewReader(aggContent)) + if err != nil { + return err + } + + out, err := policy.Print(doc) + if err != nil { + return err + } + + if c.Filepath != "" { + ioutil.WriteFile(c.Filepath, []byte(out), 0644) + fmt.Printf("written %d policies to %s\n", len(dumpPolicies), c.Filepath) + return nil + } + + fmt.Printf("%s", out) + return nil +} + +func parsePolicy(path string) (*policy.Document, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + + doc, err := policy.ParseNamed(f.Name(), f) + if err != nil { + return nil, err + } + + return doc, nil +} diff --git a/pkg/cli/audiences_test.go b/pkg/cli/audiences_test.go index fff0b5f..8609f00 100644 --- a/pkg/cli/audiences_test.go +++ b/pkg/cli/audiences_test.go @@ -11,11 +11,14 @@ import ( "encoding/pem" "errors" "fmt" + "io/ioutil" "os" + "strings" "testing" "time" "github.com/flynn/hubauth/pkg/hubauth" + "github.com/flynn/hubauth/pkg/policy" "github.com/googleapis/gax-go/v2" "github.com/jedib0t/go-pretty/v6/table" "github.com/stretchr/testify/mock" @@ -77,7 +80,10 @@ func (m *mockAudienceDatastore) MutateAudience(ctx context.Context, url string, args := m.Called(ctx, url, mut) return args.Error(0) } - +func (m *mockAudienceDatastore) MutateAudiencePolicy(ctx context.Context, url string, policyName string, mut []*hubauth.AudiencePolicyMutation) error { + args := m.Called(ctx, url, policyName, mut) + return args.Error(0) +} func (m *mockAudienceDatastore) MutateAudienceUserGroups(ctx context.Context, url string, domain string, mut []*hubauth.AudienceUserGroupsMutation) error { args := m.Called(ctx, url, domain, mut) return args.Error(0) @@ -623,3 +629,364 @@ func TestAudienceUpdateTypeCmd(t *testing.T) { require.NoError(t, cmd.Run(cfg)) } + +func TestAudienceListPoliciesCmd(t *testing.T) { + cmd := &audiencesListPoliciesCmd{ + AudienceURL: "https://audience.url", + } + + cfg := &Config{ + DB: &mockAudienceDatastore{}, + } + + policy1Content := "// policy1 description\npolicy \"policy1\" {}" + policy2Content := "// policy2 description\npolicy \"policy2\" {}" + policy3Content := "policy \"policy3\" {}" + + audience := &hubauth.Audience{ + URL: cmd.AudienceURL, + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "policy1", + Content: policy1Content, + Groups: []string{"grp1", "grp2"}, + }, + { + Name: "policy2", + Content: policy2Content, + Groups: nil, + }, + { + Name: "policy3", + Content: policy3Content, + Groups: nil, + }, + }, + } + + cfg.DB.(*mockAudienceDatastore).On("GetAudience", mock.Anything, cmd.AudienceURL).Return(audience, nil) + + r, w, err := os.Pipe() + require.NoError(t, err) + origStdout := os.Stdout + os.Stdout = w + + require.NoError(t, cmd.Run(cfg)) + + os.Stdout = origStdout + + buf := make([]byte, 2048) + n, err := r.Read(buf) + require.NoError(t, err) + + expectedBuf := new(bytes.Buffer) + tw := table.NewWriter() + tw.SetOutputMirror(expectedBuf) + tw.AppendHeader(table.Row{"Name", "Groups", "Description"}) + for _, p := range audience.Policies { + tw.AppendRow(table.Row{p.Name, p.Groups, getFirstComment(p)}) + } + tw.Render() + + require.Equal(t, expectedBuf.String(), string(buf[:n])) +} + +func TestAudiencesSetPoliciesCmd(t *testing.T) { + policy1Content := `// policy1 +policy "policy1" { + rules { + // rule1 + *r1($a) <- f1($a) + } +}` + + policy2Content := `// policy2 +policy "policy2" {} +` + + policyFile, err := ioutil.TempFile(os.TempDir(), "testaudiencessetpoliciescmd-") + require.NoError(t, err) + defer func() { + policyFile.Close() + os.Remove(policyFile.Name()) + }() + + _, err = policyFile.WriteString(policy1Content) + require.NoError(t, err) + _, err = policyFile.WriteString(policy2Content) + require.NoError(t, err) + + groups := []string{"grp1", "grp2"} + + cmd := &audiencesSetPoliciesCmd{ + AudienceURL: "https://audience.url", + Filepath: policyFile.Name(), + Groups: groups, + } + + cfg := &Config{ + DB: &mockAudienceDatastore{}, + } + + policy1ContentFmt, err := policy.Format(strings.NewReader(policy1Content)) + require.NoError(t, err) + policy2ContentFmt, err := policy.Format(strings.NewReader(policy2Content)) + require.NoError(t, err) + + expectedMuts := []*hubauth.AudienceMutation{ + { + Op: hubauth.AudienceMutationSetPolicy, + Policy: hubauth.BiscuitPolicy{ + Name: "policy1", + Content: policy1ContentFmt, + Groups: groups, + }, + }, + { + Op: hubauth.AudienceMutationSetPolicy, + Policy: hubauth.BiscuitPolicy{ + Name: "policy2", + Content: policy2ContentFmt, + Groups: groups, + }, + }, + } + + cfg.DB.(*mockAudienceDatastore).On("MutateAudience", mock.Anything, cmd.AudienceURL, expectedMuts).Return(nil) + + require.NoError(t, cmd.Run(cfg)) +} + +func TestAudiencesUpdatePolicyCmd(t *testing.T) { + policy1Content := "// policy1\npolicy \"policy1\" {}" + + policy1ContentFmt, err := policy.Format(strings.NewReader(policy1Content)) + require.NoError(t, err) + + policyFile, err := ioutil.TempFile(os.TempDir(), "testaudiencesupdatepoliciescmd-") + require.NoError(t, err) + defer func() { + policyFile.Close() + os.Remove(policyFile.Name()) + }() + + _, err = policyFile.WriteString(policy1Content) + require.NoError(t, err) + + cmd := &audiencesUpdatePolicyCmd{ + AudienceURL: "https://audience.url", + PolicyName: "policy1", + Filepath: policyFile.Name(), + AddGroups: []string{"grp1", "grp2"}, + DeleteGroups: []string{"grp3"}, + } + + cfg := &Config{ + DB: &mockAudienceDatastore{}, + } + + expectedMuts := []*hubauth.AudiencePolicyMutation{ + { + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: "grp1", + }, + { + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: "grp2", + }, + { + Op: hubauth.AudiencePolicyMutationOpDeleteGroup, + Group: "grp3", + }, + { + Op: hubauth.AudiencePolicyMutationOpSetContent, + Content: policy1ContentFmt, + }, + } + + cfg.DB.(*mockAudienceDatastore).On("MutateAudiencePolicy", mock.Anything, cmd.AudienceURL, cmd.PolicyName, expectedMuts).Return(nil) + + require.NoError(t, cmd.Run(cfg)) + + cmd.PolicyName = "not-existing-policy" + require.Error(t, cmd.Run(cfg)) +} + +func TestAudiencesDeletePolicyCmd(t *testing.T) { + cmd := audiencesDeletePolicyCmd{ + AudienceURL: "https://audience.url", + PolicyName: "policy1", + } + + cfg := &Config{ + DB: &mockAudienceDatastore{}, + } + + expectedMuts := []*hubauth.AudienceMutation{ + { + Op: hubauth.AudienceMutationDeletePolicy, + Policy: hubauth.BiscuitPolicy{ + Name: cmd.PolicyName, + }, + }, + } + + cfg.DB.(*mockAudienceDatastore).On("MutateAudience", mock.Anything, cmd.AudienceURL, expectedMuts).Return(nil) + + require.NoError(t, cmd.Run(cfg)) +} + +func TestAudiencesNewPolicyCmd(t *testing.T) { + cmd := &audiencesNewPolicyCmd{} + + r, w, err := os.Pipe() + require.NoError(t, err) + origStdout := os.Stdout + os.Stdout = w + + require.NoError(t, cmd.Run(&Config{})) + + os.Stdout = origStdout + + buf := make([]byte, 2048) + n, err := r.Read(buf) + require.NoError(t, err) + + policyTemplateFmt, err := policy.Format(strings.NewReader(policyTemplate)) + require.NoError(t, err) + + require.Equal(t, policyTemplateFmt, string(buf[:n])) + + policyFile, err := ioutil.TempFile(os.TempDir(), "testaudiencesnewpolicycmd-") + require.NoError(t, err) + defer func() { + policyFile.Close() + os.Remove(policyFile.Name()) + }() + + cmd.Filepath = policyFile.Name() + require.NoError(t, cmd.Run(&Config{})) + + out, err := ioutil.ReadFile(policyFile.Name()) + require.NoError(t, err) + require.Equal(t, policyTemplateFmt, string(out)) +} + +func TestAudiencesValidatePoliciesCmd(t *testing.T) { + testCases := []struct { + desc string + content string + expectValid bool + }{ + { + desc: "valid policy", + content: `policy "p1" {}`, + expectValid: true, + }, + { + desc: "invalid policy", + content: `policy {}`, + expectValid: false, + }, + { + desc: "empty", + content: ``, + expectValid: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.desc, func(t *testing.T) { + policyFile, err := ioutil.TempFile(os.TempDir(), "testaudiencesvalidatepolicycmd-") + require.NoError(t, err) + defer func() { + policyFile.Close() + os.Remove(policyFile.Name()) + }() + + _, err = policyFile.WriteString(tc.content) + require.NoError(t, err) + + cmd := &audiencesValidatePoliciesCmd{ + Filepath: policyFile.Name(), + } + + err = cmd.Run(&Config{}) + if tc.expectValid { + require.NoError(t, err) + } else { + require.Error(t, err) + } + }) + } +} + +func TestAudiencesDumpPoliciesCmd(t *testing.T) { + cmd := audiencesDumpPoliciesCmd{ + AudienceURL: "https://audience.url", + } + + cfg := &Config{ + DB: &mockAudienceDatastore{}, + } + + p1Content := "policy \"p1\" {}" + p2Content := "policy \"p2\" {}" + p3Content := "policy \"p3\" {}" + + audience := &hubauth.Audience{ + URL: cmd.AudienceURL, + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "p1", + Content: p1Content, + }, + { + Name: "p2", + Content: p2Content, + }, + { + Name: "p3", + Content: p3Content, + }, + }, + } + + cfg.DB.(*mockAudienceDatastore).On("GetAudience", mock.Anything, cmd.AudienceURL).Return(audience, nil) + + r, w, err := os.Pipe() + require.NoError(t, err) + origStdout := os.Stdout + os.Stdout = w + + require.NoError(t, cmd.Run(cfg)) + + buf := make([]byte, 2048) + n, err := r.Read(buf) + require.NoError(t, err) + require.Equal(t, strings.Join([]string{p1Content, p2Content, p3Content}, "\n\n")+"\n", string(buf[:n])) + + cmd.PolicyNames = []string{"p1", "p3"} + require.NoError(t, cmd.Run(cfg)) + + os.Stdout = origStdout + + n, err = r.Read(buf) + require.NoError(t, err) + + expectedOut := strings.Join([]string{p1Content, p3Content}, "\n\n") + "\n" + require.Equal(t, expectedOut, string(buf[:n])) + + policyFile, err := ioutil.TempFile(os.TempDir(), "testaudiencesdumppolicycmd-") + require.NoError(t, err) + defer func() { + policyFile.Close() + os.Remove(policyFile.Name()) + }() + cmd.Filepath = policyFile.Name() + require.NoError(t, cmd.Run(cfg)) + + got, err := ioutil.ReadFile(policyFile.Name()) + require.NoError(t, err) + require.Equal(t, expectedOut, string(got)) +} diff --git a/pkg/cli/clients_test.go b/pkg/cli/clients_test.go index 687a7a8..3a14ef5 100644 --- a/pkg/cli/clients_test.go +++ b/pkg/cli/clients_test.go @@ -215,7 +215,7 @@ func TestClientUpdateCmd(t *testing.T) { } cfg := &Config{DB: &mockClientDatastore{}} expectedMutations := []*hubauth.ClientMutation{ - &hubauth.ClientMutation{ + { Op: hubauth.ClientMutationOpSetRefreshTokenExpiry, RefreshTokenExpiry: 5 * time.Minute, }, @@ -232,19 +232,19 @@ func TestClientUpdateCmd(t *testing.T) { } cfg := &Config{DB: &mockClientDatastore{}} expectedMutations := []*hubauth.ClientMutation{ - &hubauth.ClientMutation{ + { Op: hubauth.ClientMutationOpSetRefreshTokenExpiry, RefreshTokenExpiry: 5 * time.Minute, }, - &hubauth.ClientMutation{ + { Op: hubauth.ClientMutationOpAddRedirectURI, RedirectURI: "http://localhost:1234", }, - &hubauth.ClientMutation{ + { Op: hubauth.ClientMutationOpAddRedirectURI, RedirectURI: "http://localhost:5678", }, - &hubauth.ClientMutation{ + { Op: hubauth.ClientMutationOpDeleteRedirectURI, RedirectURI: "http://removed-domain:1234", }, diff --git a/pkg/datastore/audience.go b/pkg/datastore/audience.go index 4fad7c8..b8a2b6f 100644 --- a/pkg/datastore/audience.go +++ b/pkg/datastore/audience.go @@ -18,12 +18,18 @@ func buildAudience(c *hubauth.Audience) *audience { userGroups[i] = buildGoogleUserGroups(p) } + policies := make([]biscuitPolicy, len(c.Policies)) + for i, p := range c.Policies { + policies[i] = buildBiscuitPolicy(p) + } + return &audience{ Key: audienceKey(c.URL), Name: c.Name, Type: c.Type, ClientIDs: c.ClientIDs, UserGroups: userGroups, + Policies: policies, CreateTime: now, UpdateTime: now, } @@ -35,6 +41,7 @@ type audience struct { Type string ClientIDs []string UserGroups []googleUserGroups `datastore:",flatten"` + Policies []biscuitPolicy `datastore:",flatten"` CreateTime time.Time UpdateTime time.Time } @@ -47,32 +54,67 @@ func buildGoogleUserGroups(p *hubauth.GoogleUserGroups) googleUserGroups { } } +func buildBiscuitPolicy(p *hubauth.BiscuitPolicy) biscuitPolicy { + return biscuitPolicy{ + Name: p.Name, + Content: p.Content, + Groups: strings.Join(p.Groups, ","), + } +} + type googleUserGroups struct { Domain string APIUser string Groups string // datastore doesn't take nested lists, so encode by comma-separating } +type biscuitPolicy struct { + Name string + Content string + Groups string // datastore doesn't take nested lists, so encode by comma-separating +} + func (c *audience) Export() *hubauth.Audience { - userGroups := make([]*hubauth.GoogleUserGroups, len(c.UserGroups)) - for i, p := range c.UserGroups { - var grps []string - if p.Groups != "" { - grps = strings.Split(p.Groups, ",") + var userGroups []*hubauth.GoogleUserGroups + if len(c.UserGroups) > 0 { + userGroups = make([]*hubauth.GoogleUserGroups, len(c.UserGroups)) + for i, p := range c.UserGroups { + var grps []string + if p.Groups != "" { + grps = strings.Split(p.Groups, ",") + } + + userGroups[i] = &hubauth.GoogleUserGroups{ + Domain: p.Domain, + APIUser: p.APIUser, + Groups: grps, + } } + } + var policies []*hubauth.BiscuitPolicy + if len(c.Policies) > 0 { + policies = make([]*hubauth.BiscuitPolicy, len(c.Policies)) + for i, p := range c.Policies { + var grps []string + if p.Groups != "" { + grps = strings.Split(p.Groups, ",") + } - userGroups[i] = &hubauth.GoogleUserGroups{ - Domain: p.Domain, - APIUser: p.APIUser, - Groups: grps, + policies[i] = &hubauth.BiscuitPolicy{ + Name: p.Name, + Content: p.Content, + Groups: grps, + } } } + return &hubauth.Audience{ URL: c.Key.Name, Name: c.Name, Type: c.Type, ClientIDs: c.ClientIDs, UserGroups: userGroups, + Policies: policies, CreateTime: c.CreateTime, UpdateTime: c.UpdateTime, } @@ -172,6 +214,25 @@ func (s *service) MutateAudience(ctx context.Context, url string, mut []*hubauth } aud.Type = m.Type modified = true + case hubauth.AudienceMutationSetPolicy: + for i, p := range aud.Policies { + if p.Name == m.Policy.Name { + aud.Policies[i] = buildBiscuitPolicy(&m.Policy) + modified = true + continue outer + } + } + aud.Policies = append(aud.Policies, buildBiscuitPolicy(&m.Policy)) + modified = true + case hubauth.AudienceMutationDeletePolicy: + for i, p := range aud.Policies { + if p.Name != m.Policy.Name { + continue + } + aud.Policies[i] = aud.Policies[len(aud.Policies)-1] + aud.Policies = aud.Policies[:len(aud.Policies)-1] + modified = true + } default: return fmt.Errorf("datastore: unknown audience mutation op %s", m.Op) } @@ -272,6 +333,89 @@ func (s *service) MutateAudienceUserGroups(ctx context.Context, url string, doma return nil } +func (s *service) MutateAudiencePolicy(ctx context.Context, url string, policyName string, mut []*hubauth.AudiencePolicyMutation) error { + ctx, span := trace.StartSpan(ctx, "datastore.MutateAudiencePolicy") + span.AddAttributes( + trace.StringAttribute("audience_url", url), + trace.StringAttribute("audience_policy_name", policyName), + trace.Int64Attribute("audience_policy_mutation_count", int64(len(mut))), + ) + defer span.End() + + k := audienceKey(url) + _, err := s.db.RunInTransaction(ctx, func(tx *datastore.Transaction) error { + aud := &audience{} + if err := tx.Get(k, aud); err != nil { + if err == datastore.ErrNoSuchEntity { + err = hubauth.ErrNotFound + } + return fmt.Errorf("datastore: error fetching audience %s: %w", url, err) + } + + var policy *biscuitPolicy + for i := range aud.Policies { + if aud.Policies[i].Name == policyName { + policy = &aud.Policies[i] + break + } + } + if policy == nil { + return hubauth.ErrNotFound + } + + modified := false + outer: + for _, m := range mut { + switch m.Op { + case hubauth.AudiencePolicyMutationOpAddGroup: + var groups []string + if policy.Groups != "" { + groups = strings.Split(policy.Groups, ",") + } + for _, g := range groups { + if g == m.Group { + continue outer + } + } + policy.Groups = strings.Join(append(groups, m.Group), ",") + modified = true + case hubauth.AudiencePolicyMutationOpDeleteGroup: + var groups []string + if policy.Groups != "" { + groups = strings.Split(policy.Groups, ",") + } + for i, g := range groups { + if g != m.Group { + continue + } + groups[i] = groups[len(groups)-1] + groups = groups[:len(groups)-1] + } + policy.Groups = strings.Join(groups, ",") + modified = true + case hubauth.AudiencePolicyMutationOpSetContent: + if policy.Content == m.Content { + continue + } + policy.Content = m.Content + modified = true + default: + return fmt.Errorf("datastore: unknown audience policy mutation op %s", m.Op) + } + } + if !modified { + return nil + } + aud.UpdateTime = time.Now() + _, err := tx.Put(k, aud) + return err + }) + if err != nil { + return fmt.Errorf("datastore: error mutating audience %s: %w", url, err) + } + return nil +} + func (s *service) ListAudiences(ctx context.Context) ([]*hubauth.Audience, error) { ctx, span := trace.StartSpan(ctx, "datastore.ListAudiences") defer span.End() diff --git a/pkg/datastore/audience_test.go b/pkg/datastore/audience_test.go index 0c754bb..abfefa1 100644 --- a/pkg/datastore/audience_test.go +++ b/pkg/datastore/audience_test.go @@ -358,6 +358,105 @@ func TestAudienceMutate(t *testing.T) { Type: "new-type", }, }, + { + desc: "set new policy", + mut: []*hubauth.AudienceMutation{ + { + Op: hubauth.AudienceMutationSetPolicy, + Policy: hubauth.BiscuitPolicy{ + Name: "policy", + Content: "policy content", + Groups: []string{"grp1"}, + }, + }, + }, + before: &hubauth.Audience{ + Policies: nil, + }, + after: &hubauth.Audience{ + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "policy", + Content: "policy content", + Groups: []string{"grp1"}, + }, + }, + }, + }, + { + desc: "set existing policy", + mut: []*hubauth.AudienceMutation{ + { + Op: hubauth.AudienceMutationSetPolicy, + Policy: hubauth.BiscuitPolicy{ + Name: "policy", + Content: "policy content", + Groups: []string{"grp1"}, + }, + }, + }, + before: &hubauth.Audience{ + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "policy", + Content: "old policy content", + Groups: []string{"grpA", "grpB"}, + }, + { + Name: "another_policy", + Content: "another policy content", + }, + }, + }, + after: &hubauth.Audience{ + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "policy", + Content: "policy content", + Groups: []string{"grp1"}, + }, + { + Name: "another_policy", + Content: "another policy content", + }, + }, + }, + }, + { + desc: "delete policy", + mut: []*hubauth.AudienceMutation{ + { + Op: hubauth.AudienceMutationDeletePolicy, + Policy: hubauth.BiscuitPolicy{ + Name: "policy", + Content: "policy content", + Groups: []string{"grp1"}, + }, + }, + }, + before: &hubauth.Audience{ + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "policy", + Content: "old policy content", + Groups: []string{"grpA", "grpB"}, + }, + { + Name: "another_policy", + Content: "another policy content", + }, + }, + }, + after: &hubauth.Audience{ + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "another_policy", + Content: "another policy content", + }, + }, + }, + }, + { desc: "multiple", mut: []*hubauth.AudienceMutation{ @@ -387,28 +486,30 @@ func TestAudienceMutate(t *testing.T) { ctx := context.Background() url := "https://cluster.mutate.example.com" for _, tt := range tests { - tt.before.URL = url - tt.after.URL = url - err := s.CreateAudience(ctx, tt.before) - require.NoError(t, err, tt.desc) - before, err := s.GetAudience(ctx, url) - require.NoError(t, err) - - err = s.MutateAudience(ctx, url, tt.mut) - require.NoError(t, err, tt.desc) - - res, err := s.GetAudience(ctx, url) - require.NoError(t, err, tt.desc) - if len(res.UserGroups) == 0 { - res.UserGroups = nil - } - require.Equal(t, before.CreateTime, res.CreateTime) - - res.CreateTime = time.Time{} - res.UpdateTime = time.Time{} - require.Equal(t, tt.after, res, tt.desc) - - s.DeleteAudience(ctx, url) + t.Run(tt.desc, func(t *testing.T) { + tt.before.URL = url + tt.after.URL = url + err := s.CreateAudience(ctx, tt.before) + require.NoError(t, err, tt.desc) + before, err := s.GetAudience(ctx, url) + require.NoError(t, err) + + err = s.MutateAudience(ctx, url, tt.mut) + require.NoError(t, err, tt.desc) + + res, err := s.GetAudience(ctx, url) + require.NoError(t, err, tt.desc) + if len(res.UserGroups) == 0 { + res.UserGroups = nil + } + require.Equal(t, before.CreateTime, res.CreateTime) + + res.CreateTime = time.Time{} + res.UpdateTime = time.Time{} + require.Equal(t, tt.after, res, tt.desc) + + s.DeleteAudience(ctx, url) + }) } } @@ -624,3 +725,170 @@ func TestMutateAudienceUserGroups(t *testing.T) { s.DeleteAudience(ctx, aud.URL) } } + +func TestMutateAudiencePolicy(t *testing.T) { + policyName := "policy_name" + type test struct { + desc string + mut []*hubauth.AudiencePolicyMutation + before []*hubauth.BiscuitPolicy + after []*hubauth.BiscuitPolicy + } + tests := []test{ + { + desc: "add single group", + mut: []*hubauth.AudiencePolicyMutation{ + { + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: "grp1", + }, + }, + before: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: nil, + }, + }, + after: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: []string{"grp1"}, + }, + }, + }, + { + desc: "add multiple groups", + mut: []*hubauth.AudiencePolicyMutation{ + { + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: "existing", + }, + { + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: "grp1", + }, + { + Op: hubauth.AudiencePolicyMutationOpAddGroup, + Group: "grp2", + }, + }, + before: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: []string{"existing"}, + }, + }, + after: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: []string{"existing", "grp1", "grp2"}, + }, + }, + }, + { + desc: "delete last group", + mut: []*hubauth.AudiencePolicyMutation{ + { + Op: hubauth.AudiencePolicyMutationOpDeleteGroup, + Group: "grp1", + }, + }, + before: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: []string{"grp1"}, + }, + }, + after: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: nil, + }, + }, + }, + { + desc: "delete multiple groups", + mut: []*hubauth.AudiencePolicyMutation{ + { + Op: hubauth.AudiencePolicyMutationOpDeleteGroup, + Group: "grp1", + }, + { + Op: hubauth.AudiencePolicyMutationOpDeleteGroup, + Group: "grp2", + }, + }, + before: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: []string{"existing", "grp1", "grp2"}, + }, + }, + after: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Groups: []string{"existing"}, + }, + }, + }, + { + desc: "set content", + mut: []*hubauth.AudiencePolicyMutation{ + { + Op: hubauth.AudiencePolicyMutationOpSetContent, + Content: "new content", + }, + }, + before: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Content: "", + }, + }, + after: []*hubauth.BiscuitPolicy{ + { + Name: policyName, + Content: "new content", + }, + }, + }, + } + + s := newTestService(t) + ctx := context.Background() + for _, tt := range tests { + t.Run(tt.desc, func(t *testing.T) { + aud := &hubauth.Audience{ + URL: "https://cluster.mutate.example.com", + Policies: tt.before, + } + + err := s.CreateAudience(ctx, aud) + require.NoError(t, err, tt.desc) + before, err := s.GetAudience(ctx, aud.URL) + require.NoError(t, err) + + err = s.MutateAudiencePolicy(ctx, aud.URL, policyName, tt.mut) + require.NoError(t, err, tt.desc) + + res, err := s.GetAudience(ctx, aud.URL) + require.NoError(t, err, tt.desc) + if len(res.UserGroups) == 0 { + res.UserGroups = nil + } + require.Equal(t, before.CreateTime, res.CreateTime) + + // sort to ensure consistent slice comparison + for _, p := range res.Policies { + sort.Strings(p.Groups) + } + for _, p := range tt.after { + sort.Strings(p.Groups) + } + + require.Equal(t, tt.after, res.Policies, tt.desc) + + s.DeleteAudience(ctx, aud.URL) + }) + } +} diff --git a/pkg/hubauth/data.go b/pkg/hubauth/data.go index dd5752f..b97dc42 100644 --- a/pkg/hubauth/data.go +++ b/pkg/hubauth/data.go @@ -57,11 +57,16 @@ type ClientMutation struct { RefreshTokenExpiry time.Duration } -type AudienceStore interface { +type AudienceGetterStore interface { GetAudience(ctx context.Context, url string) (*Audience, error) +} + +type AudienceStore interface { + AudienceGetterStore CreateAudience(ctx context.Context, audience *Audience) error MutateAudience(ctx context.Context, url string, mut []*AudienceMutation) error MutateAudienceUserGroups(ctx context.Context, url string, domain string, mut []*AudienceUserGroupsMutation) error + MutateAudiencePolicy(ctx context.Context, url string, policyName string, mut []*AudiencePolicyMutation) error ListAudiencesForClient(ctx context.Context, clientID string) ([]*Audience, error) ListAudiences(ctx context.Context) ([]*Audience, error) DeleteAudience(ctx context.Context, url string) error @@ -73,6 +78,7 @@ type Audience struct { Type string `json:"type"` ClientIDs []string `json:"-"` UserGroups []*GoogleUserGroups `json:"-"` + Policies []*BiscuitPolicy `json:"-"` CreateTime time.Time `json:"-"` UpdateTime time.Time `json:"-"` } @@ -83,6 +89,12 @@ type GoogleUserGroups struct { Groups []string } +type BiscuitPolicy struct { + Name string + Content string + Groups []string +} + type AudienceMutationOp byte const ( @@ -91,6 +103,8 @@ const ( AudienceMutationOpSetUserGroups AudienceMutationOpDeleteUserGroups AudienceMutationSetType + AudienceMutationSetPolicy + AudienceMutationDeletePolicy ) type AudienceMutation struct { @@ -99,6 +113,7 @@ type AudienceMutation struct { ClientID string Type string UserGroups GoogleUserGroups + Policy BiscuitPolicy } type AudienceUserGroupsMutationOp byte @@ -116,6 +131,21 @@ type AudienceUserGroupsMutation struct { Group string } +type AudiencePolicyMutationOp byte + +const ( + AudiencePolicyMutationOpAddGroup AudiencePolicyMutationOp = iota + AudiencePolicyMutationOpDeleteGroup + AudiencePolicyMutationOpSetContent +) + +type AudiencePolicyMutation struct { + Op AudiencePolicyMutationOp + + Content string + Group string +} + type CodeStore interface { GetCode(ctx context.Context, id string) (*Code, error) VerifyAndDeleteCode(ctx context.Context, id, secret string) (*Code, error) diff --git a/pkg/idp/oauth.go b/pkg/idp/oauth.go index 33a9102..b329095 100644 --- a/pkg/idp/oauth.go +++ b/pkg/idp/oauth.go @@ -35,7 +35,7 @@ func (clockImpl) Now() time.Time { } type idpSteps interface { - VerifyAudience(ctx context.Context, audienceURL, clientID, userID string) error + VerifyAudience(ctx context.Context, audienceURL, clientID, userID string) ([]string, error) VerifyUserGroups(ctx context.Context, userID string) error CreateCode(ctx context.Context, code *hubauth.Code) (string, string, error) @@ -277,8 +277,10 @@ func (s *idpService) ExchangeCode(parentCtx context.Context, req *hubauth.Exchan return err }) - g.Go(func() error { - return s.steps.VerifyAudience(ctx, req.Audience, req.ClientID, codeInfo.UserId) + var userGroups []string + g.Go(func() (err error) { + userGroups, err = s.steps.VerifyAudience(ctx, req.Audience, req.ClientID, codeInfo.UserId) + return err }) var client *hubauth.Client @@ -303,35 +305,35 @@ func (s *idpService) ExchangeCode(parentCtx context.Context, req *hubauth.Exchan return err }) + if err := g.Wait(); err != nil { + return nil, err + } + + // build access token var accessToken string var tokenType string - g.Go(func() (err error) { - if req.Audience == "" { - return nil - } - + if req.Audience != "" { var userPublicKey []byte if len(req.UserPublicKey) > 0 { var err error userPublicKey, err = base64Decode(req.UserPublicKey) if err != nil { - return fmt.Errorf("idp: invalid public key: %v", err) + return nil, fmt.Errorf("idp: invalid public key: %v", err) } } - accessToken, tokenType, err = s.steps.BuildAccessToken(ctx, req.Audience, &token.AccessTokenData{ + accessToken, tokenType, err = s.steps.BuildAccessToken(parentCtx, req.Audience, &token.AccessTokenData{ ClientID: req.ClientID, UserID: codeInfo.UserId, UserEmail: codeInfo.UserEmail, UserPublicKey: userPublicKey, + UserGroups: userGroups, IssueTime: now, ExpireTime: now.Add(accessTokenDuration), }) - return err - }) - - if err := g.Wait(); err != nil { - return nil, err + if err != nil { + return nil, err + } } res := &hubauth.AccessToken{ @@ -343,6 +345,7 @@ func (s *idpService) ExchangeCode(parentCtx context.Context, req *hubauth.Exchan RefreshTokenExpiresIn: int(client.RefreshTokenExpiry / time.Second), RefreshTokenIssueTime: now, } + if res.AccessToken == "" { // if no audience was provided, provide a refresh token that can be used to to access /audiences res.TokenType = "RefreshToken" @@ -355,16 +358,18 @@ func (s *idpService) ExchangeCode(parentCtx context.Context, req *hubauth.Exchan return res, nil } -func (s *idpService) RefreshToken(ctx context.Context, req *hubauth.RefreshTokenRequest) (*hubauth.AccessToken, error) { - oldToken, err := s.decodeRefreshToken(ctx, req.RefreshToken) +func (s *idpService) RefreshToken(parentCtx context.Context, req *hubauth.RefreshTokenRequest) (*hubauth.AccessToken, error) { + oldToken, err := s.decodeRefreshToken(parentCtx, req.RefreshToken) if err != nil { return nil, err } - g, ctx := errgroup.WithContext(ctx) + g, ctx := errgroup.WithContext(parentCtx) - g.Go(func() error { - return s.steps.VerifyAudience(ctx, req.Audience, req.ClientID, oldToken.UserID) + var userGroups []string + g.Go(func() (err error) { + userGroups, err = s.steps.VerifyAudience(ctx, req.Audience, req.ClientID, oldToken.UserID) + return err }) now := s.clock.Now() @@ -390,35 +395,34 @@ func (s *idpService) RefreshToken(ctx context.Context, req *hubauth.RefreshToken return err }) + if err := g.Wait(); err != nil { + return nil, err + } + var accessToken string var tokenType string - g.Go(func() (err error) { - if req.Audience == "" { - return nil - } - + if req.Audience != "" { var userPublicKey []byte if len(req.UserPublicKey) > 0 { var err error userPublicKey, err = base64Decode(req.UserPublicKey) if err != nil { - return fmt.Errorf("idp: invalid public key: %v", err) + return nil, fmt.Errorf("idp: invalid public key: %v", err) } } - accessToken, tokenType, err = s.steps.BuildAccessToken(ctx, req.Audience, &token.AccessTokenData{ + accessToken, tokenType, err = s.steps.BuildAccessToken(parentCtx, req.Audience, &token.AccessTokenData{ ClientID: req.ClientID, UserID: oldToken.UserID, UserEmail: oldToken.UserEmail, + UserGroups: userGroups, UserPublicKey: userPublicKey, IssueTime: now, ExpireTime: now.Add(accessTokenDuration), }) - return err - }) - - if err := g.Wait(); err != nil { - return nil, err + if err != nil { + return nil, err + } } res := &hubauth.AccessToken{ diff --git a/pkg/idp/oauth_test.go b/pkg/idp/oauth_test.go index 1722aec..48bef2b 100644 --- a/pkg/idp/oauth_test.go +++ b/pkg/idp/oauth_test.go @@ -58,9 +58,9 @@ func (m *mockSteps) SignCode(ctx context.Context, signKey hmacpb.Key, code *sign args := m.Called(ctx, signKey, code) return args.String(0), args.Error(1) } -func (m *mockSteps) VerifyAudience(ctx context.Context, audienceURL, clientID, userID string) error { +func (m *mockSteps) VerifyAudience(ctx context.Context, audienceURL, clientID, userID string) ([]string, error) { args := m.Called(ctx, audienceURL, clientID, userID) - return args.Error(0) + return args.Get(0).([]string), args.Error(1) } func (m *mockSteps) VerifyUserGroups(ctx context.Context, userID string) error { args := m.Called(ctx, userID) @@ -653,9 +653,11 @@ func TestExchangeCode(t *testing.T) { ExpiryTime: now.Add(refreshTokenExpiry), } + userGroups := []string{"grp1", "grp2"} + idpService.clock.(*mockClock).On("Now").Return(now) idpService.steps.(*mockSteps).On("AllocateRefreshToken", mock.Anything, clientID).Return(rtID, nil) - idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, testCase.AudienceURL, clientID, userID).Return(nil) + idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, testCase.AudienceURL, clientID, userID).Return(userGroups, nil) idpService.steps.(*mockSteps).On("VerifyCode", mock.Anything, &verifyCodeData{ ClientID: clientID, RedirectURI: redirectURI, @@ -669,6 +671,7 @@ func TestExchangeCode(t *testing.T) { ClientID: clientID, UserID: userID, UserEmail: userEmail, + UserGroups: userGroups, IssueTime: now, ExpireTime: now.Add(accessTokenDuration), }).Return(accessToken, testCase.Want.TokenType, nil) @@ -799,7 +802,7 @@ func TestExchangeCodeErrors(t *testing.T) { idpService.clock.(*mockClock).On("Now").Return(now) idpService.steps.(*mockSteps).On("AllocateRefreshToken", mock.Anything, mock.Anything).Return("", testCase.AllocateErr) idpService.steps.(*mockSteps).On("VerifyCode", mock.Anything, mock.Anything).Return(&hubauth.Code{}, testCase.VerifyCodeErr) - idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(testCase.VerifyAudienceErr) + idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{}, testCase.VerifyAudienceErr) idpService.steps.(*mockSteps).On("SaveRefreshToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&hubauth.Client{}, testCase.SaveErr) idpService.steps.(*mockSteps).On("SignRefreshToken", mock.Anything, mock.Anything, mock.Anything).Return("", testCase.SignRTErr) idpService.steps.(*mockSteps).On("BuildAccessToken", mock.Anything, mock.Anything, mock.Anything).Return("", "", testCase.SignATErr) @@ -879,12 +882,14 @@ func TestRefreshToken(t *testing.T) { }, } + userGroups := []string{"grp1", "grp2"} + for _, testCase := range testCases { t.Run(testCase.Desc, func(t *testing.T) { idpService := newTestIdPService(t) idpService.clock.(*mockClock).On("Now").Return(now) - idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, testCase.AudienceURL, b64ClientID, userID).Return(nil) + idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, testCase.AudienceURL, b64ClientID, userID).Return(userGroups, nil) idpService.steps.(*mockSteps).On("RenewRefreshToken", mock.Anything, b64ClientID, b64OldTokenID, issueTimeFromProto, now).Return(newRefreshToken, nil) idpService.steps.(*mockSteps).On("SignRefreshToken", mock.Anything, idpService.refreshKey, &signedRefreshTokenData{ refreshTokenData: &refreshTokenData{ @@ -900,6 +905,7 @@ func TestRefreshToken(t *testing.T) { ClientID: b64ClientID, UserID: userID, UserEmail: userEmail, + UserGroups: userGroups, IssueTime: now, ExpireTime: now.Add(accessTokenDuration), }).Return(newAccessTokenStr, testCase.Want.TokenType, nil) @@ -1002,7 +1008,7 @@ func TestRefreshTokenStepErrors(t *testing.T) { } idpService.clock.(*mockClock).On("Now").Return(now) - idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(testCase.VerifyAudienceErr) + idpService.steps.(*mockSteps).On("VerifyAudience", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return([]string{}, testCase.VerifyAudienceErr) idpService.steps.(*mockSteps).On("RenewRefreshToken", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(&hubauth.RefreshToken{}, testCase.RenewRTErr) idpService.steps.(*mockSteps).On("SignRefreshToken", mock.Anything, mock.Anything, mock.Anything).Return("", testCase.SignRTErr) idpService.steps.(*mockSteps).On("BuildAccessToken", mock.Anything, mock.Anything, mock.Anything).Return("", "", testCase.SignATErr) diff --git a/pkg/idp/steps.go b/pkg/idp/steps.go index db08163..c5d3620 100644 --- a/pkg/idp/steps.go +++ b/pkg/idp/steps.go @@ -122,19 +122,23 @@ func (s *steps) SignCode(ctx context.Context, signKey hmacpb.Key, code *signCode return base64Encode(res), nil } -func (s *steps) VerifyAudience(ctx context.Context, audienceURL, clientID, userID string) error { +// VerifyAudience ensure the user can access the audience by verifying +// that the user have at least one group belonging to the audience policies groups +// It returns the list of user groups, or an error when the user is not allowed to access this audience. +// When no audience is provided, no group and no error is returned, +func (s *steps) VerifyAudience(ctx context.Context, audienceURL, clientID, userID string) ([]string, error) { if audienceURL == "" { - return nil + return nil, nil } audience, err := s.db.GetAudience(ctx, audienceURL) if err != nil { if errors.Is(err, hubauth.ErrNotFound) { - return &hubauth.OAuthError{ + return nil, &hubauth.OAuthError{ Code: "invalid_request", Description: "unknown audience", } } - return fmt.Errorf("idp: error getting audience %s: %w", audienceURL, err) + return nil, fmt.Errorf("idp: error getting audience %s: %w", audienceURL, err) } foundClient := false for _, c := range audience.ClientIDs { @@ -145,20 +149,20 @@ func (s *steps) VerifyAudience(ctx context.Context, audienceURL, clientID, userI } if !foundClient { clog.Set(ctx, zap.Strings("audience_client_ids", audience.ClientIDs)) - return &hubauth.OAuthError{ + return nil, &hubauth.OAuthError{ Code: "invalid_client", Description: "unknown client for audience", } } - err = s.checkUser(ctx, audience, userID) + userGroups, err := s.checkUser(ctx, audience, userID) if errors.Is(err, hubauth.ErrUnauthorizedUser) { - return &hubauth.OAuthError{ + return nil, &hubauth.OAuthError{ Code: "access_denied", Description: "user is not authorized for access", } } - return err + return userGroups, err } func (s *steps) VerifyUserGroups(ctx context.Context, userID string) error { @@ -175,10 +179,10 @@ func (s *steps) VerifyUserGroups(ctx context.Context, userID string) error { return nil } -func (s *steps) checkUser(ctx context.Context, cluster *hubauth.Audience, userID string) error { +func (s *steps) checkUser(ctx context.Context, cluster *hubauth.Audience, userID string) ([]string, error) { groups, err := s.db.GetCachedMemberGroups(ctx, userID) if err != nil { - return fmt.Errorf("idp: error getting cached groups for user: %w", err) + return nil, fmt.Errorf("idp: error getting cached groups for user: %w", err) } // TODO: log allowed groups and cached groups @@ -196,9 +200,9 @@ outer: } } if !allowed { - return hubauth.ErrUnauthorizedUser + return nil, hubauth.ErrUnauthorizedUser } - return nil + return groups, nil } type refreshTokenData struct { diff --git a/pkg/idp/steps_test.go b/pkg/idp/steps_test.go index 9c02bbf..c950493 100644 --- a/pkg/idp/steps_test.go +++ b/pkg/idp/steps_test.go @@ -312,11 +312,12 @@ func TestVerifyAudience(t *testing.T) { require.NoError(t, err) testCases := []struct { - Desc string - Err error - AudienceURL string - ClientID string - UserID string + Desc string + Err error + ExpectedGroups []string + AudienceURL string + ClientID string + UserID string }{ { Desc: "no audience does nothing", @@ -350,21 +351,23 @@ func TestVerifyAudience(t *testing.T) { }, }, { - Desc: "all valid no error", - AudienceURL: validAudienceURL, - ClientID: validClientID, - UserID: validUserID, - Err: nil, + Desc: "all valid no error", + AudienceURL: validAudienceURL, + ExpectedGroups: []string{validGroupID}, + ClientID: validClientID, + UserID: validUserID, + Err: nil, }, } for _, testCase := range testCases { t.Run(testCase.Desc, func(t *testing.T) { - err := s.VerifyAudience(context.Background(), testCase.AudienceURL, testCase.ClientID, testCase.UserID) + grps, err := s.VerifyAudience(context.Background(), testCase.AudienceURL, testCase.ClientID, testCase.UserID) if testCase.Err != nil { require.Equal(t, testCase.Err, err) } else { require.NoError(t, err) + require.Equal(t, testCase.ExpectedGroups, grps) } }) } diff --git a/pkg/idp/token/biscuit.go b/pkg/idp/token/biscuit.go index 9c997be..a4857b0 100644 --- a/pkg/idp/token/biscuit.go +++ b/pkg/idp/token/biscuit.go @@ -6,10 +6,14 @@ import ( "encoding/base64" "errors" "fmt" + "strings" + "github.com/flynn/biscuit-go" "github.com/flynn/biscuit-go/cookbook/signedbiscuit" "github.com/flynn/biscuit-go/sig" + "github.com/flynn/hubauth/pkg/hubauth" "github.com/flynn/hubauth/pkg/kmssign" + "github.com/flynn/hubauth/pkg/policy" ) var ( @@ -18,13 +22,15 @@ var ( type biscuitBuilder struct { kms kmssign.KMSClient + db hubauth.AudienceGetterStore audienceKey kmssign.AudienceKeyNamer rootKeyPair sig.Keypair } -func NewBiscuitBuilder(kms kmssign.KMSClient, audienceKey kmssign.AudienceKeyNamer, rootKeyPair sig.Keypair) AccessTokenBuilder { +func NewBiscuitBuilder(kms kmssign.KMSClient, db hubauth.AudienceGetterStore, audienceKey kmssign.AudienceKeyNamer, rootKeyPair sig.Keypair) AccessTokenBuilder { return &biscuitBuilder{ kms: kms, + db: db, audienceKey: audienceKey, rootKeyPair: rootKeyPair, } @@ -37,13 +43,37 @@ func (b *biscuitBuilder) Build(ctx context.Context, audience string, t *AccessTo audienceKey := kmssign.NewPrivateKey(b.kms, b.audienceKey(audience), crypto.SHA256) meta := &signedbiscuit.Metadata{ - ClientID: t.ClientID, - UserID: t.UserID, - UserEmail: t.UserEmail, - IssueTime: t.IssueTime, + ClientID: t.ClientID, + UserID: t.UserID, + UserEmail: t.UserEmail, + UserGroups: t.UserGroups, + IssueTime: t.IssueTime, } - return signedbiscuit.GenerateSignable(b.rootKeyPair, audience, audienceKey, t.UserPublicKey, t.ExpireTime, meta) + builder := biscuit.NewBuilder(b.rootKeyPair) + builder, err := signedbiscuit.WithSignableFacts(builder, audience, audienceKey, t.UserPublicKey, t.ExpireTime, meta) + if err != nil { + return nil, err + } + + // retrieve policies from user groups and add each policy rules and caveats to the biscuit + userPolicies, err := b.getUserPolicies(ctx, audience, t.UserGroups) + if err != nil { + return nil, err + } + + for _, p := range userPolicies { + builder, err = withPolicy(builder, p) + if err != nil { + return nil, err + } + } + + bisc, err := builder.Build() + if err != nil { + return nil, err + } + return bisc.Serialize() } func (b *biscuitBuilder) TokenType() string { @@ -63,3 +93,50 @@ func DecodeB64PrivateKey(b64key string) (sig.Keypair, error) { kp = sig.NewKeypair(rootPrivateKey) return kp, nil } + +func (b *biscuitBuilder) getUserPolicies(ctx context.Context, audience string, userGroups []string) ([]*hubauth.BiscuitPolicy, error) { + aud, err := b.db.GetAudience(ctx, audience) + if err != nil { + return nil, err + } + + var userPolicies []*hubauth.BiscuitPolicy + for _, p := range aud.Policies { + outer: + for _, g := range p.Groups { + for _, ug := range userGroups { + if g == ug { + userPolicies = append(userPolicies, p) + continue outer + } + } + } + } + return userPolicies, nil +} + +func withPolicy(builder biscuit.Builder, p *hubauth.BiscuitPolicy) (biscuit.Builder, error) { + parsed, err := policy.ParseDocumentPolicy(strings.NewReader(p.Content)) + if err != nil { + return nil, err + } + for _, rule := range parsed.Rules { + biscuitRule, err := rule.ToBiscuit() + if err != nil { + return nil, err + } + if err := builder.AddAuthorityRule(*biscuitRule); err != nil { + return nil, err + } + } + for _, caveat := range parsed.Caveats { + biscuitCaveat, err := caveat.ToBiscuit() + if err != nil { + return nil, err + } + if err := builder.AddAuthorityCaveat(*biscuitCaveat); err != nil { + return nil, err + } + } + return builder, nil +} diff --git a/pkg/idp/token/biscuit_test.go b/pkg/idp/token/biscuit_test.go index 41b705f..0129b59 100644 --- a/pkg/idp/token/biscuit_test.go +++ b/pkg/idp/token/biscuit_test.go @@ -7,22 +7,40 @@ import ( "crypto/rand" "crypto/x509" "encoding/base64" + "encoding/pem" "testing" "time" "github.com/flynn/biscuit-go" + "github.com/flynn/biscuit-go/cookbook/signedbiscuit" "github.com/flynn/biscuit-go/sig" + "github.com/flynn/hubauth/pkg/hubauth" "github.com/flynn/hubauth/pkg/kmssign/kmssim" + "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/genproto/googleapis/cloud/kms/v1" ) +type mockAudienceGetterStore struct { + mock.Mock +} + +func (m *mockAudienceGetterStore) GetAudience(ctx context.Context, url string) (*hubauth.Audience, error) { + args := m.Called(ctx, url) + return args.Get(0).(*hubauth.Audience), args.Error(1) +} + +var _ hubauth.AudienceGetterStore = (*mockAudienceGetterStore)(nil) + func TestBiscuitBuilder(t *testing.T) { audience := "https://audience.url" audienceKeyName := audienceKeyNamer(audience) kmsClient := kmssim.NewClient([]string{audienceKeyName}) rootKeyPair := sig.GenerateKeypair(rand.Reader) - builder := NewBiscuitBuilder(kmsClient, audienceKeyNamer, rootKeyPair) + audienceGetterStore := new(mockAudienceGetterStore) + + builder := NewBiscuitBuilder(kmsClient, audienceGetterStore, audienceKeyNamer, rootKeyPair) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) require.NoError(t, err) @@ -30,26 +48,114 @@ func TestBiscuitBuilder(t *testing.T) { require.NoError(t, err) now := time.Now() + userGroups := []string{"grp1", "grp2"} + accessTokenData := &AccessTokenData{ ClientID: "clientID", ExpireTime: now.Add(1 * time.Minute), IssueTime: now, UserEmail: "user@email", + UserGroups: userGroups, UserID: "userID", } _, err = builder.Build(context.Background(), audience, accessTokenData) require.Equal(t, ErrPublicKeyRequired, err) accessTokenData.UserPublicKey = userPublicKey + + p1Content := ` + policy "p1" { + caveats {[ + *valid() <- test(#ambient, "policy1exists") + ]} + } + ` + + p2Content := ` + policy "p2" { + rules { + *test(#authority, $inputStr) + <- testRule(#ambient, $inputStr) + } + caveats {[ + *valid() <- test(#authority, "policy2exists") + ]} + } + ` + + p3Content := ` + policy "p3" { + caveats {[ + *valid() <- test(#ambient, "policy3exists") + ]} + } + ` + + aud := &hubauth.Audience{ + URL: audience, + Policies: []*hubauth.BiscuitPolicy{ + { + Name: "p1", + Content: p1Content, + Groups: []string{"grp1"}, + }, + { + Name: "p2", + Content: p2Content, + Groups: []string{"grp2", "grp3"}, + }, + { + Name: "p3", + Content: p3Content, + Groups: []string{"grp3"}, + }, + }, + } + audienceGetterStore.On("GetAudience", mock.Anything, audience).Return(aud, nil) + token, err := builder.Build(context.Background(), audience, accessTokenData) require.NoError(t, err) require.NotEmpty(t, token) + userKeyPair, err := signedbiscuit.NewECDSAKeyPair(priv) + require.NoError(t, err) + token, err = signedbiscuit.Sign(token, rootKeyPair.Public(), userKeyPair) + require.NoError(t, err) + b, err := biscuit.Unmarshal(token) require.NoError(t, err) - _, err = b.Verify(rootKeyPair.Public()) + verifier, err := b.Verify(rootKeyPair.Public()) + require.NoError(t, err) + + kmsPubkey, err := kmsClient.GetPublicKey(context.Background(), &kms.GetPublicKeyRequest{Name: audienceKeyName}) + require.NoError(t, err) + pemBlock, _ := pem.Decode([]byte(kmsPubkey.Pem)) + audiencePubKey, err := x509.ParsePKIXPublicKey(pemBlock.Bytes) + require.NoError(t, err) + + verifier, metas, err := signedbiscuit.WithSignatureVerification(verifier, audience, audiencePubKey.(*ecdsa.PublicKey)) require.NoError(t, err) + + require.Equal(t, accessTokenData.ClientID, metas.ClientID) + require.Equal(t, accessTokenData.UserEmail, metas.UserEmail) + require.Equal(t, accessTokenData.UserGroups, metas.UserGroups) + require.Equal(t, accessTokenData.UserID, metas.UserID) + require.Equal(t, accessTokenData.IssueTime.Unix(), metas.IssueTime.Unix()) + + require.Error(t, verifier.Verify()) + + verifier.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "test", + IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.String("policy1exists")}, + }}) + require.Error(t, verifier.Verify()) + + verifier.AddFact(biscuit.Fact{Predicate: biscuit.Predicate{ + Name: "testRule", + IDs: []biscuit.Atom{biscuit.Symbol("ambient"), biscuit.String("policy2exists")}, + }}) + require.NoError(t, verifier.Verify()) } func TestDecodeB64PrivateKey(t *testing.T) { diff --git a/pkg/idp/token/builder.go b/pkg/idp/token/builder.go index 0da8c9a..e6dd403 100644 --- a/pkg/idp/token/builder.go +++ b/pkg/idp/token/builder.go @@ -9,6 +9,7 @@ type AccessTokenData struct { ClientID string UserID string UserEmail string + UserGroups []string UserPublicKey []byte IssueTime time.Time ExpireTime time.Time diff --git a/pkg/policy/parser.go b/pkg/policy/parser.go new file mode 100644 index 0000000..b3886aa --- /dev/null +++ b/pkg/policy/parser.go @@ -0,0 +1,90 @@ +package policy + +import ( + "fmt" + "io" + + "github.com/alecthomas/participle/v2" + "github.com/alecthomas/participle/v2/lexer/stateful" + "github.com/flynn/biscuit-go" + "github.com/flynn/biscuit-go/parser" +) + +var defaultParserOptions = append(parser.DefaultParserOptions, participle.Lexer(policyLexer)) + +var policyLexer = stateful.MustSimple(append( + parser.BiscuitLexerRules, + stateful.Rule{Name: "Policy", Pattern: `policy`}, +)) + +type Document struct { + Policies []*DocumentPolicy `@@+` +} + +type DocumentPolicy struct { + Comments []*parser.Comment `@Comment*` + Name *string `"policy" @String "{"` + Rules []*parser.Rule `("rules" "{" @@* "}")?` + Caveats []*parser.Caveat `("caveats" "{" (@@ ("," @@+)*)* "}")? "}"` +} + +func (d *DocumentPolicy) BiscuitRules() ([]biscuit.Rule, error) { + rules := make([]biscuit.Rule, 0, len(d.Rules)) + for _, r := range d.Rules { + rule, err := r.ToBiscuit() + if err != nil { + return nil, err + } + rules = append(rules, *rule) + } + return rules, nil +} + +func (d *DocumentPolicy) BiscuitCaveats() ([]biscuit.Caveat, error) { + caveats := make([]biscuit.Caveat, 0, len(d.Caveats)) + for _, c := range d.Caveats { + caveat, err := c.ToBiscuit() + if err != nil { + return nil, err + } + + caveats = append(caveats, *caveat) + } + + return caveats, nil +} + +var documentParser = participle.MustBuild(&Document{}, defaultParserOptions...) +var documentPolicyParser = participle.MustBuild(&DocumentPolicy{}, defaultParserOptions...) + +func Parse(r io.Reader) (*Document, error) { + return ParseNamed("policy", r) +} + +func ParseNamed(filename string, r io.Reader) (*Document, error) { + parsed := &Document{} + if err := documentParser.Parse(filename, r, parsed); err != nil { + return nil, err + } + + policies := make(map[string]DocumentPolicy, len(parsed.Policies)) + for _, p := range parsed.Policies { + if _, exists := policies[*p.Name]; exists { + return nil, fmt.Errorf("parse error: duplicate policy %q", *p.Name) + } + } + + return parsed, nil +} + +func ParseDocumentPolicy(r io.Reader) (*DocumentPolicy, error) { + return ParseNamedDocumentPolicy("policy", r) +} + +func ParseNamedDocumentPolicy(name string, r io.Reader) (*DocumentPolicy, error) { + p := &DocumentPolicy{} + if err := documentPolicyParser.Parse(name, r, p); err != nil { + return nil, err + } + return p, nil +} diff --git a/pkg/policy/parser_test.go b/pkg/policy/parser_test.go new file mode 100644 index 0000000..8953d38 --- /dev/null +++ b/pkg/policy/parser_test.go @@ -0,0 +1,222 @@ +package policy + +import ( + "strings" + "testing" + + "github.com/flynn/biscuit-go/parser" + "github.com/stretchr/testify/require" +) + +func TestParse(t *testing.T) { + definition := ` + // admin policy comment + policy "admin" { + rules { + // rule 1 comment + *authorized($0) + <- namespace(#ambient, $0) + @ prefix($0, "demo.v1") + } + caveats {[ + // caveat 1 comment + *caveat0($0) <- authorized($0) + ]} + } + + policy "developer" { + rules { + *authorized("demo.v1.Account", $1) + <- namespace(#ambient, "demo.v1.Account"), + method(#ambient, $1), + arg(#ambient, "env", $2) + @ $1 in ["Create", "Read", "Update"], + $2 in ["DEV", "STAGING"] + *authorized("demo.v1.Account", "Read") + <- namespace(#ambient, "demo.v1.Account"), + method(#ambient, "Read"), + arg(#ambient, "env", "PROD") + } + caveats { + [*caveat1($1) <- authorized("demo.v1.Account", $1)] + } + } + + policy "auditor" { + rules { + *authorized("demo.v1.Account", "Read") + <- namespace(#ambient, "demo.v1.Account"), + method(#ambient, "Read"), + arg(#ambient, "env", "DEV") + } + caveats { + [*caveat2("Read") <- authorized("demo.v1.Account", "Read")] + } + } + ` + + doc, err := Parse(strings.NewReader(definition)) + require.NoError(t, err) + + expectedPolicies := &Document{ + Policies: []*DocumentPolicy{{ + Name: sptr("admin"), + Comments: []*parser.Comment{commentptr("admin policy comment")}, + Rules: []*parser.Rule{ + { + Comments: []*parser.Comment{commentptr("rule 1 comment")}, + Head: &parser.Predicate{Name: sptr("authorized"), IDs: []*parser.Atom{{Variable: varptr("0")}}}, + Body: []*parser.Predicate{ + {Name: sptr("namespace"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {Variable: varptr("0")}}}, + }, + Constraints: []*parser.Constraint{ + {FunctionConstraint: &parser.FunctionConstraint{ + Function: sptr("prefix"), + Variable: varptr("0"), + Argument: sptr("demo.v1"), + }}, + }, + }, + }, + Caveats: []*parser.Caveat{{Queries: []*parser.Rule{ + { + Comments: []*parser.Comment{commentptr("caveat 1 comment")}, + Head: &parser.Predicate{Name: sptr("caveat0"), IDs: []*parser.Atom{{Variable: varptr("0")}}}, + Body: []*parser.Predicate{ + {Name: sptr("authorized"), IDs: []*parser.Atom{{Variable: varptr("0")}}}, + }, + }, + }}}, + }, + { + Name: sptr("developer"), + Rules: []*parser.Rule{ + { + Head: &parser.Predicate{Name: sptr("authorized"), IDs: []*parser.Atom{ + {String: sptr("demo.v1.Account")}, + {Variable: varptr("1")}, + }}, + Body: []*parser.Predicate{ + {Name: sptr("namespace"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("demo.v1.Account")}}}, + {Name: sptr("method"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {Variable: varptr("1")}}}, + {Name: sptr("arg"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("env")}, {Variable: varptr("2")}}}, + }, + Constraints: []*parser.Constraint{ + { + VariableConstraint: &parser.VariableConstraint{ + Variable: varptr("1"), + Set: &parser.Set{ + Not: false, + String: []string{"Create", "Read", "Update"}, + }, + }, + }, + { + VariableConstraint: &parser.VariableConstraint{ + Variable: varptr("2"), + Set: &parser.Set{ + Not: false, + String: []string{"DEV", "STAGING"}, + }, + }, + }, + }, + }, + { + Head: &parser.Predicate{Name: sptr("authorized"), IDs: []*parser.Atom{{String: sptr("demo.v1.Account")}, {String: sptr("Read")}}}, + Body: []*parser.Predicate{ + {Name: sptr("namespace"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("demo.v1.Account")}}}, + {Name: sptr("method"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("Read")}}}, + {Name: sptr("arg"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("env")}, {String: sptr("PROD")}}}, + }, + }, + }, + Caveats: []*parser.Caveat{{Queries: []*parser.Rule{ + { + Head: &parser.Predicate{Name: sptr("caveat1"), IDs: []*parser.Atom{{Variable: varptr("1")}}}, + Body: []*parser.Predicate{ + {Name: sptr("authorized"), IDs: []*parser.Atom{{String: sptr("demo.v1.Account")}, {Variable: varptr("1")}}}, + }, + }, + }}}, + }, + { + Name: sptr("auditor"), + Rules: []*parser.Rule{ + { + Head: &parser.Predicate{Name: sptr("authorized"), IDs: []*parser.Atom{{String: sptr("demo.v1.Account")}, {String: sptr("Read")}}}, + Body: []*parser.Predicate{ + {Name: sptr("namespace"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("demo.v1.Account")}}}, + {Name: sptr("method"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("Read")}}}, + {Name: sptr("arg"), IDs: []*parser.Atom{{Symbol: symptr("ambient")}, {String: sptr("env")}, {String: sptr("DEV")}}}, + }, + }, + }, + Caveats: []*parser.Caveat{{Queries: []*parser.Rule{ + { + Head: &parser.Predicate{Name: sptr("caveat2"), IDs: []*parser.Atom{{String: sptr("Read")}}}, + Body: []*parser.Predicate{ + {Name: sptr("authorized"), IDs: []*parser.Atom{{String: sptr("demo.v1.Account")}, {String: sptr("Read")}}}, + }, + }, + }}}, + }, + }, + } + + require.Equal(t, len(expectedPolicies.Policies), len(doc.Policies)) + for i, expectedPolicy := range expectedPolicies.Policies { + require.Equal(t, doc.Policies[i], expectedPolicy) + } +} + +func TestParseDocumentPolicy(t *testing.T) { + testCases := []struct { + Desc string + Input string + ExpectedErr bool + ExpectedOut *DocumentPolicy + }{ + { + Desc: "single policy", + Input: `policy "foo" {}`, + ExpectedOut: &DocumentPolicy{Name: sptr("foo")}, + }, + { + Desc: "empty document returns an error", + Input: "", + ExpectedErr: true, + }, + } + + for _, testCase := range testCases { + t.Run(testCase.Desc, func(t *testing.T) { + out, err := ParseDocumentPolicy(strings.NewReader(testCase.Input)) + if testCase.ExpectedErr { + require.Error(t, err) + } else { + require.NoError(t, err) + require.Equal(t, testCase.ExpectedOut, out) + } + }) + } +} + +func sptr(s string) *string { + return &s +} + +func symptr(s string) *parser.Symbol { + sym := parser.Symbol(s) + return &sym +} + +func varptr(s string) *parser.Variable { + v := parser.Variable(s) + return &v +} + +func commentptr(s string) *parser.Comment { + c := parser.Comment(s) + return &c +} diff --git a/pkg/policy/printer.go b/pkg/policy/printer.go new file mode 100644 index 0000000..8676784 --- /dev/null +++ b/pkg/policy/printer.go @@ -0,0 +1,243 @@ +package policy + +import ( + "fmt" + "io" + "strings" + + "github.com/flynn/biscuit-go/parser" +) + +func Format(r io.Reader) (string, error) { + d, err := Parse(r) + if err != nil { + return "", err + } + + return Print(d) +} + +func Print(d *Document) (string, error) { + p := &printer{ + indent: 0, + out: &strings.Builder{}, + } + + for i, policy := range d.Policies { + p.printPolicy(policy) + if i != len(d.Policies)-1 { + p.write("\n") + } + } + + return p.out.String(), nil +} + +func PrintPolicy(policy *DocumentPolicy) string { + p := &printer{ + indent: 0, + out: &strings.Builder{}, + } + + p.printPolicy(policy) + + return p.out.String() +} + +type printer struct { + indent int + out *strings.Builder +} + +func (p *printer) write(format string, args ...interface{}) { + format = strings.ReplaceAll(format, "\n", "\n"+strings.Repeat(" ", p.indent)) + p.out.WriteString(fmt.Sprintf(format, args...)) +} + +func (p *printer) printPolicy(policy *DocumentPolicy) { + for _, c := range policy.Comments { + p.write("// %s\n", *c) + } + + p.write("policy %q {", *policy.Name) + + if len(policy.Rules) > 0 { + p.indent++ + p.write("\nrules {") + p.indent++ + for _, r := range policy.Rules { + p.write("\n") + p.printRule(r) + } + p.indent-- + p.write("\n") + p.indent-- + p.write("}\n") + } + + if len(policy.Caveats) > 0 { + p.indent++ + p.write("\ncaveats {") + for i, c := range policy.Caveats { + p.indent++ + p.printCaveat(c) + if i != len(policy.Caveats)-1 { + p.write(", ") + } + } + p.indent-- + p.write("}\n") + } + + p.write("}\n") +} + +func (p *printer) printRule(rule *parser.Rule) { + for _, c := range rule.Comments { + p.write("// %s\n", *c) + } + + p.write("*") + p.printPredicate(rule.Head) + p.indent++ + p.write("\n") + + for i, b := range rule.Body { + if i == 0 { + p.write("<- ") + } else { + p.write(" ") + } + p.printPredicate(b) + if i != len(rule.Body)-1 { + p.write(",\n") + } + } + + if len(rule.Constraints) > 0 { + p.write("\n") + } + + for i, c := range rule.Constraints { + if i == 0 { + p.write("@ ") + } else { + p.write(" ") + } + p.printConstraint(c) + if i != len(rule.Constraints)-1 { + p.write(",\n") + } + } + p.indent-- +} + +func (p *printer) printCaveat(c *parser.Caveat) { + p.write("[\n") + for j, r := range c.Queries { + if j != 0 { + p.write("||") + p.indent++ + p.write("\n") + } + p.printRule(r) + if j != len(c.Queries)-1 { + p.indent-- + p.write("\n") + } + } + p.indent-- + p.write("\n]") +} + +func (p *printer) printPredicate(pred *parser.Predicate) { + p.write("%s(%s)", *pred.Name, strings.Join(atomsToString(pred.IDs), ", ")) +} + +func (p *printer) printConstraint(c *parser.Constraint) { + switch { + case c.FunctionConstraint != nil: + p.printFunctionConstraint(c.FunctionConstraint) + case c.VariableConstraint != nil: + p.printVariableConstraint(c.VariableConstraint) + } +} + +func (p *printer) printFunctionConstraint(c *parser.FunctionConstraint) { + p.write("%s($%s, %q)", *c.Function, *c.Variable, *c.Argument) +} + +func (p *printer) printVariableConstraint(c *parser.VariableConstraint) { + var op, target string + switch { + case c.Bytes != nil: + op = *c.Bytes.Operation + target = c.Bytes.Target.String() + case c.Date != nil: + op = *c.Date.Operation + target = fmt.Sprintf("%q", *c.Date.Target) + case c.Int != nil: + op = *c.Int.Operation + target = fmt.Sprintf("%d", *c.Int.Target) + case c.Set != nil: + op = "in" + if c.Set.Not { + op = "not in" + } + + switch { + case c.Set.Bytes != nil: + members := make([]string, 0, len(c.Set.Bytes)) + for _, b := range c.Set.Bytes { + members = append(members, b.String()) + } + target = fmt.Sprintf("[%s]", strings.Join(members, ", ")) + case c.Set.Int != nil: + members := make([]string, 0, len(c.Set.Int)) + for _, i := range c.Set.Int { + members = append(members, fmt.Sprintf("%d", i)) + } + target = fmt.Sprintf("[%s]", strings.Join(members, ", ")) + case c.Set.String != nil: + members := make([]string, 0, len(c.Set.String)) + for _, s := range c.Set.String { + members = append(members, fmt.Sprintf("%q", s)) + } + target = fmt.Sprintf("[%s]", strings.Join(members, ", ")) + case c.Set.Symbols != nil: + members := make([]string, 0, len(c.Set.Symbols)) + for _, s := range c.Set.Symbols { + members = append(members, fmt.Sprintf("#%s", s)) + } + target = fmt.Sprintf("[%s]", strings.Join(members, ", ")) + } + case c.String != nil: + op = *c.String.Operation + target = fmt.Sprintf("%q", *c.String.Target) + } + p.write("$%s %s %s", *c.Variable, op, target) +} + +func atomsToString(atoms []*parser.Atom) []string { + out := make([]string, 0, len(atoms)) + for _, a := range atoms { + var atomStr string + switch { + case a.Bytes != nil: + atomStr = a.Bytes.String() + case a.Integer != nil: + atomStr = fmt.Sprintf("%d", *a.Integer) + case a.Set != nil: + atomStr = fmt.Sprintf("[%s]", strings.Join(atomsToString(a.Set), ", ")) + case a.String != nil: + atomStr = fmt.Sprintf("%q", *a.String) + case a.Symbol != nil: + atomStr = fmt.Sprintf("#%s", *a.Symbol) + case a.Variable != nil: + atomStr = fmt.Sprintf("$%s", *a.Variable) + } + + out = append(out, atomStr) + } + return out +} diff --git a/pkg/policy/printer_test.go b/pkg/policy/printer_test.go new file mode 100644 index 0000000..248b9cc --- /dev/null +++ b/pkg/policy/printer_test.go @@ -0,0 +1,28 @@ +package policy + +import ( + "io/ioutil" + "path/filepath" + "strings" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestPrintTemplateGolden(t *testing.T) { + files, err := filepath.Glob("./testdata/printer/*.golden") + require.NoError(t, err) + + for _, f := range files { + src, err := ioutil.ReadFile(f) + require.NoError(t, err) + + golden := string(src) + d, err := Parse(strings.NewReader(golden)) + require.NoError(t, err) + + out, err := Print(d) + require.NoError(t, err) + require.Equal(t, golden, out) + } +} diff --git a/pkg/policy/testdata/printer/comments.golden b/pkg/policy/testdata/printer/comments.golden new file mode 100644 index 0000000..555490b --- /dev/null +++ b/pkg/policy/testdata/printer/comments.golden @@ -0,0 +1,40 @@ +// some comment +policy "developer" { + rules { + // comment this specific rule + // on multiple lines + *allow_method("Status") + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, "Status") + *allow_method($0) + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, $0), + arg(#ambient, "env", "DEV") + @ $0 in ["Create", "Delete"] + } + + caveats {[ + // this caveat is required + *authorized($0) + <- allow_method(#authority, $0) + || + // this caveat is required too + *authorized($0) + <- allow_method(#authority, $0) + @ $0 == "method" + ], [ + *authorized($0) + <- allow_method(#authority, $0) + ]} +} + +// some comment +policy "admin" { + rules { + // comment this specific rule + // on multiple lines + *allow_method("Status") + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, "Status") + } +} diff --git a/pkg/policy/testdata/printer/empty_policy.golden b/pkg/policy/testdata/printer/empty_policy.golden new file mode 100644 index 0000000..c5767df --- /dev/null +++ b/pkg/policy/testdata/printer/empty_policy.golden @@ -0,0 +1 @@ +policy "test" {} diff --git a/pkg/policy/testdata/printer/multiple.golden b/pkg/policy/testdata/printer/multiple.golden new file mode 100644 index 0000000..6cfc860 --- /dev/null +++ b/pkg/policy/testdata/printer/multiple.golden @@ -0,0 +1,72 @@ +policy "admin" { + rules { + *allow_method($0) + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, $0) + @ $0 in ["Status"] + } + + caveats {[ + *authorized($0) + <- allow_method(#authority, $0) + || + *authorized($0) + <- method(#ambient, $0), + env(#ambient, $1) + @ $1 in ["DEV", "STG"] + ], [ + *authorized_server($2) + <- service(#ambient, $2) + @ prefix($2, "demo.api.v1") + ]} +} + +policy "auditor" { + caveats {[ + *allow_dev() + <- arg(#ambient, "env", "DEV") + ]} +} + +policy "developer" { + rules { + *allow_method("Status") + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, "Status") + *allow_method($0) + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, $0), + arg(#ambient, "env", "DEV") + @ $0 in ["Create", "Delete", "Read", "Status", "Update"] + *allow_method($0) + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, $0), + arg(#ambient, "env", $1) + @ $0 in ["Read"], + $1 in ["DEV"] + *allow_method("Read") + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, "Read"), + arg(#ambient, "env", "PRD"), + arg(#ambient, "entities.name", $3) + @ $3 in ["entity1", "entity2", "entity3"] + } + + caveats {[ + *authorized($0) + <- allow_method(#authority, $0) + ]} +} + +policy "guest" { + rules { + *allow_method("Status") + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, "Status") + } + + caveats {[ + *authorized($0) + <- allow_method(#authority, $0) + ]} +} diff --git a/pkg/policy/testdata/printer/single copy.golden b/pkg/policy/testdata/printer/single copy.golden new file mode 100644 index 0000000..cbedf97 --- /dev/null +++ b/pkg/policy/testdata/printer/single copy.golden @@ -0,0 +1,24 @@ +policy "developer" { + rules { + *allow_method("Status") + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, "Status") + *allow_method($0) + <- service(#ambient, "demo.api.v1.Demo"), + method(#ambient, $0), + arg(#ambient, "env", "DEV") + @ $0 in ["Create", "Delete"] + } + + caveats {[ + *authorized($0) + <- allow_method(#authority, $0) + || + *authorized($0) + <- allow_method(#authority, $0) + @ $0 == "method" + ], [ + *authorized($0) + <- allow_method(#authority, $0) + ]} +}