diff --git a/manager/api/endpoint.go b/manager/api/endpoint.go index 3f68a577..42da116d 100644 --- a/manager/api/endpoint.go +++ b/manager/api/endpoint.go @@ -20,7 +20,7 @@ func listPropletsEndpoint(svc manager.Service) endpoint.Endpoint { return listpropletResponse{}, errors.Join(apiutil.ErrValidation, err) } - proplets, err := svc.ListProplets(ctx, req.offset, req.limit) + proplets, err := svc.ListProplets(ctx, req.offset, req.limit, req.status) if err != nil { return listpropletResponse{}, err } @@ -169,7 +169,7 @@ func listJobsEndpoint(svc manager.Service) endpoint.Endpoint { return listJobResponse{}, errors.Join(apiutil.ErrValidation, err) } - jobs, err := svc.ListJobs(ctx, req.offset, req.limit) + jobs, err := svc.ListJobs(ctx, req.offset, req.limit, req.status) if err != nil { return listJobResponse{}, err } @@ -231,6 +231,9 @@ func listTasksEndpoint(svc manager.Service) endpoint.Endpoint { if err := req.validate(); err != nil { return listTaskResponse{}, errors.Join(apiutil.ErrValidation, err) } + if req.status != "" { + return listTaskResponse{}, errors.Join(apiutil.ErrValidation, pkgerrors.ErrInvalidValue) + } tasks, err := svc.ListTasks(ctx, req.offset, req.limit) if err != nil { diff --git a/manager/api/requests.go b/manager/api/requests.go index 3be90bbe..9d3f7f2e 100644 --- a/manager/api/requests.go +++ b/manager/api/requests.go @@ -94,6 +94,7 @@ func (e *entityReq) validate() error { type listEntityReq struct { offset, limit uint64 + status string } func (e *listEntityReq) validate() error { diff --git a/manager/api/transport.go b/manager/api/transport.go index faf0e260..1e49cc0c 100644 --- a/manager/api/transport.go +++ b/manager/api/transport.go @@ -309,6 +309,7 @@ func decodeListEntityReq(_ context.Context, r *http.Request) (any, error) { return listEntityReq{ offset: o, limit: l, + status: r.URL.Query().Get("status"), }, nil } diff --git a/manager/manager.go b/manager/manager.go index e38fc7af..6fbd6639 100644 --- a/manager/manager.go +++ b/manager/manager.go @@ -9,7 +9,7 @@ import ( type Service interface { GetProplet(ctx context.Context, propletID string) (proplet.Proplet, error) - ListProplets(ctx context.Context, offset, limit uint64) (proplet.PropletPage, error) + ListProplets(ctx context.Context, offset, limit uint64, status string) (proplet.PropletPage, error) SelectProplet(ctx context.Context, task task.Task) (proplet.Proplet, error) DeleteProplet(ctx context.Context, propletID string) error @@ -18,7 +18,7 @@ type Service interface { CreateJob(ctx context.Context, name string, tasks []task.Task, executionMode string) (string, []task.Task, error) GetTask(ctx context.Context, taskID string) (task.Task, error) GetJob(ctx context.Context, jobID string) ([]task.Task, error) - ListJobs(ctx context.Context, offset, limit uint64) (JobPage, error) + ListJobs(ctx context.Context, offset, limit uint64, status string) (JobPage, error) StartJob(ctx context.Context, jobID string) error StopJob(ctx context.Context, jobID string) error ListTasks(ctx context.Context, offset, limit uint64) (task.TaskPage, error) diff --git a/manager/middleware/logging.go b/manager/middleware/logging.go index 9ef334a3..0cc5ccde 100644 --- a/manager/middleware/logging.go +++ b/manager/middleware/logging.go @@ -42,12 +42,13 @@ func (lm *loggingMiddleware) GetProplet(ctx context.Context, id string) (resp pr return lm.svc.GetProplet(ctx, id) } -func (lm *loggingMiddleware) ListProplets(ctx context.Context, offset, limit uint64) (resp proplet.PropletPage, err error) { +func (lm *loggingMiddleware) ListProplets(ctx context.Context, offset, limit uint64, status string) (resp proplet.PropletPage, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), slog.Uint64("offset", offset), slog.Uint64("limit", limit), + slog.String("status", status), } if err != nil { args = append(args, slog.Any("error", err)) @@ -58,7 +59,7 @@ func (lm *loggingMiddleware) ListProplets(ctx context.Context, offset, limit uin lm.logger.Info("List proplets completed successfully", args...) }(time.Now()) - return lm.svc.ListProplets(ctx, offset, limit) + return lm.svc.ListProplets(ctx, offset, limit, status) } func (lm *loggingMiddleware) SelectProplet(ctx context.Context, t task.Task) (w proplet.Proplet, err error) { @@ -348,12 +349,13 @@ func (lm *loggingMiddleware) GetJob(ctx context.Context, jobID string) (resp []t return lm.svc.GetJob(ctx, jobID) } -func (lm *loggingMiddleware) ListJobs(ctx context.Context, offset, limit uint64) (resp manager.JobPage, err error) { +func (lm *loggingMiddleware) ListJobs(ctx context.Context, offset, limit uint64, status string) (resp manager.JobPage, err error) { defer func(begin time.Time) { args := []any{ slog.String("duration", time.Since(begin).String()), slog.Uint64("offset", offset), slog.Uint64("limit", limit), + slog.String("status", status), } if err != nil { args = append(args, slog.Any("error", err)) @@ -365,7 +367,7 @@ func (lm *loggingMiddleware) ListJobs(ctx context.Context, offset, limit uint64) lm.logger.Info("List jobs completed successfully", args...) }(time.Now()) - return lm.svc.ListJobs(ctx, offset, limit) + return lm.svc.ListJobs(ctx, offset, limit, status) } func (lm *loggingMiddleware) StartJob(ctx context.Context, jobID string) (err error) { diff --git a/manager/middleware/metrics.go b/manager/middleware/metrics.go index c4ae24db..6c8972e8 100644 --- a/manager/middleware/metrics.go +++ b/manager/middleware/metrics.go @@ -35,13 +35,13 @@ func (mm *metricsMiddleware) GetProplet(ctx context.Context, id string) (proplet return mm.svc.GetProplet(ctx, id) } -func (mm *metricsMiddleware) ListProplets(ctx context.Context, offset, limit uint64) (proplet.PropletPage, error) { +func (mm *metricsMiddleware) ListProplets(ctx context.Context, offset, limit uint64, status string) (proplet.PropletPage, error) { defer func(begin time.Time) { mm.counter.With("method", "list-proplets").Add(1) mm.latency.With("method", "list-proplets").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.ListProplets(ctx, offset, limit) + return mm.svc.ListProplets(ctx, offset, limit, status) } func (mm *metricsMiddleware) SelectProplet(ctx context.Context, t task.Task) (proplet.Proplet, error) { @@ -170,13 +170,13 @@ func (mm *metricsMiddleware) GetJob(ctx context.Context, jobID string) ([]task.T return mm.svc.GetJob(ctx, jobID) } -func (mm *metricsMiddleware) ListJobs(ctx context.Context, offset, limit uint64) (manager.JobPage, error) { +func (mm *metricsMiddleware) ListJobs(ctx context.Context, offset, limit uint64, status string) (manager.JobPage, error) { defer func(begin time.Time) { mm.counter.With("method", "list-jobs").Add(1) mm.latency.With("method", "list-jobs").Observe(time.Since(begin).Seconds()) }(time.Now()) - return mm.svc.ListJobs(ctx, offset, limit) + return mm.svc.ListJobs(ctx, offset, limit, status) } func (mm *metricsMiddleware) StartJob(ctx context.Context, jobID string) error { diff --git a/manager/middleware/tracing.go b/manager/middleware/tracing.go index 3000a4f0..daa754b8 100644 --- a/manager/middleware/tracing.go +++ b/manager/middleware/tracing.go @@ -30,14 +30,15 @@ func (tm *tracing) GetProplet(ctx context.Context, id string) (resp proplet.Prop return tm.svc.GetProplet(ctx, id) } -func (tm *tracing) ListProplets(ctx context.Context, offset, limit uint64) (resp proplet.PropletPage, err error) { +func (tm *tracing) ListProplets(ctx context.Context, offset, limit uint64, status string) (resp proplet.PropletPage, err error) { ctx, span := tm.tracer.Start(ctx, "list-proplets", trace.WithAttributes( attribute.Int64("offset", int64(offset)), attribute.Int64("limit", int64(limit)), + attribute.String("status", status), )) defer span.End() - return tm.svc.ListProplets(ctx, offset, limit) + return tm.svc.ListProplets(ctx, offset, limit, status) } func (tm *tracing) SelectProplet(ctx context.Context, t task.Task) (resp proplet.Proplet, err error) { @@ -178,14 +179,15 @@ func (tm *tracing) GetJob(ctx context.Context, jobID string) (resp []task.Task, return tm.svc.GetJob(ctx, jobID) } -func (tm *tracing) ListJobs(ctx context.Context, offset, limit uint64) (resp manager.JobPage, err error) { +func (tm *tracing) ListJobs(ctx context.Context, offset, limit uint64, status string) (resp manager.JobPage, err error) { ctx, span := tm.tracer.Start(ctx, "list-jobs", trace.WithAttributes( attribute.Int64("offset", int64(offset)), attribute.Int64("limit", int64(limit)), + attribute.String("status", status), )) defer span.End() - return tm.svc.ListJobs(ctx, offset, limit) + return tm.svc.ListJobs(ctx, offset, limit, status) } func (tm *tracing) StartJob(ctx context.Context, jobID string) (err error) { diff --git a/manager/mocks/service.go b/manager/mocks/service.go index 77f39d3f..2eba3883 100644 --- a/manager/mocks/service.go +++ b/manager/mocks/service.go @@ -1062,8 +1062,8 @@ func (_c *MockService_GetTaskResults_Call) RunAndReturn(run func(ctx context.Con } // ListJobs provides a mock function for the type MockService -func (_mock *MockService) ListJobs(ctx context.Context, offset uint64, limit uint64) (manager.JobPage, error) { - ret := _mock.Called(ctx, offset, limit) +func (_mock *MockService) ListJobs(ctx context.Context, offset uint64, limit uint64, status string) (manager.JobPage, error) { + ret := _mock.Called(ctx, offset, limit, status) if len(ret) == 0 { panic("no return value specified for ListJobs") @@ -1071,16 +1071,16 @@ func (_mock *MockService) ListJobs(ctx context.Context, offset uint64, limit uin var r0 manager.JobPage var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, uint64, uint64) (manager.JobPage, error)); ok { - return returnFunc(ctx, offset, limit) + if returnFunc, ok := ret.Get(0).(func(context.Context, uint64, uint64, string) (manager.JobPage, error)); ok { + return returnFunc(ctx, offset, limit, status) } - if returnFunc, ok := ret.Get(0).(func(context.Context, uint64, uint64) manager.JobPage); ok { - r0 = returnFunc(ctx, offset, limit) + if returnFunc, ok := ret.Get(0).(func(context.Context, uint64, uint64, string) manager.JobPage); ok { + r0 = returnFunc(ctx, offset, limit, status) } else { r0 = ret.Get(0).(manager.JobPage) } - if returnFunc, ok := ret.Get(1).(func(context.Context, uint64, uint64) error); ok { - r1 = returnFunc(ctx, offset, limit) + if returnFunc, ok := ret.Get(1).(func(context.Context, uint64, uint64, string) error); ok { + r1 = returnFunc(ctx, offset, limit, status) } else { r1 = ret.Error(1) } @@ -1096,11 +1096,12 @@ type MockService_ListJobs_Call struct { // - ctx context.Context // - offset uint64 // - limit uint64 -func (_e *MockService_Expecter) ListJobs(ctx interface{}, offset interface{}, limit interface{}) *MockService_ListJobs_Call { - return &MockService_ListJobs_Call{Call: _e.mock.On("ListJobs", ctx, offset, limit)} +// - status string +func (_e *MockService_Expecter) ListJobs(ctx interface{}, offset interface{}, limit interface{}, status interface{}) *MockService_ListJobs_Call { + return &MockService_ListJobs_Call{Call: _e.mock.On("ListJobs", ctx, offset, limit, status)} } -func (_c *MockService_ListJobs_Call) Run(run func(ctx context.Context, offset uint64, limit uint64)) *MockService_ListJobs_Call { +func (_c *MockService_ListJobs_Call) Run(run func(ctx context.Context, offset uint64, limit uint64, status string)) *MockService_ListJobs_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -1114,10 +1115,15 @@ func (_c *MockService_ListJobs_Call) Run(run func(ctx context.Context, offset ui if args[2] != nil { arg2 = args[2].(uint64) } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -1128,14 +1134,14 @@ func (_c *MockService_ListJobs_Call) Return(jobPage manager.JobPage, err error) return _c } -func (_c *MockService_ListJobs_Call) RunAndReturn(run func(ctx context.Context, offset uint64, limit uint64) (manager.JobPage, error)) *MockService_ListJobs_Call { +func (_c *MockService_ListJobs_Call) RunAndReturn(run func(ctx context.Context, offset uint64, limit uint64, status string) (manager.JobPage, error)) *MockService_ListJobs_Call { _c.Call.Return(run) return _c } // ListProplets provides a mock function for the type MockService -func (_mock *MockService) ListProplets(ctx context.Context, offset uint64, limit uint64) (proplet.PropletPage, error) { - ret := _mock.Called(ctx, offset, limit) +func (_mock *MockService) ListProplets(ctx context.Context, offset uint64, limit uint64, status string) (proplet.PropletPage, error) { + ret := _mock.Called(ctx, offset, limit, status) if len(ret) == 0 { panic("no return value specified for ListProplets") @@ -1143,16 +1149,16 @@ func (_mock *MockService) ListProplets(ctx context.Context, offset uint64, limit var r0 proplet.PropletPage var r1 error - if returnFunc, ok := ret.Get(0).(func(context.Context, uint64, uint64) (proplet.PropletPage, error)); ok { - return returnFunc(ctx, offset, limit) + if returnFunc, ok := ret.Get(0).(func(context.Context, uint64, uint64, string) (proplet.PropletPage, error)); ok { + return returnFunc(ctx, offset, limit, status) } - if returnFunc, ok := ret.Get(0).(func(context.Context, uint64, uint64) proplet.PropletPage); ok { - r0 = returnFunc(ctx, offset, limit) + if returnFunc, ok := ret.Get(0).(func(context.Context, uint64, uint64, string) proplet.PropletPage); ok { + r0 = returnFunc(ctx, offset, limit, status) } else { r0 = ret.Get(0).(proplet.PropletPage) } - if returnFunc, ok := ret.Get(1).(func(context.Context, uint64, uint64) error); ok { - r1 = returnFunc(ctx, offset, limit) + if returnFunc, ok := ret.Get(1).(func(context.Context, uint64, uint64, string) error); ok { + r1 = returnFunc(ctx, offset, limit, status) } else { r1 = ret.Error(1) } @@ -1168,11 +1174,12 @@ type MockService_ListProplets_Call struct { // - ctx context.Context // - offset uint64 // - limit uint64 -func (_e *MockService_Expecter) ListProplets(ctx interface{}, offset interface{}, limit interface{}) *MockService_ListProplets_Call { - return &MockService_ListProplets_Call{Call: _e.mock.On("ListProplets", ctx, offset, limit)} +// - status string +func (_e *MockService_Expecter) ListProplets(ctx interface{}, offset interface{}, limit interface{}, status interface{}) *MockService_ListProplets_Call { + return &MockService_ListProplets_Call{Call: _e.mock.On("ListProplets", ctx, offset, limit, status)} } -func (_c *MockService_ListProplets_Call) Run(run func(ctx context.Context, offset uint64, limit uint64)) *MockService_ListProplets_Call { +func (_c *MockService_ListProplets_Call) Run(run func(ctx context.Context, offset uint64, limit uint64, status string)) *MockService_ListProplets_Call { _c.Call.Run(func(args mock.Arguments) { var arg0 context.Context if args[0] != nil { @@ -1186,10 +1193,15 @@ func (_c *MockService_ListProplets_Call) Run(run func(ctx context.Context, offse if args[2] != nil { arg2 = args[2].(uint64) } + var arg3 string + if args[3] != nil { + arg3 = args[3].(string) + } run( arg0, arg1, arg2, + arg3, ) }) return _c @@ -1200,7 +1212,7 @@ func (_c *MockService_ListProplets_Call) Return(propletPage proplet.PropletPage, return _c } -func (_c *MockService_ListProplets_Call) RunAndReturn(run func(ctx context.Context, offset uint64, limit uint64) (proplet.PropletPage, error)) *MockService_ListProplets_Call { +func (_c *MockService_ListProplets_Call) RunAndReturn(run func(ctx context.Context, offset uint64, limit uint64, status string) (proplet.PropletPage, error)) *MockService_ListProplets_Call { _c.Call.Return(run) return _c } diff --git a/manager/service.go b/manager/service.go index c06d40f1..893d3100 100644 --- a/manager/service.go +++ b/manager/service.go @@ -39,6 +39,14 @@ const ( ExecutionModeConfigurable = "configurable" EnvJobExecutionMode = "JOB_EXECUTION_MODE" shutdownTaskStopWait = 200 * time.Millisecond + + PropletStatusActive = "active" + PropletStatusInactive = "inactive" + + JobStatusPending = "pending" + JobStatusRunning = "running" + JobStatusCompleted = "completed" + JobStatusFailed = "failed" ) var ( @@ -112,8 +120,31 @@ func (svc *service) GetProplet(ctx context.Context, propletID string) (proplet.P return w, nil } -func (svc *service) ListProplets(ctx context.Context, offset, limit uint64) (proplet.PropletPage, error) { - proplets, total, err := svc.propletRepo.List(ctx, offset, limit) +func (svc *service) ListProplets(ctx context.Context, offset, limit uint64, status string) (proplet.PropletPage, error) { + if status != "" && status != PropletStatusActive && status != PropletStatusInactive { + return proplet.PropletPage{}, fmt.Errorf("%w: proplet status must be %q, %q, or empty, got %q", pkgerrors.ErrInvalidValue, PropletStatusActive, PropletStatusInactive, status) + } + + if status == "" { + proplets, total, err := svc.propletRepo.List(ctx, offset, limit) + if err != nil { + return proplet.PropletPage{}, err + } + for i := range proplets { + proplets[i].SetAlive() + } + + return proplet.PropletPage{ + Offset: offset, + Limit: limit, + Total: total, + Proplets: proplets, + }, nil + } + + alive := status == PropletStatusActive + since := time.Now().Add(-proplet.AliveTimeout) + proplets, total, err := svc.propletRepo.ListByAlive(ctx, offset, limit, alive, since) if err != nil { return proplet.PropletPage{}, err } @@ -130,7 +161,7 @@ func (svc *service) ListProplets(ctx context.Context, offset, limit uint64) (pro } func (svc *service) SelectProplet(ctx context.Context, t task.Task) (proplet.Proplet, error) { - proplets, err := svc.ListProplets(ctx, defOffset, defLimit) + proplets, err := svc.ListProplets(ctx, defOffset, defLimit, "") if err != nil { return proplet.Proplet{}, err } @@ -338,7 +369,11 @@ func (svc *service) GetJob(ctx context.Context, jobID string) ([]task.Task, erro return svc.getJobTasks(ctx, jobID) } -func (svc *service) ListJobs(ctx context.Context, offset, limit uint64) (JobPage, error) { +func (svc *service) ListJobs(ctx context.Context, offset, limit uint64, status string) (JobPage, error) { + if status != "" && status != JobStatusPending && status != JobStatusRunning && status != JobStatusCompleted && status != JobStatusFailed { + return JobPage{}, fmt.Errorf("%w: job status must be %q, %q, %q, %q, or empty, got %q", pkgerrors.ErrInvalidValue, JobStatusPending, JobStatusRunning, JobStatusCompleted, JobStatusFailed, status) + } + jobs := make([]JobSummary, 0) seen := make(map[string]struct{}) @@ -393,6 +428,27 @@ func (svc *service) ListJobs(ctx context.Context, offset, limit uint64) (JobPage } }) + if status != "" { + // ComputeJobState only ever returns Pending, Running, Completed, or Failed. + // Skipped and Interrupted task states are collapsed: Interrupted → Failed, + // Scheduled → Running, all-Skipped → Completed. No job summary can carry + // any other state, so the map below is exhaustive. + statusStateMap := map[string]task.State{ + JobStatusPending: task.Pending, + JobStatusRunning: task.Running, + JobStatusCompleted: task.Completed, + JobStatusFailed: task.Failed, + } + targetState := statusStateMap[status] + filtered := make([]JobSummary, 0, len(jobs)) + for i := range jobs { + if jobs[i].State == targetState { + filtered = append(filtered, jobs[i]) + } + } + jobs = filtered + } + total := uint64(len(jobs)) if offset >= total { return JobPage{ diff --git a/manager/service_job_test.go b/manager/service_job_test.go index 9b01ab55..50424eae 100644 --- a/manager/service_job_test.go +++ b/manager/service_job_test.go @@ -66,7 +66,7 @@ func TestListJobs(t *testing.T) { }, "sequential") require.NoError(t, err) - page, err := svc.ListJobs(context.Background(), 0, 100) + page, err := svc.ListJobs(context.Background(), 0, 100, "") require.NoError(t, err) assert.Equal(t, uint64(2), page.Total) assert.Len(t, page.Jobs, 2) @@ -88,7 +88,7 @@ func TestListJobsIncludesLegacyTaskOnlyJob(t *testing.T) { }) require.NoError(t, err) - page, err := svc.ListJobs(context.Background(), 0, 100) + page, err := svc.ListJobs(context.Background(), 0, 100, "") require.NoError(t, err) assert.Equal(t, uint64(2), page.Total) @@ -198,12 +198,12 @@ func TestListJobsPagination(t *testing.T) { require.NoError(t, err) } - page, err := svc.ListJobs(context.Background(), 0, 3) + page, err := svc.ListJobs(context.Background(), 0, 3, "") require.NoError(t, err) assert.Equal(t, uint64(5), page.Total) assert.Len(t, page.Jobs, 3) - page2, err := svc.ListJobs(context.Background(), 3, 3) + page2, err := svc.ListJobs(context.Background(), 3, 3, "") require.NoError(t, err) assert.Len(t, page2.Jobs, 2) } @@ -265,3 +265,60 @@ func TestComputeJobState(t *testing.T) { }) } } + +func TestListJobsFilterByStatus(t *testing.T) { + t.Parallel() + svc := newService(t) + _, _, err := svc.CreateJob(context.Background(), "pending-job", []task.Task{ + {Name: "p1", State: task.Pending}, + }, "parallel") + require.NoError(t, err) + + _, _, err = svc.CreateJob(context.Background(), "running-job", []task.Task{ + {Name: "r1", State: task.Running}, + }, "parallel") + require.NoError(t, err) + + _, _, err = svc.CreateJob(context.Background(), "completed-job", []task.Task{ + {Name: "c1", State: task.Completed}, + }, "parallel") + require.NoError(t, err) + + _, _, err = svc.CreateJob(context.Background(), "failed-job", []task.Task{ + {Name: "f1", State: task.Failed}, + }, "parallel") + require.NoError(t, err) + + all, err := svc.ListJobs(context.Background(), 0, 100, "") + require.NoError(t, err) + assert.Equal(t, uint64(4), all.Total) + + page, err := svc.ListJobs(context.Background(), 0, 100, "pending") + require.NoError(t, err) + assert.Equal(t, uint64(1), page.Total) + assert.Equal(t, task.Pending, page.Jobs[0].State) + + page, err = svc.ListJobs(context.Background(), 0, 100, "running") + require.NoError(t, err) + assert.Equal(t, uint64(1), page.Total) + assert.Equal(t, task.Running, page.Jobs[0].State) + + page, err = svc.ListJobs(context.Background(), 0, 100, "completed") + require.NoError(t, err) + assert.Equal(t, uint64(1), page.Total) + assert.Equal(t, task.Completed, page.Jobs[0].State) + + page, err = svc.ListJobs(context.Background(), 0, 100, "failed") + require.NoError(t, err) + assert.Equal(t, uint64(1), page.Total) + assert.Equal(t, task.Failed, page.Jobs[0].State) +} + +func TestListJobsInvalidStatusFilter(t *testing.T) { + t.Parallel() + svc := newService(t) + + _, err := svc.ListJobs(context.Background(), 0, 100, "invalid") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid value provided") +} diff --git a/manager/service_proplet_test.go b/manager/service_proplet_test.go new file mode 100644 index 00000000..411b545a --- /dev/null +++ b/manager/service_proplet_test.go @@ -0,0 +1,132 @@ +package manager_test + +import ( + "context" + "log/slog" + "testing" + "time" + + "github.com/absmach/propeller/manager" + mqttmocks "github.com/absmach/propeller/pkg/mqtt/mocks" + "github.com/absmach/propeller/pkg/proplet" + "github.com/absmach/propeller/pkg/scheduler" + "github.com/absmach/propeller/pkg/storage" + "github.com/absmach/propeller/pkg/task" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" +) + +func newServiceWithRepos(t *testing.T) (manager.Service, *storage.Repositories) { + t.Helper() + repos, err := storage.NewRepositories(storage.Config{Type: "memory"}) + require.NoError(t, err) + sched := scheduler.NewRoundRobin() + pubsub := mqttmocks.NewMockPubSub(t) + pubsub.On("Publish", mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + pubsub.On("Subscribe", mock.Anything, mock.Anything, mock.Anything).Return(nil).Maybe() + pubsub.On("Unsubscribe", mock.Anything, mock.Anything).Return(nil).Maybe() + pubsub.On("Disconnect", mock.Anything).Return(nil).Maybe() + logger := slog.Default() + svc, _ := manager.NewService(repos, sched, pubsub, "test-domain", "test-channel", logger) + + return svc, repos +} + +func TestListPropletsFilterByStatus(t *testing.T) { + t.Parallel() + svc, repos := newServiceWithRepos(t) + ctx := context.Background() + + activeProplet := proplet.Proplet{ + ID: uuid.NewString(), + Name: "active-proplet", + AliveHistory: []time.Time{time.Now()}, + } + inactiveProplet := proplet.Proplet{ + ID: uuid.NewString(), + Name: "inactive-proplet", + AliveHistory: []time.Time{time.Now().Add(-1 * time.Hour)}, + } + require.NoError(t, repos.Proplets.Create(ctx, activeProplet)) + require.NoError(t, repos.Proplets.Create(ctx, inactiveProplet)) + + all, err := svc.ListProplets(ctx, 0, 100, "") + require.NoError(t, err) + assert.Equal(t, uint64(2), all.Total) + + active, err := svc.ListProplets(ctx, 0, 100, manager.PropletStatusActive) + require.NoError(t, err) + assert.Equal(t, uint64(1), active.Total) + assert.True(t, active.Proplets[0].Alive) + + inactive, err := svc.ListProplets(ctx, 0, 100, manager.PropletStatusInactive) + require.NoError(t, err) + assert.Equal(t, uint64(1), inactive.Total) + assert.False(t, inactive.Proplets[0].Alive) +} + +func TestListPropletsInvalidStatus(t *testing.T) { + t.Parallel() + svc, _ := newServiceWithRepos(t) + + _, err := svc.ListProplets(context.Background(), 0, 100, "unknown") + require.Error(t, err) + assert.Contains(t, err.Error(), "invalid value") +} + +func TestListPropletsFilterPagination(t *testing.T) { + t.Parallel() + svc, repos := newServiceWithRepos(t) + ctx := context.Background() + + for range 5 { + p := proplet.Proplet{ + ID: uuid.NewString(), + Name: uuid.NewString(), + AliveHistory: []time.Time{time.Now()}, + } + require.NoError(t, repos.Proplets.Create(ctx, p)) + } + + page, err := svc.ListProplets(ctx, 0, 3, manager.PropletStatusActive) + require.NoError(t, err) + assert.Equal(t, uint64(5), page.Total) + assert.Len(t, page.Proplets, 3) + + page2, err := svc.ListProplets(ctx, 3, 3, manager.PropletStatusActive) + require.NoError(t, err) + assert.Equal(t, uint64(5), page2.Total) + assert.Len(t, page2.Proplets, 2) +} + +func TestListJobsInterruptedMapsToFailed(t *testing.T) { + t.Parallel() + svc := newService(t) + ctx := context.Background() + + _, _, err := svc.CreateJob(ctx, "interrupted-job", []task.Task{ + {Name: "t1", State: task.Interrupted}, + }, "parallel") + require.NoError(t, err) + + page, err := svc.ListJobs(ctx, 0, 100, manager.JobStatusFailed) + require.NoError(t, err) + assert.Equal(t, uint64(1), page.Total, "interrupted job must appear under 'failed' filter") +} + +func TestListJobsScheduledMapsToRunning(t *testing.T) { + t.Parallel() + svc := newService(t) + ctx := context.Background() + + _, _, err := svc.CreateJob(ctx, "scheduled-job", []task.Task{ + {Name: "t1", State: task.Scheduled}, + }, "parallel") + require.NoError(t, err) + + page, err := svc.ListJobs(ctx, 0, 100, manager.JobStatusRunning) + require.NoError(t, err) + assert.Equal(t, uint64(1), page.Total, "scheduled job must appear under 'running' filter") +} diff --git a/pkg/api/api.go b/pkg/api/api.go index 1cfaee7a..f9bbac90 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -8,6 +8,7 @@ import ( pkgerrors "github.com/absmach/propeller/pkg/errors" "github.com/absmach/supermq" + apiutil "github.com/absmach/supermq/api/http/util" ) const ( @@ -40,7 +41,9 @@ func EncodeResponse(_ context.Context, w http.ResponseWriter, response any) erro func EncodeError(_ context.Context, err error, w http.ResponseWriter) { w.Header().Set("Content-Type", ContentType) switch { - case errors.Is(err, pkgerrors.ErrEmptyKey): + case errors.Is(err, apiutil.ErrValidation), + errors.Is(err, pkgerrors.ErrEmptyKey), + errors.Is(err, pkgerrors.ErrInvalidValue): w.WriteHeader(http.StatusBadRequest) case errors.Is(err, pkgerrors.ErrNotFound): w.WriteHeader(http.StatusNotFound) diff --git a/pkg/proplet/proplet.go b/pkg/proplet/proplet.go index 1f76ed6e..9b7f2777 100644 --- a/pkg/proplet/proplet.go +++ b/pkg/proplet/proplet.go @@ -2,7 +2,7 @@ package proplet import "time" -const aliveTimeout = 10 * time.Second +const AliveTimeout = 10 * time.Second type PropletMetadata struct { Description string `json:"description,omitempty"` @@ -30,7 +30,7 @@ type Proplet struct { func (p *Proplet) SetAlive() { if len(p.AliveHistory) > 0 { lastAlive := p.AliveHistory[len(p.AliveHistory)-1] - if time.Since(lastAlive) <= aliveTimeout { + if time.Since(lastAlive) <= AliveTimeout { p.Alive = true return diff --git a/pkg/sdk/proplet.go b/pkg/sdk/proplet.go index 19c68361..436628c7 100644 --- a/pkg/sdk/proplet.go +++ b/pkg/sdk/proplet.go @@ -1,13 +1,65 @@ package sdk -import "net/http" +import ( + "encoding/json" + "fmt" + "net/http" + "net/url" + "strings" + "time" +) const propletsEndpoint = "/proplets" +type Proplet struct { + ID string `json:"id"` + Name string `json:"name"` + TaskCount uint64 `json:"task_count"` + Alive bool `json:"alive"` + CreatedAt time.Time `json:"created_at"` +} + +type PropletPage struct { + Offset uint64 `json:"offset"` + Limit uint64 `json:"limit"` + Total uint64 `json:"total"` + Proplets []Proplet `json:"proplets"` +} + +func (sdk *propSDK) ListProplets(offset, limit uint64, status string) (PropletPage, error) { + params := make([]string, 0) + if offset > 0 { + params = append(params, fmt.Sprintf("offset=%d", offset)) + } + if limit > 0 { + params = append(params, fmt.Sprintf("limit=%d", limit)) + } + if status != "" { + params = append(params, "status="+url.QueryEscape(status)) + } + query := "" + if len(params) > 0 { + query = "?" + strings.Join(params, "&") + } + reqURL := sdk.managerURL + propletsEndpoint + query + + body, err := sdk.processRequest(http.MethodGet, reqURL, nil, http.StatusOK) + if err != nil { + return PropletPage{}, err + } + + var pp PropletPage + if err := json.Unmarshal(body, &pp); err != nil { + return PropletPage{}, err + } + + return pp, nil +} + func (sdk *propSDK) DeleteProplet(id string) error { - url := sdk.managerURL + propletsEndpoint + "/" + id + reqURL := sdk.managerURL + propletsEndpoint + "/" + id - if _, err := sdk.processRequest(http.MethodDelete, url, nil, http.StatusNoContent); err != nil { + if _, err := sdk.processRequest(http.MethodDelete, reqURL, nil, http.StatusNoContent); err != nil { return err } diff --git a/pkg/sdk/sdk.go b/pkg/sdk/sdk.go index c3cb6b28..31557d4b 100644 --- a/pkg/sdk/sdk.go +++ b/pkg/sdk/sdk.go @@ -88,11 +88,13 @@ type SDK interface { // job, _ := sdk.GetJob("b1d10738-c5d7-4ff1-8f4d-b9328ce6f040") GetJob(jobID string) (JobResponse, error) - // ListJobs lists jobs. + // ListJobs lists jobs with optional status filter. + // Status can be "pending", "running", "completed", "failed", or "" (all). // // example: - // jobPage, _ := sdk.ListJobs(0, 10) - ListJobs(offset uint64, limit uint64) (JobPage, error) + // jobPage, _ := sdk.ListJobs(0, 10, "") + // jobPage, _ := sdk.ListJobs(0, 10, "running") + ListJobs(offset uint64, limit uint64, status string) (JobPage, error) // StartJob starts a job. // @@ -106,6 +108,14 @@ type SDK interface { // _ := sdk.StopJob("b1d10738-c5d7-4ff1-8f4d-b9328ce6f040") StopJob(jobID string) error + // ListProplets lists proplets with optional status filter. + // Status can be "active", "inactive", or "" (all). + // + // example: + // page, _ := sdk.ListProplets(0, 10, "") + // page, _ := sdk.ListProplets(0, 10, "active") + ListProplets(offset uint64, limit uint64, status string) (PropletPage, error) + // DeleteProplet deletes a proplet by id. // // example: diff --git a/pkg/sdk/task.go b/pkg/sdk/task.go index 47ab7321..2cdb0e77 100644 --- a/pkg/sdk/task.go +++ b/pkg/sdk/task.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "net/http" + "net/url" "strings" "time" ) @@ -40,9 +41,9 @@ func (sdk *propSDK) CreateTask(task Task) (Task, error) { return Task{}, err } - url := sdk.managerURL + tasksEndpoint + reqURL := sdk.managerURL + tasksEndpoint - body, err := sdk.processRequest(http.MethodPost, url, data, http.StatusCreated) + body, err := sdk.processRequest(http.MethodPost, reqURL, data, http.StatusCreated) if err != nil { return Task{}, err } @@ -56,9 +57,9 @@ func (sdk *propSDK) CreateTask(task Task) (Task, error) { } func (sdk *propSDK) GetTask(id string) (Task, error) { - url := sdk.managerURL + tasksEndpoint + "/" + id + reqURL := sdk.managerURL + tasksEndpoint + "/" + id - body, err := sdk.processRequest(http.MethodGet, url, nil, http.StatusOK) + body, err := sdk.processRequest(http.MethodGet, reqURL, nil, http.StatusOK) if err != nil { return Task{}, err } @@ -83,9 +84,9 @@ func (sdk *propSDK) ListTasks(offset, limit uint64) (TaskPage, error) { if len(queries) > 0 { query = "?" + strings.Join(queries, "&") } - url := sdk.managerURL + tasksEndpoint + query + reqURL := sdk.managerURL + tasksEndpoint + query - body, err := sdk.processRequest(http.MethodGet, url, nil, http.StatusOK) + body, err := sdk.processRequest(http.MethodGet, reqURL, nil, http.StatusOK) if err != nil { return TaskPage{}, err } @@ -103,9 +104,9 @@ func (sdk *propSDK) UpdateTask(task Task) (Task, error) { if err != nil { return Task{}, err } - url := sdk.managerURL + tasksEndpoint + "/" + task.ID + reqURL := sdk.managerURL + tasksEndpoint + "/" + task.ID - body, err := sdk.processRequest(http.MethodPut, url, data, http.StatusOK) + body, err := sdk.processRequest(http.MethodPut, reqURL, data, http.StatusOK) if err != nil { return Task{}, err } @@ -119,9 +120,9 @@ func (sdk *propSDK) UpdateTask(task Task) (Task, error) { } func (sdk *propSDK) DeleteTask(id string) error { - url := sdk.managerURL + tasksEndpoint + "/" + id + reqURL := sdk.managerURL + tasksEndpoint + "/" + id - if _, err := sdk.processRequest(http.MethodDelete, url, nil, http.StatusNoContent); err != nil { + if _, err := sdk.processRequest(http.MethodDelete, reqURL, nil, http.StatusNoContent); err != nil { return err } @@ -129,9 +130,9 @@ func (sdk *propSDK) DeleteTask(id string) error { } func (sdk *propSDK) StartTask(id string) error { - url := fmt.Sprintf("%s/tasks/%s/start", sdk.managerURL, id) + reqURL := fmt.Sprintf("%s/tasks/%s/start", sdk.managerURL, id) - if _, err := sdk.processRequest(http.MethodPost, url, nil, http.StatusOK); err != nil { + if _, err := sdk.processRequest(http.MethodPost, reqURL, nil, http.StatusOK); err != nil { return err } @@ -139,9 +140,9 @@ func (sdk *propSDK) StartTask(id string) error { } func (sdk *propSDK) StopTask(id string) error { - url := fmt.Sprintf("%s/tasks/%s/stop", sdk.managerURL, id) + reqURL := fmt.Sprintf("%s/tasks/%s/stop", sdk.managerURL, id) - if _, err := sdk.processRequest(http.MethodPost, url, nil, http.StatusOK); err != nil { + if _, err := sdk.processRequest(http.MethodPost, reqURL, nil, http.StatusOK); err != nil { return err } @@ -184,9 +185,9 @@ func (sdk *propSDK) CreateJob(req JobRequest) (JobResponse, error) { return JobResponse{}, err } - url := sdk.managerURL + jobsEndpoint + reqURL := sdk.managerURL + jobsEndpoint - body, err := sdk.processRequest(http.MethodPost, url, data, http.StatusCreated) + body, err := sdk.processRequest(http.MethodPost, reqURL, data, http.StatusCreated) if err != nil { return JobResponse{}, err } @@ -200,9 +201,9 @@ func (sdk *propSDK) CreateJob(req JobRequest) (JobResponse, error) { } func (sdk *propSDK) GetJob(jobID string) (JobResponse, error) { - url := sdk.managerURL + jobsEndpoint + "/" + jobID + reqURL := sdk.managerURL + jobsEndpoint + "/" + jobID - body, err := sdk.processRequest(http.MethodGet, url, nil, http.StatusOK) + body, err := sdk.processRequest(http.MethodGet, reqURL, nil, http.StatusOK) if err != nil { return JobResponse{}, err } @@ -215,21 +216,24 @@ func (sdk *propSDK) GetJob(jobID string) (JobResponse, error) { return jr, nil } -func (sdk *propSDK) ListJobs(offset, limit uint64) (JobPage, error) { - queries := make([]string, 0) +func (sdk *propSDK) ListJobs(offset, limit uint64, status string) (JobPage, error) { + params := make([]string, 0) if offset > 0 { - queries = append(queries, fmt.Sprintf("offset=%d", offset)) + params = append(params, fmt.Sprintf("offset=%d", offset)) } if limit > 0 { - queries = append(queries, fmt.Sprintf("limit=%d", limit)) + params = append(params, fmt.Sprintf("limit=%d", limit)) + } + if status != "" { + params = append(params, "status="+url.QueryEscape(status)) } query := "" - if len(queries) > 0 { - query = "?" + strings.Join(queries, "&") + if len(params) > 0 { + query = "?" + strings.Join(params, "&") } - url := sdk.managerURL + jobsEndpoint + query + reqURL := sdk.managerURL + jobsEndpoint + query - body, err := sdk.processRequest(http.MethodGet, url, nil, http.StatusOK) + body, err := sdk.processRequest(http.MethodGet, reqURL, nil, http.StatusOK) if err != nil { return JobPage{}, err } @@ -243,9 +247,9 @@ func (sdk *propSDK) ListJobs(offset, limit uint64) (JobPage, error) { } func (sdk *propSDK) StartJob(jobID string) error { - url := fmt.Sprintf("%s/jobs/%s/start", sdk.managerURL, jobID) + reqURL := fmt.Sprintf("%s/jobs/%s/start", sdk.managerURL, jobID) - if _, err := sdk.processRequest(http.MethodPost, url, nil, http.StatusOK); err != nil { + if _, err := sdk.processRequest(http.MethodPost, reqURL, nil, http.StatusOK); err != nil { return err } @@ -253,9 +257,9 @@ func (sdk *propSDK) StartJob(jobID string) error { } func (sdk *propSDK) StopJob(jobID string) error { - url := fmt.Sprintf("%s/jobs/%s/stop", sdk.managerURL, jobID) + reqURL := fmt.Sprintf("%s/jobs/%s/stop", sdk.managerURL, jobID) - if _, err := sdk.processRequest(http.MethodPost, url, nil, http.StatusOK); err != nil { + if _, err := sdk.processRequest(http.MethodPost, reqURL, nil, http.StatusOK); err != nil { return err } diff --git a/pkg/storage/badger/init.go b/pkg/storage/badger/init.go index 0155a1b4..9b5d593b 100644 --- a/pkg/storage/badger/init.go +++ b/pkg/storage/badger/init.go @@ -62,6 +62,7 @@ type PropletRepository interface { Get(ctx context.Context, id string) (proplet.Proplet, error) Update(ctx context.Context, p proplet.Proplet) error List(ctx context.Context, offset, limit uint64) ([]proplet.Proplet, uint64, error) + ListByAlive(ctx context.Context, offset, limit uint64, alive bool, since time.Time) ([]proplet.Proplet, uint64, error) Delete(ctx context.Context, id string) error } diff --git a/pkg/storage/badger/proplets.go b/pkg/storage/badger/proplets.go index f4ec7712..09ef6925 100644 --- a/pkg/storage/badger/proplets.go +++ b/pkg/storage/badger/proplets.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "time" "github.com/absmach/propeller/pkg/proplet" ) @@ -78,6 +79,36 @@ func (r *propletRepo) List(ctx context.Context, offset, limit uint64) ([]proplet return proplets, total, nil } +const maxBadgerScan uint64 = 100000 + +func (r *propletRepo) ListByAlive(ctx context.Context, offset, limit uint64, alive bool, since time.Time) ([]proplet.Proplet, uint64, error) { + prefix := []byte("proplet:") + values, err := r.db.listWithPrefix(prefix, 0, maxBadgerScan) + if err != nil { + return nil, 0, err + } + + var filtered []proplet.Proplet + for _, val := range values { + var p proplet.Proplet + if err := json.Unmarshal(val, &p); err != nil { + return nil, 0, fmt.Errorf("unmarshal error: %w", err) + } + isAlive := len(p.AliveHistory) > 0 && !p.AliveHistory[len(p.AliveHistory)-1].Before(since) + if isAlive == alive { + filtered = append(filtered, p) + } + } + + filteredTotal := uint64(len(filtered)) + if offset >= filteredTotal { + return []proplet.Proplet{}, filteredTotal, nil + } + end := min(offset+limit, filteredTotal) + + return filtered[offset:end], filteredTotal, nil +} + func (r *propletRepo) Delete(ctx context.Context, id string) error { key := []byte("proplet:" + id) diff --git a/pkg/storage/factory.go b/pkg/storage/factory.go index f4d37798..b2c55d4b 100644 --- a/pkg/storage/factory.go +++ b/pkg/storage/factory.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "time" "github.com/absmach/propeller/pkg/job" "github.com/absmach/propeller/pkg/proplet" @@ -189,6 +190,10 @@ func (a *postgresPropletAdapter) List(ctx context.Context, offset, limit uint64) return a.repo.List(ctx, offset, limit) } +func (a *postgresPropletAdapter) ListByAlive(ctx context.Context, offset, limit uint64, alive bool, since time.Time) ([]proplet.Proplet, uint64, error) { + return a.repo.ListByAlive(ctx, offset, limit, alive, since) +} + func (a *postgresPropletAdapter) Delete(ctx context.Context, id string) error { return a.repo.Delete(ctx, id) } @@ -332,6 +337,10 @@ func (a *sqlitePropletAdapter) List(ctx context.Context, offset, limit uint64) ( return a.repo.List(ctx, offset, limit) } +func (a *sqlitePropletAdapter) ListByAlive(ctx context.Context, offset, limit uint64, alive bool, since time.Time) ([]proplet.Proplet, uint64, error) { + return a.repo.ListByAlive(ctx, offset, limit, alive, since) +} + func (a *sqlitePropletAdapter) Delete(ctx context.Context, id string) error { return a.repo.Delete(ctx, id) } @@ -475,6 +484,10 @@ func (a *badgerPropletAdapter) List(ctx context.Context, offset, limit uint64) ( return a.repo.List(ctx, offset, limit) } +func (a *badgerPropletAdapter) ListByAlive(ctx context.Context, offset, limit uint64, alive bool, since time.Time) ([]proplet.Proplet, uint64, error) { + return a.repo.ListByAlive(ctx, offset, limit, alive, since) +} + func (a *badgerPropletAdapter) Delete(ctx context.Context, id string) error { return a.repo.Delete(ctx, id) } diff --git a/pkg/storage/memory_adapter.go b/pkg/storage/memory_adapter.go index 77399099..15f508a1 100644 --- a/pkg/storage/memory_adapter.go +++ b/pkg/storage/memory_adapter.go @@ -3,6 +3,7 @@ package storage import ( "context" "fmt" + "time" pkgerrors "github.com/absmach/propeller/pkg/errors" "github.com/absmach/propeller/pkg/job" @@ -152,6 +153,33 @@ func (r *memoryPropletRepo) List(ctx context.Context, offset, limit uint64) ([]p return proplets, total, nil } +func (r *memoryPropletRepo) ListByAlive(ctx context.Context, offset, limit uint64, alive bool, since time.Time) ([]proplet.Proplet, uint64, error) { + data, _, err := r.storage.List(ctx, 0, maxMemoryFetch) + if err != nil { + return nil, 0, err + } + + var filtered []proplet.Proplet + for _, d := range data { + p, ok := d.(proplet.Proplet) + if !ok { + continue + } + isAlive := len(p.AliveHistory) > 0 && !p.AliveHistory[len(p.AliveHistory)-1].Before(since) + if isAlive == alive { + filtered = append(filtered, p) + } + } + + filteredTotal := uint64(len(filtered)) + if offset >= filteredTotal { + return []proplet.Proplet{}, filteredTotal, nil + } + end := min(offset+limit, filteredTotal) + + return filtered[offset:end], filteredTotal, nil +} + func (r *memoryPropletRepo) Delete(ctx context.Context, id string) error { return r.storage.Delete(ctx, id) } diff --git a/pkg/storage/mocks/proplet_repository.go b/pkg/storage/mocks/proplet_repository.go index 6c4040ab..39e4fabf 100644 --- a/pkg/storage/mocks/proplet_repository.go +++ b/pkg/storage/mocks/proplet_repository.go @@ -6,6 +6,7 @@ package mocks import ( "context" + "time" "github.com/absmach/propeller/pkg/proplet" mock "github.com/stretchr/testify/mock" @@ -298,6 +299,84 @@ func (_c *MockPropletRepository_List_Call) RunAndReturn(run func(ctx context.Con return _c } +func (_mock *MockPropletRepository) ListByAlive(ctx context.Context, offset uint64, limit uint64, alive bool, since time.Time) ([]proplet.Proplet, uint64, error) { + ret := _mock.Called(ctx, offset, limit, alive, since) + + if len(ret) == 0 { + panic("no return value specified for ListByAlive") + } + + var r0 []proplet.Proplet + var r1 uint64 + var r2 error + if returnFunc, ok := ret.Get(0).(func(context.Context, uint64, uint64, bool, time.Time) ([]proplet.Proplet, uint64, error)); ok { + return returnFunc(ctx, offset, limit, alive, since) + } + if returnFunc, ok := ret.Get(0).(func(context.Context, uint64, uint64, bool, time.Time) []proplet.Proplet); ok { + r0 = returnFunc(ctx, offset, limit, alive, since) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]proplet.Proplet) + } + } + if returnFunc, ok := ret.Get(1).(func(context.Context, uint64, uint64, bool, time.Time) uint64); ok { + r1 = returnFunc(ctx, offset, limit, alive, since) + } else { + r1 = ret.Get(1).(uint64) + } + if returnFunc, ok := ret.Get(2).(func(context.Context, uint64, uint64, bool, time.Time) error); ok { + r2 = returnFunc(ctx, offset, limit, alive, since) + } else { + r2 = ret.Error(2) + } + return r0, r1, r2 +} + +type MockPropletRepository_ListByAlive_Call struct { + *mock.Call +} + +func (_e *MockPropletRepository_Expecter) ListByAlive(ctx interface{}, offset interface{}, limit interface{}, alive interface{}, since interface{}) *MockPropletRepository_ListByAlive_Call { + return &MockPropletRepository_ListByAlive_Call{Call: _e.mock.On("ListByAlive", ctx, offset, limit, alive, since)} +} + +func (_c *MockPropletRepository_ListByAlive_Call) Run(run func(ctx context.Context, offset uint64, limit uint64, alive bool, since time.Time)) *MockPropletRepository_ListByAlive_Call { + _c.Call.Run(func(args mock.Arguments) { + var arg0 context.Context + if args[0] != nil { + arg0 = args[0].(context.Context) + } + var arg1 uint64 + if args[1] != nil { + arg1 = args[1].(uint64) + } + var arg2 uint64 + if args[2] != nil { + arg2 = args[2].(uint64) + } + var arg3 bool + if args[3] != nil { + arg3 = args[3].(bool) + } + var arg4 time.Time + if args[4] != nil { + arg4 = args[4].(time.Time) + } + run(arg0, arg1, arg2, arg3, arg4) + }) + return _c +} + +func (_c *MockPropletRepository_ListByAlive_Call) Return(proplets []proplet.Proplet, v uint64, err error) *MockPropletRepository_ListByAlive_Call { + _c.Call.Return(proplets, v, err) + return _c +} + +func (_c *MockPropletRepository_ListByAlive_Call) RunAndReturn(run func(ctx context.Context, offset uint64, limit uint64, alive bool, since time.Time) ([]proplet.Proplet, uint64, error)) *MockPropletRepository_ListByAlive_Call { + _c.Call.Return(run) + return _c +} + // Update provides a mock function for the type MockPropletRepository func (_mock *MockPropletRepository) Update(ctx context.Context, p proplet.Proplet) error { ret := _mock.Called(ctx, p) diff --git a/pkg/storage/postgres/init.go b/pkg/storage/postgres/init.go index 6465d4b6..85870376 100644 --- a/pkg/storage/postgres/init.go +++ b/pkg/storage/postgres/init.go @@ -64,6 +64,7 @@ type PropletRepository interface { Get(ctx context.Context, id string) (proplet.Proplet, error) Update(ctx context.Context, p proplet.Proplet) error List(ctx context.Context, offset, limit uint64) ([]proplet.Proplet, uint64, error) + ListByAlive(ctx context.Context, offset, limit uint64, alive bool, since time.Time) ([]proplet.Proplet, uint64, error) Delete(ctx context.Context, id string) error } diff --git a/pkg/storage/postgres/proplets.go b/pkg/storage/postgres/proplets.go index 0b45700b..d7d47f87 100644 --- a/pkg/storage/postgres/proplets.go +++ b/pkg/storage/postgres/proplets.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "time" "github.com/absmach/propeller/pkg/proplet" ) @@ -119,6 +120,51 @@ func (r *propletRepo) List(ctx context.Context, offset, limit uint64) ([]proplet return proplets, total, nil } +func (r *propletRepo) ListByAlive(ctx context.Context, offset, limit uint64, alive bool, since time.Time) ([]proplet.Proplet, uint64, error) { + var whereClause string + if alive { + whereClause = `WHERE alive_history IS NOT NULL AND jsonb_array_length(alive_history) > 0 AND (alive_history ->> (jsonb_array_length(alive_history) - 1))::timestamptz >= $1` + } else { + whereClause = `WHERE alive_history IS NULL OR jsonb_array_length(alive_history) = 0 OR (alive_history ->> (jsonb_array_length(alive_history) - 1))::timestamptz < $1` + } + + tx, err := r.db.BeginTxx(ctx, &sql.TxOptions{Isolation: sql.LevelRepeatableRead, ReadOnly: true}) + if err != nil { + return nil, 0, fmt.Errorf("%w: %w", ErrDBQuery, err) + } + defer func() { _ = tx.Rollback() }() + + var total uint64 + if err := tx.GetContext(ctx, &total, "SELECT COUNT(*) FROM proplets "+whereClause, since); err != nil { + return nil, 0, fmt.Errorf("%w: %w", ErrDBQuery, err) + } + + query := fmt.Sprintf(`SELECT id, name, task_count, alive, alive_history, metadata FROM proplets %s LIMIT $2 OFFSET $3`, whereClause) + rows, err := tx.QueryContext(ctx, query, since, limit, offset) + if err != nil { + return nil, 0, fmt.Errorf("%w: %w", ErrDBQuery, err) + } + defer rows.Close() + + proplets := make([]proplet.Proplet, 0) + for rows.Next() { + var dbp dbProplet + if err := rows.Scan(&dbp.ID, &dbp.Name, &dbp.TaskCount, &dbp.Alive, &dbp.AliveHistory, &dbp.Metadata); err != nil { + return nil, 0, fmt.Errorf("%w: %w", ErrDBScan, err) + } + p, err := r.toProplet(dbp) + if err != nil { + return nil, 0, fmt.Errorf("%w: %w", ErrDBScan, err) + } + proplets = append(proplets, p) + } + if err := rows.Err(); err != nil { + return nil, 0, fmt.Errorf("%w: %w", ErrDBQuery, err) + } + + return proplets, total, nil +} + func (r *propletRepo) Delete(ctx context.Context, id string) error { query := `DELETE FROM proplets WHERE id = $1` diff --git a/pkg/storage/repository.go b/pkg/storage/repository.go index 0379139c..174e5d0d 100644 --- a/pkg/storage/repository.go +++ b/pkg/storage/repository.go @@ -2,6 +2,7 @@ package storage import ( "context" + "time" "github.com/absmach/propeller/pkg/job" "github.com/absmach/propeller/pkg/proplet" @@ -23,6 +24,7 @@ type PropletRepository interface { Get(ctx context.Context, id string) (proplet.Proplet, error) Update(ctx context.Context, p proplet.Proplet) error List(ctx context.Context, offset, limit uint64) ([]proplet.Proplet, uint64, error) + ListByAlive(ctx context.Context, offset, limit uint64, alive bool, since time.Time) ([]proplet.Proplet, uint64, error) Delete(ctx context.Context, id string) error } diff --git a/pkg/storage/sqlite/init.go b/pkg/storage/sqlite/init.go index 2ad7da03..af29be2f 100644 --- a/pkg/storage/sqlite/init.go +++ b/pkg/storage/sqlite/init.go @@ -64,6 +64,7 @@ type PropletRepository interface { Get(ctx context.Context, id string) (proplet.Proplet, error) Update(ctx context.Context, p proplet.Proplet) error List(ctx context.Context, offset, limit uint64) ([]proplet.Proplet, uint64, error) + ListByAlive(ctx context.Context, offset, limit uint64, alive bool, since time.Time) ([]proplet.Proplet, uint64, error) Delete(ctx context.Context, id string) error } diff --git a/pkg/storage/sqlite/proplets.go b/pkg/storage/sqlite/proplets.go index 86c991f8..0ad0e432 100644 --- a/pkg/storage/sqlite/proplets.go +++ b/pkg/storage/sqlite/proplets.go @@ -5,6 +5,7 @@ import ( "database/sql" "errors" "fmt" + "time" "github.com/absmach/propeller/pkg/proplet" ) @@ -120,6 +121,46 @@ func (r *propletRepo) List(ctx context.Context, offset, limit uint64) ([]proplet return proplets, total, nil } +func (r *propletRepo) ListByAlive(ctx context.Context, offset, limit uint64, alive bool, since time.Time) ([]proplet.Proplet, uint64, error) { + tx, err := r.db.BeginTxx(ctx, nil) + if err != nil { + return nil, 0, fmt.Errorf("%w: %w", ErrDBQuery, err) + } + defer func() { _ = tx.Rollback() }() + + rows, err := tx.QueryContext(ctx, `SELECT id, name, task_count, alive, alive_history, metadata FROM proplets`) + if err != nil { + return nil, 0, fmt.Errorf("%w: %w", ErrDBQuery, err) + } + defer rows.Close() + + var filtered []proplet.Proplet + for rows.Next() { + var dbp dbProplet + if err := rows.Scan(&dbp.ID, &dbp.Name, &dbp.TaskCount, &dbp.Alive, &dbp.AliveHistory, &dbp.Metadata); err != nil { + return nil, 0, fmt.Errorf("%w: %w", ErrDBScan, err) + } + p, err := r.toProplet(dbp) + if err != nil { + return nil, 0, fmt.Errorf("%w: %w", ErrDBScan, err) + } + isAlive := len(p.AliveHistory) > 0 && !p.AliveHistory[len(p.AliveHistory)-1].Before(since) + if isAlive == alive { + filtered = append(filtered, p) + } + } + if err := rows.Err(); err != nil { + return nil, 0, fmt.Errorf("%w: %w", ErrDBQuery, err) + } + + filteredTotal := uint64(len(filtered)) + if offset >= filteredTotal { + return []proplet.Proplet{}, filteredTotal, nil + } + + return filtered[offset:min(offset+limit, filteredTotal)], filteredTotal, nil +} + func (r *propletRepo) Delete(ctx context.Context, id string) error { query := `DELETE FROM proplets WHERE id = ?`