Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions pkg/runner/input.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,15 @@ func (r *Runner) handleInput(callCtx engine.Context, monitor Monitor, env []stri
data := map[string]any{}
_ = json.Unmarshal([]byte(input), &data)
data["input"] = input
inputData, err := json.Marshal(data)

inputArgs, err := argsForFilters(callCtx.Program, inputToolRef, &State{
StartInput: &input,
}, data)
if err != nil {
return "", fmt.Errorf("failed to marshal input: %w", err)
}

res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, string(inputData), "", engine.InputToolCategory)
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, inputArgs, "", engine.InputToolCategory)
if err != nil {
return "", err
}
Expand Down
40 changes: 38 additions & 2 deletions pkg/runner/output.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,48 @@ import (
"encoding/json"
"errors"
"fmt"
"maps"
"strings"

"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/types"
)

func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []string, state *State, retErr error) (*State, error) {
func argsForFilters(prg *types.Program, tool types.ToolReference, startState *State, filterDefinedInput map[string]any) (string, error) {
startInput := ""
if startState.ResumeInput != nil {
startInput = *startState.ResumeInput
} else if startState.StartInput != nil {
startInput = *startState.StartInput
}

parsedArgs, err := getToolRefInput(prg, tool, startInput)
if err != nil {
return "", err
}

argData := map[string]any{}
if strings.HasPrefix(parsedArgs, "{") {
if err := json.Unmarshal([]byte(parsedArgs), &argData); err != nil {
return "", fmt.Errorf("failed to unmarshal parsedArgs for filter: %w", err)
}
} else if _, hasInput := filterDefinedInput["input"]; parsedArgs != "" && !hasInput {
argData["input"] = parsedArgs
}

resultData := map[string]any{}
maps.Copy(resultData, filterDefinedInput)
maps.Copy(resultData, argData)

result, err := json.Marshal(resultData)
if err != nil {
return "", fmt.Errorf("failed to marshal resultData for filter: %w", err)
}

return string(result), nil
}

func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []string, startState, state *State, retErr error) (*State, error) {
outputToolRefs, err := callCtx.Tool.GetToolsByType(callCtx.Program, types.ToolTypeOutput)
if err != nil {
return nil, err
Expand Down Expand Up @@ -40,7 +76,7 @@ func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []str
}

for _, outputToolRef := range outputToolRefs {
inputData, err := json.Marshal(map[string]any{
inputData, err := argsForFilters(callCtx.Program, outputToolRef, startState, map[string]any{
"output": output,
"continuation": continuation,
"chat": callCtx.Tool.Chat,
Expand Down
16 changes: 9 additions & 7 deletions pkg/runner/runner.go
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,9 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
outputMap := map[string]interface{}{}

_ = json.Unmarshal([]byte(input), &inputMap)
for k, v := range inputMap {
inputMap[strings.ToLower(k)] = v
}

fields := strings.Fields(ref.Arg)

Expand All @@ -291,7 +294,7 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
key := strings.TrimPrefix(field, "$")
key = strings.TrimPrefix(key, "{")
key = strings.TrimSuffix(key, "}")
val = inputMap[key]
val = inputMap[strings.ToLower(key)]
} else {
val = field
}
Expand Down Expand Up @@ -425,6 +428,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
msg = "Tool call request has been denied"
}
return &State{
StartInput: &input,
Continuation: &engine.Return{
Result: &msg,
},
Expand All @@ -438,6 +442,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
}

return &State{
StartInput: &input,
Continuation: ret,
}, nil
}
Expand All @@ -447,6 +452,8 @@ type State struct {
ContinuationToolID string `json:"continuationToolID,omitempty"`
Result *string `json:"result,omitempty"`

StartInput *string `json:"startInput,omitempty"`

ResumeInput *string `json:"resumeInput,omitempty"`
SubCalls []SubCallResult `json:"subCalls,omitempty"`
SubCallID string `json:"subCallID,omitempty"`
Expand Down Expand Up @@ -485,14 +492,9 @@ func (s State) ContinuationContent() (string, error) {
return "", fmt.Errorf("illegal state: no result message found in chat response")
}

type Needed struct {
Content string `json:"content,omitempty"`
Input string `json:"input,omitempty"`
}

func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, state *State) (retState *State, retErr error) {
defer func() {
retState, retErr = r.handleOutput(callCtx, monitor, env, retState, retErr)
retState, retErr = r.handleOutput(callCtx, monitor, env, state, retState, retErr)
}()

if state.Continuation == nil {
Expand Down
91 changes: 91 additions & 0 deletions pkg/tests/runner2_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package tests

import (
"context"
"encoding/json"
"testing"

"github.com/gptscript-ai/gptscript/pkg/loader"
"github.com/gptscript-ai/gptscript/pkg/tests/tester"
"github.com/hexops/autogold/v2"
"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -111,3 +113,92 @@ echo '{"env": {"CRED2": "that also worked"}}'
resp, err := r.Chat(context.Background(), nil, prg, nil, "")
r.AssertStep(t, resp, err)
}

func TestFilterArgs(t *testing.T) {
r := tester.NewRunner(t)
prg, err := loader.ProgramFromSource(context.Background(), `
inputfilters: input with ${Foo}
inputfilters: input with foo
inputfilters: input with *
outputfilters: output with *
outputfilters: output with foo
outputfilters: output with ${Foo}
params: Foo: a description

#!/bin/bash
echo ${FOO}

---
name: input
params: notfoo: a description

#!/bin/bash
echo "${GPTSCRIPT_INPUT}"

---
name: output
params: notfoo: a description

#!/bin/bash
echo "${GPTSCRIPT_INPUT}"
`, "")
require.NoError(t, err)

resp, err := r.Chat(context.Background(), nil, prg, nil, `{"foo":"baz", "start": true}`)
r.AssertStep(t, resp, err)

data := map[string]any{}
err = json.Unmarshal([]byte(resp.Content), &data)
require.NoError(t, err)

autogold.Expect(map[string]interface{}{
"chat": false,
"continuation": false,
"notfoo": "baz",
"output": `{"chat":false,"continuation":false,"notfoo":"foo","output":"{\"chat\":false,\"continuation\":false,\"foo\":\"baz\",\"input\":\"{\\\"foo\\\":\\\"baz\\\",\\\"input\\\":\\\"{\\\\\\\"foo\\\\\\\":\\\\\\\"baz\\\\\\\", \\\\\\\"start\\\\\\\": true}\\\",\\\"notfoo\\\":\\\"baz\\\",\\\"start\\\":true}\\n\",\"notfoo\":\"foo\",\"output\":\"baz\\n\",\"start\":true}\n"}
`,
}).Equal(t, data)

val := data["output"].(string)
data = map[string]any{}
err = json.Unmarshal([]byte(val), &data)
require.NoError(t, err)
autogold.Expect(map[string]interface{}{
"chat": false,
"continuation": false,
"notfoo": "foo",
"output": `{"chat":false,"continuation":false,"foo":"baz","input":"{\"foo\":\"baz\",\"input\":\"{\\\"foo\\\":\\\"baz\\\", \\\"start\\\": true}\",\"notfoo\":\"baz\",\"start\":true}\n","notfoo":"foo","output":"baz\n","start":true}
`,
}).Equal(t, data)

val = data["output"].(string)
data = map[string]any{}
err = json.Unmarshal([]byte(val), &data)
require.NoError(t, err)
autogold.Expect(map[string]interface{}{
"chat": false,
"continuation": false,
"foo": "baz", "input": `{"foo":"baz","input":"{\"foo\":\"baz\", \"start\": true}","notfoo":"baz","start":true}
`,
"notfoo": "foo",
"output": "baz\n",
"start": true,
}).Equal(t, data)

val = data["input"].(string)
data = map[string]any{}
err = json.Unmarshal([]byte(val), &data)
require.NoError(t, err)
autogold.Expect(map[string]interface{}{
"foo": "baz",
"input": `{"foo":"baz", "start": true}`,
"notfoo": "baz",
"start": true,
}).Equal(t, data)

val = data["input"].(string)
data = map[string]any{}
err = json.Unmarshal([]byte(val), &data)
require.NoError(t, err)
autogold.Expect(map[string]interface{}{"foo": "baz", "start": true}).Equal(t, data)
}
6 changes: 6 additions & 0 deletions pkg/tests/testdata/TestFilterArgs/step1.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
`{
"done": true,
"content": "{\"chat\":false,\"continuation\":false,\"notfoo\":\"baz\",\"output\":\"{\\\"chat\\\":false,\\\"continuation\\\":false,\\\"notfoo\\\":\\\"foo\\\",\\\"output\\\":\\\"{\\\\\\\"chat\\\\\\\":false,\\\\\\\"continuation\\\\\\\":false,\\\\\\\"foo\\\\\\\":\\\\\\\"baz\\\\\\\",\\\\\\\"input\\\\\\\":\\\\\\\"{\\\\\\\\\\\\\\\"foo\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\"baz\\\\\\\\\\\\\\\",\\\\\\\\\\\\\\\"input\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\"{\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"foo\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"baz\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\", \\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\"start\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\\": true}\\\\\\\\\\\\\\\",\\\\\\\\\\\\\\\"notfoo\\\\\\\\\\\\\\\":\\\\\\\\\\\\\\\"baz\\\\\\\\\\\\\\\",\\\\\\\\\\\\\\\"start\\\\\\\\\\\\\\\":true}\\\\\\\\n\\\\\\\",\\\\\\\"notfoo\\\\\\\":\\\\\\\"foo\\\\\\\",\\\\\\\"output\\\\\\\":\\\\\\\"baz\\\\\\\\n\\\\\\\",\\\\\\\"start\\\\\\\":true}\\\\n\\\"}\\n\"}\n",
"toolID": "",
"state": null
}`