Skip to content

Commit a297bd0

Browse files
committed
enhance: add mTLS between gptscript and daemon tools
Signed-off-by: Grant Linville <grant@acorn.io>
1 parent c5d85f1 commit a297bd0

File tree

7 files changed

+183
-10
lines changed

7 files changed

+183
-10
lines changed

pkg/certs/certs.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package certs
2+
3+
import (
4+
"crypto/ecdsa"
5+
"crypto/elliptic"
6+
"crypto/rand"
7+
"crypto/x509"
8+
"crypto/x509/pkix"
9+
"encoding/pem"
10+
"fmt"
11+
"math/big"
12+
"net"
13+
"time"
14+
)
15+
16+
// CertAndKey contains an x509 certificate (PEM format) and ECDSA private key (also PEM format)
17+
type CertAndKey struct {
18+
Cert []byte
19+
Key []byte
20+
}
21+
22+
func GenerateGPTScriptCert() (CertAndKey, error) {
23+
return GenerateSelfSignedCert("gptscript server")
24+
}
25+
26+
func GenerateSelfSignedCert(name string) (CertAndKey, error) {
27+
privateKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
28+
if err != nil {
29+
return CertAndKey{}, fmt.Errorf("failed to generate ECDSA key: %v", err)
30+
}
31+
32+
marshalledPrivateKey, err := x509.MarshalECPrivateKey(privateKey)
33+
if err != nil {
34+
return CertAndKey{}, fmt.Errorf("failed to marshal ECDSA key: %v", err)
35+
}
36+
37+
marshalledPrivateKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "EC PRIVATE KEY", Bytes: marshalledPrivateKey})
38+
39+
template := &x509.Certificate{
40+
SerialNumber: big.NewInt(time.Now().UnixNano()),
41+
Subject: pkix.Name{
42+
CommonName: name,
43+
},
44+
NotBefore: time.Now(),
45+
NotAfter: time.Now().AddDate(1, 0, 0), // a year from now
46+
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
47+
ExtKeyUsage: []x509.ExtKeyUsage{
48+
x509.ExtKeyUsageServerAuth,
49+
x509.ExtKeyUsageClientAuth,
50+
},
51+
IsCA: false,
52+
IPAddresses: []net.IP{net.ParseIP("127.0.0.1")},
53+
}
54+
55+
cert, err := x509.CreateCertificate(rand.Reader, template, template, &privateKey.PublicKey, privateKey)
56+
if err != nil {
57+
return CertAndKey{}, fmt.Errorf("failed to create certificate: %v", err)
58+
}
59+
60+
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: cert})
61+
62+
return CertAndKey{Cert: certPEM, Key: marshalledPrivateKeyPEM}, nil
63+
}

pkg/engine/daemon.go

Lines changed: 55 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@ package engine
22

