Skip to content

Commit 697102f

Browse files
authored
Merge pull request #81 from netlify/safedial
Add SafeDialContext method
2 parents 5159788 + b301cf8 commit 697102f

File tree

4 files changed

+269
-0
lines changed

4 files changed

+269
-0
lines changed

graceful/graceful.go

Lines changed: 107 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,107 @@
1+
package graceful
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net"
7+
"net/http"
8+
"os"
9+
"os/signal"
10+
"sync/atomic"
11+
"syscall"
12+
"time"
13+
14+
"github.com/sirupsen/logrus"
15+
)
16+
17+
var DefaultShutdownTimeout = time.Second * 60
18+
19+
const shutdown uint32 = 1
20+
21+
type GracefulServer struct {
22+
server *http.Server
23+
listener net.Listener
24+
log *logrus.Entry
25+
26+
exit chan struct{}
27+
28+
URL string
29+
state uint32
30+
ShutdownTimeout time.Duration
31+
ShutdownError error
32+
}
33+
34+
func NewGracefulServer(handler http.Handler, log *logrus.Entry) *GracefulServer {
35+
log.Warn("NewGracefulServer is deprecated, see https://github.com/netlify/netlify-commons/pull/72")
36+
return &GracefulServer{
37+
server: &http.Server{Handler: handler},
38+
log: log,
39+
listener: nil,
40+
exit: make(chan struct{}),
41+
ShutdownTimeout: DefaultShutdownTimeout,
42+
}
43+
}
44+
45+
func (svr *GracefulServer) Bind(addr string) error {
46+
l, err := net.Listen("tcp", addr)
47+
if err != nil {
48+
return err
49+
}
50+
svr.URL = "http://" + l.Addr().String()
51+
svr.listener = l
52+
return nil
53+
}
54+
55+
func (svr *GracefulServer) Listen() error {
56+
go svr.listenForShutdownSignal()
57+
serveErr := svr.server.Serve(svr.listener)
58+
if serveErr != http.ErrServerClosed {
59+
svr.log.WithError(serveErr).Warn("Error while running server")
60+
return serveErr
61+
}
62+
63+
<-svr.exit
64+
65+
return svr.ShutdownError
66+
}
67+
68+
func (svr *GracefulServer) listenForShutdownSignal() {
69+
c := make(chan os.Signal, 1)
70+
signal.Notify(c, os.Interrupt, syscall.SIGTERM, syscall.SIGINT)
71+
sig := <-c
72+
svr.log.Infof("Triggering shutdown from signal %s", sig)
73+
svr.Shutdown()
74+
}
75+
76+
func (svr *GracefulServer) ListenAndServe(addr string) error {
77+
if svr.listener != nil {
78+
return errors.New("The listener has already started, don't call Bind first")
79+
}
80+
if err := svr.Bind(addr); err != nil {
81+
return err
82+
}
83+
84+
return svr.Listen()
85+
}
86+
87+
func (svr *GracefulServer) Shutdown() error {
88+
if atomic.SwapUint32(&svr.state, shutdown) == shutdown {
89+
svr.log.Debug("Calling shutdown on already shutdown server, ignoring")
90+
return nil
91+
}
92+
93+
ctx, cancel := context.WithTimeout(context.Background(), svr.ShutdownTimeout)
94+
defer cancel()
95+
96+
svr.log.Infof("Triggering shutdown, in at most %s ", svr.ShutdownTimeout.String())
97+
shutErr := svr.server.Shutdown(ctx)
98+
if shutErr == context.DeadlineExceeded {
99+
svr.log.WithError(shutErr).Warnf("Forcing a shutdown after waiting %s", svr.ShutdownTimeout.String())
100+
shutErr = svr.server.Close()
101+
}
102+
103+
svr.ShutdownError = shutErr
104+
close(svr.exit)
105+
106+
return shutErr
107+
}

