@@ -32,10 +32,13 @@ const (
3232 keepAlivePeriod = time .Minute
3333)
3434
35- // errNotCached is returned when the instance was not found in the Client's
36- // cache. It is an internal detail and is not actually ever returned to the
37- // user.
38- var errNotCached = errors .New ("instance was not found in cache" )
35+ var (
36+ // errNotCached is returned when the instance was not found in the Client's
37+ // cache. It is an internal detail and is not actually ever returned to the
38+ // user.
39+ errNotCached = errors .New ("instance was not found in cache" )
40+ refreshCertBuffer = 5 * time .Minute
41+ )
3942
4043// Conn represents a connection from a client to a specific instance.
4144type Conn struct {
@@ -78,9 +81,12 @@ type Client struct {
7881
7982 // The cfgCache holds the most recent connection configuration keyed by
8083 // instance. Relevant functions are refreshCfg and cachedCfg. It is
81- // protected by cfgL .
84+ // protected by cacheL .
8285 cfgCache map [string ]cacheEntry
83- cfgL sync.RWMutex
86+ cacheL sync.Mutex
87+
88+ // refreshCfgL prevents multiple goroutines from contacting the Cloud SQL API at once.
89+ refreshCfgL sync.Mutex
8490
8591 // MaxConnections is the maximum number of connections to establish
8692 // before refusing new connections. 0 means no limit.
@@ -150,32 +156,39 @@ func (c *Client) handleConn(conn Conn) {
150156// address as well as construct a new tls.Config to connect to the instance. It
151157// caches the result.
152158func (c * Client ) refreshCfg (instance string ) (addr string , cfg * tls.Config , err error ) {
153- c .cfgL .Lock ()
154- defer c .cfgL .Unlock ()
159+ c .refreshCfgL .Lock ()
160+ defer c .refreshCfgL .Unlock ()
155161
156162 throttle := c .RefreshCfgThrottle
157163 if throttle == 0 {
158164 throttle = DefaultRefreshCfgThrottle
159165 }
160166
161- if old := c .cfgCache [instance ]; time .Since (old .lastRefreshed ) < throttle {
167+ c .cacheL .Lock ()
168+ if c .cfgCache == nil {
169+ c .cfgCache = make (map [string ]cacheEntry )
170+ }
171+ old , oldok := c .cfgCache [instance ]
172+ c .cacheL .Unlock ()
173+
174+ if oldok && time .Since (old .lastRefreshed ) < throttle {
162175 logging .Errorf ("Throttling refreshCfg(%s): it was only called %v ago" , instance , time .Since (old .lastRefreshed ))
163176 // Refresh was called too recently, just reuse the result.
164177 return old .addr , old .cfg , old .err
165178 }
166179
167- if c .cfgCache == nil {
168- c .cfgCache = make (map [string ]cacheEntry )
169- }
170-
171180 defer func () {
181+ if err != nil && oldok {
182+ return
183+ }
184+ c .cacheL .Lock ()
172185 c .cfgCache [instance ] = cacheEntry {
173186 lastRefreshed : time .Now (),
174-
175- err : err ,
176- addr : addr ,
177- cfg : cfg ,
187+ err : err ,
188+ addr : addr ,
189+ cfg : cfg ,
178190 }
191+ c .cacheL .Unlock ()
179192 }()
180193
181194 mycert , err := c .Certs .Local (instance )
@@ -195,13 +208,26 @@ func (c *Client) refreshCfg(instance string) (addr string, cfg *tls.Config, err
195208 Certificates : []tls.Certificate {mycert },
196209 RootCAs : certs ,
197210 }
211+
212+ // Refresh cert 5 minutes before it expires.
213+ timeToRefresh := cfg .Certificates [0 ].Leaf .NotAfter .Sub (time .Now ()) - refreshCertBuffer
214+ if timeToRefresh > 0 {
215+ go func () {
216+ <- time .After (timeToRefresh )
217+ logging .Verbosef ("Cert for instance %s will expire soon, refreshing now." , instance )
218+ if _ , _ , err := c .refreshCfg (instance ); err != nil {
219+ logging .Errorf ("couldn't connect to %q: %v" , instance , err )
220+ }
221+ }()
222+ }
223+
198224 return fmt .Sprintf ("%s:%d" , addr , c .Port ), cfg , nil
199225}
200226
201227func (c * Client ) cachedCfg (instance string ) (string , * tls.Config ) {
202- c .cfgL . RLock ()
228+ c .cacheL . Lock ()
203229 ret , ok := c .cfgCache [instance ]
204- c .cfgL . RUnlock ()
230+ c .cacheL . Unlock ()
205231
206232 // Don't waste time returning an expired/invalid cert.
207233 if ! ok || ret .err != nil || time .Now ().After (ret .cfg .Certificates [0 ].Leaf .NotAfter ) {
@@ -225,6 +251,7 @@ func (c *Client) Dial(instance string) (net.Conn, error) {
225251 if err != nil {
226252 return nil , err
227253 }
254+
228255 return c .tryConnect (addr , cfg )
229256}
230257
0 commit comments