33
import (
44
"context"
5+
"crypto/tls"
6+
"crypto/x509"
7+
"encoding/base64"
58
"fmt"
69
"io"
710
"math/rand"
@@ -11,11 +14,13 @@ import (
1114
"sync"
1215
"time"
1316

17+
"github.com/gptscript-ai/gptscript/pkg/certs"
1418
"github.com/gptscript-ai/gptscript/pkg/system"
1519
"github.com/gptscript-ai/gptscript/pkg/types"
1620
)
1721

1822
var ports Ports
23+
var certificates Certs
1924

2025
type Ports struct {
2126
daemonPorts map[string]int64
@@ -29,6 +34,11 @@ type Ports struct {
2934
daemonWG sync.WaitGroup
3035
}
3136

37+
type Certs struct {
38+
daemonCerts map[string]certs.CertAndKey
39+
daemonLock sync.Mutex
40+
}
41+
3242
func IsDaemonRunning(url string) bool {
3343
ports.daemonLock.Lock()
3444
defer ports.daemonLock.Unlock()
@@ -128,7 +138,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
128138
tool.Instructions = types.CommandPrefix + instructions
129139

130140
port, ok := ports.daemonPorts[tool.ID]
131-
url := fmt.Sprintf("http://127.0.0.1:%d%s", port, path)
141+
url := fmt.Sprintf("https://127.0.0.1:%d%s", port, path)
132142
if ok && ports.daemonsRunning[url] != nil {
133143
return url, nil
134144
}
@@ -144,11 +154,31 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
144154

145155
ctx := ports.daemonCtx
146156
port = nextPort()
147-
url = fmt.Sprintf("http://127.0.0.1:%d%s", port, path)
157+
url = fmt.Sprintf("https://127.0.0.1:%d%s", port, path)
158+
159+
// Generate a certificate for the daemon, unless one already exists.
160+
certificates.daemonLock.Lock()
161+
defer certificates.daemonLock.Unlock()
162+
cert, exists := certificates.daemonCerts[tool.ID]
163+
if !exists {
164+
var err error
165+
cert, err = certs.GenerateSelfSignedCert(tool.ID)
166+
if err != nil {
167+
return "", fmt.Errorf("failed to generate certificate for daemon: %v", err)
168+
}
169+
170+
if certificates.daemonCerts == nil {
171+
certificates.daemonCerts = map[string]certs.CertAndKey{}
172+
}
173+
certificates.daemonCerts[tool.ID] = cert
174+
}
148175

149176
cmd, stop, err := e.newCommand(ctx, []string{
150177
fmt.Sprintf("PORT=%d", port),
178+
fmt.Sprintf("CERT=%s", base64.StdEncoding.EncodeToString(cert.Cert)),
179+
fmt.Sprintf("PRIVATE_KEY=%s", base64.StdEncoding.EncodeToString(cert.Key)),
151180
fmt.Sprintf("GPTSCRIPT_PORT=%d", port),
181+
fmt.Sprintf("GPTSCRIPT_CERT=%s", base64.StdEncoding.EncodeToString(e.GPTScriptCert.Cert)),
152182
},
153183
tool,
154184
"{}",
@@ -210,8 +240,30 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
210240
ports.daemonWG.Done()
211241
}()
212242

243+
// Build HTTP client for checking the health of the daemon
244+
clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key)
245+
if err != nil {
246+
return "", fmt.Errorf("failed to create client certificate: %v", err)
247+
}
248+
249+
pool := x509.NewCertPool()
250+
if !pool.AppendCertsFromPEM(cert.Cert) {
251+
return "", fmt.Errorf("failed to append daemon certificate for [%s]", tool.ID)
252+
}
253+
254+
httpClient := &http.Client{
255+
Transport: &http.Transport{
256+
TLSClientConfig: &tls.Config{
257+
Certificates: []tls.Certificate{clientCert},
258+
RootCAs: pool,
259+
InsecureSkipVerify: false,
260+
},
261+
},
262+
}
263+
264+
// Check the health of the daemon
213265
for i := 0; i < 120; i++ {
214-
resp, err := http.Get(url)
266+
resp, err := httpClient.Get(url)
215267
if err == nil && resp.StatusCode == http.StatusOK {
216268
go func() {
217269
_, _ = io.ReadAll(resp.Body)

pkg/engine/engine.go

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strings"
88
"sync"
99

10+
"github.com/gptscript-ai/gptscript/pkg/certs"
1011
"github.com/gptscript-ai/gptscript/pkg/counter"
1112
"github.com/gptscript-ai/gptscript/pkg/types"
1213
"github.com/gptscript-ai/gptscript/pkg/version"
@@ -22,6 +23,7 @@ type RuntimeManager interface {
2223
}
2324

2425
type Engine struct {
26+
GPTScriptCert certs.CertAndKey
2527
Model Model
2628
RuntimeManager RuntimeManager
2729
Env []string

pkg/engine/http.go

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package engine
22

33
import (
44
"context"
5+
"crypto/tls"
6+
"crypto/x509"
57
"encoding/json"
68
"fmt"
79
"io"
@@ -40,6 +42,7 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
4042
return nil, err
4143
}
4244

45+
var tlsConfigForDaemonRequest *tls.Config
4346
if strings.HasSuffix(parsed.Hostname(), DaemonURLSuffix) {
4447
referencedToolName := strings.TrimSuffix(parsed.Hostname(), DaemonURLSuffix)
4548
referencedToolRefs, ok := tool.ToolMapping[referencedToolName]
@@ -60,6 +63,33 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
6063
}
6164
parsed.Host = toolURLParsed.Host
6265
toolURL = parsed.String()
66+
67+
// Find the certificate corresponding to this daemon tool
68+
certificates.daemonLock.Lock()
69+
daemonCert, exists := certificates.daemonCerts[referencedTool.ID]
70+
certificates.daemonLock.Unlock()
71+
72+
if !exists {
73+
return nil, fmt.Errorf("missing daemon certificate for [%s]", referencedTool.ID)
74+
}
75+
76+
// Create a pool for the certificate to treat as a CA
77+
pool := x509.NewCertPool()
78+
if !pool.AppendCertsFromPEM(daemonCert.Cert) {
79+
return nil, fmt.Errorf("failed to append daemon certificate for [%s]", referencedTool.ID)
80+
}
81+
82+
clientCert, err := tls.X509KeyPair(e.GPTScriptCert.Cert, e.GPTScriptCert.Key)
83+
if err != nil {
84+
return nil, fmt.Errorf("failed to create client certificate: %v", err)
85+
}
86+
87+
// Create TLS config for use in the HTTP client later
88+
tlsConfigForDaemonRequest = &tls.Config{
89+
Certificates: []tls.Certificate{clientCert},
90+
RootCAs: pool,
91+
InsecureSkipVerify: false,
92+
}
6393
}
6494

6595
if tool.Blocking {
@@ -112,7 +142,18 @@ func (e *Engine) runHTTP(ctx context.Context, prg *types.Program, tool types.Too
112142
req.Header.Set("Content-Type", "text/plain")
113143
}
114144

115-
resp, err := http.DefaultClient.Do(req)
145+
var httpClient *http.Client
146+
if tlsConfigForDaemonRequest != nil {
147+
httpClient = &http.Client{
148+
Transport: &http.Transport{
149+
TLSClientConfig: tlsConfigForDaemonRequest,
150+
},
151+
}
152+
} else {
153+
httpClient = http.DefaultClient
154+
}
155+
156+
resp, err := httpClient.Do(req)
116157
if err != nil {
117158
return nil, err
118159
}

pkg/gptscript/gptscript.go

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
"github.com/gptscript-ai/gptscript/pkg/builtin"
1414
"github.com/gptscript-ai/gptscript/pkg/cache"
15+
"github.com/gptscript-ai/gptscript/pkg/certs"
1516
"github.com/gptscript-ai/gptscript/pkg/config"
1617
context2 "github.com/gptscript-ai/gptscript/pkg/context"
1718
"github.com/gptscript-ai/gptscript/pkg/credentials"
@@ -107,7 +108,12 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) {
107108
opts.Runner.RuntimeManager = runtimes.Default(cacheClient.CacheDir(), opts.SystemToolsDir)
108109
}
109110

110-
simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env)
111+
gptscriptCert, err := certs.GenerateGPTScriptCert()
112+
if err != nil {
113+
return nil, err
114+
}
115+
116+
simplerRunner, err := newSimpleRunner(cacheClient, opts.Runner.RuntimeManager, opts.Env, gptscriptCert)
111117
if err != nil {
112118
return nil, err
113119
}
@@ -140,7 +146,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) {
140146
opts.Runner.MonitorFactory = monitor.NewConsole(opts.Monitor, monitor.Options{DebugMessages: *opts.Quiet})
141147
}
142148

143-
runner, err := runner.New(registry, credStore, opts.Runner)
149+
runner, err := runner.New(registry, credStore, gptscriptCert, opts.Runner)
144150
if err != nil {
145151
return nil, err
146152
}
@@ -285,8 +291,8 @@ type simpleRunner struct {
285291
env []string
286292
}
287293

288-
func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string) (*simpleRunner, error) {
289-
runner, err := runner.New(noopModel{}, credentials.NoopStore{}, runner.Options{
294+
func newSimpleRunner(cache *cache.Client, rm engine.RuntimeManager, env []string, gptscriptCert certs.CertAndKey) (*simpleRunner, error) {
295+
runner, err := runner.New(noopModel{}, credentials.NoopStore{}, gptscriptCert, runner.Options{
290296
RuntimeManager: rm,
291297
MonitorFactory: simpleMonitorFactory{},
292298
})

pkg/runner/runner.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"time"
1212

1313
"github.com/gptscript-ai/gptscript/pkg/builtin"
14+
"github.com/gptscript-ai/gptscript/pkg/certs"
1415
context2 "github.com/gptscript-ai/gptscript/pkg/context"
1516
"github.com/gptscript-ai/gptscript/pkg/credentials"
1617
"github.com/gptscript-ai/gptscript/pkg/engine"
@@ -95,9 +96,10 @@ type Runner struct {
9596
credOverrides []string
9697
credStore credentials.CredentialStore
9798
sequential bool
99+
gptscriptCert certs.CertAndKey
98100
}
99101

100-
func New(client engine.Model, credStore credentials.CredentialStore, opts ...Options) (*Runner, error) {
102+
func New(client engine.Model, credStore credentials.CredentialStore, gptscriptCert certs.CertAndKey, opts ...Options) (*Runner, error) {
101103
opt := complete(opts...)
102104

103105
runner := &Runner{
@@ -109,6 +111,7 @@ func New(client engine.Model, credStore credentials.CredentialStore, opts ...Opt
109111
credStore: credStore,
110112
sequential: opt.Sequential,
111113
auth: opt.Authorizer,
114+
gptscriptCert: gptscriptCert,
112115
}
113116

114117
if opt.StartPort != 0 {
@@ -411,6 +414,7 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
411414
RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager),
412415
Progress: progress,
413416
Env: env,
417+
GPTScriptCert: r.gptscriptCert,
414418
}
415419

416420
callCtx.Ctx = context2.AddPauseFuncToCtx(callCtx.Ctx, monitor.Pause)
@@ -593,6 +597,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
593597
RuntimeManager: runtimeWithLogger(callCtx, monitor, r.runtimeManager),
594598
Progress: progress,
595599
Env: env,
600+
GPTScriptCert: r.gptscriptCert,
596601
}
597602

598603
var contentInput string

pkg/tests/tester/runner.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"testing"
1010

1111
"github.com/adrg/xdg"
12+
"github.com/gptscript-ai/gptscript/pkg/certs"
1213
"github.com/gptscript-ai/gptscript/pkg/credentials"
1314
"github.com/gptscript-ai/gptscript/pkg/loader"
1415
"github.com/gptscript-ai/gptscript/pkg/repos/runtimes"
@@ -198,7 +199,10 @@ func NewRunner(t *testing.T) *Runner {
198199

199200
rm := runtimes.Default(cacheDir, "")
200201

201-
run, err := runner.New(c, credentials.NoopStore{}, runner.Options{
202+
gptscriptCert, err := certs.GenerateGPTScriptCert()
203+
require.NoError(t, err)
204+
205+
run, err := runner.New(c, credentials.NoopStore{}, gptscriptCert, runner.Options{
202206
Sequential: true,
203207
RuntimeManager: rm,
204208
})

0 commit comments

Comments
 (0)