diff --git a/pkg/certs/certs.go b/pkg/certs/certs.go new file mode 100644 index 00000000..f001532b --- /dev/null +++ b/pkg/certs/certs.go @@ -0,0 +1,63 @@ +package certs + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "fmt" + "math/big" + "net" + "time" +) + +// CertAndKey contains an x509 certificate (PEM format) and ECDSA private key (also PEM format) +type CertAndKey struct { + Cert []byte + Key []byte +} + +func GenerateGPTScriptCert() (CertAndKey, error) { + return GenerateSelfSignedCert("gptscript server") +} + +func GenerateSelfSignedCert(name string) (CertAndKey, error) { + privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return CertAndKey{}, fmt.Errorf("failed to generate ECDSA key: %v", err) + } + + marshalledPrivateKey, err := x509.MarshalECPrivateKey(privateKey) + if err != nil { + return CertAndKey{}, fmt.Errorf("failed to marshal ECDSA key: %v", err) + } + + marshalledPrivateKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: marshalledPrivateKey}) + + template := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().UnixNano()), + Subject: pkix.Name{ + CommonName: name, + }, + NotBefore: time.Now(), + NotAfter: time.Now().AddDate(1, 0, 0), // a year from now + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{ + x509.ExtKeyUsageServerAuth, + x509.ExtKeyUsageClientAuth, + }, + IsCA: false, + IPAddresses: []net.IP{net.ParseIP("127.0.0.1")}, + } + + cert, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey) + if err != nil { + return CertAndKey{}, fmt.Errorf("failed to create certificate: %v", err) + } + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert}) + + return CertAndKey{Cert: certPEM, Key: marshalledPrivateKeyPEM}, nil +} diff --git a/pkg/engine/daemon.go b/pkg/engine/daemon.go index b7877da3..59c93361 100644 --- a/pkg/engine/daemon.go +++ b/pkg/engine/daemon.go @@ -2,6 +2,9 @@ package engine import ( "context" + "crypto/tls" + "crypto/x509" + "encoding/base64" "fmt" "io" "math/rand" @@ -11,11 +14,13 @@ import ( "sync" "time" + "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/system" "github.com/gptscript-ai/gptscript/pkg/types" ) var ports Ports +var certificates Certs type Ports struct { daemonPorts map[string]int64 @@ -29,6 +34,35 @@ type Ports struct { daemonWG sync.WaitGroup } +type Certs struct { + daemonCerts map[string]certs.CertAndKey + clientCert certs.CertAndKey + lock sync.Mutex +} + +func GetClientCert() (certs.CertAndKey, error) { + certificates.lock.Lock() + defer certificates.lock.Unlock() + if len(certificates.clientCert.Cert) == 0 { + cert, err := certs.GenerateGPTScriptCert() + if err != nil { + return certs.CertAndKey{}, fmt.Errorf("failed to generate GPTScript certificate: %v", err) + } + certificates.clientCert = cert + } + return certificates.clientCert, nil +} + +func GetDaemonCert(toolID string) ([]byte, error) { + certificates.lock.Lock() + defer certificates.lock.Unlock() + cert, exists := certificates.daemonCerts[toolID] + if !exists { + return nil, fmt.Errorf("daemon certificate for [%s] not found", toolID) + } + return cert.Cert, nil +} + func IsDaemonRunning(url string) bool { ports.daemonLock.Lock() defer ports.daemonLock.Unlock() @@ -128,7 +162,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { tool.Instructions = types.CommandPrefix + instructions port, ok := ports.daemonPorts[tool.ID] - url := fmt.Sprintf("http://127.0.0.1:%d%s", port, path) + url := fmt.Sprintf("https://127.0.0.1:%d%s", port, path) if ok && ports.daemonsRunning[url] != nil { return url, nil } @@ -144,11 +178,40 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { ctx := ports.daemonCtx port = nextPort() - url = fmt.Sprintf("http://127.0.0.1:%d%s", port, path) + url = fmt.Sprintf("https://127.0.0.1:%d%s", port, path) + + // Generate a certificate for the daemon, unless one already exists. + certificates.lock.Lock() + defer certificates.lock.Unlock() + cert, exists := certificates.daemonCerts[tool.ID] + if !exists { + var err error + cert, err = certs.GenerateSelfSignedCert(tool.ID) + if err != nil { + return "", fmt.Errorf("failed to generate certificate for daemon: %v", err) + } + + if certificates.daemonCerts == nil { + certificates.daemonCerts = map[string]certs.CertAndKey{} + } + certificates.daemonCerts[tool.ID] = cert + } + + // Set the client certificate if there isn't one already. + if len(certificates.clientCert.Cert) == 0 { + gptscriptCert, err := certs.GenerateGPTScriptCert() + if err != nil { + return "", fmt.Errorf("failed to generate GPTScript certificate: %v", err) + } + certificates.clientCert = gptscriptCert + } cmd, stop, err := e.newCommand(ctx, []string{ fmt.Sprintf("PORT=%d", port), + fmt.Sprintf("CERT=%s", base64.StdEncoding.EncodeToString(cert.Cert)), + fmt.Sprintf("PRIVATE_KEY=%s", base64.StdEncoding.EncodeToString(cert.Key)), fmt.Sprintf("GPTSCRIPT_PORT=%d", port), + fmt.Sprintf("GPTSCRIPT_CERT=%s", base64.StdEncoding.EncodeToString(certificates.clientCert.Cert)), }, tool, "{}", @@ -210,8 +273,30 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) { ports.daemonWG.Done() }() + // Build HTTP client for checking the health of the daemon + tlsClientCert, err := tls.X509KeyPair(certificates.clientCert.Cert, certificates.clientCert.Key) + if err != nil { + return "", fmt.Errorf("failed to create client certificate: %v", err) + } + + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(cert.Cert) { + return "", fmt.Errorf("failed to append daemon certificate for [%s]", tool.ID) + } + + httpClient := &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{tlsClientCert}, + RootCAs: pool, + InsecureSkipVerify: false, + }, + }, + } + + // Check the health of the daemon for i := 0; i < 120; i++ { - resp, err := http.Get(url) + resp, err := httpClient.Get(url) if err == nil && resp.StatusCode == http.StatusOK { go func() { _, _ = io.ReadAll(resp.Body) diff --git a/pkg/engine/http.go b/pkg/engine/http.go index 109db559..222d6977 100644 --- a/pkg/engine/http.go +++ b/pkg/engine/http.go @@ -2,6 +2,8 @@ package engine import ( "context" + "crypto/tls" + "crypto/x509" "encoding/json" "fmt" "io" @@ -12,6 +14,7 @@ import ( "slices" "strings" + "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/types" ) @@ -40,6 +43,7 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too return nil, err } + var tlsConfigForDaemonRequest *tls.Config if strings.HasSuffix(parsed.Hostname(), DaemonURLSuffix) { referencedToolName := strings.TrimSuffix(parsed.Hostname(), DaemonURLSuffix) referencedToolRefs, ok := tool.ToolMapping[referencedToolName] @@ -60,6 +64,34 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too } parsed.Host = toolURLParsed.Host toolURL = parsed.String() + + // Find the certificate corresponding to this daemon tool + certificates.lock.Lock() + daemonCert, exists := certificates.daemonCerts[referencedTool.ID] + clientCert := certificates.clientCert + certificates.lock.Unlock() + + if !exists { + return nil, fmt.Errorf("missing daemon certificate for [%s]", referencedTool.ID) + } + + tlsConfigForDaemonRequest, err = getTLSConfig(clientCert, daemonCert.Cert) + if err != nil { + return nil, err + } + } else if isLocalhostHTTPS(toolURL) { + // This sometimes happens when talking to a model provider + certificates.lock.Lock() + daemonCert, exists := certificates.daemonCerts[tool.ID] + clientCert := certificates.clientCert + certificates.lock.Unlock() + + if exists { + tlsConfigForDaemonRequest, err = getTLSConfig(clientCert, daemonCert.Cert) + if err != nil { + return nil, err + } + } } if tool.Blocking { @@ -112,7 +144,18 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too req.Header.Set("Content-Type", "text/plain") } - resp, err := http.DefaultClient.Do(req) + var httpClient *http.Client + if tlsConfigForDaemonRequest != nil { + httpClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: tlsConfigForDaemonRequest, + }, + } + } else { + httpClient = http.DefaultClient + } + + resp, err := httpClient.Do(req) if err != nil { return nil, err } @@ -143,3 +186,30 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too Result: &s, }, nil } + +func isLocalhostHTTPS(u string) bool { + parsed, err := url.Parse(u) + if err != nil { + return false + } + + return parsed.Scheme == "https" && (parsed.Hostname() == "localhost" || parsed.Hostname() == "127.0.0.1") +} + +func getTLSConfig(clientCert certs.CertAndKey, daemonCert []byte) (*tls.Config, error) { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(daemonCert) { + return nil, fmt.Errorf("failed to append daemon certificate") + } + + tlsClientCert, err := tls.X509KeyPair(clientCert.Cert, clientCert.Key) + if err != nil { + return nil, fmt.Errorf("failed to create client certificate: %v", err) + } + + return &tls.Config{ + Certificates: []tls.Certificate{tlsClientCert}, + RootCAs: pool, + InsecureSkipVerify: false, + }, nil +} diff --git a/pkg/openai/client.go b/pkg/openai/client.go index 1894bdda..0c660b4c 100644 --- a/pkg/openai/client.go +++ b/pkg/openai/client.go @@ -2,9 +2,12 @@ package openai import ( "context" + "crypto/tls" + "crypto/x509" "errors" "io" "log/slog" + "net/http" "os" "slices" "sort" @@ -13,6 +16,7 @@ import ( openai "github.com/gptscript-ai/chat-completion-client" "github.com/gptscript-ai/gptscript/pkg/cache" + "github.com/gptscript-ai/gptscript/pkg/certs" "github.com/gptscript-ai/gptscript/pkg/counter" "github.com/gptscript-ai/gptscript/pkg/credentials" "github.com/gptscript-ai/gptscript/pkg/hash" @@ -51,13 +55,15 @@ type Client struct { } type Options struct { - BaseURL string `usage:"OpenAI base URL" name:"openai-base-url" env:"OPENAI_BASE_URL"` - APIKey string `usage:"OpenAI API KEY" name:"openai-api-key" env:"OPENAI_API_KEY"` - OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"` - DefaultModel string `usage:"Default LLM model to use" default:"gpt-4o"` - ConfigFile string `usage:"Path to GPTScript config file" name:"config"` - SetSeed bool `usage:"-"` - CacheKey string `usage:"-"` + BaseURL string `usage:"OpenAI base URL" name:"openai-base-url" env:"OPENAI_BASE_URL"` + APIKey string `usage:"OpenAI API KEY" name:"openai-api-key" env:"OPENAI_API_KEY"` + OrgID string `usage:"OpenAI organization ID" name:"openai-org-id" env:"OPENAI_ORG_ID"` + DefaultModel string `usage:"Default LLM model to use" default:"gpt-4o"` + ConfigFile string `usage:"Path to GPTScript config file" name:"config"` + SetSeed bool `usage:"-"` + CacheKey string `usage:"-"` + ClientCert certs.CertAndKey `usage:"-"` + ServerCert []byte `usage:"-"` Cache *cache.Client } @@ -70,6 +76,14 @@ func Complete(opts ...Options) (result Options) { result.DefaultModel = types.FirstSet(opt.DefaultModel, result.DefaultModel) result.SetSeed = types.FirstSet(opt.SetSeed, result.SetSeed) result.CacheKey = types.FirstSet(opt.CacheKey, result.CacheKey) + + if len(opt.ClientCert.Cert) > 0 { + result.ClientCert = opt.ClientCert + } + + if len(opt.ServerCert) > 0 { + result.ServerCert = opt.ServerCert + } } return result @@ -116,6 +130,29 @@ func NewClient(ctx context.Context, credStore credentials.CredentialStore, opts cfg.BaseURL = types.FirstSet(opt.BaseURL, cfg.BaseURL) cfg.OrgID = types.FirstSet(opt.OrgID, cfg.OrgID) + // Set up for mTLS, if configured. + if opt.ServerCert != nil && len(opt.ClientCert.Cert) > 0 { + pool := x509.NewCertPool() + if !pool.AppendCertsFromPEM(opt.ServerCert) { + return nil, errors.New("failed to append server cert to pool") + } + + clientCert, err := tls.X509KeyPair(opt.ClientCert.Cert, opt.ClientCert.Key) + if err != nil { + return nil, err + } + + cfg.HTTPClient = &http.Client{ + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + Certificates: []tls.Certificate{clientCert}, + RootCAs: pool, + InsecureSkipVerify: false, + }, + }, + } + } + cacheKeyBase := opt.CacheKey if cacheKeyBase == "" { cacheKeyBase = hash.ID(opt.APIKey, opt.BaseURL) diff --git a/pkg/remote/remote.go b/pkg/remote/remote.go index 5542372b..e5132d78 100644 --- a/pkg/remote/remote.go +++ b/pkg/remote/remote.go @@ -166,10 +166,22 @@ func (c *Client) load(ctx context.Context, toolName string, env ...string) (*ope return nil, err } + clientCert, err := engine.GetClientCert() + if err != nil { + return nil, err + } + + serverCert, err := engine.GetDaemonCert(prg.EntryToolID) + if err != nil { + return nil, err + } + oClient, err := openai.NewClient(ctx, c.credStore, openai.Options{ - BaseURL: strings.TrimSuffix(url, "/") + "/v1", - Cache: c.cache, - CacheKey: prg.EntryToolID, + BaseURL: strings.TrimSuffix(url, "/") + "/v1", + Cache: c.cache, + CacheKey: prg.EntryToolID, + ClientCert: clientCert, + ServerCert: serverCert, }) if err != nil { return nil, err