diff --git a/README.md b/README.md index bf3581d..506a9c0 100644 --- a/README.md +++ b/README.md @@ -35,7 +35,7 @@ func main() { - **Tasks**: Collections of actions that execute sequentially - **Actions**: Individual operations (file, Docker, system) -- **Parameters**: Pass data between actions using `ActionOutput()` and `TaskOutput()` +- **Parameters**: Pass data between actions using `ActionOutput()` and `TaskOutput()`, and fetch rich results using `ActionResult()` and `TaskResult()` - **Context**: Share data across tasks with `TaskManager` ## Built-in Actions @@ -51,7 +51,7 @@ See [ACTIONS.md](ACTIONS.md) for complete list. ## Parameter Passing -Actions can reference outputs from previous actions: +Actions can reference outputs from previous actions, results from actions, and results from tasks: ```go // Reference action output @@ -70,6 +70,12 @@ docker.NewDockerRunAction( []string{"-p", "8080:8080"}, logger, ) + +// Reference action result (from an action implementing ResultProvider) +useChecksum := task_engine.ActionResultField("download-artifact", "checksum") + +// Reference task result (from a task implementing ResultProvider or using ResultBuilder) +preflightMode := task_engine.TaskResultField("preflight", "UpdateMode") ``` ## Task Management diff --git a/action.go b/action.go index 56847cf..2da1fc0 100644 --- a/action.go +++ b/action.go @@ -5,7 +5,6 @@ import ( "fmt" "io" "log/slog" - "reflect" "regexp" "strings" "sync" @@ -36,276 +35,6 @@ type ActionWithResults interface { ResultProvider } -// ActionParameter interface for all parameter types that can be resolved at runtime -// to provide values for action execution. Parameters support references to outputs -// from other actions, tasks, or static values. -type ActionParameter interface { - // Resolve returns the actual value for this parameter by looking up - // references in the global context or returning static values. - Resolve(ctx context.Context, globalContext *GlobalContext) (interface{}, error) -} - -// StaticParameter represents a fixed value that doesn't need resolution. -// Use this for values known at task creation time. -type StaticParameter struct { - Value interface{} // The static value to use -} - -func (p StaticParameter) Resolve(ctx context.Context, globalContext *GlobalContext) (interface{}, error) { - return p.Value, nil -} - -// ActionOutputParameter references output from a specific action. -// Use this to pass data between actions within the same task. -type ActionOutputParameter struct { - ActionID string // Required: ID of the action to reference - OutputKey string // Optional: specific output field to extract (omit for entire output) -} - -func (p ActionOutputParameter) Resolve(ctx context.Context, globalContext *GlobalContext) (interface{}, error) { - if p.ActionID == "" { - return nil, fmt.Errorf("ActionOutputParameter: ActionID cannot be empty") - } - - output, exists := globalContext.ActionOutputs[p.ActionID] - if !exists { - return nil, fmt.Errorf("ActionOutputParameter: action '%s' not found in context", p.ActionID) - } - - if p.OutputKey != "" { - // Validate OutputKey exists in output - if outputMap, ok := output.(map[string]interface{}); ok { - if value, exists := outputMap[p.OutputKey]; exists { - return value, nil - } - return nil, fmt.Errorf("ActionOutputParameter: output key '%s' not found in action '%s'", p.OutputKey, p.ActionID) - } - return nil, fmt.Errorf("ActionOutputParameter: action '%s' output is not a map, cannot extract key '%s'", p.ActionID, p.OutputKey) - } - - return output, nil -} - -// ActionResultParameter references results from actions implementing ResultProvider -type ActionResultParameter struct { - ActionID string // Required: ID of the action to reference - ResultKey string // Optional: specific result field to extract -} - -func (p ActionResultParameter) Resolve(ctx context.Context, globalContext *GlobalContext) (interface{}, error) { - if p.ActionID == "" { - return nil, fmt.Errorf("ActionResultParameter: ActionID cannot be empty") - } - - resultProvider, exists := globalContext.ActionResults[p.ActionID] - if !exists { - return nil, fmt.Errorf("ActionResultParameter: action '%s' not found in context", p.ActionID) - } - - result := resultProvider.GetResult() - if p.ResultKey != "" { - // Extract specific field from result - if resultMap, ok := result.(map[string]interface{}); ok { - if value, exists := resultMap[p.ResultKey]; exists { - return value, nil - } - return nil, fmt.Errorf("ActionResultParameter: result key '%s' not found in action '%s'", p.ResultKey, p.ActionID) - } - return nil, fmt.Errorf("ActionResultParameter: action '%s' result is not a map, cannot extract key '%s'", p.ActionID, p.ResultKey) - } - - return result, nil -} - -// TaskOutputParameter references output from a specific task -type TaskOutputParameter struct { - TaskID string // Required: ID of the task to reference - OutputKey string // Optional: specific output field to extract -} - -func (p TaskOutputParameter) Resolve(ctx context.Context, globalContext *GlobalContext) (interface{}, error) { - if p.TaskID == "" { - return nil, fmt.Errorf("TaskOutputParameter: TaskID cannot be empty") - } - - output, exists := globalContext.TaskOutputs[p.TaskID] - if !exists { - return nil, fmt.Errorf("TaskOutputParameter: task '%s' not found in context", p.TaskID) - } - - if p.OutputKey != "" { - // Extract specific field from output - if outputMap, ok := output.(map[string]interface{}); ok { - if value, exists := outputMap[p.OutputKey]; exists { - return value, nil - } - return nil, fmt.Errorf("TaskOutputParameter: output key '%s' not found in task '%s'", p.OutputKey, p.TaskID) - } - return nil, fmt.Errorf("TaskOutputParameter: task '%s' output is not a map, cannot extract key '%s'", p.TaskID, p.OutputKey) - } - - return output, nil -} - -// EntityOutputParameter references output from any entity (action or task) -type EntityOutputParameter struct { - EntityType string // Required: "action" or "task" - EntityID string // Required: ID of the entity to reference - OutputKey string // Optional: specific output field to extract -} - -func (p EntityOutputParameter) Resolve(ctx context.Context, globalContext *GlobalContext) (interface{}, error) { - if p.EntityType == "" || p.EntityID == "" { - return nil, fmt.Errorf("EntityOutputParameter: EntityType and EntityID cannot be empty") - } - - switch p.EntityType { - case "action": - // Try ActionOutputs first - if output, exists := globalContext.ActionOutputs[p.EntityID]; exists { - if p.OutputKey != "" { - if outputMap, ok := output.(map[string]interface{}); ok { - if value, exists := outputMap[p.OutputKey]; exists { - return value, nil - } - return nil, fmt.Errorf("EntityOutputParameter: output key '%s' not found in action '%s'", p.OutputKey, p.EntityID) - } - return nil, fmt.Errorf("EntityOutputParameter: action '%s' output is not a map, cannot extract key '%s'", p.EntityID, p.OutputKey) - } - return output, nil - } - // Try ActionResults if ActionOutputs doesn't have it - if resultProvider, exists := globalContext.ActionResults[p.EntityID]; exists { - result := resultProvider.GetResult() - if p.OutputKey != "" { - if resultMap, ok := result.(map[string]interface{}); ok { - if value, exists := resultMap[p.OutputKey]; exists { - return value, nil - } - return nil, fmt.Errorf("EntityOutputParameter: result key '%s' not found in action '%s'", p.OutputKey, p.EntityID) - } - return nil, fmt.Errorf("EntityOutputParameter: action '%s' result is not a map, cannot extract key '%s'", p.EntityID, p.OutputKey) - } - return result, nil - } - return nil, fmt.Errorf("EntityOutputParameter: action '%s' not found in context", p.EntityID) - - case "task": - output, exists := globalContext.TaskOutputs[p.EntityID] - if !exists { - return nil, fmt.Errorf("EntityOutputParameter: task '%s' not found in context", p.EntityID) - } - if p.OutputKey != "" { - if outputMap, ok := output.(map[string]interface{}); ok { - if value, exists := outputMap[p.OutputKey]; exists { - return value, nil - } - return nil, fmt.Errorf("EntityOutputParameter: output key '%s' not found in task '%s'", p.OutputKey, p.EntityID) - } - return nil, fmt.Errorf("EntityOutputParameter: task '%s' output is not a map, cannot extract key '%s'", p.EntityID, p.OutputKey) - } - return output, nil - - default: - return nil, fmt.Errorf("EntityOutputParameter: invalid entity type '%s', must be 'action' or 'task'", p.EntityType) - } -} - -// --- Typed parameter resolution helpers --- - -// ResolveString resolves an ActionParameter to a string with helpful -// conversions and clear error messages. When the parameter is nil, -// it returns an empty string without error. -func ResolveString(ctx context.Context, p ActionParameter, globalContext *GlobalContext) (string, error) { - if p == nil { - return "", nil - } - v, err := p.Resolve(ctx, globalContext) - if err != nil { - return "", err - } - switch t := v.(type) { - case string: - return t, nil - case []byte: - return string(t), nil - case fmt.Stringer: - return t.String(), nil - case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: - return fmt.Sprint(v), nil - default: - return "", fmt.Errorf("parameter is not a string, got %T", v) - } -} - -// ResolveBool resolves an ActionParameter to a bool with common coercions. -// If parameter is nil, returns false. -func ResolveBool(ctx context.Context, p ActionParameter, globalContext *GlobalContext) (bool, error) { - if p == nil { - return false, nil - } - v, err := p.Resolve(ctx, globalContext) - if err != nil { - return false, err - } - switch t := v.(type) { - case bool: - return t, nil - case string: - s := strings.TrimSpace(strings.ToLower(t)) - if s == "true" || s == "1" || s == "yes" || s == "y" { // common truthy strings - return true, nil - } - if s == "false" || s == "0" || s == "no" || s == "n" { - return false, nil - } - return false, fmt.Errorf("cannot convert string '%s' to bool", t) - case int: - return t != 0, nil - case int64: - return t != 0, nil - case uint: - return t != 0, nil - default: - return false, fmt.Errorf("parameter is not a bool, got %T", v) - } -} - -// ResolveStringSlice resolves an ActionParameter into a []string. -// Accepts []string directly, or splits a string by comma or spaces. -func ResolveStringSlice(ctx context.Context, p ActionParameter, globalContext *GlobalContext) ([]string, error) { - if p == nil { - return nil, nil - } - v, err := p.Resolve(ctx, globalContext) - if err != nil { - return nil, err - } - switch t := v.(type) { - case []string: - return t, nil - case string: - s := strings.TrimSpace(t) - if s == "" { - return []string{}, nil - } - if strings.Contains(s, ",") { - parts := strings.Split(s, ",") - out := make([]string, 0, len(parts)) - for _, p := range parts { - p = strings.TrimSpace(p) - if p != "" { - out = append(out, p) - } - } - return out, nil - } - return strings.Fields(s), nil - default: - return nil, fmt.Errorf("parameter is not a string slice or string, got %T", v) - } -} - // --- Consistent action ID helpers --- var idSanitizer = regexp.MustCompile(`[^a-z0-9_:\-.]+`) @@ -344,73 +73,6 @@ func BuildActionID(prefix string, parts ...string) string { return base + "-" + strings.Join(cleaned, "-") + "-action" } -// Helper functions for common parameter patterns -// ActionOutput creates a parameter reference to an entire action output -func ActionOutput(actionID string) ActionOutputParameter { - return ActionOutputParameter{ActionID: actionID} -} - -// ActionOutputField creates a parameter reference to a specific field in an action output -func ActionOutputField(actionID, field string) ActionOutputParameter { - return ActionOutputParameter{ActionID: actionID, OutputKey: field} -} - -// ActionResult creates a parameter reference to an action result (for ResultProvider actions) -func ActionResult(actionID string) ActionResultParameter { - return ActionResultParameter{ActionID: actionID} -} - -// ActionResultField creates a parameter reference to a specific field in an action result -func ActionResultField(actionID, field string) ActionResultParameter { - return ActionResultParameter{ActionID: actionID, ResultKey: field} -} - -// TaskOutput creates a parameter reference to an entire task output -func TaskOutput(taskID string) TaskOutputParameter { - return TaskOutputParameter{TaskID: taskID} -} - -// TaskOutputField creates a parameter reference to a specific field in a task output -func TaskOutputField(taskID, field string) TaskOutputParameter { - return TaskOutputParameter{TaskID: taskID, OutputKey: field} -} - -// EntityOutput creates a parameter reference to any entity type (action or task) -func EntityOutput(entityType, entityID string) EntityOutputParameter { - return EntityOutputParameter{EntityType: entityType, EntityID: entityID} -} - -// EntityOutputField creates a parameter reference to a specific field in any entity output -func EntityOutputField(entityType, entityID, field string) EntityOutputParameter { - return EntityOutputParameter{EntityType: entityType, EntityID: entityID, OutputKey: field} -} - -// --- Phase 5 Ergonomics --- - -// TypedOutputKey provides a way to associate an output field name with an expected -// struct type T. Validate can be used to check that the field exists on T at runtime. -// Note: This is a runtime validation helper; compile-time validation would require codegen. -// TypedOutputKey provides compile-time validation of output keys for type-safe -// parameter references. Use this when you want to ensure output keys exist -// in your output types at compile time. -type TypedOutputKey[T any] struct { - ActionID string // ID of the action to reference - Key string // Field name to extract from the output -} - -// Validate checks whether Key is a valid exported field on T when T is a struct. -// If T is not a struct, Validate returns nil (no validation performed). -func (k TypedOutputKey[T]) Validate() error { - t := reflect.TypeOf((*T)(nil)).Elem() - if t.Kind() != reflect.Struct { - return nil - } - if _, exists := t.FieldByName(k.Key); !exists { - return fmt.Errorf("field '%s' does not exist on output type %s", k.Key, t.Name()) - } - return nil -} - // BaseAction is used as a composite struct for newly defined actions, to provide a default no-op implementation of the before/after // hooks. It also has a logger passed from the action that wraps it. // BaseAction provides common functionality for actions including logging @@ -492,6 +154,151 @@ func (gc *GlobalContext) StoreTaskResult(taskID string, resultProvider ResultPro gc.TaskResults[taskID] = resultProvider } +// --- Typed convenience helpers (simplest way to fetch data) --- + +// ActionResultAs returns a typed action result from an action implementing ResultProvider. +func ActionResultAs[T any](gc *GlobalContext, actionID string) (T, bool) { + gc.mu.RLock() + rp, ok := gc.ActionResults[actionID] + gc.mu.RUnlock() + var zero T + if !ok || rp == nil { + return zero, false + } + v, ok := rp.GetResult().(T) + return v, ok +} + +// TaskResultAs returns a typed task result from a task implementing ResultProvider. +func TaskResultAs[T any](gc *GlobalContext, taskID string) (T, bool) { + gc.mu.RLock() + rp, ok := gc.TaskResults[taskID] + gc.mu.RUnlock() + var zero T + if !ok || rp == nil { + return zero, false + } + v, ok := rp.GetResult().(T) + return v, ok +} + +// ActionOutputFieldAs returns a typed value from an action's output map. +func ActionOutputFieldAs[T any](gc *GlobalContext, actionID, key string) (T, error) { + gc.mu.RLock() + output, exists := gc.ActionOutputs[actionID] + gc.mu.RUnlock() + var zero T + if !exists { + return zero, fmt.Errorf("action '%s' not found in context", actionID) + } + if key == "" { + if v, ok := output.(T); ok { + return v, nil + } + return zero, fmt.Errorf("action '%s' output is not %T", actionID, zero) + } + m, ok := output.(map[string]interface{}) + if !ok { + return zero, fmt.Errorf("action '%s' output is not a map, cannot extract key '%s'", actionID, key) + } + val, exists := m[key] + if !exists { + return zero, fmt.Errorf("output key '%s' not found in action '%s'", key, actionID) + } + typed, ok := val.(T) + if !ok { + return zero, fmt.Errorf("action '%s' output key '%s' is not %T", actionID, key, zero) + } + return typed, nil +} + +// TaskOutputFieldAs returns a typed value from a task's output map. +func TaskOutputFieldAs[T any](gc *GlobalContext, taskID, key string) (T, error) { + gc.mu.RLock() + output, exists := gc.TaskOutputs[taskID] + gc.mu.RUnlock() + var zero T + if !exists { + return zero, fmt.Errorf("task '%s' not found in context", taskID) + } + if key == "" { + if v, ok := output.(T); ok { + return v, nil + } + return zero, fmt.Errorf("task '%s' output is not %T", taskID, zero) + } + m, ok := output.(map[string]interface{}) + if !ok { + return zero, fmt.Errorf("task '%s' output is not a map, cannot extract key '%s'", taskID, key) + } + val, exists := m[key] + if !exists { + return zero, fmt.Errorf("output key '%s' not found in task '%s'", key, taskID) + } + typed, ok := val.(T) + if !ok { + return zero, fmt.Errorf("task '%s' output key '%s' is not %T", taskID, key, zero) + } + return typed, nil +} + +// EntityValue returns a value from either outputs or results for the given entity. +// For actions, tries ActionOutputs then ActionResults. For tasks, tries TaskOutputs then TaskResults. +func EntityValue(gc *GlobalContext, entityType, id, key string) (interface{}, error) { + switch entityType { + case "action": + if key == "" { + gc.mu.RLock() + out, exists := gc.ActionOutputs[id] + gc.mu.RUnlock() + if exists { + return out, nil + } + gc.mu.RLock() + rp, exists := gc.ActionResults[id] + gc.mu.RUnlock() + if exists && rp != nil { + return rp.GetResult(), nil + } + return nil, fmt.Errorf("action '%s' not found in context", id) + } + return ActionOutputFieldAs[interface{}](gc, id, key) + case "task": + if key == "" { + gc.mu.RLock() + out, exists := gc.TaskOutputs[id] + gc.mu.RUnlock() + if exists { + return out, nil + } + gc.mu.RLock() + rp, exists := gc.TaskResults[id] + gc.mu.RUnlock() + if exists && rp != nil { + return rp.GetResult(), nil + } + return nil, fmt.Errorf("task '%s' not found in context", id) + } + return TaskOutputFieldAs[interface{}](gc, id, key) + default: + return nil, fmt.Errorf("invalid entity type '%s'", entityType) + } +} + +// EntityValueAs returns a typed value from either outputs or results for the given entity. +func EntityValueAs[T any](gc *GlobalContext, entityType, id, key string) (T, error) { + var zero T + v, err := EntityValue(gc, entityType, id, key) + if err != nil { + return zero, err + } + out, ok := v.(T) + if !ok { + return zero, fmt.Errorf("entity '%s' value is not %T", id, zero) + } + return out, nil +} + // ActionWrapper interface for actions that can be executed by tasks. // This interface provides the contract that tasks use to interact with actions, // including execution, metadata access, and output retrieval. diff --git a/docs/API.md b/docs/API.md index 6a7e3cd..c136e0b 100644 --- a/docs/API.md +++ b/docs/API.md @@ -13,6 +13,8 @@ type Task struct { Logger *slog.Logger TotalTime time.Duration CompletedTasks int + // Optional builder to produce a structured task result at the end + ResultBuilder func(ctx *TaskContext) (interface{}, error) } func (t *Task) Run(ctx context.Context) error @@ -21,6 +23,10 @@ func (t *Task) GetID() string func (t *Task) GetName() string func (t *Task) GetCompletedTasks() int func (t *Task) GetTotalTime() time.Duration +// If the task provides results +func (t *Task) SetResult(result interface{}) +func (t *Task) GetResult() interface{} +func (t *Task) GetError() error ``` ### Action @@ -67,16 +73,17 @@ func (tm *TaskManager) ResetGlobalContext() ```go type GlobalContext struct { ActionOutputs map[string]interface{} - TaskOutputs map[string]interface{} ActionResults map[string]ResultProvider + TaskOutputs map[string]interface{} + TaskResults map[string]ResultProvider mu sync.RWMutex } func NewGlobalContext() *GlobalContext -func (gc *GlobalContext) SetActionOutput(actionID string, output interface{}) -func (gc *GlobalContext) SetTaskOutput(taskID string, output interface{}) -func (gc *GlobalContext) GetActionOutput(actionID string) (interface{}, bool) -func (gc *GlobalContext) GetTaskOutput(taskID string) (interface{}, bool) +func (gc *GlobalContext) StoreActionOutput(actionID string, output interface{}) +func (gc *GlobalContext) StoreActionResult(actionID string, resultProvider ResultProvider) +func (gc *GlobalContext) StoreTaskOutput(taskID string, output interface{}) +func (gc *GlobalContext) StoreTaskResult(taskID string, resultProvider ResultProvider) ``` ## Parameter Types @@ -129,21 +136,53 @@ func (p ActionResultParameter) Resolve(ctx context.Context, globalContext *Globa ### ActionOutput ```go -func ActionOutput(actionID string, outputKey string) ActionOutputParameter +func ActionOutput(actionID string) ActionOutputParameter +func ActionOutputField(actionID, field string) ActionOutputParameter ``` ### TaskOutput ```go -func TaskOutput(taskID string, outputKey string) TaskOutputParameter +func TaskOutput(taskID string) TaskOutputParameter +func TaskOutputField(taskID, field string) TaskOutputParameter ``` ### ActionResult +````go +func ActionResult(actionID string) ActionResultParameter +func ActionResultField(actionID, field string) ActionResultParameter + +### TaskResult + ```go -func ActionResult(actionID string, resultKey string) ActionResultParameter +func TaskResult(taskID string) TaskResultParameter +func TaskResultField(taskID, field string) TaskResultParameter ``` +### Simple Typed Helpers (Recommended) + +```go +func ActionResultAs[T any](gc *GlobalContext, actionID string) (T, bool) +func TaskResultAs[T any](gc *GlobalContext, taskID string) (T, bool) +func ActionOutputFieldAs[T any](gc *GlobalContext, actionID, key string) (T, error) +func TaskOutputFieldAs[T any](gc *GlobalContext, taskID, key string) (T, error) +func EntityValue(gc *GlobalContext, entityType, id, key string) (interface{}, error) +func EntityValueAs[T any](gc *GlobalContext, entityType, id, key string) (T, error) + +// Generic parameter resolver +func ResolveAs[T any](ctx context.Context, p ActionParameter, gc *GlobalContext) (T, error) +``` + +### EntityOutput + +```go +func EntityOutput(entityType, entityID string) EntityOutputParameter +func EntityOutputField(entityType, entityID, field string) EntityOutputParameter +``` + +```` + ## Interfaces ### ActionInterface @@ -159,7 +198,7 @@ type ActionInterface interface { ### TaskInterface -```go +````go type TaskInterface interface { GetID() string GetName() string @@ -168,7 +207,17 @@ type TaskInterface interface { GetCompletedTasks() int GetTotalTime() time.Duration } -``` + +### TaskWithResults + +```go +type TaskWithResults interface { + TaskInterface + ResultProvider +} +```` + +```` ### TaskManagerInterface @@ -183,7 +232,7 @@ type TaskManagerInterface interface { GetGlobalContext() *GlobalContext ResetGlobalContext() } -``` +```` ### ResultProvider diff --git a/docs/ARCHITECTURE.md b/docs/ARCHITECTURE.md index 0cb9426..73e685e 100644 --- a/docs/ARCHITECTURE.md +++ b/docs/ARCHITECTURE.md @@ -12,6 +12,8 @@ type Task struct { Name string Actions []ActionWrapper Logger *slog.Logger + // Optional: build a structured result at the end of execution + ResultBuilder func(ctx *TaskContext) (interface{}, error) } ``` @@ -62,8 +64,25 @@ engine.ActionOutput("read-action", "content") Reference outputs from other tasks using the global context. -```go +````go engine.TaskOutput("build-task", "imageID") + +### Action Result Parameters + +Use rich results from actions that implement `ResultProvider`. + +```go +engine.ActionResult("download-artifact") +engine.ActionResultField("download-artifact", "checksum") +```` + +### Task Result Parameters + +Use rich results from tasks that implement `ResultProvider` or define a `ResultBuilder`. + +```go +engine.TaskResult("preflight") +engine.TaskResultField("preflight", "UpdateMode") ``` ## Execution Flow @@ -79,8 +98,9 @@ engine.TaskOutput("build-task", "imageID") The `GlobalContext` maintains: - `ActionOutputs`: Results from completed actions -- `TaskOutputs`: Results from completed tasks - `ActionResults`: Rich results from actions implementing `ResultProvider` +- `TaskOutputs`: Results from completed tasks +- `TaskResults`: Rich results from tasks implementing `ResultProvider` (or using `ResultBuilder`) Context is shared across tasks via the `TaskManager` and embedded in the execution context. @@ -95,3 +115,7 @@ Context is shared across tasks via the `TaskManager` and embedded in the executi - **Mocks**: Complete mock implementations for all interfaces - **Testable Manager**: Enhanced TaskManager with testing hooks - **Performance Testing**: Built-in benchmarking and load testing utilities + +``` + +``` diff --git a/docs/QUICKSTART.md b/docs/QUICKSTART.md index d778fbf..89c1570 100644 --- a/docs/QUICKSTART.md +++ b/docs/QUICKSTART.md @@ -109,7 +109,7 @@ task := &task_engine.Task{ action := file.NewReplaceLinesAction(logger).WithParameters( task_engine.StaticParameter{Value: "/tmp/output.txt"}, map[*regexp.Regexp]task_engine.ActionParameter{ - regexp.MustCompile("old"): task_engine.ActionOutput("read-file", "content"), + regexp.MustCompile("old"): task_engine.ActionOutputField("read-file", "content"), }, ) return action diff --git a/docs/examples/README.md b/docs/examples/README.md index 3afda72..a65c531 100644 --- a/docs/examples/README.md +++ b/docs/examples/README.md @@ -36,8 +36,8 @@ err := task.Run(context.Background()) ## Key Concepts -- **Parameters**: `task_engine.ActionOutput()`, `task_engine.TaskOutput()` +- **Parameters**: `task_engine.ActionOutput()`, `task_engine.TaskOutput()`, `task_engine.ActionResult()`, `task_engine.TaskResult()` - **Global Context**: Share data between tasks using `TaskManager` -- **Output Methods**: Implement `GetOutput()` in actions +- **Output Methods & Results**: Implement `GetOutput()` in actions; implement `ResultProvider` or define a task `ResultBuilder` for rich results See [README.md](../../README.md) for quick start and [ACTIONS.md](../../ACTIONS.md) for available actions. diff --git a/docs/examples/parameter_passing_examples.md b/docs/examples/parameter_passing_examples.md index 122f578..3bf6a99 100644 --- a/docs/examples/parameter_passing_examples.md +++ b/docs/examples/parameter_passing_examples.md @@ -85,35 +85,19 @@ type ContentProcessorAction struct { } func (a *ContentProcessorAction) Execute(ctx context.Context) error { - // Get global context + // Resolve content directly using parameter helper instead of manual map parsing globalCtx, ok := ctx.Value(task_engine.GlobalContextKey).(*task_engine.GlobalContext) if !ok { return fmt.Errorf("global context not found") } - - // Get content from read action - readOutput, exists := globalCtx.ActionOutputs["read-source-file"] - if !exists { - return fmt.Errorf("read action output not found") - } - - // Extract content - readOutputMap, ok := readOutput.(map[string]interface{}) - if !ok { - return fmt.Errorf("read action output is not a map") - } - - content, exists := readOutputMap["content"] - if !exists { - return fmt.Errorf("content field not found") + v, err := task_engine.ActionOutputField("read-source-file", "content").Resolve(ctx, globalCtx) + if err != nil { + return err } - - // Process content (convert to uppercase) - contentBytes, ok := content.([]byte) + contentBytes, ok := v.([]byte) if !ok { return fmt.Errorf("content is not []byte") } - a.processedContent = bytes.ToUpper(contentBytes) return nil } diff --git a/interface.go b/interface.go index b3f0f54..1bd756f 100644 --- a/interface.go +++ b/interface.go @@ -32,3 +32,11 @@ type ResultProvider interface { GetResult() interface{} GetError() error } + +// TaskWithResults interface for tasks that can optionally provide rich results +// Combines the task lifecycle with the ability to provide results and errors +// after execution, mirroring ActionWithResults. +type TaskWithResults interface { + TaskInterface + ResultProvider +} diff --git a/parameters.go b/parameters.go new file mode 100644 index 0000000..1de3295 --- /dev/null +++ b/parameters.go @@ -0,0 +1,419 @@ +package task_engine + +import ( + "context" + "fmt" + "reflect" + "strings" +) + +// ActionParameter interface for all parameter types that can be resolved at runtime +// to provide values for action execution. Parameters support references to outputs +// from other actions, tasks, or static values. +type ActionParameter interface { + // Resolve returns the actual value for this parameter by looking up + // references in the global context or returning static values. + Resolve(ctx context.Context, globalContext *GlobalContext) (interface{}, error) +} + +// StaticParameter represents a fixed value that doesn't need resolution. +// Use this for values known at task creation time. +type StaticParameter struct { + Value interface{} // The static value to use +} + +func (p StaticParameter) Resolve(ctx context.Context, globalContext *GlobalContext) (interface{}, error) { + return p.Value, nil +} + +// ActionOutputParameter references output from a specific action. +// Use this to pass data between actions within the same task. +type ActionOutputParameter struct { + ActionID string // Required: ID of the action to reference + OutputKey string // Optional: specific output field to extract (omit for entire output) +} + +func (p ActionOutputParameter) Resolve(ctx context.Context, globalContext *GlobalContext) (interface{}, error) { + if p.ActionID == "" { + return nil, fmt.Errorf("ActionOutputParameter: ActionID cannot be empty") + } + + output, exists := globalContext.ActionOutputs[p.ActionID] + if !exists { + return nil, fmt.Errorf("ActionOutputParameter: action '%s' not found in context", p.ActionID) + } + + if p.OutputKey != "" { + // Validate OutputKey exists in output + if outputMap, ok := output.(map[string]interface{}); ok { + if value, exists := outputMap[p.OutputKey]; exists { + return value, nil + } + return nil, fmt.Errorf("ActionOutputParameter: output key '%s' not found in action '%s'", p.OutputKey, p.ActionID) + } + return nil, fmt.Errorf("ActionOutputParameter: action '%s' output is not a map, cannot extract key '%s'", p.ActionID, p.OutputKey) + } + + return output, nil +} + +// ActionResultParameter references results from actions implementing ResultProvider +type ActionResultParameter struct { + ActionID string // Required: ID of the action to reference + ResultKey string // Optional: specific result field to extract +} + +func (p ActionResultParameter) Resolve(ctx context.Context, globalContext *GlobalContext) (interface{}, error) { + if p.ActionID == "" { + return nil, fmt.Errorf("ActionResultParameter: ActionID cannot be empty") + } + + resultProvider, exists := globalContext.ActionResults[p.ActionID] + if !exists { + return nil, fmt.Errorf("ActionResultParameter: action '%s' not found in context", p.ActionID) + } + + result := resultProvider.GetResult() + if p.ResultKey != "" { + // Extract specific field from result + if resultMap, ok := result.(map[string]interface{}); ok { + if value, exists := resultMap[p.ResultKey]; exists { + return value, nil + } + return nil, fmt.Errorf("ActionResultParameter: result key '%s' not found in action '%s'", p.ResultKey, p.ActionID) + } + return nil, fmt.Errorf("ActionResultParameter: action '%s' result is not a map, cannot extract key '%s'", p.ActionID, p.ResultKey) + } + + return result, nil +} + +// TaskResultParameter references results from tasks implementing ResultProvider +type TaskResultParameter struct { + TaskID string // Required: ID of the task to reference + ResultKey string // Optional: specific result field to extract +} + +func (p TaskResultParameter) Resolve(ctx context.Context, globalContext *GlobalContext) (interface{}, error) { + if p.TaskID == "" { + return nil, fmt.Errorf("TaskResultParameter: TaskID cannot be empty") + } + + resultProvider, exists := globalContext.TaskResults[p.TaskID] + if !exists { + return nil, fmt.Errorf("TaskResultParameter: task '%s' not found in context", p.TaskID) + } + + result := resultProvider.GetResult() + if p.ResultKey != "" { + if resultMap, ok := result.(map[string]interface{}); ok { + if value, exists := resultMap[p.ResultKey]; exists { + return value, nil + } + return nil, fmt.Errorf("TaskResultParameter: result key '%s' not found in task '%s'", p.ResultKey, p.TaskID) + } + return nil, fmt.Errorf("TaskResultParameter: task '%s' result is not a map, cannot extract key '%s'", p.TaskID, p.ResultKey) + } + + return result, nil +} + +// TaskOutputParameter references output from a specific task +type TaskOutputParameter struct { + TaskID string // Required: ID of the task to reference + OutputKey string // Optional: specific output field to extract +} + +func (p TaskOutputParameter) Resolve(ctx context.Context, globalContext *GlobalContext) (interface{}, error) { + if p.TaskID == "" { + return nil, fmt.Errorf("TaskOutputParameter: TaskID cannot be empty") + } + + output, exists := globalContext.TaskOutputs[p.TaskID] + if !exists { + return nil, fmt.Errorf("TaskOutputParameter: task '%s' not found in context", p.TaskID) + } + + if p.OutputKey != "" { + // Extract specific field from output + if outputMap, ok := output.(map[string]interface{}); ok { + if value, exists := outputMap[p.OutputKey]; exists { + return value, nil + } + return nil, fmt.Errorf("TaskOutputParameter: output key '%s' not found in task '%s'", p.OutputKey, p.TaskID) + } + return nil, fmt.Errorf("TaskOutputParameter: task '%s' output is not a map, cannot extract key '%s'", p.TaskID, p.OutputKey) + } + + return output, nil +} + +// EntityOutputParameter references output from any entity (action or task) +type EntityOutputParameter struct { + EntityType string // Required: "action" or "task" + EntityID string // Required: ID of the entity to reference + OutputKey string // Optional: specific output field to extract +} + +func (p EntityOutputParameter) Resolve(ctx context.Context, globalContext *GlobalContext) (interface{}, error) { + if p.EntityType == "" || p.EntityID == "" { + return nil, fmt.Errorf("EntityOutputParameter: EntityType and EntityID cannot be empty") + } + + const ( + entityTypeAction = "action" + entityTypeTask = "task" + ) + + switch p.EntityType { + case entityTypeAction: + // Try ActionOutputs first + if output, exists := globalContext.ActionOutputs[p.EntityID]; exists { + if p.OutputKey != "" { + if outputMap, ok := output.(map[string]interface{}); ok { + if value, exists := outputMap[p.OutputKey]; exists { + return value, nil + } + return nil, fmt.Errorf("EntityOutputParameter: output key '%s' not found in action '%s'", p.OutputKey, p.EntityID) + } + return nil, fmt.Errorf("EntityOutputParameter: action '%s' output is not a map, cannot extract key '%s'", p.EntityID, p.OutputKey) + } + return output, nil + } + // Try ActionResults if ActionOutputs doesn't have it + if resultProvider, exists := globalContext.ActionResults[p.EntityID]; exists { + result := resultProvider.GetResult() + if p.OutputKey != "" { + if resultMap, ok := result.(map[string]interface{}); ok { + if value, exists := resultMap[p.OutputKey]; exists { + return value, nil + } + return nil, fmt.Errorf("EntityOutputParameter: result key '%s' not found in action '%s'", p.OutputKey, p.EntityID) + } + return nil, fmt.Errorf("EntityOutputParameter: action '%s' result is not a map, cannot extract key '%s'", p.EntityID, p.OutputKey) + } + return result, nil + } + return nil, fmt.Errorf("EntityOutputParameter: action '%s' not found in context", p.EntityID) + + case entityTypeTask: + // Try TaskOutputs first + if output, exists := globalContext.TaskOutputs[p.EntityID]; exists { + if p.OutputKey != "" { + if outputMap, ok := output.(map[string]interface{}); ok { + if value, exists := outputMap[p.OutputKey]; exists { + return value, nil + } + return nil, fmt.Errorf("EntityOutputParameter: output key '%s' not found in task '%s'", p.OutputKey, p.EntityID) + } + return nil, fmt.Errorf("EntityOutputParameter: task '%s' output is not a map, cannot extract key '%s'", p.EntityID, p.OutputKey) + } + return output, nil + } + // Try TaskResults if TaskOutputs doesn't have it + if resultProvider, exists := globalContext.TaskResults[p.EntityID]; exists { + result := resultProvider.GetResult() + if p.OutputKey != "" { + if resultMap, ok := result.(map[string]interface{}); ok { + if value, exists := resultMap[p.OutputKey]; exists { + return value, nil + } + return nil, fmt.Errorf("EntityOutputParameter: result key '%s' not found in task '%s'", p.OutputKey, p.EntityID) + } + return nil, fmt.Errorf("EntityOutputParameter: task '%s' result is not a map, cannot extract key '%s'", p.EntityID, p.OutputKey) + } + return result, nil + } + return nil, fmt.Errorf("EntityOutputParameter: task '%s' not found in context", p.EntityID) + + default: + return nil, fmt.Errorf("EntityOutputParameter: invalid entity type '%s', must be 'action' or 'task'", p.EntityType) + } +} + +// --- Typed parameter resolution helpers --- + +// ResolveString resolves an ActionParameter to a string with helpful +// conversions and clear error messages. When the parameter is nil, +// it returns an empty string without error. +func ResolveString(ctx context.Context, p ActionParameter, globalContext *GlobalContext) (string, error) { + if p == nil { + return "", nil + } + v, err := p.Resolve(ctx, globalContext) + if err != nil { + return "", err + } + switch t := v.(type) { + case string: + return t, nil + case []byte: + return string(t), nil + case fmt.Stringer: + return t.String(), nil + case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64, bool: + return fmt.Sprint(v), nil + default: + return "", fmt.Errorf("parameter is not a string, got %T", v) + } +} + +// ResolveBool resolves an ActionParameter to a bool with common coercions. +// If parameter is nil, returns false. +func ResolveBool(ctx context.Context, p ActionParameter, globalContext *GlobalContext) (bool, error) { + if p == nil { + return false, nil + } + v, err := p.Resolve(ctx, globalContext) + if err != nil { + return false, err + } + switch t := v.(type) { + case bool: + return t, nil + case string: + s := strings.TrimSpace(strings.ToLower(t)) + if s == "true" || s == "1" || s == "yes" || s == "y" { // common truthy strings + return true, nil + } + if s == "false" || s == "0" || s == "no" || s == "n" { + return false, nil + } + return false, fmt.Errorf("cannot convert string '%s' to bool", t) + case int: + return t != 0, nil + case int64: + return t != 0, nil + case uint: + return t != 0, nil + default: + return false, fmt.Errorf("parameter is not a bool, got %T", v) + } +} + +// ResolveStringSlice resolves an ActionParameter into a []string. +// Accepts []string directly, or splits a string by comma or spaces. +func ResolveStringSlice(ctx context.Context, p ActionParameter, globalContext *GlobalContext) ([]string, error) { + if p == nil { + return nil, nil + } + v, err := p.Resolve(ctx, globalContext) + if err != nil { + return nil, err + } + switch t := v.(type) { + case []string: + return t, nil + case string: + s := strings.TrimSpace(t) + if s == "" { + return []string{}, nil + } + if strings.Contains(s, ",") { + parts := strings.Split(s, ",") + out := make([]string, 0, len(parts)) + for _, p := range parts { + p = strings.TrimSpace(p) + if p != "" { + out = append(out, p) + } + } + return out, nil + } + return strings.Fields(s), nil + default: + return nil, fmt.Errorf("parameter is not a string slice or string, got %T", v) + } +} + +// ResolveAs provides a generic typed resolver using existing parameter resolution. +func ResolveAs[T any](ctx context.Context, p ActionParameter, globalContext *GlobalContext) (T, error) { + var zero T + if p == nil { + return zero, nil + } + v, err := p.Resolve(ctx, globalContext) + if err != nil { + return zero, err + } + out, ok := v.(T) + if !ok { + return zero, fmt.Errorf("expected %T, got %T", zero, v) + } + return out, nil +} + +// Helper functions for common parameter patterns +// ActionOutput creates a parameter reference to an entire action output +func ActionOutput(actionID string) ActionOutputParameter { + return ActionOutputParameter{ActionID: actionID} +} + +// ActionOutputField creates a parameter reference to a specific field in an action output +func ActionOutputField(actionID, field string) ActionOutputParameter { + return ActionOutputParameter{ActionID: actionID, OutputKey: field} +} + +// ActionResult creates a parameter reference to an action result (for ResultProvider actions) +func ActionResult(actionID string) ActionResultParameter { + return ActionResultParameter{ActionID: actionID} +} + +// ActionResultField creates a parameter reference to a specific field in an action result +func ActionResultField(actionID, field string) ActionResultParameter { + return ActionResultParameter{ActionID: actionID, ResultKey: field} +} + +// TaskOutput creates a parameter reference to an entire task output +func TaskOutput(taskID string) TaskOutputParameter { + return TaskOutputParameter{TaskID: taskID} +} + +// TaskOutputField creates a parameter reference to a specific field in a task output +func TaskOutputField(taskID, field string) TaskOutputParameter { + return TaskOutputParameter{TaskID: taskID, OutputKey: field} +} + +// TaskResult creates a parameter reference to an entire task result (for ResultProvider tasks) +func TaskResult(taskID string) TaskResultParameter { + return TaskResultParameter{TaskID: taskID} +} + +// TaskResultField creates a parameter reference to a specific field in a task result +func TaskResultField(taskID, field string) TaskResultParameter { + return TaskResultParameter{TaskID: taskID, ResultKey: field} +} + +// EntityOutput creates a parameter reference to any entity type (action or task) +func EntityOutput(entityType, entityID string) EntityOutputParameter { + return EntityOutputParameter{EntityType: entityType, EntityID: entityID} +} + +// EntityOutputField creates a parameter reference to a specific field in any entity output +func EntityOutputField(entityType, entityID, field string) EntityOutputParameter { + return EntityOutputParameter{EntityType: entityType, EntityID: entityID, OutputKey: field} +} + +// TypedOutputKey provides a way to associate an output field name with an expected +// struct type T. Validate can be used to check that the field exists on T at runtime. +// Note: This is a runtime validation helper; compile-time validation would require codegen. +// TypedOutputKey provides compile-time validation of output keys for type-safe +// parameter references. Use this when you want to ensure output keys exist +// in your output types at compile time. +type TypedOutputKey[T any] struct { + ActionID string // ID of the action to reference + Key string // Field name to extract from the output +} + +// Validate checks whether Key is a valid exported field on T when T is a struct. +// If T is not a struct, Validate returns nil (no validation performed). +func (k TypedOutputKey[T]) Validate() error { + t := reflect.TypeOf((*T)(nil)).Elem() + if t.Kind() != reflect.Struct { + return nil + } + if _, exists := t.FieldByName(k.Key); !exists { + return fmt.Errorf("field '%s' does not exist on output type %s", k.Key, t.Name()) + } + return nil +} diff --git a/task.go b/task.go index 3ca5db4..7c9e3e4 100644 --- a/task.go +++ b/task.go @@ -25,6 +25,11 @@ type Task struct { TotalTime time.Duration CompletedTasks int mu sync.Mutex // protects concurrent access to TotalTime and CompletedTasks + // ResultProvider support + executionError error + customResult interface{} + // Optional: build a custom task result from accumulated action outputs + ResultBuilder func(ctx *TaskContext) (interface{}, error) } // TaskContext maintains execution context for a single task @@ -76,6 +81,10 @@ func (t *Task) RunWithContext(ctx context.Context, globalContext *GlobalContext) select { case <-ctx.Done(): t.log("Task canceled", "taskID", t.ID, "runID", runID, "reason", ctx.Err()) + t.SetError(ctx.Err()) + // Ensure task output and result provider are stored even on cancellation + t.storeTaskOutput(globalContext) + t.storeTaskResultIfAbsent(globalContext) return ctx.Err() default: // Execute action @@ -88,9 +97,17 @@ func (t *Task) RunWithContext(ctx context.Context, globalContext *GlobalContext) if execErr != nil { if errors.Is(execErr, ErrPrerequisiteNotMet) { t.log("Task aborted: prerequisite not met", "taskID", t.ID, "runID", runID, "actionID", action.GetID(), "error", execErr) + t.SetError(execErr) + // Store task output and result provider on failure + t.storeTaskOutput(globalContext) + t.storeTaskResultIfAbsent(globalContext) return fmt.Errorf("task %s (run %s) aborted: prerequisite not met in action %s: %w", t.ID, runID, action.GetID(), execErr) } else { t.log("Task failed: action execution error", "taskID", t.ID, "runID", runID, "actionID", action.GetID(), "error", execErr) + t.SetError(execErr) + // Store task output and result provider on failure + t.storeTaskOutput(globalContext) + t.storeTaskResultIfAbsent(globalContext) return fmt.Errorf("task %s (run %s) failed at action %s: %w", t.ID, runID, action.GetID(), execErr) } } @@ -108,8 +125,19 @@ func (t *Task) RunWithContext(ctx context.Context, globalContext *GlobalContext) t.mu.Unlock() } - // Store task output in global context + // Build custom result if a ResultBuilder is provided + if t.ResultBuilder != nil { + res, buildErr := t.ResultBuilder(taskContext) + if buildErr != nil { + t.SetError(buildErr) + } else if res != nil { + t.SetResult(res) + } + } + + // Store task output and result provider in global context on success t.storeTaskOutput(globalContext) + t.storeTaskResultIfAbsent(globalContext) t.log("Task completed", "taskID", t.ID, "runID", runID, "totalDuration", t.GetTotalTime()) return nil @@ -154,7 +182,10 @@ func (t *Task) storeTaskOutput(globalContext *GlobalContext) { "name": t.Name, "totalTime": t.TotalTime, "completedTasks": t.CompletedTasks, - "success": true, + "success": t.GetError() == nil, + } + if err := t.GetError(); err != nil { + taskOutput["error"] = err.Error() } globalContext.StoreTaskOutput(t.ID, taskOutput) @@ -210,3 +241,67 @@ func (t *Task) GetID() string { func (t *Task) GetName() string { return t.Name } + +// storeTaskResultIfAbsent stores this task as the task-level ResultProvider only +// if a TaskResult has not already been set by an action during execution. +func (t *Task) storeTaskResultIfAbsent(globalContext *GlobalContext) { + globalContext.mu.RLock() + _, exists := globalContext.TaskResults[t.ID] + globalContext.mu.RUnlock() + if exists { + return + } + globalContext.StoreTaskResult(t.ID, t) +} + +// SetResult allows setting a custom result payload for this task +func (t *Task) SetResult(result interface{}) { + t.mu.Lock() + defer t.mu.Unlock() + t.customResult = result +} + +// GetResult returns either a custom result if provided, or a default summary +// of the task execution as a map[string]interface{}. +func (t *Task) GetResult() interface{} { + t.mu.Lock() + result := t.customResult + runID := t.RunID + id := t.ID + name := t.Name + total := t.TotalTime + completed := t.CompletedTasks + err := t.executionError + t.mu.Unlock() + + if result != nil { + return result + } + + out := map[string]interface{}{ + "taskID": id, + "runID": runID, + "name": name, + "totalTime": total, + "completedTasks": completed, + "success": err == nil, + } + if err != nil { + out["error"] = err.Error() + } + return out +} + +// SetError stores an execution error for the task +func (t *Task) SetError(err error) { + t.mu.Lock() + defer t.mu.Unlock() + t.executionError = err +} + +// GetError returns the stored execution error for the task +func (t *Task) GetError() error { + t.mu.Lock() + defer t.mu.Unlock() + return t.executionError +} diff --git a/task_engine_test.go b/task_engine_test.go index 1f5462c..9fd2479 100644 --- a/task_engine_test.go +++ b/task_engine_test.go @@ -80,6 +80,12 @@ func (a *AfterExecuteFailingAction) Execute(ctx context.Context) error { return nil } +// testResultProvider is a minimal ResultProvider for tests +type testResultProvider struct{ v interface{} } + +func (p testResultProvider) GetResult() interface{} { return p.v } +func (p testResultProvider) GetError() error { return nil } + func (a *AfterExecuteFailingAction) AfterExecute(ctx context.Context) error { if a.ShouldFailAfter { return errors.New("simulated AfterExecute failure") @@ -294,6 +300,80 @@ func TestGlobalContext(t *testing.T) { }) } +func TestTypedGlobalContextHelpers(t *testing.T) { + gc := task_engine.NewGlobalContext() + + // Prepare action output and result + gc.StoreActionOutput("act1", map[string]interface{}{"k": 123, "s": "abc"}) + + // Simple ResultProviders + gc.StoreActionResult("actRes", testResultProvider{v: map[string]interface{}{"sum": 7}}) + gc.StoreTaskOutput("task1", map[string]interface{}{"ok": true, "n": 9}) + gc.StoreTaskResult("taskRes", testResultProvider{v: "done"}) + + // ActionOutputFieldAs + vInt, err := task_engine.ActionOutputFieldAs[int](gc, "act1", "k") + if err != nil || vInt != 123 { + t.Fatalf("expected 123, got %v, err=%v", vInt, err) + } + vStr, err := task_engine.ActionOutputFieldAs[string](gc, "act1", "s") + if err != nil || vStr != "abc" { + t.Fatalf("expected 'abc', got %v, err=%v", vStr, err) + } + + // TaskOutputFieldAs + vBool, err := task_engine.TaskOutputFieldAs[bool](gc, "task1", "ok") + if err != nil || vBool != true { + t.Fatalf("expected true, got %v, err=%v", vBool, err) + } + vNum, err := task_engine.TaskOutputFieldAs[int](gc, "task1", "n") + if err != nil || vNum != 9 { + t.Fatalf("expected 9, got %v, err=%v", vNum, err) + } + + // ActionResultAs / TaskResultAs + rmap, ok := task_engine.ActionResultAs[map[string]interface{}](gc, "actRes") + if !ok || rmap["sum"].(int) != 7 { + t.Fatalf("expected action result sum=7, got %v", rmap) + } + rstr, ok := task_engine.TaskResultAs[string](gc, "taskRes") + if !ok || rstr != "done" { + t.Fatalf("expected task result 'done', got %v", rstr) + } + + // EntityValue / EntityValueAs + if v, err := task_engine.EntityValue(gc, "action", "act1", "k"); err != nil || v.(int) != 123 { + t.Fatalf("EntityValue action k expected 123, got %v, err=%v", v, err) + } + if v, err := task_engine.EntityValue(gc, "task", "task1", "ok"); err != nil || v.(bool) != true { + t.Fatalf("EntityValue task ok expected true, got %v, err=%v", v, err) + } + if v, err := task_engine.EntityValue(gc, "action", "actRes", ""); err != nil { + t.Fatalf("EntityValue action result expected no error, got err=%v", err) + } else { + if vm, ok := v.(map[string]interface{}); !ok || vm["sum"].(int) != 7 { + t.Fatalf("EntityValue action result expected map with sum=7, got %v", v) + } + } + if s, err := task_engine.EntityValueAs[string](gc, "task", "taskRes", ""); err != nil || s != "done" { + t.Fatalf("EntityValueAs task result expected 'done', got %v, err=%v", s, err) + } +} + +func TestResolveAsGeneric(t *testing.T) { + gc := task_engine.NewGlobalContext() + gc.StoreActionOutput("act", map[string]interface{}{"name": "demo", "count": 5}) + + name, err := task_engine.ResolveAs[string](context.Background(), task_engine.ActionOutputField("act", "name"), gc) + if err != nil || name != "demo" { + t.Fatalf("expected 'demo', got %v, err=%v", name, err) + } + count, err := task_engine.ResolveAs[int](context.Background(), task_engine.ActionOutputField("act", "count"), gc) + if err != nil || count != 5 { + t.Fatalf("expected 5, got %v, err=%v", count, err) + } +} + func TestTaskWithParameterPassing(t *testing.T) { t.Run("TaskExecutionWithGlobalContext", func(t *testing.T) { logger := NewDiscardLogger() diff --git a/task_test.go b/task_test.go index 525a9b2..dc2589a 100644 --- a/task_test.go +++ b/task_test.go @@ -61,6 +61,34 @@ func newMockAction(logger *slog.Logger, name string, returnError error, executed } } +// outputAction is a simple action that returns a fixed output map +type outputAction struct { + engine.BaseAction + Output map[string]interface{} +} + +func (a *outputAction) Execute(ctx context.Context) error { return nil } +func (a *outputAction) GetOutput() interface{} { return a.Output } + +// overrideTaskResultAction sets a task-level ResultProvider during execution +type overrideTaskResultAction struct { + engine.BaseAction + TaskID string + Provider *mocks.ResultProviderMock +} + +func (a *overrideTaskResultAction) Execute(ctx context.Context) error { + if gc, ok := ctx.Value(engine.GlobalContextKey).(*engine.GlobalContext); ok { + gc.StoreTaskResult(a.TaskID, a.Provider) + } + return nil +} +func (a *overrideTaskResultAction) GetOutput() interface{} { return nil } + +type testResult struct { + Value string +} + func (suite *TaskTestSuite) TestRun_Success() { logger := mocks.NewDiscardLogger() action1Executed := false @@ -194,3 +222,178 @@ func (suite *TaskTestSuite) TestRun_ContextCancellation() { // During cancellation, we might have 0 or 1 completed tasks depending on timing assert.LessOrEqual(suite.T(), task.GetCompletedTasks(), 1, "Completed tasks should be 0 or 1 due to cancellation") } + +func (suite *TaskTestSuite) TestTask_ImplementsTaskWithResults_AndBuilderStoresCustomStruct() { + logger := mocks.NewDiscardLogger() + gc := engine.NewGlobalContext() + + // Action that produces an output used by the builder + producer := &engine.Action[*outputAction]{ + ID: "produce", + Wrapped: &outputAction{ + BaseAction: engine.BaseAction{Logger: logger}, + Output: map[string]interface{}{"field": "value-from-action"}, + }, + } + + // Task with a ResultBuilder that pulls from action output + task := &engine.Task{ + ID: "builder-task", + Name: "Builder Task", + Logger: logger, + Actions: []engine.ActionWrapper{ + producer, + }, + ResultBuilder: func(ctx *engine.TaskContext) (interface{}, error) { + v, err := engine.ActionOutputField("produce", "field").Resolve(context.Background(), ctx.GlobalContext) + if err != nil { + return nil, err + } + s, _ := v.(string) + return &testResult{Value: s}, nil + }, + } + + // Compile-time-ish and runtime checks for TaskWithResults + if _, ok := any(task).(engine.TaskWithResults); !ok { + suite.T().Fatal("Task should implement TaskWithResults") + } + + // Run and verify + assert.NoError(suite.T(), task.RunWithContext(context.Background(), gc)) + + // Fetch provider stored on task + rp, exists := gc.TaskResults[task.ID] + assert.True(suite.T(), exists) + res, ok := rp.GetResult().(*testResult) + if assert.True(suite.T(), ok) { + assert.Equal(suite.T(), "value-from-action", res.Value) + } + + // Also verify TaskResult parameter resolves the entire struct + val, err := engine.TaskResult(task.ID).Resolve(context.Background(), gc) + assert.NoError(suite.T(), err) + _, ok = val.(*testResult) + assert.True(suite.T(), ok) +} + +func (suite *TaskTestSuite) TestTask_ActionOverrideWinsOverBuilder() { + logger := mocks.NewDiscardLogger() + gc := engine.NewGlobalContext() + + // Provider to override task result + p := mocks.NewResultProviderMock() + p.SetResult("override-result") + // Ensure mock has expectations to satisfy testify's Called() + p.On("GetResult").Return("override-result") + p.On("GetError").Return(nil) + + override := &engine.Action[*overrideTaskResultAction]{ + ID: "override", + Wrapped: &overrideTaskResultAction{ + BaseAction: engine.BaseAction{Logger: logger}, + TaskID: "override-task", + Provider: p, + }, + } + + // Builder returns a different value, which should NOT replace override + task := &engine.Task{ + ID: "override-task", + Name: "Override Task", + Logger: logger, + Actions: []engine.ActionWrapper{ + override, + }, + ResultBuilder: func(ctx *engine.TaskContext) (interface{}, error) { + return &testResult{Value: "builder-result"}, nil + }, + } + + assert.NoError(suite.T(), task.RunWithContext(context.Background(), gc)) + + rp, exists := gc.TaskResults[task.ID] + assert.True(suite.T(), exists) + assert.Equal(suite.T(), "override-result", rp.GetResult()) + + // Parameter-based fetch of entire result + val, err := engine.TaskResult(task.ID).Resolve(context.Background(), gc) + assert.NoError(suite.T(), err) + assert.Equal(suite.T(), "override-result", val) +} + +// Minimal pattern test: +// - two actions produce a string and an int as outputs +// - the task aggregates into a single struct via ResultBuilder +// - test fetches the struct from task results and asserts values +func (suite *TaskTestSuite) TestTask_SimpleResultAggregation() { + logger := mocks.NewDiscardLogger() + gc := engine.NewGlobalContext() + + strAction := &engine.Action[*outputAction]{ + ID: "string-action", + Wrapped: &outputAction{ + BaseAction: engine.BaseAction{Logger: logger}, + Output: map[string]interface{}{"text": "hello"}, + }, + } + + intAction := &engine.Action[*outputAction]{ + ID: "int-action", + Wrapped: &outputAction{ + BaseAction: engine.BaseAction{Logger: logger}, + Output: map[string]interface{}{"num": 42}, + }, + } + + type simpleResult struct { + Text string + Num int + } + + task := &engine.Task{ + ID: "simple-aggregate-task", + Name: "Simple Aggregate", + Logger: logger, + Actions: []engine.ActionWrapper{ + strAction, + intAction, + }, + ResultBuilder: func(ctx *engine.TaskContext) (interface{}, error) { + gc := ctx.GlobalContext + res := &simpleResult{} + + if v, err := engine.ActionOutputField("string-action", "text").Resolve(context.Background(), gc); err == nil { + if s, ok := v.(string); ok { + res.Text = s + } + } + if v, err := engine.ActionOutputField("int-action", "num").Resolve(context.Background(), gc); err == nil { + // handle both int and numeric types gracefully for the test + switch n := v.(type) { + case int: + res.Num = n + case int32: + res.Num = int(n) + case int64: + res.Num = int(n) + case float64: + res.Num = int(n) + } + } + + return res, nil + }, + } + + assert.NoError(suite.T(), task.RunWithContext(context.Background(), gc)) + + rp, exists := gc.TaskResults[task.ID] + assert.True(suite.T(), exists) + if out, ok := rp.GetResult().(*simpleResult); ok { + assert.Equal(suite.T(), "hello", out.Text) + assert.Equal(suite.T(), 42, out.Num) + } else { + suite.T().Fatal("unexpected result type") + } +}