Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/03-how-to-add-new-route-option.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@ applications:
- route: example2.com
options:
loadbalancing: least-connection
- route: example3.com
options:
loadbalancing: hash
hash_header: tenant-id
hash_balance: 1.25
```

**NOTE**: In the implementation, the `options` property of a route represents per-route features.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,20 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response
stickyEndpointID, mustBeSticky := handlers.GetStickySession(request, rt.config.StickySessionCookieNames, rt.config.StickySessionsForAuthNegotiate)
numberOfEndpoints := reqInfo.RoutePool.NumEndpoints()
iter := reqInfo.RoutePool.Endpoints(rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone)
if reqInfo.RoutePool.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB {
if reqInfo.RoutePool.HashRoutingProperties == nil {
rt.logger.Error("hash-routing-properties-nil", slog.String("host", reqInfo.RoutePool.Host()))

} else {
headerName := reqInfo.RoutePool.HashRoutingProperties.Header
headerValue := request.Header.Get(headerName)
if headerValue != "" {
iter.(*route.HashBased).HeaderValue = headerValue
} else {
iter = reqInfo.RoutePool.FallBackToDefaultLoadBalancing(rt.config.LoadBalance, rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone)
}
}
}

// The selectEndpointErr needs to be tracked separately. If we get an error
// while selecting an endpoint we might just have run out of routes. In
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"io"
"math/rand"
"net"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -1700,6 +1701,167 @@ var _ = Describe("ProxyRoundTripper", func() {
})
})

Context("when load-balancing strategy is set to hash-based routing", func() {
JustBeforeEach(func() {
for i := 1; i <= 3; i++ {
endpoint = route.NewEndpoint(&route.EndpointOpts{
AppId: fmt.Sprintf("appID%d", i),
Host: fmt.Sprintf("%d.%d.%d.%d", i, i, i, i),
Port: 9090,
PrivateInstanceId: fmt.Sprintf("instanceID%d", i),
PrivateInstanceIndex: fmt.Sprintf("%d", i),
AvailabilityZone: AZ,
LoadBalancingAlgorithm: config.LOAD_BALANCE_HB,
HashHeaderName: "X-Hash",
})

_ = routePool.Put(endpoint)
Expect(routePool.HashLookupTable).ToNot(BeNil())

}
})

It("routes requests with same hash header value to the same endpoint", func() {
req.Header.Set("X-Hash", "value")
reqInfo, err := handlers.ContextRequestInfo(req)
Expect(err).ToNot(HaveOccurred())
reqInfo.RoutePool = routePool

var selectedEndpoints []*route.Endpoint

// Make multiple requests with the same hash value
for i := 0; i < 5; i++ {
_, err = proxyRoundTripper.RoundTrip(req)
Expect(err).NotTo(HaveOccurred())
selectedEndpoints = append(selectedEndpoints, reqInfo.RouteEndpoint)
}

// All requests should go to the same endpoint
firstEndpoint := selectedEndpoints[0]
for _, ep := range selectedEndpoints[1:] {
Expect(ep.PrivateInstanceId).To(Equal(firstEndpoint.PrivateInstanceId))
}
})

It("routes requests with different hash header values to potentially different endpoints", func() {
reqInfo, err := handlers.ContextRequestInfo(req)
Expect(err).ToNot(HaveOccurred())
reqInfo.RoutePool = routePool

endpointDistribution := make(map[string]int)

// Make requests with different hash values
for i := 0; i < 10; i++ {
req.Header.Set("X-Hash", fmt.Sprintf("value-%d", i))
_, err = proxyRoundTripper.RoundTrip(req)
Expect(err).NotTo(HaveOccurred())
endpointDistribution[reqInfo.RouteEndpoint.PrivateInstanceId]++
}

// Should distribute across multiple endpoints (not all to one)
Expect(len(endpointDistribution)).To(BeNumerically(">", 1))
})

It("falls back to default load balancing algorithm when hash header is missing", func() {
reqInfo, err := handlers.ContextRequestInfo(req)
Expect(err).ToNot(HaveOccurred())

reqInfo.RoutePool = routePool

_, err = proxyRoundTripper.RoundTrip(req)
Expect(err).NotTo(HaveOccurred())

infoLogs := logger.Lines(zap.InfoLevel)
count := 0
for i := 0; i < len(infoLogs); i++ {
if strings.Contains(infoLogs[i], "hash-based-routing-header-not-found") {
count++
}
}
Expect(count).To(Equal(1))
// Verify it still selects an endpoint
Expect(reqInfo.RouteEndpoint).ToNot(BeNil())
})

Context("when sticky session cookies (JSESSIONID and VCAP_ID) are on the request", func() {
var (
sessionCookie *http.Cookie
cookies []*http.Cookie
)

JustBeforeEach(func() {
sessionCookie = &http.Cookie{
Name: StickyCookieKey, //JSESSIONID
}
transport.RoundTripStub = func(req *http.Request) (*http.Response, error) {
resp := &http.Response{StatusCode: http.StatusTeapot, Header: make(map[string][]string)}
//Attach the same JSESSIONID on to the response if it exists on the request

if len(req.Cookies()) > 0 {
for _, cookie := range req.Cookies() {
if cookie.Name == StickyCookieKey {
resp.Header.Add(round_tripper.CookieHeader, cookie.String())
return resp, nil
}
}
}

sessionCookie.Value, _ = uuid.GenerateUUID()
resp.Header.Add(round_tripper.CookieHeader, sessionCookie.String())
return resp, nil
}
resp, err := proxyRoundTripper.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())

cookies = resp.Cookies()
Expect(cookies).To(HaveLen(2))

})

Context("when there is a JSESSIONID and __VCAP_ID__ set on the request", func() {
It("will always route to the instance specified with the __VCAP_ID__ cookie", func() {

// Generate 20 random values for the hash header, so chance that all go to instanceID1
// by accident is 0.33^20
for i := 0; i < 20; i++ {
randomStr := make([]byte, 8)
for j := range randomStr {
randomStr[j] = byte('a' + rand.Intn(26))
}

req.Header.Set("X-Hash", string(randomStr))
reqInfo, err := handlers.ContextRequestInfo(req)
req.AddCookie(&http.Cookie{Name: round_tripper.VcapCookieId, Value: "instanceID1"})
req.AddCookie(&http.Cookie{Name: StickyCookieKey, Value: "abc"})

Expect(err).ToNot(HaveOccurred())
reqInfo.RoutePool = routePool

resp, err := proxyRoundTripper.RoundTrip(req)
Expect(err).ToNot(HaveOccurred())

new_cookies := resp.Cookies()
Expect(new_cookies).To(HaveLen(2))

for _, cookie := range new_cookies {
Expect(cookie.Name).To(SatisfyAny(
Equal(StickyCookieKey),
Equal(round_tripper.VcapCookieId),
))
if cookie.Name == StickyCookieKey {
Expect(cookie.Value).To(Equal("abc"))
} else {
Expect(cookie.Value).To(Equal("instanceID1"))
}
}

}

})
})
})
})

Context("when endpoint timeout is not 0", func() {
var reqCh chan *http.Request
BeforeEach(func() {
Expand Down
141 changes: 141 additions & 0 deletions src/code.cloudfoundry.org/gorouter/route/hash_based.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
package route

import (
"context"
"errors"
"log/slog"
"sync"

log "code.cloudfoundry.org/gorouter/logger"
)

// HashBased load balancing algorithm distributes requests based on a hash of a specific header value.
// The sticky session cookie has precedence over hash-based routing and the request should be routed to the instance stored in the cookie.
// If requests do not contain the hash-related header set configured for the hash-based route option, use the default load-balancing algorithm.
type HashBased struct {
lock *sync.Mutex

logger *slog.Logger
pool *EndpointPool
lastEndpoint *Endpoint

stickyEndpointID string
mustBeSticky bool

HeaderValue string
}

// NewHashBased initializes an endpoint iterator that selects endpoints based on a hash of a header value.
// The global properties locallyOptimistic and localAvailabilityZone will be ignored when using Hash-Based Routing.
func NewHashBased(logger *slog.Logger, p *EndpointPool, initial string, mustBeSticky bool, locallyOptimistic bool, localAvailabilityZone string) EndpointIterator {
return &HashBased{
logger: logger,
pool: p,
lock: &sync.Mutex{},
stickyEndpointID: initial,
mustBeSticky: mustBeSticky,
}
}

// Next selects the next endpoint based on the hash of the header value.
// If a sticky session endpoint is available and not overloaded, it will be returned.
// If the request must be sticky and the sticky endpoint is unavailable or overloaded, nil will be returned.
// If no sticky session is present, the endpoint will be selected based on the hash of the header value.
// It returns the same endpoint for the same header value consistently.
// If the hash lookup fails or the endpoint is not found, nil will be returned.
func (h *HashBased) Next(attempt int) *Endpoint {
h.lock.Lock()
defer h.lock.Unlock()

e := h.findEndpointIfStickySession()
if e == nil && h.mustBeSticky {
return nil
}

if e != nil {
h.lastEndpoint = e
return e
}

if h.pool.HashLookupTable == nil {
h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Lookup table is empty")))
return nil
}

id, err := h.pool.HashLookupTable.Get(h.HeaderValue)

if err != nil {
h.logger.Error(
"hash-based-routing-failed",
slog.String("host", h.pool.host),
log.ErrAttr(err),
)
return nil
}

h.logger.Debug(
"hash-based-routing",
slog.String("hash header value", h.HeaderValue),
slog.String("endpoint-id", id),
)

endpointElem := h.pool.findById(id)
if endpointElem == nil {
h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Endpoint not found in pool")), slog.String("endpoint-id", id))
return nil
}

return endpointElem.endpoint
}

// findEndpointIfStickySession checks if there is a sticky session endpoint and returns it if available.
// If the sticky session endpoint is overloaded, returns nil.
func (h *HashBased) findEndpointIfStickySession() *Endpoint {
var e *endpointElem
if h.stickyEndpointID != "" {
e = h.pool.findById(h.stickyEndpointID)
if e != nil && e.isOverloaded() {
if h.mustBeSticky {
if h.logger.Enabled(context.Background(), slog.LevelDebug) {
h.logger.Debug("endpoint-overloaded-but-request-must-be-sticky", e.endpoint.ToLogData()...)
}
return nil
}
e = nil
}

if e == nil && h.mustBeSticky {
h.logger.Debug("endpoint-missing-but-request-must-be-sticky", slog.String("requested-endpoint", h.stickyEndpointID))
return nil
}

if !h.mustBeSticky {
h.logger.Debug("endpoint-missing-choosing-alternate", slog.String("requested-endpoint", h.stickyEndpointID))
h.stickyEndpointID = ""
}
}

if e != nil {
e.RLock()
defer e.RUnlock()
return e.endpoint
}
return nil
}

// EndpointFailed notifies the endpoint pool that the last selected endpoint has failed.
func (h *HashBased) EndpointFailed(err error) {
if h.lastEndpoint != nil {
h.pool.EndpointFailed(h.lastEndpoint, err)
}
}

// PreRequest increments the in-flight request count for the selected endpoint from current Gorouter.
func (h *HashBased) PreRequest(e *Endpoint) {
e.Stats.NumberConnections.Increment()
}

// PostRequest decrements the in-flight request count for the selected endpoint from current Gorouter.
func (h *HashBased) PostRequest(e *Endpoint) {
e.Stats.NumberConnections.Decrement()
}
Loading