Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions app/jobs/taskjob/result_channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,12 @@ func TestTaskJobStreamsOutputAndFinalOverResultChannel(t *testing.T) {

job.processTask(context.Background(), fetcher.tasks[0], reporter, channel)

if len(channel.started) != 1 {
t.Fatalf("started len = %d, want 1", len(channel.started))
}
if channel.started[0].TaskID != "task-1" || channel.started[0].ExecutionAttemptID != "attempt-1" {
t.Fatalf("started = %#v", channel.started[0])
}
if len(channel.outputs) != 2 {
t.Fatalf("outputs len = %d, want 2", len(channel.outputs))
}
Expand Down Expand Up @@ -189,12 +195,21 @@ func (f *fakeTaskReporter) Report(taskID string, result *taskreporter.TaskResult

type fakeResultChannel struct {
mu sync.Mutex
started []localtaskstore.TaskReceipt
outputs []localtaskstore.OutputChunk
finals []localtaskstore.FinalResult
startedErr error
outputErrs []error
finalErr error
}

func (f *fakeResultChannel) SendStarted(ctx context.Context, receipt localtaskstore.TaskReceipt) error {
f.mu.Lock()
defer f.mu.Unlock()
f.started = append(f.started, receipt)
return f.startedErr
}

func (f *fakeResultChannel) SendOutput(ctx context.Context, chunk localtaskstore.OutputChunk) error {
f.mu.Lock()
defer f.mu.Unlock()
Expand Down
37 changes: 33 additions & 4 deletions app/jobs/taskjob/taskjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,16 @@ type TaskJobConfig struct {
}

type ResultChannel interface {
SendStarted(context.Context, localtaskstore.TaskReceipt) error
SendOutput(context.Context, localtaskstore.OutputChunk) error
SendFinal(context.Context, localtaskstore.FinalResult) error
}

type TaskJob struct {
config TaskJobConfig
cancel context.CancelFunc
wg sync.WaitGroup
config TaskJobConfig
enqueueCh chan task.Task
cancel context.CancelFunc
wg sync.WaitGroup
}

func New() *TaskJob {
Expand All @@ -59,7 +61,8 @@ func NewJobWithConf(cfg TaskJobConfig) *TaskJob {
}

return &TaskJob{
config: cfg,
config: cfg,
enqueueCh: make(chan task.Task, 16),
}
}

Expand All @@ -71,6 +74,18 @@ func (tj *TaskJob) Register(ctx context.Context, tf taskfetcher.TaskFetcher, tr
channel = channels[0]
}
tj.wg.Add(1)
go func() {
defer tj.wg.Done()
for {
select {
case queued := <-tj.enqueueCh:
tj.processTask(ctx, queued, tr, channel)
case <-ctx.Done():
return
}
}
}()
tj.wg.Add(1)
go func() {
defer tj.wg.Done()
tj.config.Trigger(ctx, func() error {
Expand All @@ -93,6 +108,15 @@ func (tj *TaskJob) Register(ctx context.Context, tf taskfetcher.TaskFetcher, tr
return cancel
}

func (tj *TaskJob) Enqueue(ctx context.Context, t task.Task) error {
select {
case tj.enqueueCh <- t:
return nil
case <-ctx.Done():
return ctx.Err()
}
}

func (tj *TaskJob) processTask(ctx context.Context, t task.Task, tr taskreporter.TaskReporter, channel ResultChannel) {
tempFile, err := os.CreateTemp("", "*_script.sh")
if err != nil {
Expand Down Expand Up @@ -184,6 +208,11 @@ func (tj *TaskJob) processTaskWithResultChannel(ctx context.Context, t task.Task
tj.reportHTTPResult(t, tr, "failed", "", err.Error(), 1)
return
}
if err := channel.SendStarted(ctx, localtaskstore.TaskReceipt{TaskID: t.ID, ExecutionAttemptID: t.ExecutionAttemptID}); err != nil {
_ = execCmd.Process.Kill()
tj.reportHTTPResult(t, tr, "failed", "", fmt.Sprintf("failed to report task start: %v", err), 1)
return
}

var stdoutBuf bytes.Buffer
var stderrBuf bytes.Buffer
Expand Down
1 change: 1 addition & 0 deletions app/services/localtaskstore/store.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ type Config struct {

type ReceiptStore interface {
RecordReceived(TaskReceipt) (TaskState, error)
RecordStarted(taskID, executionAttemptID string) error
TaskState(taskID, executionAttemptID string) (TaskState, error)
}

Expand Down
116 changes: 110 additions & 6 deletions app/services/wsclient/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (

"hostlink/app/services/agentstate"
"hostlink/app/services/requestsigner"
"hostlink/domain/task"
"hostlink/internal/wsprotocol"
)

Expand All @@ -31,6 +32,10 @@ type Conn interface {

type SleepFunc func(context.Context, time.Duration) error

type TaskEnqueuer interface {
Enqueue(context.Context, task.Task) error
}

type Config struct {
URL string
AgentState *agentstate.AgentState
Expand All @@ -41,6 +46,8 @@ type Config struct {
PingInterval time.Duration
SleepFunc SleepFunc
ResultOutbox localtaskstore.ResultOutbox
ReceiptStore localtaskstore.ReceiptStore
TaskEnqueuer TaskEnqueuer
}

type Client struct {
Expand All @@ -53,12 +60,14 @@ type Client struct {
pingInterval time.Duration
sleep SleepFunc

mu sync.RWMutex
writeMu sync.Mutex
active bool
lastAck *wsprotocol.AckPayload
conn Conn
outbox localtaskstore.ResultOutbox
mu sync.RWMutex
writeMu sync.Mutex
active bool
lastAck *wsprotocol.AckPayload
conn Conn
outbox localtaskstore.ResultOutbox
receipts localtaskstore.ReceiptStore
enqueuer TaskEnqueuer
}

func New(cfg Config) (*Client, error) {
Expand Down Expand Up @@ -99,6 +108,8 @@ func New(cfg Config) (*Client, error) {
pingInterval: cfg.PingInterval,
sleep: cfg.SleepFunc,
outbox: cfg.ResultOutbox,
receipts: cfg.ReceiptStore,
enqueuer: cfg.TaskEnqueuer,
}, nil
}

Expand Down Expand Up @@ -234,12 +245,82 @@ func (c *Client) readLoop(ctx context.Context, conn Conn, helloMessageID string)
continue
}
return fmt.Errorf("websocket protocol error: %s", env.MessageID)
case wsprotocol.TypeTaskDeliver:
if err := c.receiveTaskDeliver(ctx, conn, env); err != nil {
return err
}
default:
return fmt.Errorf("unsupported inbound websocket message type: %s", env.Type)
}
}
}

func (c *Client) receiveTaskDeliver(ctx context.Context, conn Conn, env wsprotocol.Envelope) error {
if c.receipts == nil {
return fmt.Errorf("receipt store is not configured")
}
payload, err := wsprotocol.DecodePayload[wsprotocol.TaskDeliverPayload](env)
if err != nil {
return err
}
if err := payload.Validate(); err != nil {
return err
}

previous, err := c.receipts.TaskState(env.TaskID, env.ExecutionAttemptID)
if err != nil {
return err
}
state, err := c.receipts.RecordReceived(localtaskstore.TaskReceipt{
TaskID: env.TaskID,
ExecutionAttemptID: env.ExecutionAttemptID,
})
if err != nil {
return err
}
if previous.Exists && previous.Status == localtaskstore.TaskStatusRunning {
return c.writeEnvelope(ctx, conn, c.buildTaskStateEnvelope(wsprotocol.TypeTaskStarted, env.TaskID, env.ExecutionAttemptID))
}
if previous.Exists && (previous.Status == localtaskstore.TaskStatusFinal || previous.Status == localtaskstore.TaskStatusInterrupted) {
replayed, err := c.replayTaskFinal(ctx, conn, env.TaskID, env.ExecutionAttemptID)
if err != nil {
return err
}
if replayed {
return nil
}
}
if err := c.writeEnvelope(ctx, conn, c.buildTaskStateEnvelope(wsprotocol.TypeTaskReceived, env.TaskID, env.ExecutionAttemptID)); err != nil {
return err
}
if !previous.Exists && state.Status == localtaskstore.TaskStatusReceived && c.enqueuer != nil {
return c.enqueuer.Enqueue(ctx, task.Task{
ID: env.TaskID,
ExecutionAttemptID: env.ExecutionAttemptID,
Command: payload.Command,
Status: "pending",
Priority: payload.Priority,
})
}
return nil
}

func (c *Client) replayTaskFinal(ctx context.Context, conn Conn, taskID, executionAttemptID string) (bool, error) {
if c.outbox == nil {
return false, nil
}
messages, err := c.outbox.UnackedMessages()
if err != nil {
return false, err
}
for _, message := range messages {
if message.TaskID == taskID && message.ExecutionAttemptID == executionAttemptID && message.Type == localtaskstore.OutboxMessageTypeFinal {
return true, c.writeEnvelope(ctx, conn, envelopeFromOutboxMessage(c.agentID, message))
}
}
return false, nil
}

func (c *Client) SendOutput(ctx context.Context, chunk localtaskstore.OutputChunk) error {
if c.outbox == nil {
return fmt.Errorf("result outbox is not configured")
Expand Down Expand Up @@ -283,6 +364,16 @@ func (c *Client) SendFinal(ctx context.Context, result localtaskstore.FinalResul
return nil
}

func (c *Client) SendStarted(ctx context.Context, receipt localtaskstore.TaskReceipt) error {
if c.receipts == nil {
return fmt.Errorf("receipt store cannot record started state")
}
if err := c.receipts.RecordStarted(receipt.TaskID, receipt.ExecutionAttemptID); err != nil {
return err
}
return c.sendIfActive(ctx, c.buildTaskStateEnvelope(wsprotocol.TypeTaskStarted, receipt.TaskID, receipt.ExecutionAttemptID))
}

func (c *Client) buildHello() wsprotocol.Envelope {
return wsprotocol.Envelope{
ProtocolVersion: wsprotocol.ProtocolVersion,
Expand All @@ -294,6 +385,19 @@ func (c *Client) buildHello() wsprotocol.Envelope {
}
}

func (c *Client) buildTaskStateEnvelope(messageType wsprotocol.MessageType, taskID, executionAttemptID string) wsprotocol.Envelope {
return wsprotocol.Envelope{
ProtocolVersion: wsprotocol.ProtocolVersion,
MessageID: fmt.Sprintf("msg_%s_%s_%d", taskID, messageType, time.Now().UnixNano()),
Type: messageType,
AgentID: c.agentID,
TaskID: taskID,
ExecutionAttemptID: executionAttemptID,
SentAt: time.Now().UTC().Format(time.RFC3339),
Payload: map[string]any{},
}
}

func (c *Client) setActive(active bool) {
c.mu.Lock()
defer c.mu.Unlock()
Expand Down
Loading
Loading