Skip to content

Commit acb92ce

Browse files
authored
Merge pull request GoogleCloudPlatform#142 from jellybeanfiend/master
Automatically refresh SSL certs before they expire
2 parents 1e456b1 + 6903e12 commit acb92ce

File tree

2 files changed

+113
-26
lines changed

2 files changed

+113
-26
lines changed

proxy/proxy/client.go

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,13 @@ const (
3232
keepAlivePeriod = time.Minute
3333
)
3434

35-
// errNotCached is returned when the instance was not found in the Client's
36-
// cache. It is an internal detail and is not actually ever returned to the
37-
// user.
38-
var errNotCached = errors.New("instance was not found in cache")
35+
var (
36+
// errNotCached is returned when the instance was not found in the Client's
37+
// cache. It is an internal detail and is not actually ever returned to the
38+
// user.
39+
errNotCached = errors.New("instance was not found in cache")
40+
refreshCertBuffer = 5 * time.Minute
41+
)
3942

4043
// Conn represents a connection from a client to a specific instance.
4144
type Conn struct {
@@ -78,9 +81,12 @@ type Client struct {
7881

7982
// The cfgCache holds the most recent connection configuration keyed by
8083
// instance. Relevant functions are refreshCfg and cachedCfg. It is
81-
// protected by cfgL.
84+
// protected by cacheL.
8285
cfgCache map[string]cacheEntry
83-
cfgL sync.RWMutex
86+
cacheL sync.Mutex
87+
88+
// refreshCfgL prevents multiple goroutines from contacting the Cloud SQL API at once.
89+
refreshCfgL sync.Mutex
8490

8591
// MaxConnections is the maximum number of connections to establish
8692
// before refusing new connections. 0 means no limit.
@@ -150,32 +156,39 @@ func (c *Client) handleConn(conn Conn) {
150156
// address as well as construct a new tls.Config to connect to the instance. It
151157
// caches the result.
152158
func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, err error) {
153-
c.cfgL.Lock()
154-
defer c.cfgL.Unlock()
159+
c.refreshCfgL.Lock()
160+
defer c.refreshCfgL.Unlock()
155161

156162
throttle := c.RefreshCfgThrottle
157163
if throttle == 0 {
158164
throttle = DefaultRefreshCfgThrottle
159165
}
160166

161-
if old := c.cfgCache[instance]; time.Since(old.lastRefreshed) < throttle {
167+
c.cacheL.Lock()
168+
if c.cfgCache == nil {
169+
c.cfgCache = make(map[string]cacheEntry)
170+
}
171+
old, oldok := c.cfgCache[instance]
172+
c.cacheL.Unlock()
173+
174+
if oldok && time.Since(old.lastRefreshed) < throttle {
162175
logging.Errorf("Throttling refreshCfg(%s): it was only called %v ago", instance, time.Since(old.lastRefreshed))
163176
// Refresh was called too recently, just reuse the result.
164177
return old.addr, old.cfg, old.err
165178
}
166179

167-
if c.cfgCache == nil {
168-
c.cfgCache = make(map[string]cacheEntry)
169-
}
170-
171180
defer func() {
181+
if err != nil && oldok {
182+
return
183+
}
184+
c.cacheL.Lock()
172185
c.cfgCache[instance] = cacheEntry{
173186
lastRefreshed: time.Now(),
174-
175-
err: err,
176-
addr: addr,
177-
cfg: cfg,
187+
err: err,
188+
addr: addr,
189+
cfg: cfg,
178190
}
191+
c.cacheL.Unlock()
179192
}()
180193

181194
mycert, err := c.Certs.Local(instance)
@@ -195,13 +208,26 @@ func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, err
195208
Certificates: []tls.Certificate{mycert},
196209
RootCAs: certs,
197210
}
211+
212+
// Refresh cert 5 minutes before it expires.
213+
timeToRefresh := cfg.Certificates[0].Leaf.NotAfter.Sub(time.Now()) - refreshCertBuffer
214+
if timeToRefresh > 0 {
215+
go func() {
216+
<-time.After(timeToRefresh)
217+
logging.Verbosef("Cert for instance %s will expire soon, refreshing now.", instance)
218+
if _, _, err := c.refreshCfg(instance); err != nil {
219+
logging.Errorf("couldn't connect to %q: %v", instance, err)
220+
}
221+
}()
222+
}
223+
198224
return fmt.Sprintf("%s:%d", addr, c.Port), cfg, nil
199225
}
200226

201227
func (c *Client) cachedCfg(instance string) (string, *tls.Config) {
202-
c.cfgL.RLock()
228+
c.cacheL.Lock()
203229
ret, ok := c.cfgCache[instance]
204-
c.cfgL.RUnlock()
230+
c.cacheL.Unlock()
205231

206232
// Don't waste time returning an expired/invalid cert.
207233
if !ok || ret.err != nil || time.Now().After(ret.cfg.Certificates[0].Leaf.NotAfter) {
@@ -225,6 +251,7 @@ func (c *Client) Dial(instance string) (net.Conn, error) {
225251
if err != nil {
226252
return nil, err
227253
}
254+
228255
return c.tryConnect(addr, cfg)
229256
}
230257

proxy/proxy/client_test.go

Lines changed: 67 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,15 +28,19 @@ import (
2828

2929
const instance = "instance-name"
3030

31-
var errFakeDial = errors.New("this error is returned by the dialer")
31+
var (
32+
errFakeDial = errors.New("this error is returned by the dialer")
33+
forever = time.Date(9999, 0, 0, 0, 0, 0, 0, time.UTC)
34+
)
3235

3336
type fakeCerts struct {
3437
sync.Mutex
3538
called int
3639
}
3740

3841
type blockingCertSource struct {
39-
values map[string]*fakeCerts
42+
values map[string]*fakeCerts
43+
validUntil time.Time
4044
}
4145

4246
func (cs *blockingCertSource) Local(instance string) (tls.Certificate, error) {
@@ -48,11 +52,10 @@ func (cs *blockingCertSource) Local(instance string) (tls.Certificate, error) {
4852
v.called++
4953
v.Unlock()
5054

51-
validUntil, _ := time.Parse("2006", "9999")
5255
// Returns a cert which is valid forever.
5356
return tls.Certificate{
5457
Leaf: &x509.Certificate{
55-
NotAfter: validUntil,
58+
NotAfter: cs.validUntil,
5659
},
5760
}, nil
5861
}
@@ -67,7 +70,9 @@ func TestClientCache(t *testing.T) {
6770
Certs: &blockingCertSource{
6871
map[string]*fakeCerts{
6972
instance: b,
70-
}},
73+
},
74+
forever,
75+
},
7176
Dialer: func(string, string) (net.Conn, error) {
7277
return nil, errFakeDial
7378
},
@@ -92,7 +97,9 @@ func TestConcurrentRefresh(t *testing.T) {
9297
Certs: &blockingCertSource{
9398
map[string]*fakeCerts{
9499
instance: b,
95-
}},
100+
},
101+
forever,
102+
},
96103
Dialer: func(string, string) (net.Conn, error) {
97104
return nil, errFakeDial
98105
},
@@ -131,7 +138,9 @@ func TestMaximumConnectionsCount(t *testing.T) {
131138

132139
b := &fakeCerts{}
133140
certSource := blockingCertSource{
134-
map[string]*fakeCerts{}}
141+
map[string]*fakeCerts{},
142+
forever,
143+
}
135144
firstDialExited := make(chan struct{})
136145
c := &Client{
137146
Certs: &certSource,
@@ -183,3 +192,54 @@ func TestMaximumConnectionsCount(t *testing.T) {
183192
t.Errorf("client should have dialed exactly the maximum of %d connections (%d connections, %d dials)", maxConnections, numConnections, dials)
184193
}
185194
}
195+
196+
func TestRefreshTimer(t *testing.T) {
197+
refreshCertBuffer = time.Millisecond * 10
198+
timeToExpire := time.Millisecond * 500
199+
b := &fakeCerts{}
200+
c := &Client{
201+
Certs: &blockingCertSource{
202+
map[string]*fakeCerts{
203+
instance: b,
204+
},
205+
time.Now().Add(timeToExpire),
206+
},
207+
Dialer: func(string, string) (net.Conn, error) {
208+
return nil, errFakeDial
209+
},
210+
RefreshCfgThrottle: 20 * time.Millisecond,
211+
}
212+
213+
// Call Dial to cache the cert.
214+
if _, err := c.Dial(instance); err != errFakeDial {
215+
t.Errorf("unexpected error: %v", err)
216+
}
217+
218+
c.cacheL.Lock()
219+
cached, ok := c.cfgCache[instance]
220+
c.cacheL.Unlock()
221+
if !ok {
222+
t.Error("expected instance to be cached")
223+
}
224+
waitTil := time.After(timeToExpire + (10 * time.Millisecond))
225+
loop:
226+
for {
227+
select {
228+
case <-waitTil:
229+
break loop
230+
default:
231+
time.Sleep(100 * time.Millisecond)
232+
}
233+
}
234+
235+
// Verify cert was refreshed in the background, without calling Dial again.
236+
c.cacheL.Lock()
237+
refreshed, ok := c.cfgCache[instance]
238+
c.cacheL.Unlock()
239+
if !ok {
240+
t.Error("expected instance to be cached")
241+
}
242+
if !refreshed.lastRefreshed.After(cached.lastRefreshed) {
243+
t.Error("expected cert to be refreshed.")
244+
}
245+
}

0 commit comments

Comments
 (0)