From 0ee521dd450ce312d491e4c67caa012e835ef981 Mon Sep 17 00:00:00 2001 From: Alexander Nicke Date: Fri, 29 Aug 2025 12:13:37 +0200 Subject: [PATCH 1/2] Implement hash-based routing (#505) This commit provides the basic implementation for hash-based routing. It does not consider the balance factor yet. Co-authored-by: Clemens Hoffmann Co-authored-by: Tamara Boehm Co-authored-by: Soha Alboghdady --- docs/03-how-to-add-new-route-option.md | 5 + .../round_tripper/proxy_round_tripper.go | 14 + .../round_tripper/proxy_round_tripper_test.go | 162 +++++++++++ .../gorouter/route/hash_based.go | 141 ++++++++++ .../gorouter/route/hash_based_test.go | 155 +++++++++++ .../gorouter/route/maglev.go | 212 +++++++++++++++ .../gorouter/route/maglev_test.go | 257 ++++++++++++++++++ .../gorouter/route/pool.go | 79 +++++- .../gorouter/route/pool_test.go | 40 +++ 9 files changed, 1061 insertions(+), 4 deletions(-) create mode 100644 src/code.cloudfoundry.org/gorouter/route/hash_based.go create mode 100644 src/code.cloudfoundry.org/gorouter/route/hash_based_test.go create mode 100644 src/code.cloudfoundry.org/gorouter/route/maglev.go create mode 100644 src/code.cloudfoundry.org/gorouter/route/maglev_test.go diff --git a/docs/03-how-to-add-new-route-option.md b/docs/03-how-to-add-new-route-option.md index 39ce25447..308213fcc 100644 --- a/docs/03-how-to-add-new-route-option.md +++ b/docs/03-how-to-add-new-route-option.md @@ -22,6 +22,11 @@ applications: - route: example2.com options: loadbalancing: least-connection + - route: example3.com + options: + loadbalancing: hash + hash_header: tenant-id + hash_balance: 1.25 ``` **NOTE**: In the implementation, the `options` property of a route represents per-route features. diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go index 88cfd20a5..84252262f 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper.go @@ -127,6 +127,20 @@ func (rt *roundTripper) RoundTrip(originalRequest *http.Request) (*http.Response stickyEndpointID, mustBeSticky := handlers.GetStickySession(request, rt.config.StickySessionCookieNames, rt.config.StickySessionsForAuthNegotiate) numberOfEndpoints := reqInfo.RoutePool.NumEndpoints() iter := reqInfo.RoutePool.Endpoints(rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone) + if reqInfo.RoutePool.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + if reqInfo.RoutePool.HashRoutingProperties == nil { + rt.logger.Error("hash-routing-properties-nil", slog.String("host", reqInfo.RoutePool.Host())) + + } else { + headerName := reqInfo.RoutePool.HashRoutingProperties.Header + headerValue := request.Header.Get(headerName) + if headerValue != "" { + iter.(*route.HashBased).HeaderValue = headerValue + } else { + iter = reqInfo.RoutePool.FallBackToDefaultLoadBalancing(rt.config.LoadBalance, rt.logger, stickyEndpointID, mustBeSticky, rt.config.LoadBalanceAZPreference, rt.config.Zone) + } + } + } // The selectEndpointErr needs to be tracked separately. If we get an error // while selecting an endpoint we might just have run out of routes. In diff --git a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go index 9d270867c..6abe4d218 100644 --- a/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go +++ b/src/code.cloudfoundry.org/gorouter/proxy/round_tripper/proxy_round_tripper_test.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "io" + "math/rand" "net" "net/http" "net/http/httptest" @@ -1700,6 +1701,167 @@ var _ = Describe("ProxyRoundTripper", func() { }) }) + Context("when load-balancing strategy is set to hash-based routing", func() { + JustBeforeEach(func() { + for i := 1; i <= 3; i++ { + endpoint = route.NewEndpoint(&route.EndpointOpts{ + AppId: fmt.Sprintf("appID%d", i), + Host: fmt.Sprintf("%d.%d.%d.%d", i, i, i, i), + Port: 9090, + PrivateInstanceId: fmt.Sprintf("instanceID%d", i), + PrivateInstanceIndex: fmt.Sprintf("%d", i), + AvailabilityZone: AZ, + LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + HashHeaderName: "X-Hash", + }) + + _ = routePool.Put(endpoint) + Expect(routePool.HashLookupTable).ToNot(BeNil()) + + } + }) + + It("routes requests with same hash header value to the same endpoint", func() { + req.Header.Set("X-Hash", "value") + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).ToNot(HaveOccurred()) + reqInfo.RoutePool = routePool + + var selectedEndpoints []*route.Endpoint + + // Make multiple requests with the same hash value + for i := 0; i < 5; i++ { + _, err = proxyRoundTripper.RoundTrip(req) + Expect(err).NotTo(HaveOccurred()) + selectedEndpoints = append(selectedEndpoints, reqInfo.RouteEndpoint) + } + + // All requests should go to the same endpoint + firstEndpoint := selectedEndpoints[0] + for _, ep := range selectedEndpoints[1:] { + Expect(ep.PrivateInstanceId).To(Equal(firstEndpoint.PrivateInstanceId)) + } + }) + + It("routes requests with different hash header values to potentially different endpoints", func() { + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).ToNot(HaveOccurred()) + reqInfo.RoutePool = routePool + + endpointDistribution := make(map[string]int) + + // Make requests with different hash values + for i := 0; i < 10; i++ { + req.Header.Set("X-Hash", fmt.Sprintf("value-%d", i)) + _, err = proxyRoundTripper.RoundTrip(req) + Expect(err).NotTo(HaveOccurred()) + endpointDistribution[reqInfo.RouteEndpoint.PrivateInstanceId]++ + } + + // Should distribute across multiple endpoints (not all to one) + Expect(len(endpointDistribution)).To(BeNumerically(">", 1)) + }) + + It("falls back to default load balancing algorithm when hash header is missing", func() { + reqInfo, err := handlers.ContextRequestInfo(req) + Expect(err).ToNot(HaveOccurred()) + + reqInfo.RoutePool = routePool + + _, err = proxyRoundTripper.RoundTrip(req) + Expect(err).NotTo(HaveOccurred()) + + infoLogs := logger.Lines(zap.InfoLevel) + count := 0 + for i := 0; i < len(infoLogs); i++ { + if strings.Contains(infoLogs[i], "hash-based-routing-header-not-found") { + count++ + } + } + Expect(count).To(Equal(1)) + // Verify it still selects an endpoint + Expect(reqInfo.RouteEndpoint).ToNot(BeNil()) + }) + + Context("when sticky session cookies (JSESSIONID and VCAP_ID) are on the request", func() { + var ( + sessionCookie *http.Cookie + cookies []*http.Cookie + ) + + JustBeforeEach(func() { + sessionCookie = &http.Cookie{ + Name: StickyCookieKey, //JSESSIONID + } + transport.RoundTripStub = func(req *http.Request) (*http.Response, error) { + resp := &http.Response{StatusCode: http.StatusTeapot, Header: make(map[string][]string)} + //Attach the same JSESSIONID on to the response if it exists on the request + + if len(req.Cookies()) > 0 { + for _, cookie := range req.Cookies() { + if cookie.Name == StickyCookieKey { + resp.Header.Add(round_tripper.CookieHeader, cookie.String()) + return resp, nil + } + } + } + + sessionCookie.Value, _ = uuid.GenerateUUID() + resp.Header.Add(round_tripper.CookieHeader, sessionCookie.String()) + return resp, nil + } + resp, err := proxyRoundTripper.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + + cookies = resp.Cookies() + Expect(cookies).To(HaveLen(2)) + + }) + + Context("when there is a JSESSIONID and __VCAP_ID__ set on the request", func() { + It("will always route to the instance specified with the __VCAP_ID__ cookie", func() { + + // Generate 20 random values for the hash header, so chance that all go to instanceID1 + // by accident is 0.33^20 + for i := 0; i < 20; i++ { + randomStr := make([]byte, 8) + for j := range randomStr { + randomStr[j] = byte('a' + rand.Intn(26)) + } + + req.Header.Set("X-Hash", string(randomStr)) + reqInfo, err := handlers.ContextRequestInfo(req) + req.AddCookie(&http.Cookie{Name: round_tripper.VcapCookieId, Value: "instanceID1"}) + req.AddCookie(&http.Cookie{Name: StickyCookieKey, Value: "abc"}) + + Expect(err).ToNot(HaveOccurred()) + reqInfo.RoutePool = routePool + + resp, err := proxyRoundTripper.RoundTrip(req) + Expect(err).ToNot(HaveOccurred()) + + new_cookies := resp.Cookies() + Expect(new_cookies).To(HaveLen(2)) + + for _, cookie := range new_cookies { + Expect(cookie.Name).To(SatisfyAny( + Equal(StickyCookieKey), + Equal(round_tripper.VcapCookieId), + )) + if cookie.Name == StickyCookieKey { + Expect(cookie.Value).To(Equal("abc")) + } else { + Expect(cookie.Value).To(Equal("instanceID1")) + } + } + + } + + }) + }) + }) + }) + Context("when endpoint timeout is not 0", func() { var reqCh chan *http.Request BeforeEach(func() { diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based.go b/src/code.cloudfoundry.org/gorouter/route/hash_based.go new file mode 100644 index 000000000..39d551eb8 --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based.go @@ -0,0 +1,141 @@ +package route + +import ( + "context" + "errors" + "log/slog" + "sync" + + log "code.cloudfoundry.org/gorouter/logger" +) + +// HashBased load balancing algorithm distributes requests based on a hash of a specific header value. +// The sticky session cookie has precedence over hash-based routing and the request should be routed to the instance stored in the cookie. +// If requests do not contain the hash-related header set configured for the hash-based route option, use the default load-balancing algorithm. +type HashBased struct { + lock *sync.Mutex + + logger *slog.Logger + pool *EndpointPool + lastEndpoint *Endpoint + + stickyEndpointID string + mustBeSticky bool + + HeaderValue string +} + +// NewHashBased initializes an endpoint iterator that selects endpoints based on a hash of a header value. +// The global properties locallyOptimistic and localAvailabilityZone will be ignored when using Hash-Based Routing. +func NewHashBased(logger *slog.Logger, p *EndpointPool, initial string, mustBeSticky bool, locallyOptimistic bool, localAvailabilityZone string) EndpointIterator { + return &HashBased{ + logger: logger, + pool: p, + lock: &sync.Mutex{}, + stickyEndpointID: initial, + mustBeSticky: mustBeSticky, + } +} + +// Next selects the next endpoint based on the hash of the header value. +// If a sticky session endpoint is available and not overloaded, it will be returned. +// If the request must be sticky and the sticky endpoint is unavailable or overloaded, nil will be returned. +// If no sticky session is present, the endpoint will be selected based on the hash of the header value. +// It returns the same endpoint for the same header value consistently. +// If the hash lookup fails or the endpoint is not found, nil will be returned. +func (h *HashBased) Next(attempt int) *Endpoint { + h.lock.Lock() + defer h.lock.Unlock() + + e := h.findEndpointIfStickySession() + if e == nil && h.mustBeSticky { + return nil + } + + if e != nil { + h.lastEndpoint = e + return e + } + + if h.pool.HashLookupTable == nil { + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Lookup table is empty"))) + return nil + } + + id, err := h.pool.HashLookupTable.Get(h.HeaderValue) + + if err != nil { + h.logger.Error( + "hash-based-routing-failed", + slog.String("host", h.pool.host), + log.ErrAttr(err), + ) + return nil + } + + h.logger.Debug( + "hash-based-routing", + slog.String("hash header value", h.HeaderValue), + slog.String("endpoint-id", id), + ) + + endpointElem := h.pool.findById(id) + if endpointElem == nil { + h.logger.Error("hash-based-routing-failed", slog.String("host", h.pool.host), log.ErrAttr(errors.New("Endpoint not found in pool")), slog.String("endpoint-id", id)) + return nil + } + + return endpointElem.endpoint +} + +// findEndpointIfStickySession checks if there is a sticky session endpoint and returns it if available. +// If the sticky session endpoint is overloaded, returns nil. +func (h *HashBased) findEndpointIfStickySession() *Endpoint { + var e *endpointElem + if h.stickyEndpointID != "" { + e = h.pool.findById(h.stickyEndpointID) + if e != nil && e.isOverloaded() { + if h.mustBeSticky { + if h.logger.Enabled(context.Background(), slog.LevelDebug) { + h.logger.Debug("endpoint-overloaded-but-request-must-be-sticky", e.endpoint.ToLogData()...) + } + return nil + } + e = nil + } + + if e == nil && h.mustBeSticky { + h.logger.Debug("endpoint-missing-but-request-must-be-sticky", slog.String("requested-endpoint", h.stickyEndpointID)) + return nil + } + + if !h.mustBeSticky { + h.logger.Debug("endpoint-missing-choosing-alternate", slog.String("requested-endpoint", h.stickyEndpointID)) + h.stickyEndpointID = "" + } + } + + if e != nil { + e.RLock() + defer e.RUnlock() + return e.endpoint + } + return nil +} + +// EndpointFailed notifies the endpoint pool that the last selected endpoint has failed. +func (h *HashBased) EndpointFailed(err error) { + if h.lastEndpoint != nil { + h.pool.EndpointFailed(h.lastEndpoint, err) + } +} + +// PreRequest increments the in-flight request count for the selected endpoint from current Gorouter. +func (h *HashBased) PreRequest(e *Endpoint) { + e.Stats.NumberConnections.Increment() +} + +// PostRequest decrements the in-flight request count for the selected endpoint from current Gorouter. +func (h *HashBased) PostRequest(e *Endpoint) { + e.Stats.NumberConnections.Decrement() +} diff --git a/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go new file mode 100644 index 000000000..1caaed19c --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/hash_based_test.go @@ -0,0 +1,155 @@ +package route_test + +import ( + "code.cloudfoundry.org/gorouter/config" + _ "errors" + "time" + + "code.cloudfoundry.org/gorouter/route" + "code.cloudfoundry.org/gorouter/test_util" + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("HashBased", func() { + var ( + pool *route.EndpointPool + logger *test_util.TestLogger + ) + + BeforeEach(func() { + logger = test_util.NewTestLogger("test") + pool = route.NewPool(&route.PoolOpts{ + Logger: logger.Logger, + RetryAfterFailure: 2 * time.Minute, + Host: "", + ContextPath: "", + MaxConnsPerBackend: 0, + LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + }) + }) + + Describe("Next", func() { + + Context("when pool is empty", func() { + It("does not select an endpoint", func() { + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + Expect(iter.Next(0)).To(BeNil()) + }) + }) + + Context("when pool has endpoints", func() { + var ( + endpoints []*route.Endpoint + ) + BeforeEach(func() { + e1 := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID1"}) + e2 := route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID2"}) + endpoints = []*route.Endpoint{e1, e2} + for _, e := range endpoints { + pool.Put(e) + } + + }) + It("It returns the same endpoint for the same header value", func() { + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "tenant-1" + first := iter.Next(0) + second := iter.Next(0) + Expect(first).NotTo(BeNil()) + Expect(second).NotTo(BeNil()) + Expect(first).To(Equal(second)) + }) + + It("It selects another instance for other hash header value", func() { + iter := route.NewHashBased(logger.Logger, pool, "", false, false, "") + iter.(*route.HashBased).HeaderValue = "example.com" + Expect(iter.Next(0)).NotTo(BeNil()) + Expect(iter.Next(0)).To(Equal(endpoints[1])) + Expect(iter.Next(0)).To(Equal(endpoints[1])) + Expect(iter.Next(0)).To(Equal(endpoints[1])) + }) + }) + + Context("when using sticky sessions", func() { + var ( + endpoints []*route.Endpoint + iter route.EndpointIterator + ) + + BeforeEach(func() { + e1 := route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + e2 := route.NewEndpoint(&route.EndpointOpts{Host: "2.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID2"}) + e3 := route.NewEndpoint(&route.EndpointOpts{Host: "3.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", HashHeaderName: "tenant-id", PrivateInstanceId: "ID3"}) + endpoints = []*route.Endpoint{e1, e2, e3} + for _, e := range endpoints { + pool.Put(e) + } + }) + + Context("when mustBeSticky is true", func() { + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "ID1", true, false, "") + }) + + It("returns the sticky endpoint when it exists", func() { + endpoint := iter.Next(0) + Expect(endpoint).NotTo(BeNil()) + Expect(endpoint.PrivateInstanceId).To(Equal("ID1")) + }) + + It("returns nil when sticky endpoint doesn't exist", func() { + iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", true, false, "") + Expect(iter.Next(0)).To(BeNil()) + }) + }) + + Context("when mustBeSticky is false", func() { + BeforeEach(func() { + iter = route.NewHashBased(logger.Logger, pool, "ID1", false, false, "") + }) + + It("returns the sticky endpoint when it exists", func() { + endpoint := iter.Next(0) + Expect(endpoint).NotTo(BeNil()) + Expect(endpoint.PrivateInstanceId).To(Equal("ID1")) + }) + + It("falls back to hash-based routing when sticky endpoint doesn't exist", func() { + iter = route.NewHashBased(logger.Logger, pool, "nonexistent-id", false, false, "") + hashIter := iter.(*route.HashBased) + hashIter.HeaderValue = "some-value" + endpoint := iter.Next(0) + Expect(endpoint).NotTo(BeNil()) + }) + }) + }) + }) + + Context("when testing PreRequest and PostRequest", func() { + var ( + endpoint *route.Endpoint + iter route.EndpointIterator + ) + + BeforeEach(func() { + endpoint = route.NewEndpoint(&route.EndpointOpts{Host: "1.2.3.4", Port: 5678, LoadBalancingAlgorithm: "hash", PrivateInstanceId: "ID1"}) + pool.Put(endpoint) + iter = route.NewHashBased(logger.Logger, pool, "", false, false, "") + }) + + It("increments connection count on PreRequest", func() { + initialCount := endpoint.Stats.NumberConnections.Count() + iter.PreRequest(endpoint) + Expect(endpoint.Stats.NumberConnections.Count()).To(Equal(initialCount + 1)) + }) + + It("decrements connection count on PostRequest", func() { + iter.PreRequest(endpoint) + initialCount := endpoint.Stats.NumberConnections.Count() + iter.PostRequest(endpoint) + Expect(endpoint.Stats.NumberConnections.Count()).To(Equal(initialCount - 1)) + }) + }) + +}) diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev.go b/src/code.cloudfoundry.org/gorouter/route/maglev.go new file mode 100644 index 000000000..9b70aaa09 --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/maglev.go @@ -0,0 +1,212 @@ +package route + +import ( + "errors" + "fmt" + "hash/fnv" + "log/slog" + "sort" + "strconv" + "strings" + "sync" +) + +const ( + // lookupTableSize is prime number for the size of the maglev lookup table, which should be approximately 100x + // the number of expected endpoints + lookupTableSize uint64 = 1801 +) + +// Maglev implementation of consistent hashing algorithm described in "Maglev: A Fast and Reliable Software Network +// Load Balancer" (https://storage.googleapis.com/gweb-research2023-media/pubtools/2904.pdf) +type Maglev struct { + logger *slog.Logger + permutationTable [][]uint64 + lookupTable []int + endpointList []string + lock *sync.RWMutex +} + +// NewMaglev initializes an empty maglev lookupTable table +func NewMaglev(logger *slog.Logger) *Maglev { + return &Maglev{ + lock: &sync.RWMutex{}, + lookupTable: make([]int, lookupTableSize), + endpointList: make([]string, 0, 2), + permutationTable: make([][]uint64, 0, 2), + logger: logger, + } +} + +// Add a new endpoint to lookupTable if it's not already contained. +func (m *Maglev) Add(endpoint string) { + m.lock.Lock() + defer m.lock.Unlock() + + if lookupTableSize == uint64(len(m.endpointList)) { + m.logger.Warn("maglev-add-lookuptable-capacity-exceeded", slog.String("endpoint-id", endpoint)) + return + } + + index := sort.SearchStrings(m.endpointList, endpoint) + if index < len(m.endpointList) && m.endpointList[index] == endpoint { + m.logger.Debug("maglev-add-lookuptable-endpoint-exists", slog.String("endpoint-id", endpoint), slog.Int("current-endpoints", len(m.endpointList))) + return + } + + m.endpointList = append(m.endpointList, "") + copy(m.endpointList[index+1:], m.endpointList[index:]) + m.endpointList[index] = endpoint + + m.generatePermutation(endpoint) + m.fillLookupTable() +} + +// Remove an endpoint from lookupTable if it's contained. +func (m *Maglev) Remove(endpoint string) { + m.lock.Lock() + defer m.lock.Unlock() + + index := sort.SearchStrings(m.endpointList, endpoint) + if index >= len(m.endpointList) || m.endpointList[index] != endpoint { + m.logger.Debug("maglev-remove-endpoint-not-found", slog.String("endpoint-id", endpoint)) + return + } + + m.endpointList = append(m.endpointList[:index], m.endpointList[index+1:]...) + m.permutationTable = append(m.permutationTable[:index], m.permutationTable[index+1:]...) + + m.fillLookupTable() +} + +// Get endpoint by specified request header value +// Todo: Overload scenario: Get should return an index rather than an instance, +// so that we can iterate to the next endpoint in case it is overloaded (e.g. via another +// helper function that resolves the endpoint via the index) +func (m *Maglev) Get(headerValue string) (string, error) { + m.lock.RLock() + defer m.lock.RUnlock() + + if len(m.endpointList) == 0 { + return "", errors.New("maglev-get-endpoint-no-endpoints") + } + key := m.hashKey(headerValue) + return m.endpointList[m.lookupTable[key%lookupTableSize]], nil +} + +func (m *Maglev) hashKey(headerValue string) uint64 { + return m.calculateFNVHash64(headerValue) +} + +// generatePermutation creates a permutationTable of the lookup table for each endpoint +func (m *Maglev) generatePermutation(endpoint string) { + pos := sort.SearchStrings(m.endpointList, endpoint) + if pos == len(m.endpointList) { + m.logger.Debug("maglev-permutation-no-endpoints") + return + } + + endpointHash := m.calculateFNVHash64(endpoint) + offset := endpointHash % lookupTableSize + skip := (endpointHash % (lookupTableSize - 1)) + 1 + + permutationForEndpoint := make([]uint64, lookupTableSize) + for j := uint64(0); j < lookupTableSize; j++ { + permutationForEndpoint[j] = (offset + j*skip) % lookupTableSize + } + + // insert permutationForEndpoint at position pos, shifting the rest to the right + m.permutationTable = append(m.permutationTable, nil) + copy(m.permutationTable[pos+1:], m.permutationTable[pos:]) + m.permutationTable[pos] = permutationForEndpoint + +} + +func (m *Maglev) fillLookupTable() { + if len(m.endpointList) == 0 { + return + } + + numberOfEndpoints := len(m.endpointList) + next := make([]int, numberOfEndpoints) + entry := make([]int, lookupTableSize) + for j := range entry { + entry[j] = -1 + } + + for n := uint64(0); n <= lookupTableSize; { + for i := 0; i < numberOfEndpoints; i++ { + candidate := m.findNextAvailableSlot(i, next, entry) + entry[candidate] = int(i) + next[i] = next[i] + 1 + n++ + + if n == lookupTableSize { + m.lookupTable = entry + return + } + } + } +} + +func (m *Maglev) findNextAvailableSlot(i int, next []int, entry []int) uint64 { + candidate := m.permutationTable[i][next[i]] + for entry[candidate] >= 0 { + next[i]++ + if next[i] >= len(m.permutationTable[i]) { + // This should not happen in a properly functioning Maglev algorithm, + // but we add this safety check to prevent panic + m.logger.Error("maglev-permutation-table-exhausted", + slog.Int("endpoint-index", i), + slog.Int("next-value", next[i]), + slog.Int("table-size", len(m.permutationTable[i]))) + // Reset to beginning of permutation table as fallback + next[i] = 0 + } + candidate = m.permutationTable[i][next[i]] + } + return candidate +} + +// Getters for unit tests +func (m *Maglev) GetEndpointList() []string { + m.lock.RLock() + defer m.lock.RUnlock() + return append([]string(nil), m.endpointList...) +} + +func (m *Maglev) GetLookupTable() []int { + m.lock.RLock() + defer m.lock.RUnlock() + return append([]int(nil), m.lookupTable...) +} + +func (m *Maglev) GetPermutationTable() [][]uint64 { + m.lock.RLock() + defer m.lock.RUnlock() + copied := make([][]uint64, len(m.permutationTable)) + for i, v := range m.permutationTable { + copied[i] = append([]uint64(nil), v...) + } + return copied +} + +func (m *Maglev) GetLookupTableSize() uint64 { + return lookupTableSize +} + +// TODO: Remove in final version +func (m *Maglev) PrintLookupTable() string { + strArr := make([]string, len(m.lookupTable)) + for i, value := range m.lookupTable { + strArr[i] = strconv.Itoa(value) + } + return fmt.Sprintf("[%s]", strings.Join(strArr, ", ")) +} + +// calculateFNVHash64 computes a hash using the non-cryptographic FNV hash algorithm. +func (m *Maglev) calculateFNVHash64(key string) uint64 { + h := fnv.New64a() + _, _ = h.Write([]byte(key)) + return h.Sum64() +} diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev_test.go b/src/code.cloudfoundry.org/gorouter/route/maglev_test.go new file mode 100644 index 000000000..ae8af9d07 --- /dev/null +++ b/src/code.cloudfoundry.org/gorouter/route/maglev_test.go @@ -0,0 +1,257 @@ +package route_test + +import ( + "fmt" + "strconv" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" + + "code.cloudfoundry.org/gorouter/route" + "code.cloudfoundry.org/gorouter/test_util" +) + +var _ = Describe("Maglev", func() { + var ( + logger *test_util.TestLogger + maglev *route.Maglev + ) + + BeforeEach(func() { + logger = test_util.NewTestLogger("test") + + maglev = route.NewMaglev(logger.Logger) + }) + + Describe("NewMaglev", func() { + It("should create a new Maglev instance", func() { + Expect(maglev).NotTo(BeNil()) + }) + }) + + Describe("Add", func() { + Context("when adding a new backend", func() { + It("should add the backend successfully", func() { + maglev.Add("backend1") + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + + result, err := maglev.Get("test-key") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("backend1")) + }) + }) + + Context("when adding a backend twice", func() { + It("should skip adding subsequent adds", func() { + maglev.Add("backend1") + maglev.Add("backend1") + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + + result, err := maglev.Get("test-key") + Expect(err).NotTo(HaveOccurred()) + Expect(result).To(Equal("backend1")) + }) + }) + + Context("when adding multiple backends", func() { + It("should make all backends reachable", func() { + maglev.Add("backend1") + maglev.Add("backend2") + maglev.Add("backend3") + + Expect(maglev.GetEndpointList()).To(HaveLen(3)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(len(maglev.GetEndpointList()))) + for i := range len(maglev.GetEndpointList()) { + Expect(maglev.GetPermutationTable()[i]).To(HaveLen(int(maglev.GetLookupTableSize()))) + } + + backends := make(map[string]bool) + for i := 0; i < 1000; i++ { + result, err := maglev.Get(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + backends[result] = true + } + + Expect(backends["backend1"]).To(BeTrue()) + Expect(backends["backend2"]).To(BeTrue()) + Expect(backends["backend3"]).To(BeTrue()) + }) + }) + }) + + Describe("Remove", func() { + Context("when removing an existing backend", func() { + It("should remove the backend successfully", func() { + maglev.Add("backend1") + maglev.Add("backend2") + + maglev.Remove("backend1") + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + + }) + }) + + Context("when removing a non-existent backend", func() { + It("should handle gracefully without error", func() { + maglev.Add("backend1") + + Expect(func() { maglev.Remove("non-existent") }).NotTo(Panic()) + + Expect(maglev.GetEndpointList()).To(HaveLen(1)) + Expect(maglev.GetLookupTable()).To(HaveLen(int(maglev.GetLookupTableSize()))) + Expect(maglev.GetPermutationTable()).To(HaveLen(1)) + Expect(maglev.GetPermutationTable()[0]).To(HaveLen(int(maglev.GetLookupTableSize()))) + }) + }) + }) + + Describe("Get", func() { + Context("when no backends were added", func() { + It("should return an error", func() { + _, err := maglev.Get("test-key") + Expect(err).To(HaveOccurred()) + }) + }) + + Context("when backends are added", func() { + BeforeEach(func() { + maglev.Add("backend1") + maglev.Add("backend2") + }) + + It("should return consistent results for the same key", func() { + var counter = make(map[string]int) + var result1 string + var err error + for _ = range 100 { + result1, err = maglev.Get("consistent-key") + Expect(err).NotTo(HaveOccurred()) + counter[result1]++ + } + + Expect(counter[result1]).To(Equal(100)) + }) + + It("should distribute keys across backends", func() { + maglev.Add("backend1") + maglev.Add("backend2") + maglev.Add("backend3") + + distribution := make(map[string]int) + for i := range 1000 { + result, err := maglev.Get(string(rune(i))) + Expect(err).NotTo(HaveOccurred()) + distribution[result]++ + } + + Expect(distribution["backend1"]).To(BeNumerically(">", 0)) + Expect(distribution["backend2"]).To(BeNumerically(">", 0)) + Expect(distribution["backend3"]).To(BeNumerically(">", 0)) + }) + }) + + Context("when backends are removed", func() { + BeforeEach(func() { + maglev.Add("backend1") + maglev.Add("backend2") + maglev.Remove("backend1") + }) + + It("should not return the removed backend", func() { + for _ = range 100 { + endpoint, err := maglev.Get("consistent-key") + Expect(err).NotTo(HaveOccurred()) + Expect(endpoint).To(Equal("backend2")) + } + }) + }) + }) + + Describe("Consistency", func() { + // We test that at most half the keys are reassigned to new backends, when one backend is added. + // This ensures a minimal level of consistency. + It("should minimize disruption when adding backends", func() { + for i := range 10 { + maglev.Add(fmt.Sprintf("backend%d", i+1)) + } + keys := make([]string, 1000) + for i := range keys { + keys[i] = fmt.Sprintf("key%d", i+1) + } + + initialMappings := make(map[string]string) + + for _, key := range keys { + backend, err := maglev.Get(key) + Expect(err).NotTo(HaveOccurred()) + initialMappings[key] = backend + } + + maglev.Add("newbackend") + + changedMappings := 0 + for _, key := range keys { + backend, err := maglev.Get(key) + Expect(err).NotTo(HaveOccurred()) + if initialMappings[key] != backend { + changedMappings++ + } + } + + Expect(changedMappings).To(BeNumerically("<=", len(keys)/2)) + }) + }) + + Describe("Concurrency", func() { + It("should handle concurrent reads safely", func() { + maglev.Add("backend1") + + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + defer GinkgoRecover() + for j := 0; j < 100; j++ { + _, err := maglev.Get("test-key") + Expect(err).NotTo(HaveOccurred()) + } + done <- true + }() + } + + for i := 0; i < 10; i++ { + Eventually(done).Should(Receive()) + } + }) + It("should handle concurrent endpoint registrations safely", func() { + done := make(chan bool, 10) + for i := 0; i < 10; i++ { + go func() { + defer GinkgoRecover() + for j := 0; j < 100; j++ { + Expect(func() { maglev.Add("endpoint" + strconv.Itoa(j)) }).NotTo(Panic()) + } + done <- true + }() + } + + for i := 0; i < 10; i++ { + Eventually(done).Should(Receive()) + } + Expect(len(maglev.GetEndpointList())).To(Equal(100)) + }) + + }) +}) diff --git a/src/code.cloudfoundry.org/gorouter/route/pool.go b/src/code.cloudfoundry.org/gorouter/route/pool.go index f089fc15b..b9f491798 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool.go @@ -74,6 +74,21 @@ type ProxyRoundTripper interface { CancelRequest(*http.Request) } +type HashRoutingProperties struct { + Header string + BalanceFactor float64 +} + +func (hrp *HashRoutingProperties) Equal(hrp2 *HashRoutingProperties) bool { + if hrp == nil && hrp2 == nil { + return true + } + if hrp == nil || hrp2 == nil { + return false + } + return hrp.Header == hrp2.Header && hrp.BalanceFactor == hrp2.BalanceFactor +} + type Endpoint struct { ApplicationId string AvailabilityZone string @@ -186,6 +201,8 @@ type EndpointPool struct { logger *slog.Logger updatedAt time.Time LoadBalancingAlgorithm string + HashRoutingProperties *HashRoutingProperties + HashLookupTable *Maglev } type EndpointOpts struct { @@ -248,10 +265,12 @@ type PoolOpts struct { MaxConnsPerBackend int64 Logger *slog.Logger LoadBalancingAlgorithm string + HashHeader string + HashBalanceFactor float64 } func NewPool(opts *PoolOpts) *EndpointPool { - return &EndpointPool{ + pool := &EndpointPool{ endpoints: make([]*endpointElem, 0, 1), index: make(map[string]*endpointElem), retryAfterFailure: opts.RetryAfterFailure, @@ -264,6 +283,14 @@ func NewPool(opts *PoolOpts) *EndpointPool { updatedAt: time.Now(), LoadBalancingAlgorithm: opts.LoadBalancingAlgorithm, } + if pool.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + pool.HashLookupTable = NewMaglev(opts.Logger) + pool.HashRoutingProperties = &HashRoutingProperties{ + Header: opts.HashHeader, + BalanceFactor: opts.HashBalanceFactor, + } + } + return pool } func PoolsMatch(p1, p2 *EndpointPool) bool { @@ -320,7 +347,6 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { // new one. e.Lock() defer e.Unlock() - oldEndpoint := e.endpoint e.endpoint = endpoint @@ -336,6 +362,9 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { p.RouteSvcUrl = e.endpoint.RouteServiceUrl p.setPoolLoadBalancingAlgorithm(e.endpoint) e.updated = time.Now() + if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + p.HashLookupTable.Add(e.endpoint.PrivateInstanceId) + } p.Update() return EndpointUpdated @@ -348,7 +377,6 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { updated: time.Now(), maxConnsPerBackend: p.maxConnsPerBackend, } - p.endpoints = append(p.endpoints, e) p.index[endpoint.CanonicalAddr()] = e @@ -356,6 +384,9 @@ func (p *EndpointPool) Put(endpoint *Endpoint) PoolPutResult { p.RouteSvcUrl = e.endpoint.RouteServiceUrl p.setPoolLoadBalancingAlgorithm(e.endpoint) + if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + p.HashLookupTable.Add(e.endpoint.PrivateInstanceId) + } p.Update() return EndpointAdded @@ -433,6 +464,11 @@ func (p *EndpointPool) removeEndpoint(e *endpointElem) { delete(p.index, e.endpoint.CanonicalAddr()) delete(p.index, e.endpoint.PrivateInstanceId) p.Update() + + if p.LoadBalancingAlgorithm == config.LOAD_BALANCE_HB { + p.HashLookupTable.Remove(e.endpoint.PrivateInstanceId) + } + } func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeSticky bool, azPreference string, az string) EndpointIterator { @@ -443,6 +479,9 @@ func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeStic case config.LOAD_BALANCE_RR: logger.Debug("endpoint-iterator-with-round-robin-lb-algo") return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + case config.LOAD_BALANCE_HB: + logger.Debug("endpoint-iterator-with-hash-based-lb-algo") + return NewHashBased(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) default: logger.Error("invalid-pool-load-balancing-algorithm", slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm), @@ -452,6 +491,23 @@ func (p *EndpointPool) Endpoints(logger *slog.Logger, initial string, mustBeStic } } +func (p *EndpointPool) FallBackToDefaultLoadBalancing(defaultLBAlgo string, logger *slog.Logger, initial string, mustBeSticky bool, azPreference string, az string) EndpointIterator { + logger.Info("hash-based-routing-header-not-found", + slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm), + slog.String("Host", p.host), + slog.String("Path", p.contextPath)) + + switch defaultLBAlgo { + case config.LOAD_BALANCE_LC: + logger.Debug("endpoint-iterator-with-least-connection-lb-algo") + return NewLeastConnection(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + case config.LOAD_BALANCE_RR: + logger.Debug("endpoint-iterator-with-round-robin-lb-algo") + return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) + } + return NewRoundRobin(logger, p, initial, mustBeSticky, azPreference == config.AZ_PREF_LOCAL, az) +} + func (p *EndpointPool) NumEndpoints() int { p.Lock() defer p.Unlock() @@ -561,12 +617,13 @@ func (p *EndpointPool) MarshalJSON() ([]byte, error) { // setPoolLoadBalancingAlgorithm overwrites the load balancing algorithm of a pool by that of a specified endpoint, if that is valid. func (p *EndpointPool) setPoolLoadBalancingAlgorithm(endpoint *Endpoint) { - if len(endpoint.LoadBalancingAlgorithm) > 0 && endpoint.LoadBalancingAlgorithm != p.LoadBalancingAlgorithm { + if endpoint.LoadBalancingAlgorithm != "" && endpoint.LoadBalancingAlgorithm != p.LoadBalancingAlgorithm { if config.IsLoadBalancingAlgorithmValid(endpoint.LoadBalancingAlgorithm) { p.LoadBalancingAlgorithm = endpoint.LoadBalancingAlgorithm p.logger.Debug("setting-pool-load-balancing-algorithm-to-that-of-an-endpoint", slog.String("endpointLBAlgorithm", endpoint.LoadBalancingAlgorithm), slog.String("poolLBAlgorithm", p.LoadBalancingAlgorithm)) + p.prepareHashBasedRouting(endpoint) } else { p.logger.Error("invalid-endpoint-load-balancing-algorithm-provided-keeping-pool-lb-algo", slog.String("endpointLBAlgorithm", endpoint.LoadBalancingAlgorithm), @@ -575,6 +632,20 @@ func (p *EndpointPool) setPoolLoadBalancingAlgorithm(endpoint *Endpoint) { } } +func (p *EndpointPool) prepareHashBasedRouting(endpoint *Endpoint) { + if p.LoadBalancingAlgorithm != config.LOAD_BALANCE_HB { + return + } + if p.HashLookupTable == nil { + p.HashLookupTable = NewMaglev(p.logger) + } + p.HashRoutingProperties = &HashRoutingProperties{ + Header: endpoint.HashHeaderName, + BalanceFactor: endpoint.HashBalanceFactor, + } + +} + func (e *endpointElem) failed() { t := time.Now() e.failedAt = &t diff --git a/src/code.cloudfoundry.org/gorouter/route/pool_test.go b/src/code.cloudfoundry.org/gorouter/route/pool_test.go index 31da6c8d7..7709a1d8b 100644 --- a/src/code.cloudfoundry.org/gorouter/route/pool_test.go +++ b/src/code.cloudfoundry.org/gorouter/route/pool_test.go @@ -428,6 +428,46 @@ var _ = Describe("EndpointPool", func() { Expect(pool.LoadBalancingAlgorithm).To(Equal(config.LOAD_BALANCE_RR)) }) }) + + Context("When switching to hash-based routing", func() { + It("will create the maglev table and add the endpoint", func() { + pool := route.NewPool(&route.PoolOpts{ + Logger: logger.Logger, + LoadBalancingAlgorithm: config.LOAD_BALANCE_RR, + }) + + endpointOpts := route.EndpointOpts{ + Host: "host-1", + Port: 1234, + RouteServiceUrl: "url", + LoadBalancingAlgorithm: config.LOAD_BALANCE_RR, + } + + initalEndpoint := route.NewEndpoint(&endpointOpts) + + pool.Put(initalEndpoint) + Expect(pool.LoadBalancingAlgorithm).To(Equal(config.LOAD_BALANCE_RR)) + + endpointOptsHash := route.EndpointOpts{ + Host: "host-1", + Port: 1234, + RouteServiceUrl: "url", + LoadBalancingAlgorithm: config.LOAD_BALANCE_HB, + HashBalanceFactor: 1.25, + HashHeaderName: "X-Tenant", + } + + hashEndpoint := route.NewEndpoint(&endpointOptsHash) + + pool.Put(hashEndpoint) + Expect(pool.LoadBalancingAlgorithm).To(Equal(config.LOAD_BALANCE_HB)) + Expect(pool.HashLookupTable).ToNot(BeNil()) + Expect(pool.HashLookupTable.GetEndpointList()).To(HaveLen(1)) + Expect(pool.HashLookupTable.GetEndpointList()[0]).To(Equal(hashEndpoint.PrivateInstanceId)) + }) + + }) + }) Context("RouteServiceUrl", func() { From 010c4e0b762e11b347ae2149f6495527457f669f Mon Sep 17 00:00:00 2001 From: Clemens Hoffmann Date: Wed, 29 Oct 2025 14:32:37 +0100 Subject: [PATCH 2/2] Add LICENSE information for maglev.go --- src/code.cloudfoundry.org/gorouter/route/maglev.go | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/src/code.cloudfoundry.org/gorouter/route/maglev.go b/src/code.cloudfoundry.org/gorouter/route/maglev.go index 9b70aaa09..6085c7d0f 100644 --- a/src/code.cloudfoundry.org/gorouter/route/maglev.go +++ b/src/code.cloudfoundry.org/gorouter/route/maglev.go @@ -1,5 +1,15 @@ package route +/****************************************************************************** + * Original github.com/kkdai/maglev/maglev.go + * + * Copyright (c) 2019 Evan Lin (github.com/kkdai) + * + * This program and the accompanying materials are made available under + * the terms of the Apache License, Version 2.0 which is available at + * http://www.apache.org/licenses/LICENSE-2.0. + ******************************************************************************/ + import ( "errors" "fmt"