From 485332c655dd5d3a02782425f61ba7adec4ec96c Mon Sep 17 00:00:00 2001 From: Donnie Adams Date: Mon, 4 Nov 2024 07:35:04 -0500 Subject: [PATCH] feat: add ability to pass request-specific env vars to chat completion This will allow authentication per-request in model providers. Signed-off-by: Donnie Adams --- go.mod | 2 +- go.sum | 4 ++-- pkg/context/context.go | 11 ----------- pkg/engine/engine.go | 5 ++--- pkg/llm/proxy.go | 4 ++-- pkg/llm/registry.go | 16 ++++++++-------- pkg/openai/client.go | 39 ++++++++++++++++++++++++++------------ pkg/remote/remote.go | 19 +++++++++---------- pkg/runner/output.go | 2 +- pkg/tests/judge/judge.go | 2 +- pkg/tests/tester/runner.go | 2 +- 11 files changed, 54 insertions(+), 52 deletions(-) diff --git a/go.mod b/go.mod index 4a95a521..5fa1a5c8 100644 --- a/go.mod +++ b/go.mod @@ -15,7 +15,7 @@ require ( github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/uuid v1.6.0 github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 - github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3 + github.com/gptscript-ai/chat-completion-client v0.0.0-20241104122544-5fe75f07c131 github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e github.com/gptscript-ai/tui v0.0.0-20240923192013-172e51ccf1d6 diff --git a/go.sum b/go.sum index 80cbcea1..3661a6c6 100644 --- a/go.sum +++ b/go.sum @@ -200,8 +200,8 @@ github.com/gorilla/css v1.0.0 h1:BQqNyPTi50JCFMTw/b67hByjMVXZRwGha6wxVGkeihY= github.com/gorilla/css v1.0.0/go.mod h1:Dn721qIggHpt4+EFCcTLTU/vk5ySda2ReITrtgBl60c= github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86 h1:m9yLtIEd0z1ia8qFjq3u0Ozb6QKwidyL856JLJp6nbA= github.com/gptscript-ai/broadcaster v0.0.0-20240625175512-c43682019b86/go.mod h1:lK3K5EZx4dyT24UG3yCt0wmspkYqrj4D/8kxdN3relk= -github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3 h1:EQiFTZv+BnOWJX2B9XdF09fL2Zj7h19n1l23TpWCafc= -github.com/gptscript-ai/chat-completion-client v0.0.0-20240813051153-a440ada7e3c3/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo= +github.com/gptscript-ai/chat-completion-client v0.0.0-20241104122544-5fe75f07c131 h1:y2FcmT4X8U606gUS0teX5+JWX9K/NclsLEhHiyrd+EU= +github.com/gptscript-ai/chat-completion-client v0.0.0-20241104122544-5fe75f07c131/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo= github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb h1:ky2J2CzBOskC7Jgm2VJAQi2x3p7FVGa+2/PcywkFJuc= github.com/gptscript-ai/cmd v0.0.0-20240802230653-326b7baf6fcb/go.mod h1:DJAo1xTht1LDkNYFNydVjTHd576TC7MlpsVRl3oloVw= github.com/gptscript-ai/go-gptscript v0.9.5-rc5.0.20240927213153-2af51434b93e h1:WpNae0NBx+Ri8RB3SxF8DhadDKU7h+jfWPQterDpbJA= diff --git a/pkg/context/context.go b/pkg/context/context.go index 0169d0e0..31474f6c 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -46,14 +46,3 @@ func GetLogger(ctx context.Context) mvl.Logger { return l } - -type envKey struct{} - -func WithEnv(ctx context.Context, env []string) context.Context { - return context.WithValue(ctx, envKey{}, env) -} - -func GetEnv(ctx context.Context) []string { - l, _ := ctx.Value(envKey{}).([]string) - return l -} diff --git a/pkg/engine/engine.go b/pkg/engine/engine.go index 0665991c..44ed50bb 100644 --- a/pkg/engine/engine.go +++ b/pkg/engine/engine.go @@ -8,14 +8,13 @@ import ( "sync" "github.com/gptscript-ai/gptscript/pkg/config" - gcontext "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/types" "github.com/gptscript-ai/gptscript/pkg/version" ) type Model interface { - Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) + Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) ProxyInfo() (string, string, error) } @@ -389,7 +388,7 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) { } }() - resp, err := e.Model.Call(gcontext.WithEnv(ctx, e.Env), state.Completion, progress) + resp, err := e.Model.Call(ctx, state.Completion, e.Env, progress) if err != nil { return nil, err } diff --git a/pkg/llm/proxy.go b/pkg/llm/proxy.go index 7c3091b3..aa8802be 100644 --- a/pkg/llm/proxy.go +++ b/pkg/llm/proxy.go @@ -54,7 +54,7 @@ func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) { var ( model string - data = map[string]any{} + data map[string]any ) if json.Unmarshal(inBytes, &data) == nil { @@ -65,7 +65,7 @@ func (r *Registry) ServeHTTP(w http.ResponseWriter, req *http.Request) { model = builtin.GetDefaultModel() } - c, err := r.getClient(req.Context(), model) + c, err := r.getClient(req.Context(), model, nil) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) return diff --git a/pkg/llm/registry.go b/pkg/llm/registry.go index 8129c788..09fe1dce 100644 --- a/pkg/llm/registry.go +++ b/pkg/llm/registry.go @@ -15,7 +15,7 @@ import ( ) type Client interface { - Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) + Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) ListModels(ctx context.Context, providers ...string) (result []string, _ error) Supports(ctx context.Context, modelName string) (bool, error) } @@ -78,7 +78,7 @@ func (r *Registry) fastPath(modelName string) Client { return r.clients[0] } -func (r *Registry) getClient(ctx context.Context, modelName string) (Client, error) { +func (r *Registry) getClient(ctx context.Context, modelName string, env []string) (Client, error) { if c := r.fastPath(modelName); c != nil { return c, nil } @@ -101,7 +101,7 @@ func (r *Registry) getClient(ctx context.Context, modelName string) (Client, err if len(errs) > 0 && oaiClient != nil { // Prompt the user to enter their OpenAI API key and try again. - if err := oaiClient.RetrieveAPIKey(ctx); err != nil { + if err := oaiClient.RetrieveAPIKey(ctx, env); err != nil { return nil, err } ok, err := oaiClient.Supports(ctx, modelName) @@ -119,13 +119,13 @@ func (r *Registry) getClient(ctx context.Context, modelName string) (Client, err return nil, errors.Join(errs...) } -func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { +func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { if messageRequest.Model == "" { return nil, fmt.Errorf("model is required") } if c := r.fastPath(messageRequest.Model); c != nil { - return c.Call(ctx, messageRequest, status) + return c.Call(ctx, messageRequest, env, status) } var errs []error @@ -140,20 +140,20 @@ func (r *Registry) Call(ctx context.Context, messageRequest types.CompletionRequ errs = append(errs, err) } else if ok { - return client.Call(ctx, messageRequest, status) + return client.Call(ctx, messageRequest, env, status) } } if len(errs) > 0 && oaiClient != nil { // Prompt the user to enter their OpenAI API key and try again. - if err := oaiClient.RetrieveAPIKey(ctx); err != nil { + if err := oaiClient.RetrieveAPIKey(ctx, env); err != nil { return nil, err } ok, err := oaiClient.Supports(ctx, messageRequest.Model) if err != nil { return nil, err } else if ok { - return oaiClient.Call(ctx, messageRequest, status) + return oaiClient.Call(ctx, messageRequest, env, status) } } diff --git a/pkg/openai/client.go b/pkg/openai/client.go index be5c6253..6178c997 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -13,7 +13,6 @@ import ( openai "github.com/gptscript-ai/chat-completion-client" "github.com/gptscript-ai/gptscript/pkg/cache" - gcontext "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/hash" @@ -303,9 +302,9 @@ func toMessages(request types.CompletionRequest, compat bool) (result []openai.C return } -func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { +func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { if err := c.ValidAuth(); err != nil { - if err := c.RetrieveAPIKey(ctx); err != nil { + if err := c.RetrieveAPIKey(ctx, env); err != nil { return nil, err } } @@ -401,7 +400,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques if err != nil { return nil, err } else if !ok { - result, err = c.call(ctx, request, id, status) + result, err = c.call(ctx, request, id, env, status) // If we got back a context length exceeded error, keep retrying and shrinking the message history until we pass. var apiError *openai.APIError @@ -409,7 +408,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques // Decrease maxTokens by 10% to make garbage collection more aggressive. // The retry loop will further decrease maxTokens if needed. maxTokens := decreaseTenPercent(messageRequest.MaxTokens) - result, err = c.contextLimitRetryLoop(ctx, request, id, maxTokens, status) + result, err = c.contextLimitRetryLoop(ctx, request, id, env, maxTokens, status) } if err != nil { return nil, err @@ -443,7 +442,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques return &result, nil } -func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) { +func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatCompletionRequest, id string, env []string, maxTokens int, status chan<- types.CompletionStatus) (types.CompletionMessage, error) { var ( response types.CompletionMessage err error @@ -452,7 +451,7 @@ func (c *Client) contextLimitRetryLoop(ctx context.Context, request openai.ChatC for range 10 { // maximum 10 tries // Try to drop older messages again, with a decreased max tokens. request.Messages = dropMessagesOverCount(maxTokens, request.Messages) - response, err = c.call(ctx, request, id, status) + response, err = c.call(ctx, request, id, env, status) if err == nil { return response, nil } @@ -542,7 +541,7 @@ func override(left, right string) string { return left } -func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) { +func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, transactionID string, env []string, partial chan<- types.CompletionStatus) (types.CompletionMessage, error) { streamResponse := os.Getenv("GPTSCRIPT_INTERNAL_OPENAI_STREAMING") != "false" partial <- types.CompletionStatus{ @@ -553,11 +552,27 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, }, } + var ( + headers map[string]string + modelProviderEnv []string + ) + for _, e := range env { + if strings.HasPrefix(e, "GPTSCRIPT_MODEL_PROVIDER_") { + modelProviderEnv = append(modelProviderEnv, e) + } + } + + if len(modelProviderEnv) > 0 { + headers = map[string]string{ + "X-GPTScript-Env": strings.Join(modelProviderEnv, ","), + } + } + slog.Debug("calling openai", "message", request.Messages) if !streamResponse { request.StreamOptions = nil - resp, err := c.c.CreateChatCompletion(ctx, request) + resp, err := c.c.CreateChatCompletion(ctx, request, headers) if err != nil { return types.CompletionMessage{}, err } @@ -582,7 +597,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, }), nil } - stream, err := c.c.CreateChatCompletionStream(ctx, request) + stream, err := c.c.CreateChatCompletionStream(ctx, request, headers) if err != nil { return types.CompletionMessage{}, err } @@ -614,8 +629,8 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest, } } -func (c *Client) RetrieveAPIKey(ctx context.Context) error { - k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", gcontext.GetEnv(ctx)) +func (c *Client) RetrieveAPIKey(ctx context.Context, env []string) error { + k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", env) if err != nil { return err } diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index fa1d40c2..5542372b 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -10,7 +10,6 @@ import ( "sync" "github.com/gptscript-ai/gptscript/pkg/cache" - gcontext "github.com/gptscript-ai/gptscript/pkg/context" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/engine" env2 "github.com/gptscript-ai/gptscript/pkg/env" @@ -42,13 +41,13 @@ func New(r *runner.Runner, envs []string, cache *cache.Client, credStore credent } } -func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { +func (c *Client) Call(ctx context.Context, messageRequest types.CompletionRequest, env []string, status chan<- types.CompletionStatus) (*types.CompletionMessage, error) { _, provider := c.parseModel(messageRequest.Model) if provider == "" { return nil, fmt.Errorf("failed to find remote model %s", messageRequest.Model) } - client, err := c.load(ctx, provider) + client, err := c.load(ctx, provider, env...) if err != nil { return nil, err } @@ -60,7 +59,7 @@ func (c *Client) Call(ctx context.Context, messageRequest types.CompletionReques modelName = toolName } messageRequest.Model = modelName - return client.Call(ctx, messageRequest, status) + return client.Call(ctx, messageRequest, env, status) } func (c *Client) ListModels(ctx context.Context, providers ...string) (result []string, _ error) { @@ -111,7 +110,7 @@ func isHTTPURL(toolName string) bool { strings.HasPrefix(toolName, "https://") } -func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Client, error) { +func (c *Client) clientFromURL(ctx context.Context, apiURL string, envs []string) (*openai.Client, error) { parsed, err := url.Parse(apiURL) if err != nil { return nil, err @@ -121,7 +120,7 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie if key == "" && !isLocalhost(apiURL) { var err error - key, err = c.retrieveAPIKey(ctx, env, apiURL) + key, err = c.retrieveAPIKey(ctx, env, apiURL, envs) if err != nil { return nil, err } @@ -134,7 +133,7 @@ func (c *Client) clientFromURL(ctx context.Context, apiURL string) (*openai.Clie }) } -func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, error) { +func (c *Client) load(ctx context.Context, toolName string, env ...string) (*openai.Client, error) { c.clientsLock.Lock() defer c.clientsLock.Unlock() @@ -144,7 +143,7 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err } if isHTTPURL(toolName) { - remoteClient, err := c.clientFromURL(ctx, toolName) + remoteClient, err := c.clientFromURL(ctx, toolName, env) if err != nil { return nil, err } @@ -183,8 +182,8 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err return oClient, nil } -func (c *Client) retrieveAPIKey(ctx context.Context, env, url string) (string, error) { - return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), append(gcontext.GetEnv(ctx), c.envs...)) +func (c *Client) retrieveAPIKey(ctx context.Context, env, url string, envs []string) (string, error) { + return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), append(envs, c.envs...)) } func isLocalhost(url string) bool { diff --git a/pkg/runner/output.go b/pkg/runner/output.go index 8a6aefdb..5f1d2818 100644 --- a/pkg/runner/output.go +++ b/pkg/runner/output.go @@ -84,7 +84,7 @@ func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []str if err != nil { return nil, fmt.Errorf("marshaling input for output filter: %w", err) } - res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, outputToolRef.ToolID, string(inputData), "", engine.OutputToolCategory) + res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, outputToolRef.ToolID, inputData, "", engine.OutputToolCategory) if err != nil { return nil, err } diff --git a/pkg/tests/judge/judge.go b/pkg/tests/judge/judge.go index f6581dcc..26464386 100644 --- a/pkg/tests/judge/judge.go +++ b/pkg/tests/judge/judge.go @@ -112,7 +112,7 @@ func (j *Judge[T]) Equal(ctx context.Context, expected, actual T, criteria strin }, }, } - response, err := j.client.CreateChatCompletion(ctx, request) + response, err := j.client.CreateChatCompletion(ctx, request, nil) if err != nil { return false, "", fmt.Errorf("failed to create chat completion request: %w", err) } diff --git a/pkg/tests/tester/runner.go b/pkg/tests/tester/runner.go index fa7f7683..1f59ea03 100644 --- a/pkg/tests/tester/runner.go +++ b/pkg/tests/tester/runner.go @@ -35,7 +35,7 @@ func (c *Client) ProxyInfo() (string, string, error) { return "test-auth", "test-url", nil } -func (c *Client) Call(_ context.Context, messageRequest types.CompletionRequest, _ chan<- types.CompletionStatus) (resp *types.CompletionMessage, respErr error) { +func (c *Client) Call(_ context.Context, messageRequest types.CompletionRequest, _ []string, _ chan<- types.CompletionStatus) (resp *types.CompletionMessage, respErr error) { msgData, err := json.MarshalIndent(messageRequest, "", " ") require.NoError(c.t, err)