Skip to content

Commit aa5ef57

Browse files
committed
fixes
Signed-off-by: Grant Linville <grant@acorn.io>
1 parent 2b7cb50 commit aa5ef57

File tree

4 files changed

+123
-23
lines changed

4 files changed

+123
-23
lines changed

pkg/engine/daemon.go

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,29 @@ type Certs struct {
4040
lock sync.Mutex
4141
}
4242

43+
func GetClientCert() (certs.CertAndKey, error) {
44+
certificates.lock.Lock()
45+
defer certificates.lock.Unlock()
46+
if len(certificates.clientCert.Cert) == 0 {
47+
cert, err := certs.GenerateGPTScriptCert()
48+
if err != nil {
49+
return certs.CertAndKey{}, fmt.Errorf("failed to generate GPTScript certificate: %v", err)
50+
}
51+
certificates.clientCert = cert
52+
}
53+
return certificates.clientCert, nil
54+
}
55+
56+
func GetDaemonCert(toolID string) ([]byte, error) {
57+
certificates.lock.Lock()
58+
defer certificates.lock.Unlock()
59+
cert, exists := certificates.daemonCerts[toolID]
60+
if !exists {
61+
return nil, fmt.Errorf("daemon certificate for [%s] not found", toolID)
62+
}
63+
return cert.Cert, nil
64+
}
65+
4366
func IsDaemonRunning(url string) bool {
4467
ports.daemonLock.Lock()
4568
defer ports.daemonLock.Unlock()

pkg/engine/http.go

Lines changed: 41 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414
"slices"
1515
"strings"
1616

17+
"github.com/gptscript-ai/gptscript/pkg/certs"
1718
"github.com/gptscript-ai/gptscript/pkg/types"
1819
)
1920

@@ -74,22 +75,22 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
7475
return nil, fmt.Errorf("missing daemon certificate for [%s]", referencedTool.ID)
7576
}
7677

77-
// Create a pool for the certificate to treat as a CA
78-
pool := x509.NewCertPool()
79-
if !pool.AppendCertsFromPEM(daemonCert.Cert) {
80-
return nil, fmt.Errorf("failed to append daemon certificate for [%s]", referencedTool.ID)
81-
}
82-
83-
tlsClientCert, err := tls.X509KeyPair(clientCert.Cert, clientCert.Key)
78+
tlsConfigForDaemonRequest, err = getTLSConfig(clientCert, daemonCert.Cert)
8479
if err != nil {
85-
return nil, fmt.Errorf("failed to create client certificate: %v", err)
80+
return nil, err
8681
}
82+
} else if isLocalhostHTTPS(toolURL) {
83+
// This sometimes happens when talking to a model provider
84+
certificates.lock.Lock()
85+
daemonCert, exists := certificates.daemonCerts[tool.ID]
86+
clientCert := certificates.clientCert
87+
certificates.lock.Unlock()
8788

88-
// Create TLS config for use in the HTTP client later
89-
tlsConfigForDaemonRequest = &tls.Config{
90-
Certificates: []tls.Certificate{tlsClientCert},
91-
RootCAs: pool,
92-
InsecureSkipVerify: false,
89+
if exists {
90+
tlsConfigForDaemonRequest, err = getTLSConfig(clientCert, daemonCert.Cert)
91+
if err != nil {
92+
return nil, err
93+
}
9394
}
9495
}
9596

@@ -185,3 +186,30 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
185186
Result: &s,
186187
}, nil
187188
}
189+
190+
func isLocalhostHTTPS(u string) bool {
191+
parsed, err := url.Parse(u)
192+
if err != nil {
193+
return false
194+
}
195+
196+
return parsed.Scheme == "https" && (parsed.Hostname() == "localhost" || parsed.Hostname() == "127.0.0.1")
197+
}
198+
199+
func getTLSConfig(clientCert certs.CertAndKey, daemonCert []byte) (*tls.Config, error) {
200+
pool := x509.NewCertPool()
201+
if !pool.AppendCertsFromPEM(daemonCert) {
202+
return nil, fmt.Errorf("failed to append daemon certificate")
203+
}
204+
205+
tlsClientCert, err := tls.X509KeyPair(clientCert.Cert, clientCert.Key)
206+
if err != nil {
207+
return nil, fmt.Errorf("failed to create client certificate: %v", err)
208+
}
209+
210+
return &tls.Config{
211+
Certificates: []tls.Certificate{tlsClientCert},
212+
RootCAs: pool,
213+
InsecureSkipVerify: false,
214+
}, nil
215+
}

pkg/openai/client.go

Lines changed: 44 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,12 @@ package openai
22