graceful/graceful_test.go

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
package graceful
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"net/url"
7+
"reflect"
8+
"strings"
9+
"testing"
10+
"time"
11+
12+
"github.com/sirupsen/logrus"
13+
"github.com/stretchr/testify/assert"
14+
"github.com/stretchr/testify/require"
15+
)
16+
17+
func TestStartAndStop(t *testing.T) {
18+
orTimeout := func(c chan bool, i int, msg string) {
19+
select {
20+
case <-c:
21+
case <-time.After(time.Duration(i) * time.Second):
22+
assert.FailNow(t, msg)
23+
}
24+
}
25+
26+
gotRequest := make(chan bool)
27+
clearRequest := make(chan bool)
28+
stoppedServer := make(chan bool)
29+
finishedListening := make(chan bool)
30+
31+
var finished bool
32+
33+
oh := func(w http.ResponseWriter, r *http.Request) {
34+
// trigger that we got the request
35+
close(gotRequest)
36+
// wait for a clear on that request
37+
orTimeout(clearRequest, 2, "waiting for request to be cleared")
38+
finished = true
39+
}
40+
41+
svr := NewGracefulServer(http.HandlerFunc(oh), logrus.WithField("testing", true))
42+
require.NoError(t, svr.Bind("127.0.0.1:0"))
43+
44+
go func() {
45+
assert.NoError(t, svr.Listen())
46+
close(finishedListening)
47+
}()
48+
49+
// make a request
50+
go func() {
51+
rsp, err := http.Get(svr.URL + "/something")
52+
require.NoError(t, err)
53+
assert.Equal(t, http.StatusOK, rsp.StatusCode)
54+
}()
55+
56+
// wait for the origin to get the request
57+
orTimeout(gotRequest, 1, "didn't get the original request in time")
58+
59+
// initate a shutdown
60+
go func() {
61+
assert.NoError(t, svr.Shutdown())
62+
close(stoppedServer)
63+
}()
64+
65+
<-time.After(time.Second)
66+
67+
// make a second request ~ should be bounced
68+
rsp, err := http.Get(svr.URL + "/something")
69+
switch e := err.(type) {
70+
case *url.Error:
71+
assert.True(t, strings.Contains(e.Error(), "connection refused"))
72+
default:
73+
assert.Fail(t, fmt.Sprintf("unknown type: %v", reflect.TypeOf(err)))
74+
}
75+
assert.Nil(t, rsp)
76+
77+
// finish the first request
78+
close(clearRequest)
79+
80+
// wait for server to close
81+
orTimeout(stoppedServer, 1, "didn't stop server in time")
82+
83+
assert.True(t, finished)
84+
orTimeout(finishedListening, 1, "didn't actually stop the server in time")
85+
}
86+
87+
func TestDoubleShutdown(t *testing.T) {
88+
logrus.SetLevel(logrus.DebugLevel)
89+
oh := func(w http.ResponseWriter, r *http.Request) {}
90+
svr := NewGracefulServer(http.HandlerFunc(oh), logrus.WithField("testing", true))
91+
92+
assert.NoError(t, svr.Shutdown())
93+
assert.NoError(t, svr.Shutdown())
94+
}

http/http.go

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
package http
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net"
7+
)
8+
9+
var privateIPBlocks []*net.IPNet
10+
11+
func init() {
12+
for _, cidr := range []string{
13+
"127.0.0.0/8", // IPv4 loopback
14+
"10.0.0.0/8", // RFC1918
15+
"100.64.0.0/10", // RFC6598
16+
"172.16.0.0/12", // RFC1918
17+
"192.0.0.0/24", // RFC6890
18+
"192.168.0.0/16", // RFC1918
19+
"169.254.0.0/16", // RFC3927
20+
"::1/128", // IPv6 loopback
21+
"fe80::/10", // IPv6 link-local
22+
"fc00::/7", // IPv6 unique local addr
23+
} {
24+
_, block, _ := net.ParseCIDR(cidr)
25+
privateIPBlocks = append(privateIPBlocks, block)
26+
}
27+
}
28+
29+
func isPrivateIP(ip net.IP) bool {
30+
for _, block := range privateIPBlocks {
31+
if block.Contains(ip) {
32+
return true
33+
}
34+
}
35+
return false
36+
}
37+
38+
func isLocalAddress(addr string) bool {
39+
ip := net.ParseIP(addr)
40+
return isPrivateIP(ip)
41+
}
42+
43+
// SafeDialContext exchanges a DialContext for a SafeDialContext that will never dial a reserved IP range
44+
func SafeDialContext(dialContext func(ctx context.Context, network, addr string) (net.Conn, error)) func(ctx context.Context, network, addr string) (net.Conn, error) {
45+
return func(ctx context.Context, network, addr string) (net.Conn, error) {
46+
if isLocalAddress(addr) {
47+
return nil, errors.New("Connection to local network address denied")
48+
}
49+
50+
return dialContext(ctx, network, addr)
51+
}
52+
}

http/http_test.go

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
package http
2+
3+
import (
4+
"testing"
5+
6+
"github.com/stretchr/testify/assert"
7+
)
8+
9+
func TestIsLocalAddress(t *testing.T) {
10+
assert.False(t, isLocalAddress("216.58.194.206"))
11+
assert.True(t, isLocalAddress("127.0.0.1"))
12+
assert.True(t, isLocalAddress("10.0.0.1"))
13+
assert.True(t, isLocalAddress("192.168.0.1"))
14+
assert.True(t, isLocalAddress("172.16.0.0"))
15+
assert.True(t, isLocalAddress("169.254.169.254"))
16+
}

0 commit comments

Comments
 (0)