Skip to content

Commit b858be3

Browse files
authored
Merge pull request GoogleCloudPlatform#158 from GoogleCloudPlatform/revert-142-master
Revert "Automatically refresh SSL certs before they expire"
2 parents acb92ce + 544cf13 commit b858be3

File tree

2 files changed

+26
-113
lines changed

2 files changed

+26
-113
lines changed

proxy/proxy/client.go

Lines changed: 19 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,10 @@ const (
3232
keepAlivePeriod = time.Minute
3333
)
3434

35-
var (
36-
// errNotCached is returned when the instance was not found in the Client's
37-
// cache. It is an internal detail and is not actually ever returned to the
38-
// user.
39-
errNotCached = errors.New("instance was not found in cache")
40-
refreshCertBuffer = 5 * time.Minute
41-
)
35+
// errNotCached is returned when the instance was not found in the Client's
36+
// cache. It is an internal detail and is not actually ever returned to the
37+
// user.
38+
var errNotCached = errors.New("instance was not found in cache")
4239

4340
// Conn represents a connection from a client to a specific instance.
4441
type Conn struct {
@@ -81,12 +78,9 @@ type Client struct {
8178

8279
// The cfgCache holds the most recent connection configuration keyed by
8380
// instance. Relevant functions are refreshCfg and cachedCfg. It is
84-
// protected by cacheL.
81+
// protected by cfgL.
8582
cfgCache map[string]cacheEntry
86-
cacheL sync.Mutex
87-
88-
// refreshCfgL prevents multiple goroutines from contacting the Cloud SQL API at once.
89-
refreshCfgL sync.Mutex
83+
cfgL sync.RWMutex
9084

9185
// MaxConnections is the maximum number of connections to establish
9286
// before refusing new connections. 0 means no limit.
@@ -156,39 +150,32 @@ func (c *Client) handleConn(conn Conn) {
156150
// address as well as construct a new tls.Config to connect to the instance. It
157151
// caches the result.
158152
func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, err error) {
159-
c.refreshCfgL.Lock()
160-
defer c.refreshCfgL.Unlock()
153+
c.cfgL.Lock()
154+
defer c.cfgL.Unlock()
161155

162156
throttle := c.RefreshCfgThrottle
163157
if throttle == 0 {
164158
throttle = DefaultRefreshCfgThrottle
165159
}
166160

167-
c.cacheL.Lock()
168-
if c.cfgCache == nil {
169-
c.cfgCache = make(map[string]cacheEntry)
170-
}
171-
old, oldok := c.cfgCache[instance]
172-
c.cacheL.Unlock()
173-
174-
if oldok && time.Since(old.lastRefreshed) < throttle {
161+
if old := c.cfgCache[instance]; time.Since(old.lastRefreshed) < throttle {
175162
logging.Errorf("Throttling refreshCfg(%s): it was only called %v ago", instance, time.Since(old.lastRefreshed))
176163
// Refresh was called too recently, just reuse the result.
177164
return old.addr, old.cfg, old.err
178165
}
179166

167+
if c.cfgCache == nil {
168+
c.cfgCache = make(map[string]cacheEntry)
169+
}
170+
180171
defer func() {
181-
if err != nil && oldok {
182-
return
183-
}
184-
c.cacheL.Lock()
185172
c.cfgCache[instance] = cacheEntry{
186173
lastRefreshed: time.Now(),
187-
err: err,
188-
addr: addr,
189-
cfg: cfg,
174+
175+
err: err,
176+
addr: addr,
177+
cfg: cfg,
190178
}
191-
c.cacheL.Unlock()
192179
}()
193180

194181
mycert, err := c.Certs.Local(instance)
@@ -208,26 +195,13 @@ func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, err
208195
Certificates: []tls.Certificate{mycert},
209196
RootCAs: certs,
210197
}
211-
212-
// Refresh cert 5 minutes before it expires.
213-
timeToRefresh := cfg.Certificates[0].Leaf.NotAfter.Sub(time.Now()) - refreshCertBuffer
214-
if timeToRefresh > 0 {
215-
go func() {
216-
<-time.After(timeToRefresh)
217-
logging.Verbosef("Cert for instance %s will expire soon, refreshing now.", instance)
218-
if _, _, err := c.refreshCfg(instance); err != nil {
219-
logging.Errorf("couldn't connect to %q: %v", instance, err)
220-
}
221-
}()
222-
}
223-
224198
return fmt.Sprintf("%s:%d", addr, c.Port), cfg, nil
225199
}
226200

227201
func (c *Client) cachedCfg(instance string) (string, *tls.Config) {
228-
c.cacheL.Lock()
202+
c.cfgL.RLock()
229203
ret, ok := c.cfgCache[instance]
230-
c.cacheL.Unlock()
204+
c.cfgL.RUnlock()
231205

232206
// Don't waste time returning an expired/invalid cert.
233207
if !ok || ret.err != nil || time.Now().After(ret.cfg.Certificates[0].Leaf.NotAfter) {
@@ -251,7 +225,6 @@ func (c *Client) Dial(instance string) (net.Conn, error) {
251225
if err != nil {
252226
return nil, err
253227
}
254-
255228
return c.tryConnect(addr, cfg)
256229
}
257230

proxy/proxy/client_test.go

Lines changed: 7 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -28,19 +28,15 @@ import (
2828

2929
const instance = "instance-name"
3030

31-
var (
32-
errFakeDial = errors.New("this error is returned by the dialer")
33-
forever = time.Date(9999, 0, 0, 0, 0, 0, 0, time.UTC)
34-
)
31+
var errFakeDial = errors.New("this error is returned by the dialer")
3532

3633
type fakeCerts struct {
3734
sync.Mutex
3835
called int
3936
}
4037

4138
type blockingCertSource struct {
42-
values map[string]*fakeCerts
43-
validUntil time.Time
39+
values map[string]*fakeCerts
4440
}
4541

4642
func (cs *blockingCertSource) Local(instance string) (tls.Certificate, error) {
@@ -52,10 +48,11 @@ func (cs *blockingCertSource) Local(instance string) (tls.Certificate, error) {
5248
v.called++
5349
v.Unlock()
5450

51+
validUntil, _ := time.Parse("2006", "9999")
5552
// Returns a cert which is valid forever.
5653
return tls.Certificate{
5754
Leaf: &x509.Certificate{
58-
NotAfter: cs.validUntil,
55+
NotAfter: validUntil,
5956
},
6057
}, nil
6158
}
@@ -70,9 +67,7 @@ func TestClientCache(t *testing.T) {
7067
Certs: &blockingCertSource{
7168
map[string]*fakeCerts{
7269
instance: b,
73-
},
74-
forever,
75-
},
70+
}},
7671
Dialer: func(string, string) (net.Conn, error) {
7772
return nil, errFakeDial
7873
},
@@ -97,9 +92,7 @@ func TestConcurrentRefresh(t *testing.T) {
9792
Certs: &blockingCertSource{
9893
map[string]*fakeCerts{
9994
instance: b,
100-
},
101-
forever,
102-
},
95+
}},
10396
Dialer: func(string, string) (net.Conn, error) {
10497
return nil, errFakeDial
10598
},
@@ -138,9 +131,7 @@ func TestMaximumConnectionsCount(t *testing.T) {
138131

139132
b := &fakeCerts{}
140133
certSource := blockingCertSource{
141-
map[string]*fakeCerts{},
142-
forever,
143-
}
134+
map[string]*fakeCerts{}}
144135
firstDialExited := make(chan struct{})
145136
c := &Client{
146137
Certs: &certSource,
@@ -192,54 +183,3 @@ func TestMaximumConnectionsCount(t *testing.T) {
192183
t.Errorf("client should have dialed exactly the maximum of %d connections (%d connections, %d dials)", maxConnections, numConnections, dials)
193184
}
194185
}
195-
196-
func TestRefreshTimer(t *testing.T) {
197-
refreshCertBuffer = time.Millisecond * 10
198-
timeToExpire := time.Millisecond * 500
199-
b := &fakeCerts{}
200-
c := &Client{
201-
Certs: &blockingCertSource{
202-
map[string]*fakeCerts{
203-
instance: b,
204-
},
205-
time.Now().Add(timeToExpire),
206-
},
207-
Dialer: func(string, string) (net.Conn, error) {
208-
return nil, errFakeDial
209-
},
210-
RefreshCfgThrottle: 20 * time.Millisecond,
211-
}
212-
213-
// Call Dial to cache the cert.
214-
if _, err := c.Dial(instance); err != errFakeDial {
215-
t.Errorf("unexpected error: %v", err)
216-
}
217-
218-
c.cacheL.Lock()
219-
cached, ok := c.cfgCache[instance]
220-
c.cacheL.Unlock()
221-
if !ok {
222-
t.Error("expected instance to be cached")
223-
}
224-
waitTil := time.After(timeToExpire + (10 * time.Millisecond))
225-
loop:
226-
for {
227-
select {
228-
case <-waitTil:
229-
break loop
230-
default:
231-
time.Sleep(100 * time.Millisecond)
232-
}
233-
}
234-
235-
// Verify cert was refreshed in the background, without calling Dial again.
236-
c.cacheL.Lock()
237-
refreshed, ok := c.cfgCache[instance]
238-
c.cacheL.Unlock()
239-
if !ok {
240-
t.Error("expected instance to be cached")
241-
}
242-
if !refreshed.lastRefreshed.After(cached.lastRefreshed) {
243-
t.Error("expected cert to be refreshed.")
244-
}
245-
}

0 commit comments

Comments
 (0)