diff --git a/app/services/localtaskstore/execution_test.go b/app/services/localtaskstore/execution_test.go new file mode 100644 index 0000000..64a593a --- /dev/null +++ b/app/services/localtaskstore/execution_test.go @@ -0,0 +1,68 @@ +package localtaskstore + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRecordReceivedSurvivesRestart(t *testing.T) { + storePath := filepath.Join(t.TempDir(), "task_store.db") + store := openTestStore(t, storePath, 1024*1024, 1024) + + received, err := store.RecordReceived(TaskReceipt{ + TaskID: "task-1", + ExecutionAttemptID: "attempt-1", + }) + require.NoError(t, err) + require.Equal(t, TaskStatusReceived, received.Status) + require.NoError(t, store.Close()) + + reopened := openTestStore(t, storePath, 1024*1024, 1024) + state, err := reopened.TaskState("task-1", "attempt-1") + require.NoError(t, err) + require.True(t, state.Exists) + require.Equal(t, TaskStatusReceived, state.Status) +} + +func TestRecordReceivedReturnsExistingDuplicateState(t *testing.T) { + store := newTestStore(t, 1024*1024, 1024) + receipt := TaskReceipt{TaskID: "task-1", ExecutionAttemptID: "attempt-1"} + + first, err := store.RecordReceived(receipt) + require.NoError(t, err) + second, err := store.RecordReceived(receipt) + require.NoError(t, err) + + require.True(t, second.Exists) + require.Equal(t, first.ID, second.ID) + require.Equal(t, TaskStatusReceived, second.Status) +} + +func TestTaskStateTreatsNewAttemptAsDistinct(t *testing.T) { + store := newTestStore(t, 1024*1024, 1024) + + _, err := store.RecordReceived(TaskReceipt{TaskID: "task-1", ExecutionAttemptID: "attempt-1"}) + require.NoError(t, err) + + state, err := store.TaskState("task-1", "attempt-2") + require.NoError(t, err) + require.False(t, state.Exists) +} + +func openTestStore(t *testing.T, path string, spoolCapBytes, terminalReserveBytes int64) *Store { + t.Helper() + + store, err := New(Config{ + Path: path, + SpoolCapBytes: spoolCapBytes, + TerminalReserveBytes: terminalReserveBytes, + }) + require.NoError(t, err) + t.Cleanup(func() { + _ = store.Close() + }) + + return store +} diff --git a/app/services/localtaskstore/outbox_test.go b/app/services/localtaskstore/outbox_test.go new file mode 100644 index 0000000..17dedae --- /dev/null +++ b/app/services/localtaskstore/outbox_test.go @@ -0,0 +1,87 @@ +package localtaskstore + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestRecordStartedSurvivesRestart(t *testing.T) { + storePath := filepath.Join(t.TempDir(), "task_store.db") + store := openTestStore(t, storePath, 1024*1024, 1024) + + _, err := store.RecordReceived(TaskReceipt{TaskID: "task-1", ExecutionAttemptID: "attempt-1"}) + require.NoError(t, err) + require.NoError(t, store.RecordStarted("task-1", "attempt-1")) + require.NoError(t, store.Close()) + + reopened := openTestStore(t, storePath, 1024*1024, 1024) + state, err := reopened.TaskState("task-1", "attempt-1") + require.NoError(t, err) + require.Equal(t, TaskStatusRunning, state.Status) +} + +func TestRecordFinalCreatesUnackedTerminalMessageAcrossRestart(t *testing.T) { + storePath := filepath.Join(t.TempDir(), "task_store.db") + store := openTestStore(t, storePath, 1024*1024, 1024) + + require.NoError(t, store.RecordFinal(FinalResult{ + MessageID: "msg-final-1", + TaskID: "task-1", + ExecutionAttemptID: "attempt-1", + Status: "completed", + ExitCode: 0, + Payload: `{"status":"completed"}`, + })) + require.NoError(t, store.Close()) + + reopened := openTestStore(t, storePath, 1024*1024, 1024) + messages, err := reopened.UnackedMessages() + require.NoError(t, err) + require.Len(t, messages, 1) + require.Equal(t, "msg-final-1", messages[0].MessageID) + require.Equal(t, OutboxMessageTypeFinal, messages[0].Type) +} + +func TestAckMessageRemovesOutputChunkFromResendQueue(t *testing.T) { + store := newTestStore(t, 1024*1024, 1024) + + require.NoError(t, store.AppendOutputChunk(OutputChunk{ + MessageID: "msg-output-1", + TaskID: "task-1", + ExecutionAttemptID: "attempt-1", + Stream: "stdout", + Sequence: 1, + Payload: "hello", + ByteCount: 5, + })) + require.NoError(t, store.AckMessage("msg-output-1")) + + messages, err := store.UnackedMessages() + require.NoError(t, err) + require.Empty(t, messages) +} + +func TestAckFinalRemovesOutboxButPreservesTaskState(t *testing.T) { + store := newTestStore(t, 1024*1024, 1024) + + require.NoError(t, store.RecordFinal(FinalResult{ + MessageID: "msg-final-1", + TaskID: "task-1", + ExecutionAttemptID: "attempt-1", + Status: "completed", + ExitCode: 0, + Payload: `{"status":"completed"}`, + })) + require.NoError(t, store.AckMessage("msg-final-1")) + + messages, err := store.UnackedMessages() + require.NoError(t, err) + require.Empty(t, messages) + + state, err := store.TaskState("task-1", "attempt-1") + require.NoError(t, err) + require.True(t, state.Exists) + require.Equal(t, TaskStatusFinal, state.Status) +} diff --git a/app/services/localtaskstore/recovery_test.go b/app/services/localtaskstore/recovery_test.go new file mode 100644 index 0000000..e1bb0b3 --- /dev/null +++ b/app/services/localtaskstore/recovery_test.go @@ -0,0 +1,54 @@ +package localtaskstore + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestSnapshotIncludesLocalTruncationState(t *testing.T) { + store := newTestStore(t, 16, 4) + appendChunk(t, store, "msg-1", "task-1", 1, "12345") + appendChunk(t, store, "msg-2", "task-1", 2, "67890") + appendChunk(t, store, "msg-3", "task-1", 3, "abcde") + + snapshot, err := store.Snapshot() + require.NoError(t, err) + require.Len(t, snapshot.Tasks, 1) + require.Equal(t, "task-1", snapshot.Tasks[0].TaskID) + require.True(t, snapshot.Tasks[0].LocalOutputTruncated) +} + +func TestMarkInterruptedRunningTasksQueuesTerminalRecordAcrossRestart(t *testing.T) { + storePath := filepath.Join(t.TempDir(), "task_store.db") + store := openTestStore(t, storePath, 1024*1024, 1024) + + require.NoError(t, store.RecordStarted("task-1", "attempt-1")) + require.NoError(t, store.Close()) + + reopened := openTestStore(t, storePath, 1024*1024, 1024) + require.NoError(t, reopened.MarkInterruptedRunningTasks()) + + state, err := reopened.TaskState("task-1", "attempt-1") + require.NoError(t, err) + require.Equal(t, TaskStatusInterrupted, state.Status) + + messages, err := reopened.UnackedMessages() + require.NoError(t, err) + require.Len(t, messages, 1) + require.Equal(t, OutboxMessageTypeFinal, messages[0].Type) + require.Contains(t, messages[0].Payload, "interrupted") +} + +func TestMarkInterruptedRunningTasksIsIdempotent(t *testing.T) { + store := newTestStore(t, 1024*1024, 1024) + + require.NoError(t, store.RecordStarted("task-1", "attempt-1")) + require.NoError(t, store.MarkInterruptedRunningTasks()) + require.NoError(t, store.MarkInterruptedRunningTasks()) + + messages, err := store.UnackedMessages() + require.NoError(t, err) + require.Len(t, messages, 1) +} diff --git a/app/services/localtaskstore/spool_test.go b/app/services/localtaskstore/spool_test.go new file mode 100644 index 0000000..e5514cc --- /dev/null +++ b/app/services/localtaskstore/spool_test.go @@ -0,0 +1,81 @@ +package localtaskstore + +import ( + "errors" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestAppendOutputChunkRotatesOldestChunksAndMarksTruncated(t *testing.T) { + store := newTestStore(t, 16, 4) + + appendChunk(t, store, "msg-1", "task-1", 1, "12345") + appendChunk(t, store, "msg-2", "task-1", 2, "67890") + appendChunk(t, store, "msg-3", "task-1", 3, "abcde") + + messages, err := store.UnackedMessages() + require.NoError(t, err) + require.Len(t, messages, 2) + require.Equal(t, "msg-2", messages[0].MessageID) + require.Equal(t, "msg-3", messages[1].MessageID) + + state, err := store.TaskState("task-1", "attempt-1") + require.NoError(t, err) + require.True(t, state.LocalOutputTruncated) +} + +func TestRecordFinalPreservedUnderChunkCapPressure(t *testing.T) { + store := newTestStore(t, 14, 4) + + appendChunk(t, store, "msg-1", "task-1", 1, "12345") + appendChunk(t, store, "msg-2", "task-1", 2, "67890") + require.NoError(t, store.RecordFinal(FinalResult{ + MessageID: "msg-final-1", + TaskID: "task-1", + ExecutionAttemptID: "attempt-1", + Status: "completed", + Payload: "done", + })) + + messages, err := store.UnackedMessages() + require.NoError(t, err) + require.Contains(t, messageIDs(messages), "msg-final-1") +} + +func TestRecordReceivedFailsWhenTerminalReserveUnavailable(t *testing.T) { + store := newTestStore(t, 8, 4) + require.NoError(t, store.RecordFinal(FinalResult{ + MessageID: "msg-final-1", + TaskID: "task-1", + ExecutionAttemptID: "attempt-1", + Status: "completed", + Payload: "12345", + })) + + _, err := store.RecordReceived(TaskReceipt{TaskID: "task-2", ExecutionAttemptID: "attempt-1"}) + require.Error(t, err) + require.True(t, errors.Is(err, ErrTerminalReserveUnavailable)) +} + +func messageIDs(messages []OutboxMessage) []string { + ids := make([]string, 0, len(messages)) + for _, message := range messages { + ids = append(ids, message.MessageID) + } + return ids +} + +func appendChunk(t *testing.T, store *Store, messageID, taskID string, sequence int64, payload string) { + t.Helper() + + require.NoError(t, store.AppendOutputChunk(OutputChunk{ + MessageID: messageID, + TaskID: taskID, + ExecutionAttemptID: "attempt-1", + Stream: "stdout", + Sequence: sequence, + Payload: payload, + ByteCount: int64(len(payload)), + })) +} diff --git a/app/services/localtaskstore/store.go b/app/services/localtaskstore/store.go new file mode 100644 index 0000000..460206e --- /dev/null +++ b/app/services/localtaskstore/store.go @@ -0,0 +1,625 @@ +package localtaskstore + +import ( + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "time" + + "hostlink/config/appconf" + + "github.com/glebarez/sqlite" + "gorm.io/gorm" +) + +var ErrTerminalReserveUnavailable = errors.New("terminal reserve unavailable") + +const ( + TaskStatusReceived = "received" + TaskStatusRunning = "running" + TaskStatusFinal = "final" + TaskStatusInterrupted = "interrupted" + + OutboxMessageTypeOutput = "output" + OutboxMessageTypeFinal = "final" +) + +type TaskReceipt struct { + TaskID string + ExecutionAttemptID string +} + +type TaskState struct { + ID uint + Exists bool + TaskID string + ExecutionAttemptID string + Status string + ExitCode int + OutputTruncated bool + ErrorTruncated bool + LocalOutputTruncated bool +} + +type OutputChunk struct { + MessageID string + TaskID string + ExecutionAttemptID string + Stream string + Sequence int64 + Payload string + ByteCount int64 +} + +type FinalResult struct { + MessageID string + TaskID string + ExecutionAttemptID string + Status string + ExitCode int + Payload string + OutputTruncated bool + ErrorTruncated bool +} + +type OutboxMessage struct { + MessageID string + TaskID string + ExecutionAttemptID string + Type string + Stream string + Sequence int64 + Payload string + ByteCount int64 +} + +type Snapshot struct { + Tasks []TaskState +} + +type Config struct { + Path string + SpoolCapBytes int64 + TerminalReserveBytes int64 +} + +type ReceiptStore interface { + RecordReceived(TaskReceipt) (TaskState, error) + TaskState(taskID, executionAttemptID string) (TaskState, error) +} + +type ResultOutbox interface { + AppendOutputChunk(OutputChunk) error + RecordFinal(FinalResult) error + UnackedMessages() ([]OutboxMessage, error) + AckMessage(messageID string) error +} + +type RecoveryStore interface { + Snapshot() (Snapshot, error) + MarkInterruptedRunningTasks() error +} + +type Store struct { + db *gorm.DB + spoolCapBytes int64 + terminalReserveBytes int64 +} + +type taskExecutionRecord struct { + ID uint `gorm:"primaryKey"` + TaskID string + ExecutionAttemptID string + Status string + ExitCode int + OutputTruncated bool + ErrorTruncated bool + LocalOutputTruncated bool + CreatedAt time.Time + UpdatedAt time.Time +} + +func (taskExecutionRecord) TableName() string { + return "local_task_executions" +} + +type outboxMessageRecord struct { + ID uint `gorm:"primaryKey"` + MessageID string + TaskID string + ExecutionAttemptID string + Type string + Stream string + Sequence int64 + Payload string + ByteCount int64 + AckedAt *time.Time + CreatedAt time.Time + UpdatedAt time.Time +} + +func (outboxMessageRecord) TableName() string { + return "local_task_outbox_messages" +} + +func New(cfg Config) (*Store, error) { + if cfg.Path == "" { + return nil, fmt.Errorf("local task store path is required") + } + if cfg.SpoolCapBytes <= 0 { + return nil, fmt.Errorf("local task store spool cap must be positive") + } + if cfg.TerminalReserveBytes <= 0 { + return nil, fmt.Errorf("local task store terminal reserve must be positive") + } + if cfg.TerminalReserveBytes > cfg.SpoolCapBytes { + return nil, fmt.Errorf("local task store terminal reserve cannot exceed spool cap") + } + + if err := os.MkdirAll(filepath.Dir(cfg.Path), 0700); err != nil { + return nil, fmt.Errorf("create local task store directory: %w", err) + } + + db, err := gorm.Open(sqlite.Open(cfg.Path), &gorm.Config{}) + if err != nil { + return nil, fmt.Errorf("open local task store: %w", err) + } + + store := &Store{ + db: db, + spoolCapBytes: cfg.SpoolCapBytes, + terminalReserveBytes: cfg.TerminalReserveBytes, + } + if err := store.migrate(); err != nil { + _ = store.Close() + return nil, err + } + + return store, nil +} + +func NewDefault() (*Store, error) { + return New(Config{ + Path: appconf.LocalTaskStorePath(), + SpoolCapBytes: appconf.LocalTaskStoreSpoolCapBytes(), + TerminalReserveBytes: appconf.LocalTaskStoreTerminalReserveBytes(), + }) +} + +func (s *Store) Close() error { + if s == nil || s.db == nil { + return nil + } + db, err := s.db.DB() + if err != nil { + return err + } + return db.Close() +} + +func (s *Store) migrate() error { + if err := s.db.AutoMigrate(&taskExecutionRecord{}, &outboxMessageRecord{}); err != nil { + return fmt.Errorf("migrate local task store: %w", err) + } + if err := s.db.Exec("CREATE UNIQUE INDEX IF NOT EXISTS idx_local_task_executions_attempt ON local_task_executions(task_id, execution_attempt_id)").Error; err != nil { + return fmt.Errorf("migrate local task execution index: %w", err) + } + if err := s.db.Exec("CREATE UNIQUE INDEX IF NOT EXISTS idx_local_task_outbox_message_id ON local_task_outbox_messages(message_id)").Error; err != nil { + return fmt.Errorf("migrate local task outbox index: %w", err) + } + return nil +} + +func (s *Store) RecordReceived(receipt TaskReceipt) (TaskState, error) { + if receipt.TaskID == "" { + return TaskState{}, fmt.Errorf("task ID is required") + } + if receipt.ExecutionAttemptID == "" { + return TaskState{}, fmt.Errorf("execution attempt ID is required") + } + + var state TaskState + err := s.db.Transaction(func(tx *gorm.DB) error { + if err := s.ensureTerminalReserveAvailable(tx); err != nil { + return err + } + + var existing taskExecutionRecord + err := tx.Where("task_id = ? AND execution_attempt_id = ?", receipt.TaskID, receipt.ExecutionAttemptID).First(&existing).Error + if err == nil { + state = taskStateFromRecord(existing) + return nil + } + if !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + record := taskExecutionRecord{ + TaskID: receipt.TaskID, + ExecutionAttemptID: receipt.ExecutionAttemptID, + Status: TaskStatusReceived, + } + if err := tx.Create(&record).Error; err != nil { + return err + } + state = taskStateFromRecord(record) + return nil + }) + if err != nil { + return TaskState{}, fmt.Errorf("record task receipt: %w", err) + } + return state, nil +} + +func (s *Store) TaskState(taskID, executionAttemptID string) (TaskState, error) { + query := s.db.Where("task_id = ?", taskID) + if executionAttemptID != "" { + query = query.Where("execution_attempt_id = ?", executionAttemptID) + } + + var record taskExecutionRecord + err := query.Order("updated_at DESC, id DESC").First(&record).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return TaskState{}, nil + } + if err != nil { + return TaskState{}, fmt.Errorf("load task state: %w", err) + } + return taskStateFromRecord(record), nil +} + +func (s *Store) RecordStarted(taskID, executionAttemptID string) error { + if taskID == "" { + return fmt.Errorf("task ID is required") + } + if executionAttemptID == "" { + return fmt.Errorf("execution attempt ID is required") + } + + return s.db.Transaction(func(tx *gorm.DB) error { + return s.upsertExecutionState(tx, taskExecutionRecord{ + TaskID: taskID, + ExecutionAttemptID: executionAttemptID, + Status: TaskStatusRunning, + }) + }) +} + +func (s *Store) AppendOutputChunk(chunk OutputChunk) error { + if err := validateOutputChunk(chunk); err != nil { + return err + } + + return s.db.Transaction(func(tx *gorm.DB) error { + if err := s.upsertExecutionState(tx, taskExecutionRecord{ + TaskID: chunk.TaskID, + ExecutionAttemptID: chunk.ExecutionAttemptID, + Status: TaskStatusRunning, + }); err != nil { + return err + } + + record := outboxMessageRecord{ + MessageID: chunk.MessageID, + TaskID: chunk.TaskID, + ExecutionAttemptID: chunk.ExecutionAttemptID, + Type: OutboxMessageTypeOutput, + Stream: chunk.Stream, + Sequence: chunk.Sequence, + Payload: chunk.Payload, + ByteCount: chunk.ByteCount, + } + if err := tx.Create(&record).Error; err != nil { + return err + } + return s.enforceChunkCap(tx) + }) +} + +func (s *Store) RecordFinal(result FinalResult) error { + if err := validateFinalResult(result); err != nil { + return err + } + + return s.db.Transaction(func(tx *gorm.DB) error { + if err := s.rotateChunksForTerminal(tx, int64(len(result.Payload))); err != nil { + return err + } + + if err := s.upsertExecutionState(tx, taskExecutionRecord{ + TaskID: result.TaskID, + ExecutionAttemptID: result.ExecutionAttemptID, + Status: TaskStatusFinal, + ExitCode: result.ExitCode, + OutputTruncated: result.OutputTruncated, + ErrorTruncated: result.ErrorTruncated, + }); err != nil { + return err + } + + record := outboxMessageRecord{ + MessageID: result.MessageID, + TaskID: result.TaskID, + ExecutionAttemptID: result.ExecutionAttemptID, + Type: OutboxMessageTypeFinal, + Payload: result.Payload, + ByteCount: int64(len(result.Payload)), + } + return tx.Create(&record).Error + }) +} + +func (s *Store) ensureTerminalReserveAvailable(tx *gorm.DB) error { + used, err := s.totalUnackedBytes(tx) + if err != nil { + return err + } + if used+s.terminalReserveBytes > s.spoolCapBytes { + return ErrTerminalReserveUnavailable + } + return nil +} + +func (s *Store) enforceChunkCap(tx *gorm.DB) error { + chunkCap := s.spoolCapBytes - s.terminalReserveBytes + if chunkCap < 0 { + chunkCap = 0 + } + + for { + used, err := s.unackedBytesByType(tx, OutboxMessageTypeOutput) + if err != nil { + return err + } + if used <= chunkCap { + return nil + } + rotated, err := s.rotateOldestChunk(tx) + if err != nil { + return err + } + if !rotated { + return nil + } + } +} + +func (s *Store) rotateChunksForTerminal(tx *gorm.DB, terminalBytes int64) error { + for { + used, err := s.totalUnackedBytes(tx) + if err != nil { + return err + } + if used+terminalBytes <= s.spoolCapBytes { + return nil + } + rotated, err := s.rotateOldestChunk(tx) + if err != nil { + return err + } + if !rotated { + return ErrTerminalReserveUnavailable + } + } +} + +func (s *Store) rotateOldestChunk(tx *gorm.DB) (bool, error) { + var oldest outboxMessageRecord + err := tx.Where("acked_at IS NULL AND type = ?", OutboxMessageTypeOutput).Order("created_at ASC, id ASC").First(&oldest).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return false, nil + } + if err != nil { + return false, err + } + + now := time.Now().UTC() + if err := tx.Model(&oldest).Update("acked_at", now).Error; err != nil { + return false, err + } + if err := s.markLocalOutputTruncated(tx, oldest.TaskID, oldest.ExecutionAttemptID); err != nil { + return false, err + } + return true, nil +} + +func (s *Store) markLocalOutputTruncated(tx *gorm.DB, taskID, executionAttemptID string) error { + return s.upsertExecutionState(tx, taskExecutionRecord{ + TaskID: taskID, + ExecutionAttemptID: executionAttemptID, + Status: TaskStatusRunning, + LocalOutputTruncated: true, + }) +} + +func (s *Store) totalUnackedBytes(tx *gorm.DB) (int64, error) { + var total int64 + if err := tx.Model(&outboxMessageRecord{}).Where("acked_at IS NULL").Select("COALESCE(SUM(byte_count), 0)").Scan(&total).Error; err != nil { + return 0, err + } + return total, nil +} + +func (s *Store) unackedBytesByType(tx *gorm.DB, messageType string) (int64, error) { + var total int64 + if err := tx.Model(&outboxMessageRecord{}).Where("acked_at IS NULL AND type = ?", messageType).Select("COALESCE(SUM(byte_count), 0)").Scan(&total).Error; err != nil { + return 0, err + } + return total, nil +} + +func (s *Store) UnackedMessages() ([]OutboxMessage, error) { + var records []outboxMessageRecord + if err := s.db.Where("acked_at IS NULL").Order("created_at ASC, id ASC").Find(&records).Error; err != nil { + return nil, fmt.Errorf("load unacked messages: %w", err) + } + + messages := make([]OutboxMessage, 0, len(records)) + for _, record := range records { + messages = append(messages, outboxMessageFromRecord(record)) + } + return messages, nil +} + +func (s *Store) AckMessage(messageID string) error { + if messageID == "" { + return fmt.Errorf("message ID is required") + } + now := time.Now().UTC() + if err := s.db.Model(&outboxMessageRecord{}).Where("message_id = ? AND acked_at IS NULL", messageID).Update("acked_at", now).Error; err != nil { + return fmt.Errorf("ack message: %w", err) + } + return nil +} + +func (s *Store) Snapshot() (Snapshot, error) { + var records []taskExecutionRecord + if err := s.db.Order("updated_at ASC, id ASC").Find(&records).Error; err != nil { + return Snapshot{}, fmt.Errorf("load task snapshot: %w", err) + } + + snapshot := Snapshot{Tasks: make([]TaskState, 0, len(records))} + for _, record := range records { + snapshot.Tasks = append(snapshot.Tasks, taskStateFromRecord(record)) + } + return snapshot, nil +} + +func (s *Store) MarkInterruptedRunningTasks() error { + return s.db.Transaction(func(tx *gorm.DB) error { + var running []taskExecutionRecord + if err := tx.Where("status = ?", TaskStatusRunning).Find(&running).Error; err != nil { + return err + } + + for _, record := range running { + if err := tx.Model(&record).Updates(map[string]any{ + "status": TaskStatusInterrupted, + "exit_code": -1, + }).Error; err != nil { + return err + } + + messageID := interruptedMessageID(record.TaskID, record.ExecutionAttemptID) + var existing outboxMessageRecord + err := tx.Where("message_id = ?", messageID).First(&existing).Error + if err == nil { + continue + } + if !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + + payload := fmt.Sprintf(`{"status":"interrupted","exit_code":-1,"output_truncated":%t,"error_truncated":%t}`, record.OutputTruncated, record.ErrorTruncated) + outbox := outboxMessageRecord{ + MessageID: messageID, + TaskID: record.TaskID, + ExecutionAttemptID: record.ExecutionAttemptID, + Type: OutboxMessageTypeFinal, + Payload: payload, + ByteCount: int64(len(payload)), + } + if err := tx.Create(&outbox).Error; err != nil { + return err + } + } + return nil + }) +} + +func (s *Store) upsertExecutionState(tx *gorm.DB, next taskExecutionRecord) error { + var existing taskExecutionRecord + err := tx.Where("task_id = ? AND execution_attempt_id = ?", next.TaskID, next.ExecutionAttemptID).First(&existing).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return tx.Create(&next).Error + } + if err != nil { + return err + } + + updates := map[string]any{"status": next.Status} + if next.ExitCode != 0 { + updates["exit_code"] = next.ExitCode + } + if next.OutputTruncated { + updates["output_truncated"] = true + } + if next.ErrorTruncated { + updates["error_truncated"] = true + } + if next.LocalOutputTruncated { + updates["local_output_truncated"] = true + } + return tx.Model(&existing).Updates(updates).Error +} + +func validateOutputChunk(chunk OutputChunk) error { + if chunk.MessageID == "" { + return fmt.Errorf("message ID is required") + } + if chunk.TaskID == "" { + return fmt.Errorf("task ID is required") + } + if chunk.ExecutionAttemptID == "" { + return fmt.Errorf("execution attempt ID is required") + } + if chunk.Stream == "" { + return fmt.Errorf("stream is required") + } + if chunk.Sequence <= 0 { + return fmt.Errorf("sequence must be positive") + } + if chunk.ByteCount < 0 { + return fmt.Errorf("byte count cannot be negative") + } + return nil +} + +func validateFinalResult(result FinalResult) error { + if result.MessageID == "" { + return fmt.Errorf("message ID is required") + } + if result.TaskID == "" { + return fmt.Errorf("task ID is required") + } + if result.ExecutionAttemptID == "" { + return fmt.Errorf("execution attempt ID is required") + } + if result.Status == "" { + return fmt.Errorf("status is required") + } + return nil +} + +func taskStateFromRecord(record taskExecutionRecord) TaskState { + return TaskState{ + ID: record.ID, + Exists: true, + TaskID: record.TaskID, + ExecutionAttemptID: record.ExecutionAttemptID, + Status: record.Status, + ExitCode: record.ExitCode, + OutputTruncated: record.OutputTruncated, + ErrorTruncated: record.ErrorTruncated, + LocalOutputTruncated: record.LocalOutputTruncated, + } +} + +func outboxMessageFromRecord(record outboxMessageRecord) OutboxMessage { + return OutboxMessage{ + MessageID: record.MessageID, + TaskID: record.TaskID, + ExecutionAttemptID: record.ExecutionAttemptID, + Type: record.Type, + Stream: record.Stream, + Sequence: record.Sequence, + Payload: record.Payload, + ByteCount: record.ByteCount, + } +} + +func interruptedMessageID(taskID, executionAttemptID string) string { + return "local-interrupted-" + strings.NewReplacer("/", "-", " ", "-", "|", "-").Replace(taskID+"-"+executionAttemptID) +} diff --git a/app/services/localtaskstore/store_test.go b/app/services/localtaskstore/store_test.go new file mode 100644 index 0000000..9c62c3f --- /dev/null +++ b/app/services/localtaskstore/store_test.go @@ -0,0 +1,77 @@ +package localtaskstore + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" +) + +func TestNewCreatesStoreFileUnderConfiguredPath(t *testing.T) { + storePath := filepath.Join(t.TempDir(), "nested", "task_store.db") + + store, err := New(Config{ + Path: storePath, + SpoolCapBytes: 1024 * 1024, + TerminalReserveBytes: 1024, + }) + require.NoError(t, err) + require.NoError(t, store.Close()) + require.FileExists(t, storePath) +} + +func TestNewMigratesExecutionAndOutboxTables(t *testing.T) { + store := newTestStore(t, 1024*1024, 1024) + + require.True(t, store.db.Migrator().HasTable(&taskExecutionRecord{})) + require.True(t, store.db.Migrator().HasTable(&outboxMessageRecord{})) +} + +func TestNewReopensExistingStore(t *testing.T) { + storePath := filepath.Join(t.TempDir(), "task_store.db") + + store, err := New(Config{ + Path: storePath, + SpoolCapBytes: 1024 * 1024, + TerminalReserveBytes: 1024, + }) + require.NoError(t, err) + require.NoError(t, store.Close()) + + reopened, err := New(Config{ + Path: storePath, + SpoolCapBytes: 1024 * 1024, + TerminalReserveBytes: 1024, + }) + require.NoError(t, err) + require.NoError(t, reopened.Close()) +} + +func TestNewDefaultUsesConfiguredAppconfValues(t *testing.T) { + stateDir := t.TempDir() + t.Setenv("HOSTLINK_STATE_PATH", stateDir) + t.Setenv("HOSTLINK_LOCAL_STORE_PATH", "") + t.Setenv("HOSTLINK_LOCAL_STORE_SPOOL_CAP_BYTES", "1048576") + t.Setenv("HOSTLINK_LOCAL_STORE_TERMINAL_RESERVE_BYTES", "1024") + + store, err := NewDefault() + require.NoError(t, err) + require.NoError(t, store.Close()) + require.FileExists(t, filepath.Join(stateDir, "task_store.db")) +} + +func newTestStore(t *testing.T, spoolCapBytes, terminalReserveBytes int64) *Store { + t.Helper() + + store, err := New(Config{ + Path: filepath.Join(t.TempDir(), "task_store.db"), + SpoolCapBytes: spoolCapBytes, + TerminalReserveBytes: terminalReserveBytes, + }) + require.NoError(t, err) + t.Cleanup(func() { + require.NoError(t, store.Close()) + }) + + return store +} diff --git a/config/appconf/appconf.go b/config/appconf/appconf.go index 6eec133..4acae86 100644 --- a/config/appconf/appconf.go +++ b/config/appconf/appconf.go @@ -5,6 +5,8 @@ import ( "net/url" "os" "path" + "path/filepath" + "strconv" "strings" "time" @@ -58,6 +60,21 @@ func AgentStatePath() string { return "/var/lib/hostlink" } +func LocalTaskStorePath() string { + if path := strings.TrimSpace(os.Getenv("HOSTLINK_LOCAL_STORE_PATH")); path != "" { + return path + } + return filepath.Join(AgentStatePath(), "task_store.db") +} + +func LocalTaskStoreSpoolCapBytes() int64 { + return parseInt64Positive("HOSTLINK_LOCAL_STORE_SPOOL_CAP_BYTES", 64*1024*1024) +} + +func LocalTaskStoreTerminalReserveBytes() int64 { + return parseInt64Positive("HOSTLINK_LOCAL_STORE_TERMINAL_RESERVE_BYTES", 1024*1024) +} + // InstallPath returns the target install path for the hostlink binary. // Controlled by HOSTLINK_INSTALL_PATH (default: /usr/bin/hostlink). func InstallPath() string { @@ -184,6 +201,19 @@ func parseDurationClamped(envVar string, defaultVal, min, max time.Duration) tim return d } +func parseInt64Positive(envVar string, defaultVal int64) int64 { + v := strings.TrimSpace(os.Getenv(envVar)) + if v == "" { + return defaultVal + } + n, err := strconv.ParseInt(v, 10, 64) + if err != nil || n <= 0 { + log.Warnf("invalid %s value %q, using default %d", envVar, v, defaultVal) + return defaultVal + } + return n +} + func init() { env := os.Getenv("APP_ENV") diff --git a/config/appconf/appconf_test.go b/config/appconf/appconf_test.go index 030d3db..0110040 100644 --- a/config/appconf/appconf_test.go +++ b/config/appconf/appconf_test.go @@ -1,6 +1,7 @@ package appconf import ( + "path/filepath" "testing" "time" @@ -86,7 +87,6 @@ func TestInstallPath_CustomValue(t *testing.T) { t.Setenv("HOSTLINK_INSTALL_PATH", "/opt/hostlink/bin/hostlink") assert.Equal(t, "/opt/hostlink/bin/hostlink", InstallPath()) } - func TestWebSocketEnabled_DefaultFalse(t *testing.T) { t.Setenv("HOSTLINK_WS_ENABLED", "") assert.False(t, WebSocketEnabled()) @@ -147,3 +147,54 @@ func TestWebSocketPingInterval_CustomValue(t *testing.T) { t.Setenv("HOSTLINK_WS_PING_INTERVAL", "45s") assert.Equal(t, 45*time.Second, WebSocketPingInterval()) } + +func TestLocalTaskStorePath_DefaultUnderAgentStatePath(t *testing.T) { + stateDir := t.TempDir() + t.Setenv("HOSTLINK_STATE_PATH", stateDir) + t.Setenv("HOSTLINK_LOCAL_STORE_PATH", "") + + assert.Equal(t, filepath.Join(stateDir, "task_store.db"), LocalTaskStorePath()) +} + +func TestLocalTaskStorePath_CustomValue(t *testing.T) { + customPath := filepath.Join(t.TempDir(), "custom.db") + t.Setenv("HOSTLINK_LOCAL_STORE_PATH", customPath) + + assert.Equal(t, customPath, LocalTaskStorePath()) +} + +func TestLocalTaskStoreSpoolCapBytes_Default64MiB(t *testing.T) { + t.Setenv("HOSTLINK_LOCAL_STORE_SPOOL_CAP_BYTES", "") + + assert.Equal(t, int64(64*1024*1024), LocalTaskStoreSpoolCapBytes()) +} + +func TestLocalTaskStoreSpoolCapBytes_CustomValue(t *testing.T) { + t.Setenv("HOSTLINK_LOCAL_STORE_SPOOL_CAP_BYTES", "2048") + + assert.Equal(t, int64(2048), LocalTaskStoreSpoolCapBytes()) +} + +func TestLocalTaskStoreSpoolCapBytes_InvalidFallsToDefault(t *testing.T) { + t.Setenv("HOSTLINK_LOCAL_STORE_SPOOL_CAP_BYTES", "garbage") + + assert.Equal(t, int64(64*1024*1024), LocalTaskStoreSpoolCapBytes()) +} + +func TestLocalTaskStoreTerminalReserveBytes_Default1MiB(t *testing.T) { + t.Setenv("HOSTLINK_LOCAL_STORE_TERMINAL_RESERVE_BYTES", "") + + assert.Equal(t, int64(1024*1024), LocalTaskStoreTerminalReserveBytes()) +} + +func TestLocalTaskStoreTerminalReserveBytes_CustomValue(t *testing.T) { + t.Setenv("HOSTLINK_LOCAL_STORE_TERMINAL_RESERVE_BYTES", "4096") + + assert.Equal(t, int64(4096), LocalTaskStoreTerminalReserveBytes()) +} + +func TestLocalTaskStoreTerminalReserveBytes_InvalidFallsToDefault(t *testing.T) { + t.Setenv("HOSTLINK_LOCAL_STORE_TERMINAL_RESERVE_BYTES", "garbage") + + assert.Equal(t, int64(1024*1024), LocalTaskStoreTerminalReserveBytes()) +} diff --git a/local_task_store_startup_test.go b/local_task_store_startup_test.go new file mode 100644 index 0000000..094f29e --- /dev/null +++ b/local_task_store_startup_test.go @@ -0,0 +1,49 @@ +package main + +import ( + "path/filepath" + "testing" + + "github.com/stretchr/testify/require" + + "hostlink/app/services/localtaskstore" +) + +func TestRecoverLocalTaskStoreInitializesStoreUnderStatePath(t *testing.T) { + stateDir := t.TempDir() + t.Setenv("HOSTLINK_STATE_PATH", stateDir) + t.Setenv("HOSTLINK_LOCAL_STORE_PATH", "") + t.Setenv("HOSTLINK_LOCAL_STORE_SPOOL_CAP_BYTES", "1048576") + t.Setenv("HOSTLINK_LOCAL_STORE_TERMINAL_RESERVE_BYTES", "1024") + + store, err := recoverLocalTaskStore() + require.NoError(t, err) + require.NoError(t, store.Close()) + require.FileExists(t, filepath.Join(stateDir, "task_store.db")) +} + +func TestRecoverLocalTaskStoreMarksRunningTasksInterrupted(t *testing.T) { + stateDir := t.TempDir() + storePath := filepath.Join(stateDir, "task_store.db") + t.Setenv("HOSTLINK_STATE_PATH", stateDir) + t.Setenv("HOSTLINK_LOCAL_STORE_PATH", "") + t.Setenv("HOSTLINK_LOCAL_STORE_SPOOL_CAP_BYTES", "1048576") + t.Setenv("HOSTLINK_LOCAL_STORE_TERMINAL_RESERVE_BYTES", "1024") + + store, err := localtaskstore.New(localtaskstore.Config{ + Path: storePath, + SpoolCapBytes: 1024 * 1024, + TerminalReserveBytes: 1024, + }) + require.NoError(t, err) + require.NoError(t, store.RecordStarted("task-1", "attempt-1")) + require.NoError(t, store.Close()) + + recovered, err := recoverLocalTaskStore() + require.NoError(t, err) + defer recovered.Close() + + state, err := recovered.TaskState("task-1", "attempt-1") + require.NoError(t, err) + require.Equal(t, localtaskstore.TaskStatusInterrupted, state.Status) +} diff --git a/main.go b/main.go index 7cde99b..d0c7b9e 100644 --- a/main.go +++ b/main.go @@ -11,6 +11,7 @@ import ( "hostlink/app/jobs/taskjob" "hostlink/app/services/agentstate" "hostlink/app/services/heartbeat" + "hostlink/app/services/localtaskstore" "hostlink/app/services/metrics" "hostlink/app/services/requestsigner" "hostlink/app/services/taskfetcher" @@ -251,6 +252,13 @@ func runServer(ctx context.Context, cmd *cli.Command) error { // Agent-related jobs run in goroutine after registration go func() { ctx := context.Background() + localStore, err := recoverLocalTaskStore() + if err != nil { + log.Printf("failed to initialize local task store: %v", err) + } else { + defer localStore.Close() + } + registeredChan := make(chan bool, 1) registrationJob := registrationjob.New() @@ -335,6 +343,19 @@ func newDefaultWebSocketRuntime() (webSocketRuntime, error) { }) } +func recoverLocalTaskStore() (*localtaskstore.Store, error) { + store, err := localtaskstore.NewDefault() + if err != nil { + return nil, err + } + + if err := store.MarkInterruptedRunningTasks(); err != nil { + _ = store.Close() + return nil, err + } + return store, nil +} + func startSelfUpdateJob(ctx context.Context) { paths := update.DefaultPaths()