Skip to content

Commit 4c9bdd3

Browse files
Add separate cache lock to allow other goroutines to continue to read from cache while another updates the cert
1 parent acd195b commit 4c9bdd3

File tree

1 file changed

+31
-21
lines changed

1 file changed

+31
-21
lines changed

proxy/proxy/client.go

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -81,9 +81,12 @@ type Client struct {
8181

8282
// The cfgCache holds the most recent connection configuration keyed by
8383
// instance. Relevant functions are refreshCfg and cachedCfg. It is
84-
// protected by cfgL.
84+
// protected by cacheL.
8585
cfgCache map[string]cacheEntry
86-
cfgL sync.RWMutex
86+
cacheL sync.RWMutex
87+
88+
// cfgL prevents multiple goroutines from contacting the Cloud SQL API at once.
89+
cfgL sync.Mutex
8790

8891
// MaxConnections is the maximum number of connections to establish
8992
// before refusing new connections. 0 means no limit.
@@ -161,24 +164,28 @@ func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, err
161164
throttle = DefaultRefreshCfgThrottle
162165
}
163166

164-
if old := c.cfgCache[instance]; time.Since(old.lastRefreshed) < throttle {
167+
c.cacheL.Lock()
168+
if c.cfgCache == nil {
169+
c.cfgCache = make(map[string]cacheEntry)
170+
}
171+
old, oldok := c.cfgCache[instance]
172+
c.cacheL.Unlock()
173+
174+
if oldok && time.Since(old.lastRefreshed) < throttle {
165175
logging.Errorf("Throttling refreshCfg(%s): it was only called %v ago", instance, time.Since(old.lastRefreshed))
166176
// Refresh was called too recently, just reuse the result.
167177
return old.addr, old.cfg, old.err
168178
}
169179

170-
if c.cfgCache == nil {
171-
c.cfgCache = make(map[string]cacheEntry)
172-
}
173-
174180
defer func() {
181+
c.cacheL.Lock()
175182
c.cfgCache[instance] = cacheEntry{
176183
lastRefreshed: time.Now(),
177-
178-
err: err,
179-
addr: addr,
180-
cfg: cfg,
184+
err: err,
185+
addr: addr,
186+
cfg: cfg,
181187
}
188+
c.cacheL.Unlock()
182189
}()
183190

184191
mycert, err := c.Certs.Local(instance)
@@ -198,13 +205,24 @@ func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, err
198205
Certificates: []tls.Certificate{mycert},
199206
RootCAs: certs,
200207
}
208+
209+
// Refresh cert 5 minutes before it expires.
210+
timeToRefresh := cfg.Certificates[0].Leaf.NotAfter.Sub(time.Now()) - refreshCertBuffer
211+
if timeToRefresh > 0 {
212+
go func() {
213+
<-time.After(timeToRefresh)
214+
logging.Verbosef("Cert for instance %s will expire soon, refreshing now.", instance)
215+
c.refreshCfg(instance)
216+
}()
217+
}
218+
201219
return fmt.Sprintf("%s:%d", addr, c.Port), cfg, nil
202220
}
203221

204222
func (c *Client) cachedCfg(instance string) (string, *tls.Config) {
205-
c.cfgL.RLock()
223+
c.cacheL.RLock()
206224
ret, ok := c.cfgCache[instance]
207-
c.cfgL.RUnlock()
225+
c.cacheL.RUnlock()
208226

209227
// Don't waste time returning an expired/invalid cert.
210228
if !ok || ret.err != nil || time.Now().After(ret.cfg.Certificates[0].Leaf.NotAfter) {
@@ -229,14 +247,6 @@ func (c *Client) Dial(instance string) (net.Conn, error) {
229247
return nil, err
230248
}
231249

232-
// Refresh cert 5 minutes before it expires.
233-
timer := time.NewTimer(cfg.Certificates[0].Leaf.NotAfter.Sub(time.Now()) - refreshCertBuffer)
234-
go func() {
235-
<-timer.C
236-
logging.Verbosef("Cert for instance %s will expire soon, refreshing now.", instance)
237-
c.refreshCfg(instance)
238-
}()
239-
240250
return c.tryConnect(addr, cfg)
241251
}
242252

0 commit comments

Comments
 (0)