diff --git a/app/account.go b/app/account.go index ab3f42e7..977b543f 100644 --- a/app/account.go +++ b/app/account.go @@ -8,6 +8,7 @@ import ( "github.com/temporalio/tcld/protogen/api/account/v1" "github.com/temporalio/tcld/protogen/api/accountservice/v1" + "github.com/temporalio/tcld/protogen/api/request/v1" "github.com/urfave/cli/v2" "google.golang.org/grpc" ) @@ -78,7 +79,7 @@ func (c *AccountClient) listRegions() ([]regionInfo, error) { return regions, nil } -func (c *AccountClient) updateAccount(ctx *cli.Context, a *account.Account) error { +func (c *AccountClient) updateAccount(ctx *cli.Context, a *account.Account) (*request.RequestStatus, error) { resourceVersion := a.ResourceVersion if v := ctx.String(ResourceVersionFlagName); v != "" { resourceVersion = v @@ -90,10 +91,9 @@ func (c *AccountClient) updateAccount(ctx *cli.Context, a *account.Account) erro Spec: a.Spec, }) if err != nil { - return err + return nil, err } - - return PrintProto(res) + return res.RequestStatus, nil } func (c *AccountClient) parseExistingMetricsCerts(ctx *cli.Context) (account *account.Account, existing caCerts, err error) { @@ -113,8 +113,9 @@ func (c *AccountClient) parseExistingMetricsCerts(ctx *cli.Context) (account *ac return a, existingCerts, nil } -func NewAccountCommand(getAccountClientFn GetAccountClientFn) (CommandOut, error) { +func NewAccountCommand(getAccountClientFn GetAccountClientFn, getRequestClientFn GetRequestClientFn) (CommandOut, error) { var c *AccountClient + var r *RequestClient return CommandOut{ Command: &cli.Command{ Name: "account", @@ -123,6 +124,10 @@ func NewAccountCommand(getAccountClientFn GetAccountClientFn) (CommandOut, error Before: func(ctx *cli.Context) error { var err error c, err = getAccountClientFn(ctx) + if err != nil { + return err + } + r, err = getRequestClientFn(ctx) return err }, Subcommands: []*cli.Command{ @@ -158,6 +163,10 @@ func NewAccountCommand(getAccountClientFn GetAccountClientFn) (CommandOut, error { Name: "enable", Usage: "Enables the metrics endpoint. CA Certificates *must* be configured prior to enabling the endpoint", + Flags: []cli.Flag{ + RequestTimeoutFlag, + WaitForRequestFlag, + }, Action: func(ctx *cli.Context) error { a, err := c.getAccount() if err != nil { @@ -173,12 +182,20 @@ func NewAccountCommand(getAccountClientFn GetAccountClientFn) (CommandOut, error } a.Spec.Metrics.Enabled = true - return c.updateAccount(ctx, a) + status, err := c.updateAccount(ctx, a) + if err != nil { + return err + } + return r.HandleRequestStatus(ctx, "enable metrics", status) }, }, { Name: "disable", Usage: "Disables the metrics endpoint", + Flags: []cli.Flag{ + RequestTimeoutFlag, + WaitForRequestFlag, + }, Action: func(ctx *cli.Context) error { a, err := c.getAccount() if err != nil { @@ -190,7 +207,11 @@ func NewAccountCommand(getAccountClientFn GetAccountClientFn) (CommandOut, error } a.Spec.Metrics.Enabled = false - return c.updateAccount(ctx, a) + status, err := c.updateAccount(ctx, a) + if err != nil { + return err + } + return r.HandleRequestStatus(ctx, "disable metrics", status) }, }, { @@ -207,6 +228,8 @@ func NewAccountCommand(getAccountClientFn GetAccountClientFn) (CommandOut, error ResourceVersionFlag, CaCertificateFlag, CaCertificateFileFlag, + RequestTimeoutFlag, + WaitForRequestFlag, }, Action: func(ctx *cli.Context) error { newCerts, err := readAndParseCACerts(ctx) @@ -238,7 +261,11 @@ func NewAccountCommand(getAccountClientFn GetAccountClientFn) (CommandOut, error } a.Spec.Metrics.AcceptedClientCa = bundle - return c.updateAccount(ctx, a) + status, err := c.updateAccount(ctx, a) + if err != nil { + return err + } + return r.HandleRequestStatus(ctx, "add metrics ca certificate", status) }, }, { @@ -251,6 +278,8 @@ func NewAccountCommand(getAccountClientFn GetAccountClientFn) (CommandOut, error CaCertificateFlag, CaCertificateFileFlag, caCertificateFingerprintFlag, + RequestTimeoutFlag, + WaitForRequestFlag, }, Action: func(ctx *cli.Context) error { a, existingCerts, err := c.parseExistingMetricsCerts(ctx) @@ -296,7 +325,11 @@ func NewAccountCommand(getAccountClientFn GetAccountClientFn) (CommandOut, error if err != nil || !y { return err } - return c.updateAccount(ctx, a) + status, err := c.updateAccount(ctx, a) + if err != nil { + return err + } + return r.HandleRequestStatus(ctx, "remove metrics ca certificate", status) }, }, { @@ -308,6 +341,8 @@ func NewAccountCommand(getAccountClientFn GetAccountClientFn) (CommandOut, error ResourceVersionFlag, CaCertificateFlag, CaCertificateFileFlag, + RequestTimeoutFlag, + WaitForRequestFlag, }, Action: func(ctx *cli.Context) error { cert, err := ReadCACerts(ctx) @@ -331,7 +366,11 @@ func NewAccountCommand(getAccountClientFn GetAccountClientFn) (CommandOut, error a.Spec.Metrics = &account.MetricsSpec{} } a.Spec.Metrics.AcceptedClientCa = cert - return c.updateAccount(ctx, a) + status, err := c.updateAccount(ctx, a) + if err != nil { + return err + } + return r.HandleRequestStatus(ctx, "set metrics ca certificates", status) }, }, { diff --git a/app/account_test.go b/app/account_test.go index 91dec5b6..a195ee94 100644 --- a/app/account_test.go +++ b/app/account_test.go @@ -17,6 +17,7 @@ import ( "github.com/temporalio/tcld/protogen/api/common/v1" "github.com/temporalio/tcld/protogen/api/request/v1" accountservicemock "github.com/temporalio/tcld/protogen/apimock/accountservice/v1" + requestservicemock "github.com/temporalio/tcld/protogen/apimock/requestservice/v1" "github.com/urfave/cli/v2" ) @@ -26,20 +27,31 @@ func TestAccount(t *testing.T) { type AccountTestSuite struct { suite.Suite - cliApp *cli.App - mockCtrl *gomock.Controller - mockService *accountservicemock.MockAccountServiceClient + cliApp *cli.App + mockCtrl *gomock.Controller + mockService *accountservicemock.MockAccountServiceClient + mockReqService *requestservicemock.MockRequestServiceClient } func (s *AccountTestSuite) SetupTest() { s.mockCtrl = gomock.NewController(s.T()) s.mockService = accountservicemock.NewMockAccountServiceClient(s.mockCtrl) - out, err := NewAccountCommand(func(ctx *cli.Context) (*AccountClient, error) { + s.mockReqService = requestservicemock.NewMockRequestServiceClient(s.mockCtrl) + + getAccountClientFn := func(ctx *cli.Context) (*AccountClient, error) { return &AccountClient{ ctx: context.TODO(), client: s.mockService, }, nil - }) + } + getRequestClientFn := func(ctx *cli.Context) (*RequestClient, error) { + return &RequestClient{ + ctx: context.TODO(), + client: s.mockReqService, + }, nil + } + + out, err := NewAccountCommand(getAccountClientFn, getRequestClientFn) s.Require().NoError(err) AutoConfirmFlag.Value = true s.cliApp = &cli.App{ diff --git a/app/app.go b/app/app.go index 3a3c883f..4df8d1d5 100644 --- a/app/app.go +++ b/app/app.go @@ -5,6 +5,10 @@ import ( "go.uber.org/fx" ) +const ( + AppName = "tcld" +) + type AppParams struct { fx.In Commands []*cli.Command `group:"commands"` @@ -17,7 +21,7 @@ type CommandOut struct { func NewApp(params AppParams) (*cli.App, error) { app := &cli.App{ - Name: "tcld", + Name: AppName, Usage: "Temporal Cloud cli", Flags: []cli.Flag{ ServerFlag, diff --git a/app/namespace.go b/app/namespace.go index 4d5e506c..dddc5acd 100644 --- a/app/namespace.go +++ b/app/namespace.go @@ -11,6 +11,7 @@ import ( "strings" "github.com/temporalio/tcld/protogen/api/auth/v1" + "github.com/temporalio/tcld/protogen/api/request/v1" "github.com/temporalio/tcld/protogen/api/sink/v1" "go.uber.org/multierr" @@ -187,7 +188,8 @@ func (c *NamespaceClient) getExportSinkResourceVersion(ctx *cli.Context, namespa return resourceVersion, nil } -func (c *NamespaceClient) deleteNamespace(ctx *cli.Context, n *namespace.Namespace) error { + +func (c *NamespaceClient) deleteNamespace(ctx *cli.Context, n *namespace.Namespace) (*request.RequestStatus, error) { resourceVersion := n.ResourceVersion if v := ctx.String(ResourceVersionFlagName); v != "" { resourceVersion = v @@ -198,12 +200,12 @@ func (c *NamespaceClient) deleteNamespace(ctx *cli.Context, n *namespace.Namespa ResourceVersion: resourceVersion, }) if err != nil { - return err + return nil, err } - return PrintProto(res) + return res.RequestStatus, nil } -func (c *NamespaceClient) createNamespace(n *namespace.Namespace, p []*auth.UserNamespacePermissions) error { +func (c *NamespaceClient) createNamespace(n *namespace.Namespace, p []*auth.UserNamespacePermissions) (*request.RequestStatus, error) { res, err := c.client.CreateNamespace(c.ctx, &namespaceservice.CreateNamespaceRequest{ RequestId: n.RequestId, Namespace: n.Namespace, @@ -211,9 +213,9 @@ func (c *NamespaceClient) createNamespace(n *namespace.Namespace, p []*auth.User UserNamespacePermissions: p, }) if err != nil { - return err + return nil, err } - return PrintProto(res) + return res.RequestStatus, nil } func (c *NamespaceClient) listNamespaces() error { @@ -249,7 +251,7 @@ func (c *NamespaceClient) getNamespace(namespace string) (*namespace.Namespace, return res.Namespace, nil } -func (c *NamespaceClient) updateNamespace(ctx *cli.Context, n *namespace.Namespace) error { +func (c *NamespaceClient) updateNamespace(ctx *cli.Context, n *namespace.Namespace) (*request.RequestStatus, error) { resourceVersion := n.ResourceVersion if v := ctx.String(ResourceVersionFlagName); v != "" { resourceVersion = v @@ -262,13 +264,13 @@ func (c *NamespaceClient) updateNamespace(ctx *cli.Context, n *namespace.Namespa Spec: n.Spec, }) if err != nil { - return err + return nil, err } - return PrintProto(res) + return res.RequestStatus, nil } -func (c *NamespaceClient) renameSearchAttribute(ctx *cli.Context, n *namespace.Namespace, existingName string, newName string) error { +func (c *NamespaceClient) renameSearchAttribute(ctx *cli.Context, n *namespace.Namespace, existingName string, newName string) (*request.RequestStatus, error) { resourceVersion := n.ResourceVersion if v := ctx.String(ResourceVersionFlagName); v != "" { resourceVersion = v @@ -281,9 +283,9 @@ func (c *NamespaceClient) renameSearchAttribute(ctx *cli.Context, n *namespace.N NewCustomSearchAttributeName: newName, }) if err != nil { - return err + return nil, err } - return PrintProto(res) + return res.RequestStatus, nil } func (c *NamespaceClient) parseExistingCerts(ctx *cli.Context) (namespace *namespace.Namespace, existing caCerts, err error) { @@ -376,8 +378,12 @@ func ReadCertFilters(ctx *cli.Context) ([]byte, error) { return certFilterBytes, nil } -func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, error) { - var c *NamespaceClient +func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn, getRequestClientFn GetRequestClientFn) (CommandOut, error) { + var ( + nc *NamespaceClient + rc *RequestClient + ) + subCommands := []*cli.Command{ { Name: "create", @@ -386,6 +392,8 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, Flags: []cli.Flag{ RequestIDFlag, CaCertificateFlag, + WaitForRequestFlag, + RequestTimeoutFlag, &cli.StringFlag{ Name: NamespaceFlagName, Usage: "The namespace hosted on temporal cloud", @@ -462,7 +470,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, if err != nil { return err } - unp, err = c.toUserNamespacePermissions(unpMap) + unp, err = nc.toUserNamespacePermissions(unpMap) if err != nil { return err } @@ -499,8 +507,11 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, } } } - - return c.createNamespace(n, unp) + status, err := nc.createNamespace(n, unp) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "create namespace", status) }, }, { @@ -510,6 +521,8 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, Flags: []cli.Flag{ RequestIDFlag, ResourceVersionFlag, + WaitForRequestFlag, + RequestTimeoutFlag, &cli.StringFlag{ Name: NamespaceFlagName, Usage: "The namespace hosted on temporal cloud", @@ -531,11 +544,15 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, if !yes { return nil } - n, err := c.getNamespace(namespaceName) + n, err := nc.getNamespace(namespaceName) + if err != nil { + return err + } + status, err := nc.deleteNamespace(ctx, n) if err != nil { return err } - return c.deleteNamespace(ctx, n) + return rc.HandleRequestStatus(ctx, "delete namespace", status) }, }, { @@ -544,7 +561,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, Aliases: []string{"l"}, Flags: []cli.Flag{}, Action: func(ctx *cli.Context) error { - return c.listNamespaces() + return nc.listNamespaces() }, }, { @@ -555,7 +572,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, NamespaceFlag, }, Action: func(ctx *cli.Context) error { - n, err := c.getNamespace(ctx.String(NamespaceFlagName)) + n, err := nc.getNamespace(ctx.String(NamespaceFlagName)) if err != nil { return err } @@ -575,7 +592,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, NamespaceFlag, }, Action: func(ctx *cli.Context) error { - n, err := c.getNamespace(ctx.String(NamespaceFlagName)) + n, err := nc.getNamespace(ctx.String(NamespaceFlagName)) if err != nil { return err } @@ -594,6 +611,8 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, NamespaceFlag, RequestIDFlag, ResourceVersionFlag, + WaitForRequestFlag, + RequestTimeoutFlag, CaCertificateFlag, CaCertificateFileFlag, }, @@ -602,7 +621,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, if err != nil { return err } - n, existingCerts, err := c.parseExistingCerts(ctx) + n, existingCerts, err := nc.parseExistingCerts(ctx) if err != nil { return err } @@ -618,7 +637,11 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, return errors.New("nothing to change") } n.Spec.AcceptedClientCa = bundle - return c.updateNamespace(ctx, n) + status, err := nc.updateNamespace(ctx, n) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "add ca certificate", status) }, }, { @@ -629,12 +652,14 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, NamespaceFlag, RequestIDFlag, ResourceVersionFlag, + WaitForRequestFlag, + RequestTimeoutFlag, CaCertificateFlag, CaCertificateFileFlag, caCertificateFingerprintFlag, }, Action: func(ctx *cli.Context) error { - n, existingCerts, err := c.parseExistingCerts(ctx) + n, existingCerts, err := nc.parseExistingCerts(ctx) if err != nil { return err } @@ -669,7 +694,11 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, if err != nil || !y { return err } - return c.updateNamespace(ctx, n) + status, err := nc.updateNamespace(ctx, n) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "remove ca certificate", status) }, }, { @@ -680,6 +709,8 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, NamespaceFlag, RequestIDFlag, ResourceVersionFlag, + WaitForRequestFlag, + RequestTimeoutFlag, CaCertificateFlag, CaCertificateFileFlag, }, @@ -688,7 +719,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, if err != nil { return err } - n, err := c.getNamespace(ctx.String(NamespaceFlagName)) + n, err := nc.getNamespace(ctx.String(NamespaceFlagName)) if err != nil { return err } @@ -696,7 +727,11 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, return errors.New("nothing to change") } n.Spec.AcceptedClientCa = cert - return c.updateNamespace(ctx, n) + status, err := nc.updateNamespace(ctx, n) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "set ca certificates", status) }, }, }, @@ -714,6 +749,8 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, NamespaceFlag, RequestIDFlag, ResourceVersionFlag, + WaitForRequestFlag, + RequestTimeoutFlag, &cli.PathFlag{ Name: certificateFilterFileFlagName, Usage: `Path to a JSON file that defines the certificate filters that will be configured on the namespace. This will replace the existing filter configuration. Sample JSON: { "filters": [ { "commonName": "test1" } ] }`, @@ -752,7 +789,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, return err } - n, err := c.getNamespace(ctx.String(NamespaceFlagName)) + n, err := nc.getNamespace(ctx.String(NamespaceFlagName)) if err != nil { return err } @@ -770,13 +807,17 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, return err } - if confirmed { - n.Spec.CertificateFilters = replacementFilters.toSpec() - return c.updateNamespace(ctx, n) + if !confirmed { + fmt.Println("operation canceled") + return nil } - fmt.Println("operation canceled") - return nil + n.Spec.CertificateFilters = replacementFilters.toSpec() + status, err := nc.updateNamespace(ctx, n) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "import certificate filters", status) }, }, { @@ -794,7 +835,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, }, }, Action: func(ctx *cli.Context) error { - n, err := c.getNamespace(ctx.String(NamespaceFlagName)) + n, err := nc.getNamespace(ctx.String(NamespaceFlagName)) if err != nil { return err } @@ -829,7 +870,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, ResourceVersionFlag, }, Action: func(ctx *cli.Context) error { - n, err := c.getNamespace(ctx.String(NamespaceFlagName)) + n, err := nc.getNamespace(ctx.String(NamespaceFlagName)) if err != nil { return err } @@ -844,13 +885,17 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, return err } - if confirmed { - n.Spec.CertificateFilters = nil - return c.updateNamespace(ctx, n) + if !confirmed { + fmt.Println("operation canceled") + return nil } - fmt.Println("operation canceled") - return nil + n.Spec.CertificateFilters = nil + status, err := nc.updateNamespace(ctx, n) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "clear certificate filters", status) }, }, { @@ -861,6 +906,8 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, NamespaceFlag, RequestIDFlag, ResourceVersionFlag, + WaitForRequestFlag, + RequestTimeoutFlag, &cli.PathFlag{ Name: certificateFilterFileFlagName, Usage: `Path to a JSON file that defines the certificate filters that will be added to the namespace. Sample JSON: { "filters": [ { "commonName": "test1" } ] }`, @@ -913,18 +960,22 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, return err } - if confirmed { - n, err := c.getNamespace(ctx.String(NamespaceFlagName)) - if err != nil { - return err - } + if !confirmed { + fmt.Println("operation canceled") + return nil + } + n, err := nc.getNamespace(ctx.String(NamespaceFlagName)) + if err != nil { + return err + } - n.Spec.CertificateFilters = append(n.Spec.CertificateFilters, newFilters.toSpec()...) - return c.updateNamespace(ctx, n) + n.Spec.CertificateFilters = append(n.Spec.CertificateFilters, newFilters.toSpec()...) + status, err := nc.updateNamespace(ctx, n) + if err != nil { + return err } + return rc.HandleRequestStatus(ctx, "add certificate filters", status) - fmt.Println("operation canceled") - return nil }, }, }, @@ -953,7 +1004,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, }, }, Action: func(ctx *cli.Context) error { - n, err := c.getNamespace(ctx.String(NamespaceFlagName)) + n, err := nc.getNamespace(ctx.String(NamespaceFlagName)) if err != nil { return err } @@ -979,7 +1030,11 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, if confirmed { n.Spec.CodecSpec = replacement - return c.updateNamespace(ctx, n) + status, err := nc.updateNamespace(ctx, n) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "codec server update", status) } fmt.Println("operation canceled") @@ -998,6 +1053,8 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, Flags: []cli.Flag{ NamespaceFlag, ResourceVersionFlag, + WaitForRequestFlag, + RequestTimeoutFlag, RetentionDaysFlag, RequestIDFlag, }, @@ -1009,7 +1066,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, if retention < 0 { return fmt.Errorf("retention cannot be negative") } - n, err := c.getNamespace(ctx.String(NamespaceFlagName)) + n, err := nc.getNamespace(ctx.String(NamespaceFlagName)) if err != nil { return err } @@ -1017,7 +1074,11 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, return fmt.Errorf("retention for namespace is already set at %d days", ctx.Int(RetentionDaysFlagName)) } n.Spec.RetentionDays = int32(retention) - return c.updateNamespace(ctx, n) + status, err := nc.updateNamespace(ctx, n) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "set retention", status) }, }, { @@ -1028,7 +1089,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, NamespaceFlag, }, Action: func(ctx *cli.Context) error { - n, err := c.getNamespace(ctx.String(NamespaceFlagName)) + n, err := nc.getNamespace(ctx.String(NamespaceFlagName)) if err != nil { return err } @@ -1051,6 +1112,8 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, NamespaceFlag, RequestIDFlag, ResourceVersionFlag, + WaitForRequestFlag, + RequestTimeoutFlag, &cli.StringSliceFlag{ Name: "search-attribute", Usage: fmt.Sprintf("Flag can be used multiple times; value must be \"name=type\"; valid types are: %v", getSearchAttributeTypes()), @@ -1063,7 +1126,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, if err != nil { return err } - n, err := c.getNamespace(ctx.String(NamespaceFlagName)) + n, err := nc.getNamespace(ctx.String(NamespaceFlagName)) if err != nil { return err } @@ -1077,8 +1140,11 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, n.Spec.SearchAttributes[attrName] = attrType } } - - return c.updateNamespace(ctx, n) + status, err := nc.updateNamespace(ctx, n) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "add search attribute", status) }, }, { @@ -1089,6 +1155,8 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, NamespaceFlag, RequestIDFlag, ResourceVersionFlag, + WaitForRequestFlag, + RequestTimeoutFlag, &cli.StringFlag{ Name: "existing-name", Usage: "The name of an existing search attribute", @@ -1103,7 +1171,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, }, }, Action: func(ctx *cli.Context) error { - n, err := c.getNamespace( + n, err := nc.getNamespace( ctx.String(NamespaceFlagName), ) if err != nil { @@ -1121,7 +1189,12 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, if err != nil || !y { return err } - return c.renameSearchAttribute(ctx, n, existingName, newName) + status, err := nc.renameSearchAttribute(ctx, n, existingName, newName) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "rename search attribute", status) + }, }, }, @@ -1155,7 +1228,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, } namespace := ctx.String(NamespaceFlagName) - ns, err := c.getNamespace(namespace) + ns, err := nc.getNamespace(namespace) if err != nil { return fmt.Errorf("unable to get namespace: %v", err) } @@ -1177,12 +1250,11 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, RequestId: ctx.String(RequestIDFlagName), } - res, err := c.client.CreateExportSink(c.ctx, request) + res, err := nc.client.CreateExportSink(nc.ctx, request) if err != nil { return err } - - return PrintProto(res.RequestStatus) + return rc.HandleRequestStatus(ctx, "create export sink", res.RequestStatus) }, }, { @@ -1194,7 +1266,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, sinkNameFlag, }, Action: func(ctx *cli.Context) error { - sink, err := c.getExportSink(ctx, ctx.String(NamespaceFlagName), ctx.String(sinkNameFlag.Name)) + sink, err := nc.getExportSink(ctx, ctx.String(NamespaceFlagName), ctx.String(sinkNameFlag.Name)) if err != nil { return err @@ -1216,7 +1288,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, Action: func(ctx *cli.Context) error { namespaceName := ctx.String(NamespaceFlagName) sinkName := ctx.String(sinkNameFlag.Name) - resourceVersion, err := c.getExportSinkResourceVersion(ctx, namespaceName, sinkName) + resourceVersion, err := nc.getExportSinkResourceVersion(ctx, namespaceName, sinkName) if err != nil { return err } @@ -1228,12 +1300,12 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, RequestId: ctx.String(RequestIDFlagName), } - deleteResp, err := c.client.DeleteExportSink(c.ctx, deleteRequest) + deleteResp, err := nc.client.DeleteExportSink(nc.ctx, deleteRequest) if err != nil { return err } + return rc.HandleRequestStatus(ctx, "create export sink", deleteResp.RequestStatus) - return PrintProto(deleteResp.RequestStatus) }, }, { @@ -1252,7 +1324,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, PageToken: ctx.String(pageTokenFlag.Name), } - resp, err := c.client.ListExportSinks(c.ctx, request) + resp, err := nc.client.ListExportSinks(nc.ctx, request) if err != nil { return err } @@ -1277,18 +1349,18 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, Action: func(ctx *cli.Context) error { namespaceName := ctx.String(NamespaceFlagName) sinkName := ctx.String(sinkNameFlag.Name) - sink, err := c.getExportSink(ctx, namespaceName, sinkName) + sink, err := nc.getExportSink(ctx, namespaceName, sinkName) if err != nil { return err } - resourceVersion := c.selectExportSinkResourceVersion(ctx, sink) + resourceVersion := nc.selectExportSinkResourceVersion(ctx, sink) - isEnabledChange, err := c.isSinkEnabledChange(ctx, sink) + isEnabledChange, err := nc.isSinkEnabledChange(ctx, sink) if err != nil { return err } - if !isEnabledChange && !c.isAssumedRoleChange(ctx, sink) && !c.isKmsArnChange(ctx, sink) && !c.isS3BucketChange(ctx, sink) { + if !isEnabledChange && !nc.isAssumedRoleChange(ctx, sink) && !nc.isKmsArnChange(ctx, sink) && !nc.isS3BucketChange(ctx, sink) { fmt.Println("nothing to update") return nil } @@ -1297,7 +1369,7 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, sink.Spec.Enabled = !sink.Spec.Enabled } - if c.isAssumedRoleChange(ctx, sink) { + if nc.isAssumedRoleChange(ctx, sink) { awsAccountID, roleName, err := parseAssumedRole(ctx.String(sinkAssumedRoleFlagOptional.Name)) if err != nil { return err @@ -1306,11 +1378,11 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, sink.Spec.S3Sink.AwsAccountId = awsAccountID } - if c.isKmsArnChange(ctx, sink) { + if nc.isKmsArnChange(ctx, sink) { sink.Spec.S3Sink.KmsArn = ctx.String(kmsArnFlag.Name) } - if c.isS3BucketChange(ctx, sink) { + if nc.isS3BucketChange(ctx, sink) { sink.Spec.S3Sink.BucketName = ctx.String(s3BucketFlagOptional.Name) } @@ -1321,12 +1393,11 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, RequestId: ctx.String(RequestIDFlagName), } - resp, err := c.client.UpdateExportSink(c.ctx, request) + resp, err := nc.client.UpdateExportSink(nc.ctx, request) if err != nil { return err } - - return PrintProto(resp.RequestStatus) + return rc.HandleRequestStatus(ctx, "update export sink", resp.RequestStatus) }, }, }, @@ -1339,7 +1410,11 @@ func NewNamespaceCommand(getNamespaceClientFn GetNamespaceClientFn) (CommandOut, Usage: "Namespace operations", Before: func(ctx *cli.Context) error { var err error - c, err = getNamespaceClientFn(ctx) + nc, err = getNamespaceClientFn(ctx) + if err != nil { + return err + } + rc, err = getRequestClientFn(ctx) return err }, Subcommands: subCommands, diff --git a/app/namespace_test.go b/app/namespace_test.go index 2b817eb9..d4c0fdf2 100644 --- a/app/namespace_test.go +++ b/app/namespace_test.go @@ -19,6 +19,7 @@ import ( "github.com/temporalio/tcld/protogen/api/request/v1" authservicemock "github.com/temporalio/tcld/protogen/apimock/authservice/v1" namespaceservicemock "github.com/temporalio/tcld/protogen/apimock/namespaceservice/v1" + requestservicemock "github.com/temporalio/tcld/protogen/apimock/requestservice/v1" "github.com/urfave/cli/v2" ) @@ -32,6 +33,7 @@ type NamespaceTestSuite struct { mockCtrl *gomock.Controller mockService *namespaceservicemock.MockNamespaceServiceClient mockAuthService *authservicemock.MockAuthServiceClient + mockReqService *requestservicemock.MockRequestServiceClient } func (s *NamespaceTestSuite) SetupTest() { @@ -50,15 +52,22 @@ func (s *NamespaceTestSuite) SetupTest() { s.mockCtrl = gomock.NewController(s.T()) s.mockService = namespaceservicemock.NewMockNamespaceServiceClient(s.mockCtrl) + s.mockReqService = requestservicemock.NewMockRequestServiceClient(s.mockCtrl) s.mockAuthService = authservicemock.NewMockAuthServiceClient(s.mockCtrl) - out, err := NewNamespaceCommand(func(ctx *cli.Context) (*NamespaceClient, error) { + getNamespaceClientFn := func(ctx *cli.Context) (*NamespaceClient, error) { return &NamespaceClient{ ctx: context.TODO(), client: s.mockService, authClient: s.mockAuthService, }, nil - }) - + } + getRequestClientFn := func(ctx *cli.Context) (*RequestClient, error) { + return &RequestClient{ + ctx: context.TODO(), + client: s.mockReqService, + }, nil + } + out, err := NewNamespaceCommand(getNamespaceClientFn, getRequestClientFn) s.Require().NoError(err) AutoConfirmFlag.Value = true s.cliApp.Commands = []*cli.Command{out.Command} diff --git a/app/prompt.go b/app/prompt.go index acf5b861..98fdd3d7 100644 --- a/app/prompt.go +++ b/app/prompt.go @@ -10,12 +10,13 @@ import ( ) const ( - AutoConfirmFlagName = "auto_confirm" + AutoConfirmFlagName = "auto-confirm" ) var ( AutoConfirmFlag = &cli.BoolFlag{ Name: AutoConfirmFlagName, + Aliases: []string{"auto_confirm"}, Usage: "Automatically confirm all prompts", EnvVars: []string{"AUTO_CONFIRM"}, } diff --git a/app/request.go b/app/request.go index d28b2c60..f0d8410f 100644 --- a/app/request.go +++ b/app/request.go @@ -2,12 +2,39 @@ package app import ( "context" + "fmt" + "time" + "github.com/gosuri/uilive" + "github.com/temporalio/tcld/protogen/api/request/v1" "github.com/temporalio/tcld/protogen/api/requestservice/v1" "github.com/urfave/cli/v2" "google.golang.org/grpc" ) +const ( + WaitForRequestFlagName = "wait-for-request" + RequestTimeoutFlagName = "request-timeout" + + minCheckDurationTime = time.Second +) + +var ( + RequestTimeoutFlag = &cli.DurationFlag{ + Name: RequestTimeoutFlagName, + Usage: "Time to wait for requests to complete", + EnvVars: []string{"REQUEST_TIMEOUT"}, + Aliases: []string{"rt"}, + Value: time.Hour, + } + WaitForRequestFlag = &cli.BoolFlag{ + Name: WaitForRequestFlagName, + Usage: "Wait for request to complete", + Aliases: []string{"wait"}, + EnvVars: []string{"WAIT_FOR_REQUEST"}, + } +) + type RequestClient struct { client requestservice.RequestServiceClient ctx context.Context @@ -30,14 +57,14 @@ func GetRequestClient(ctx *cli.Context) (*RequestClient, error) { return NewRequestClient(ct, conn), nil } -func (c *RequestClient) getRequestStatus(requestID string) error { +func (c *RequestClient) getRequestStatus(ctx *cli.Context, requestID string) (*request.RequestStatus, error) { res, err := c.client.GetRequestStatus(c.ctx, &requestservice.GetRequestStatusRequest{ RequestId: requestID, }) if err != nil { - return err + return nil, err } - return PrintProto(res) + return res.RequestStatus, nil } func NewRequestCommand(getRequestClientFn GetRequestClientFn) (CommandOut, error) { @@ -45,7 +72,7 @@ func NewRequestCommand(getRequestClientFn GetRequestClientFn) (CommandOut, error var c *RequestClient return CommandOut{Command: &cli.Command{ Name: "request", - Usage: "Manage asynchronous requests", + Usage: "Manage requests", Aliases: []string{"r"}, Before: func(ctx *cli.Context) error { var err error @@ -59,15 +86,121 @@ func NewRequestCommand(getRequestClientFn GetRequestClientFn) (CommandOut, error Aliases: []string{"g"}, Flags: []cli.Flag{ &cli.StringFlag{ - Name: "request-id", - Usage: "The request-id of the asynchronous request", + Name: RequestIDFlagName, + Usage: "The request-id of the request", Aliases: []string{"r"}, Required: true, }, }, Action: func(ctx *cli.Context) error { - return c.getRequestStatus(ctx.String("request-id")) + res, err := c.getRequestStatus(ctx, ctx.String(RequestIDFlagName)) + if err != nil { + return err + } + return PrintProto(res) + }, + }, { + Name: "wait", + Usage: "wait for the request complete", + Aliases: []string{"w"}, + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: RequestIDFlagName, + Usage: "the request-id of the request", + Aliases: []string{"r"}, + Required: true, + }, + }, + Action: func(ctx *cli.Context) error { + return c.waitOnRequest(ctx, "", ctx.String(RequestIDFlagName)) }, }}, }}, nil } + +func (c *RequestClient) waitOnRequest(ctx *cli.Context, operation string, requestID string) error { + + ticker := time.NewTicker(time.Millisecond) + defer ticker.Stop() + timer := time.NewTimer(ctx.Duration(RequestTimeoutFlagName)) + defer timer.Stop() + + writer := uilive.New() + writer.Start() + defer writer.Stop() + + var status *request.RequestStatus + defer func() { + if status != nil { + if err := PrintProto(status); err != nil { + fmt.Fprintf(writer, "failed to print status: %s", err) + } + } + }() +loop: + for { + select { + case <-timer.C: + return fmt.Errorf("timed out waiting for request to complete, requestID=%s, timeout=%s", + requestID, + ctx.Duration(RequestTimeoutFlagName), + ) + case <-ticker.C: + var err error + status, err = c.getRequestStatus(ctx, requestID) + if err != nil { + return err + } + switch status.State { + case request.STATE_FULFILLED: + break loop + case request.STATE_FAILED: + fmt.Fprintf(writer, "operation failed \n") + return fmt.Errorf("request failed: %s", status.FailureReason) + case request.STATE_CANCELLED: + fmt.Fprintf(writer, "operation cancelled\n") + return fmt.Errorf("request was cancelled: %s", status.FailureReason) + } + if operation != "" { + fmt.Fprintf(writer, "waiting for %s operation (id='%s') to finish, current state: %s\n", + operation, requestID, request.State_name[int32(status.State)]) + } else { + fmt.Fprintf(writer, "waiting for request with id='%s' to finish, current state: %s\n", + requestID, request.State_name[int32(status.State)]) + } + if status.CheckDuration == nil || status.CheckDuration.Seconds == 0 { + ticker.Reset(minCheckDurationTime) // min check duration is 1 second + } else { + ticker.Reset(time.Second * time.Duration(status.CheckDuration.Seconds)) + } + } + } + if operation != "" { + fmt.Fprintf(writer, "%s operation completed successfully\n", operation) + } else { + fmt.Fprintf(writer, "request with id='%s' finished successfully\n", requestID) + } + return nil +} + +func (c *RequestClient) HandleRequestStatus( + ctx *cli.Context, + operation string, + status *request.RequestStatus, +) error { + if status == nil { + // status can be empty when the operation is cancelled + return nil + } + if ctx.Bool(WaitForRequestFlagName) { + return c.waitOnRequest(ctx, operation, status.RequestId) + } + if err := PrintProto(status); err != nil { + return err + } + fmt.Printf( + "started %s operation with request id='%s', to monitor its progress use command: `%s request get -r '%s'`", + operation, status.RequestId, AppName, status.RequestId, + ) + return nil +} diff --git a/app/request_test.go b/app/request_test.go index 5d4f565e..354033ce 100644 --- a/app/request_test.go +++ b/app/request_test.go @@ -5,8 +5,10 @@ import ( "errors" "testing" + "github.com/gogo/protobuf/types" "github.com/golang/mock/gomock" "github.com/stretchr/testify/suite" + "github.com/temporalio/tcld/protogen/api/request/v1" "github.com/temporalio/tcld/protogen/api/requestservice/v1" requestservicemock "github.com/temporalio/tcld/protogen/apimock/requestservice/v1" "github.com/urfave/cli/v2" @@ -36,6 +38,9 @@ func (s *RequestTestSuite) SetupTest() { s.cliApp = &cli.App{ Name: "test", Commands: []*cli.Command{out.Command}, + Flags: []cli.Flag{ + RequestTimeoutFlag, + }, } } @@ -48,7 +53,6 @@ func (s *RequestTestSuite) RunCmd(args ...string) error { } func (s *RequestTestSuite) TestGet() { - s.Error(s.RunCmd("request", "get")) s.mockService.EXPECT().GetRequestStatus(gomock.Any(), &requestservice.GetRequestStatusRequest{ @@ -58,6 +62,68 @@ func (s *RequestTestSuite) TestGet() { s.mockService.EXPECT().GetRequestStatus(gomock.Any(), &requestservice.GetRequestStatusRequest{ RequestId: "req1", - }).Return(&requestservice.GetRequestStatusResponse{}, nil).Times(1) + }).Return(&requestservice.GetRequestStatusResponse{ + RequestStatus: &request.RequestStatus{ + State: request.STATE_PENDING, + CheckDuration: &types.Duration{Seconds: 1}, + }, + }, nil).Times(1) s.NoError(s.RunCmd("request", "get", "--request-id", "req1")) } + +func (s *RequestTestSuite) TestWait() { + + s.Error(s.RunCmd("request", "wait")) + + // an error is returned by the api + s.mockService.EXPECT().GetRequestStatus(gomock.Any(), &requestservice.GetRequestStatusRequest{ + RequestId: "req1", + }).Return(nil, errors.New("some error")).Times(1) + s.Error(s.RunCmd("request", "wait", "--request-id", "req1")) + + // call repetatively till a fulfilled is received + s.mockService.EXPECT().GetRequestStatus(gomock.Any(), &requestservice.GetRequestStatusRequest{ + RequestId: "req1", + }).Return(&requestservice.GetRequestStatusResponse{ + RequestStatus: &request.RequestStatus{ + State: request.STATE_PENDING, + CheckDuration: &types.Duration{Seconds: 1}, + }, + }, nil).Times(2) + s.mockService.EXPECT().GetRequestStatus(gomock.Any(), &requestservice.GetRequestStatusRequest{ + RequestId: "req1", + }).Return(&requestservice.GetRequestStatusResponse{ + RequestStatus: &request.RequestStatus{ + State: request.STATE_IN_PROGRESS, + CheckDuration: &types.Duration{Seconds: 1}, + }, + }, nil).Times(2) + s.mockService.EXPECT().GetRequestStatus(gomock.Any(), &requestservice.GetRequestStatusRequest{ + RequestId: "req1", + }).Return(&requestservice.GetRequestStatusResponse{ + RequestStatus: &request.RequestStatus{ + State: request.STATE_FULFILLED, + CheckDuration: &types.Duration{Seconds: 1}, + }, + }, nil).Times(1) + s.NoError(s.RunCmd("request", "wait", "--request-id", "req1")) + + // call repetatively till a state changes to failed is received + s.mockService.EXPECT().GetRequestStatus(gomock.Any(), &requestservice.GetRequestStatusRequest{ + RequestId: "req1", + }).Return(&requestservice.GetRequestStatusResponse{ + RequestStatus: &request.RequestStatus{ + State: request.STATE_PENDING, + CheckDuration: &types.Duration{Seconds: 1}, + }, + }, nil).Times(2) + s.mockService.EXPECT().GetRequestStatus(gomock.Any(), &requestservice.GetRequestStatusRequest{ + RequestId: "req1", + }).Return(&requestservice.GetRequestStatusResponse{ + RequestStatus: &request.RequestStatus{ + State: request.STATE_FAILED, + CheckDuration: &types.Duration{Seconds: 1}, + }, + }, nil).Times(1) + s.Error(s.RunCmd("request", "wait", "--request-id", "req1")) +} diff --git a/app/user.go b/app/user.go index 90474c05..a625c4a5 100644 --- a/app/user.go +++ b/app/user.go @@ -5,10 +5,12 @@ import ( "errors" "fmt" + "strings" + "github.com/temporalio/tcld/protogen/api/auth/v1" "github.com/temporalio/tcld/protogen/api/authservice/v1" + "github.com/temporalio/tcld/protogen/api/request/v1" "github.com/urfave/cli/v2" - "strings" ) const ( @@ -137,9 +139,9 @@ func (c *UserClient) inviteUsers( emails []string, namespacePermissions []string, accountRole string, -) error { +) (*request.RequestStatus, error) { if len(accountRole) == 0 { - return errors.New("account role required for inviting new users") + return nil, errors.New("account role required for inviting new users") } // the role ids to invite the users for @@ -148,7 +150,7 @@ func (c *UserClient) inviteUsers( // first get the required account role role, err := getAccountRole(c.ctx, c.client, accountRole) if err != nil { - return err + return nil, err } roleIDs = append(roleIDs, role.GetId()) @@ -156,11 +158,11 @@ func (c *UserClient) inviteUsers( if len(namespacePermissions) > 0 { npm, err := toNamespacePermissionsMap(namespacePermissions) if err != nil { - return err + return nil, err } nsRoles, err := getNamespaceRolesBatch(c.ctx, c.client, npm) if err != nil { - return err + return nil, err } for _, nsRole := range nsRoles { roleIDs = append(roleIDs, nsRole.GetId()) @@ -180,19 +182,19 @@ func (c *UserClient) inviteUsers( } resp, err := c.client.InviteUsers(c.ctx, req) if err != nil { - return fmt.Errorf("unable to invite users: %w", err) + return nil, fmt.Errorf("unable to invite users: %w", err) } - return PrintProto(resp.GetRequestStatus()) + return resp.GetRequestStatus(), nil } func (c *UserClient) resendInvitation( ctx *cli.Context, userID string, userEmail string, -) error { +) (*request.RequestStatus, error) { user, err := c.getUser(userID, userEmail) if err != nil { - return err + return nil, err } req := &authservice.ResendUserInviteRequest{ UserId: user.Id, @@ -200,19 +202,19 @@ func (c *UserClient) resendInvitation( } resp, err := c.client.ResendUserInvite(c.ctx, req) if err != nil { - return fmt.Errorf("unable to resend invitation for user: %w", err) + return nil, fmt.Errorf("unable to resend invitation for user: %w", err) } - return PrintProto(resp.GetRequestStatus()) + return resp.GetRequestStatus(), nil } func (c *UserClient) deleteUser( ctx *cli.Context, userID string, userEmail string, -) error { +) (*request.RequestStatus, error) { u, err := c.getUser(userID, userEmail) if err != nil { - return err + return nil, err } req := &authservice.DeleteUserRequest{ UserId: u.Id, @@ -224,12 +226,12 @@ func (c *UserClient) deleteUser( } resp, err := c.client.DeleteUser(c.ctx, req) if err != nil { - return fmt.Errorf("unable to delete user: %w", err) + return nil, fmt.Errorf("unable to delete user: %w", err) } - return PrintProto(resp.GetRequestStatus()) + return resp.GetRequestStatus(), nil } -func (c *UserClient) performUpdate(ctx *cli.Context, user *auth.User) error { +func (c *UserClient) performUpdate(ctx *cli.Context, user *auth.User) (*request.RequestStatus, error) { req := &authservice.UpdateUserRequest{ UserId: user.Id, Spec: user.Spec, @@ -241,9 +243,9 @@ func (c *UserClient) performUpdate(ctx *cli.Context, user *auth.User) error { } resp, err := c.client.UpdateUser(c.ctx, req) if err != nil { - return fmt.Errorf("unable to update user: %w", err) + return nil, fmt.Errorf("unable to update user: %w", err) } - return PrintProto(resp.GetRequestStatus()) + return resp.GetRequestStatus(), nil } func (c *UserClient) setAccountRole( @@ -251,25 +253,25 @@ func (c *UserClient) setAccountRole( userID string, userEmail string, accountRole string, -) error { +) (*request.RequestStatus, error) { user, userRoles, err := c.getUserAndRoles(userID, userEmail) if err != nil { - return err + return nil, err } var newRoleIDs []string accountRoleToSet, err := getAccountRole(c.ctx, c.client, accountRole) if err != nil { - return err + return nil, err } if accountRoleToSet.Spec.AccountRole.ActionGroup == auth.ACCOUNT_ACTION_GROUP_ADMIN { // set the user account admin role y, err := ConfirmPrompt(ctx, "Setting admin role on user. All existing namespace permissions will be replaced, please confirm") if err != nil { - return err + return nil, err } if !y { fmt.Println("operation canceled") - return nil + return nil, nil } // ensure we overwrite all existing roles since the global admin role has permissions to everything newRoleIDs = []string{accountRoleToSet.Id} @@ -293,10 +295,10 @@ func (c *UserClient) setNamespacePermissions( userID string, userEmail string, namespacePermissions []string, -) error { +) (*request.RequestStatus, error) { user, userRoles, err := c.getUserAndRoles(userID, userEmail) if err != nil { - return err + return nil, err } var newRoleIDs []string for _, r := range userRoles { @@ -310,21 +312,21 @@ func (c *UserClient) setNamespacePermissions( if len(namespacePermissions) == 0 { y, err := ConfirmPrompt(ctx, "Looks like you are about to remove all namespace permissions, please confirm") if err != nil { - return err + return nil, err } if !y { fmt.Println("operation canceled") - return nil + return nil, nil } } else { // collect the namespace roles and update user npm, err := toNamespacePermissionsMap(namespacePermissions) if err != nil { - return err + return nil, err } nsRoles, err := getNamespaceRolesBatch(c.ctx, c.client, npm) if err != nil { - return err + return nil, err } for _, nsRole := range nsRoles { newRoleIDs = append(newRoleIDs, nsRole.Id) @@ -393,8 +395,11 @@ func toUserWrapper(u *auth.User, roles []*auth.Role) *auth.UserWrapper { return p } -func NewUserCommand(getUserClientFn GetUserClientFn) (CommandOut, error) { - var c *UserClient +func NewUserCommand(getUserClientFn GetUserClientFn, getRequestClientFn GetRequestClientFn) (CommandOut, error) { + var ( + uc *UserClient + rc *RequestClient + ) return CommandOut{ Command: &cli.Command{ Name: "user", @@ -402,7 +407,11 @@ func NewUserCommand(getUserClientFn GetUserClientFn) (CommandOut, error) { Usage: "User management operations", Before: func(ctx *cli.Context) error { var err error - c, err = getUserClientFn(ctx) + uc, err = getUserClientFn(ctx) + if err != nil { + return err + } + rc, err = getRequestClientFn(ctx) return err }, Subcommands: []*cli.Command{ @@ -429,7 +438,7 @@ func NewUserCommand(getUserClientFn GetUserClientFn) (CommandOut, error) { }, }, Action: func(ctx *cli.Context) error { - return c.listUsers(ctx.String(NamespaceFlagName), ctx.String(pageTokenFlagName), ctx.Int(pageSizeFlagName)) + return uc.listUsers(ctx.String(NamespaceFlagName), ctx.String(pageTokenFlagName), ctx.Int(pageSizeFlagName)) }, }, { @@ -441,7 +450,7 @@ func NewUserCommand(getUserClientFn GetUserClientFn) (CommandOut, error) { userEmailFlag, }, Action: func(ctx *cli.Context) error { - u, roles, err := c.getUserAndRoles(ctx.String(userIDFlagName), ctx.String(userEmailFlagName)) + u, roles, err := uc.getUserAndRoles(ctx.String(userIDFlagName), ctx.String(userEmailFlagName)) if err != nil { return err } @@ -471,14 +480,20 @@ func NewUserCommand(getUserClientFn GetUserClientFn) (CommandOut, error) { Aliases: []string{"p"}, }, RequestIDFlag, + WaitForRequestFlag, + RequestTimeoutFlag, }, Action: func(ctx *cli.Context) error { - return c.inviteUsers( + status, err := uc.inviteUsers( ctx, ctx.StringSlice(userEmailFlagName), ctx.StringSlice(namespacePermissionFlagName), ctx.String(accountRoleFlagName), ) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "invite users", status) }, }, { @@ -489,13 +504,19 @@ func NewUserCommand(getUserClientFn GetUserClientFn) (CommandOut, error) { userIDFlag, userEmailFlag, RequestIDFlag, + WaitForRequestFlag, + RequestTimeoutFlag, }, Action: func(ctx *cli.Context) error { - return c.resendInvitation( + status, err := uc.resendInvitation( ctx, ctx.String(userIDFlagName), ctx.String(userEmailFlagName), ) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "resend user invite", status) }, }, { @@ -507,13 +528,19 @@ func NewUserCommand(getUserClientFn GetUserClientFn) (CommandOut, error) { userEmailFlag, ResourceVersionFlag, RequestIDFlag, + WaitForRequestFlag, + RequestTimeoutFlag, }, Action: func(ctx *cli.Context) error { - return c.deleteUser( + status, err := uc.deleteUser( ctx, ctx.String(userIDFlagName), ctx.String(userEmailFlagName), ) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "delete user", status) }, }, { @@ -525,6 +552,8 @@ func NewUserCommand(getUserClientFn GetUserClientFn) (CommandOut, error) { userEmailFlag, RequestIDFlag, ResourceVersionFlag, + WaitForRequestFlag, + RequestTimeoutFlag, &cli.StringFlag{ Name: accountRoleFlagName, Usage: fmt.Sprintf("The account role to set on the user; valid types are: %v", accountActionGroups), @@ -533,12 +562,16 @@ func NewUserCommand(getUserClientFn GetUserClientFn) (CommandOut, error) { }, }, Action: func(ctx *cli.Context) error { - return c.setAccountRole( + status, err := uc.setAccountRole( ctx, ctx.String(userIDFlagName), ctx.String(userEmailFlagName), ctx.String(accountRoleFlagName), ) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "set account role", status) }, }, { @@ -557,12 +590,16 @@ func NewUserCommand(getUserClientFn GetUserClientFn) (CommandOut, error) { }, }, Action: func(ctx *cli.Context) error { - return c.setNamespacePermissions( + status, err := uc.setNamespacePermissions( ctx, ctx.String(userIDFlagName), ctx.String(userEmailFlagName), ctx.StringSlice(namespacePermissionFlagName), ) + if err != nil { + return err + } + return rc.HandleRequestStatus(ctx, "set namespace permissions", status) }, }, }, diff --git a/app/user_test.go b/app/user_test.go index 579fa8d2..a04ac9d4 100644 --- a/app/user_test.go +++ b/app/user_test.go @@ -3,15 +3,17 @@ package app import ( "context" "errors" + "reflect" + "testing" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/suite" "github.com/temporalio/tcld/protogen/api/auth/v1" "github.com/temporalio/tcld/protogen/api/authservice/v1" "github.com/temporalio/tcld/protogen/api/request/v1" authservicemock "github.com/temporalio/tcld/protogen/apimock/authservice/v1" + requestservicemock "github.com/temporalio/tcld/protogen/apimock/requestservice/v1" "github.com/urfave/cli/v2" - "reflect" - "testing" ) func TestUser(t *testing.T) { @@ -23,17 +25,27 @@ type UserTestSuite struct { cliApp *cli.App mockCtrl *gomock.Controller mockAuthService *authservicemock.MockAuthServiceClient + mockReqService *requestservicemock.MockRequestServiceClient } func (s *UserTestSuite) SetupTest() { s.mockCtrl = gomock.NewController(s.T()) s.mockAuthService = authservicemock.NewMockAuthServiceClient(s.mockCtrl) - out, err := NewUserCommand(func(ctx *cli.Context) (*UserClient, error) { + s.mockReqService = requestservicemock.NewMockRequestServiceClient(s.mockCtrl) + + getUserClient := func(ctx *cli.Context) (*UserClient, error) { return &UserClient{ ctx: context.TODO(), client: s.mockAuthService, }, nil - }) + } + getRequestClient := func(ctx *cli.Context) (*RequestClient, error) { + return &RequestClient{ + ctx: context.TODO(), + client: s.mockReqService, + }, nil + } + out, err := NewUserCommand(getUserClient, getRequestClient) s.Require().NoError(err) AutoConfirmFlag.Value = true s.cliApp = &cli.App{ diff --git a/go.mod b/go.mod index cc157710..9ee96b8f 100644 --- a/go.mod +++ b/go.mod @@ -8,6 +8,7 @@ require ( github.com/golang/mock v1.6.0 github.com/golang/protobuf v1.5.2 github.com/google/uuid v1.3.0 + github.com/gosuri/uilive v0.0.4 github.com/kylelemons/godebug v1.1.0 github.com/stretchr/testify v1.8.2 github.com/urfave/cli/v2 v2.25.3 @@ -23,6 +24,7 @@ require ( github.com/go-playground/locales v0.14.1 // indirect github.com/go-playground/universal-translator v0.18.1 // indirect github.com/leodido/go-urn v1.2.3 // indirect + github.com/mattn/go-isatty v0.0.19 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect github.com/russross/blackfriday/v2 v2.1.0 // indirect github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673 // indirect diff --git a/go.sum b/go.sum index a7756a2f..d05eefa5 100644 --- a/go.sum +++ b/go.sum @@ -22,12 +22,16 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38= github.com/google/uuid v1.3.0 h1:t6JiXgmwXMjEs8VusXIJk2BXHsn+wx8BZdTaoZ5fu7I= github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/gosuri/uilive v0.0.4 h1:hUEBpQDj8D8jXgtCdBu7sWsy5sbW/5GhuO8KBwJ2jyY= +github.com/gosuri/uilive v0.0.4/go.mod h1:V/epo5LjjlDE5RJUcqx8dbw+zc93y5Ya3yg8tfZ74VI= github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= github.com/leodido/go-urn v1.2.3 h1:6BE2vPT0lqoz3fmOesHZiaiFh7889ssCo2GMvLCfiuA= github.com/leodido/go-urn v1.2.3/go.mod h1:7ZrI8mTSeBSHl/UaRyKQW1qZeMgak41ANeCNaVckg+4= +github.com/mattn/go-isatty v0.0.19 h1:JITubQf0MOLdlGRuRq+jtsDlekdYPia9ZFsB8h/APPA= +github.com/mattn/go-isatty v0.0.19/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D7dTCTo3Y= github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=