33
import (
44
"context"
5+
"crypto/tls"
6+
"crypto/x509"
57
"errors"
68
"io"
79
"log/slog"
10+
"net/http"
811
"os"
912
"slices"
1013
"sort"
@@ -13,6 +16,7 @@ import (
1316

1417
openai "github.com/gptscript-ai/chat-completion-client"
1518
"github.com/gptscript-ai/gptscript/pkg/cache"
19+
"github.com/gptscript-ai/gptscript/pkg/certs"
1620
"github.com/gptscript-ai/gptscript/pkg/counter"
1721
"github.com/gptscript-ai/gptscript/pkg/credentials"
1822
"github.com/gptscript-ai/gptscript/pkg/hash"
@@ -51,13 +55,15 @@ type Client struct {
5155
}
5256

5357
type Options struct {
54-
BaseURL string `usage:"OpenAI base URL" name:"openai-base-url" env:"OPENAI_BASE_URL"`
55-
APIKey string `usage:"OpenAI API KEY" name:"openai-api-key" env:"OPENAI_API_KEY"`
56-
OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"`
57-
DefaultModel string `usage:"Default LLM model to use" default:"gpt-4o"`
58-
ConfigFile string `usage:"Path to GPTScript config file" name:"config"`
59-
SetSeed bool `usage:"-"`
60-
CacheKey string `usage:"-"`
58+
BaseURL string `usage:"OpenAI base URL" name:"openai-base-url" env:"OPENAI_BASE_URL"`
59+
APIKey string `usage:"OpenAI API KEY" name:"openai-api-key" env:"OPENAI_API_KEY"`
60+
OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"`
61+
DefaultModel string `usage:"Default LLM model to use" default:"gpt-4o"`
62+
ConfigFile string `usage:"Path to GPTScript config file" name:"config"`
63+
SetSeed bool `usage:"-"`
64+
CacheKey string `usage:"-"`
65+
ClientCert certs.CertAndKey `usage:"-"`
66+
ServerCert []byte `usage:"-"`
6167
Cache *cache.Client
6268
}
6369

@@ -70,6 +76,14 @@ func Complete(opts ...Options) (result Options) {
7076
result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel)
7177
result.SetSeed = types.FirstSet(opt.SetSeed, result.SetSeed)
7278
result.CacheKey = types.FirstSet(opt.CacheKey, result.CacheKey)
79+
80+
if len(opt.ClientCert.Cert) > 0 {
81+
result.ClientCert = opt.ClientCert
82+
}
83+
84+
if len(opt.ServerCert) > 0 {
85+
result.ServerCert = opt.ServerCert
86+
}
7387
}
7488

7589
return result
@@ -116,6 +130,29 @@ func NewClient(ctx context.Context, credStore credentials.CredentialStore, opts
116130
cfg.BaseURL = types.FirstSet(opt.BaseURL, cfg.BaseURL)
117131
cfg.OrgID = types.FirstSet(opt.OrgID, cfg.OrgID)
118132

133+
// Set up for mTLS, if configured.
134+
if opt.ServerCert != nil && len(opt.ClientCert.Cert) > 0 {
135+
pool := x509.NewCertPool()
136+
if !pool.AppendCertsFromPEM(opt.ServerCert) {
137+
return nil, errors.New("failed to append server cert to pool")
138+
}
139+
140+
clientCert, err := tls.X509KeyPair(opt.ClientCert.Cert, opt.ClientCert.Key)
141+
if err != nil {
142+
return nil, err
143+
}
144+
145+
cfg.HTTPClient = &http.Client{
146+
Transport: &http.Transport{
147+
TLSClientConfig: &tls.Config{
148+
Certificates: []tls.Certificate{clientCert},
149+
RootCAs: pool,
150+
InsecureSkipVerify: false,
151+
},
152+
},
153+
}
154+
}
155+
119156
cacheKeyBase := opt.CacheKey
120157
if cacheKeyBase == "" {
121158
cacheKeyBase = hash.ID(opt.APIKey, opt.BaseURL)

pkg/remote/remote.go

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -166,10 +166,22 @@ func (c *Client) load(ctx context.Context, toolName string, env ...string) (*ope
166166
return nil, err
167167
}
168168

169+
clientCert, err := engine.GetClientCert()
170+
if err != nil {
171+
return nil, err
172+
}
173+
174+
serverCert, err := engine.GetDaemonCert(prg.EntryToolID)
175+
if err != nil {
176+
return nil, err
177+
}
178+
169179
oClient, err := openai.NewClient(ctx, c.credStore, openai.Options{
170-
BaseURL: strings.TrimSuffix(url, "/") + "/v1",
171-
Cache: c.cache,
172-
CacheKey: prg.EntryToolID,
180+
BaseURL: strings.TrimSuffix(url, "/") + "/v1",
181+
Cache: c.cache,
182+
CacheKey: prg.EntryToolID,
183+
ClientCert: clientCert,
184+
ServerCert: serverCert,
173185
})
174186
if err != nil {
175187
return nil, err

0 commit comments

Comments
 (0)