@@ -15,9 +15,10 @@ import (
1515type HashBased struct {
1616 lock * sync.Mutex
1717
18- logger * slog.Logger
19- pool * EndpointPool
20- lastEndpoint * Endpoint
18+ logger * slog.Logger
19+ pool * EndpointPool
20+ lastEndpoint * Endpoint
21+ lastLookupTableIndex uint64
2122
2223 stickyEndpointID string
2324 mustBeSticky bool
@@ -47,45 +48,107 @@ func (h *HashBased) Next(attempt int) *Endpoint {
4748 h .lock .Lock ()
4849 defer h .lock .Unlock ()
4950
50- e := h .findEndpointIfStickySession ()
51- if e == nil && h .mustBeSticky {
51+ endpoint := h .findEndpointIfStickySession ()
52+ if endpoint == nil && h .mustBeSticky {
5253 return nil
5354 }
5455
55- if e != nil {
56- h .lastEndpoint = e
57- return e
56+ if endpoint != nil {
57+ h .lastEndpoint = endpoint
58+ return endpoint
5859 }
5960
6061 if h .pool .HashLookupTable == nil {
6162 h .logger .Error ("hash-based-routing-failed" , slog .String ("host" , h .pool .host ), log .ErrAttr (errors .New ("Lookup table is empty" )))
6263 return nil
6364 }
6465
65- id , err := h .pool .HashLookupTable .Get (h .HeaderValue )
66+ if attempt == 0 || h .lastLookupTableIndex == 0 {
67+ initialLookupTableIndex , _ , err := h .pool .HashLookupTable .GetInstanceForHashHeader (h .HeaderValue )
6668
67- if err != nil {
68- h .logger .Error (
69- "hash-based-routing-failed" ,
70- slog .String ("host" , h .pool .host ),
71- log .ErrAttr (err ),
72- )
73- return nil
69+ if err != nil {
70+ h .logger .Error (
71+ "hash-based-routing-failed" ,
72+ slog .String ("host" , h .pool .host ),
73+ log .ErrAttr (err ),
74+ )
75+ return nil
76+ }
77+
78+ endpoint = h .findEndpoint (initialLookupTableIndex , attempt )
79+ } else {
80+ // On retries, start looking from the next index in the lookup table
81+ nextIndex := (h .lastLookupTableIndex + 1 ) % h .pool .HashLookupTable .GetLookupTableSize ()
82+ endpoint = h .findEndpoint (nextIndex , attempt )
7483 }
7584
76- h . logger . Debug (
77- "hash-based-routing" ,
78- slog . String ( "hash header value" , h . HeaderValue ),
79- slog . String ( "endpoint-id" , id ),
80- )
85+ if endpoint != nil {
86+ h . lastEndpoint = endpoint
87+ }
88+ return endpoint
89+ }
8190
82- endpointElem := h . pool . findById ( id )
83- if endpointElem == nil {
84- h . logger . Error ( "hash-based-routing-failed" , slog . String ( "host" , h . pool . host ), log . ErrAttr ( errors . New ( "Endpoint not found in pool" )), slog . String ( "endpoint-id" , id ))
91+ func ( h * HashBased ) findEndpoint ( index uint64 , attempt int ) * Endpoint {
92+ maxIterations := len ( h . pool . endpoints )
93+ if maxIterations == 0 {
8594 return nil
8695 }
8796
88- return endpointElem .endpoint
97+ // Ensure we don't exceed the lookup table size
98+ lookupTableSize := h .pool .HashLookupTable .GetLookupTableSize ()
99+
100+ // Normalize index
101+ currentIndex := index % lookupTableSize
102+ // Keep track of endpoints already visited, to avoid visiting them twice
103+ visitedEndpoints := make (map [string ]bool )
104+
105+ numberOfEndpoints := len (h .pool .HashLookupTable .GetEndpointList ())
106+
107+ lastEndpointPrivateId := ""
108+ if attempt > 0 && h .lastEndpoint != nil {
109+ lastEndpointPrivateId = h .lastEndpoint .PrivateInstanceId
110+ }
111+
112+ // abort when we have visited all available endpoints unsuccessfully
113+ for len (visitedEndpoints ) < numberOfEndpoints {
114+ id := h .pool .HashLookupTable .GetEndpointId (currentIndex )
115+
116+ if visitedEndpoints [id ] || id == lastEndpointPrivateId {
117+ currentIndex = (currentIndex + 1 ) % lookupTableSize
118+ continue
119+ }
120+ visitedEndpoints [id ] = true
121+
122+ endpointElem := h .pool .findById (id )
123+ if endpointElem == nil {
124+ h .logger .Error ("hash-based-routing-failed" , slog .String ("host" , h .pool .host ), log .ErrAttr (errors .New ("Endpoint not found in pool" )), slog .String ("endpoint-id" , id ))
125+ currentIndex = (currentIndex + 1 ) % lookupTableSize
126+ continue
127+ }
128+
129+ lastEndpointPrivateId = id
130+
131+ e := endpointElem .endpoint
132+ if h .pool .HashRoutingProperties .BalanceFactor <= 0 || ! h .isOverloaded (e ) {
133+ h .lastLookupTableIndex = currentIndex
134+ return e
135+ }
136+
137+ currentIndex = (currentIndex + 1 ) % lookupTableSize
138+ }
139+ // All endpoints checked and overloaded or not found
140+ h .logger .Error ("hash-based-routing-failed" , slog .String ("host" , h .pool .host ), log .ErrAttr (errors .New ("All endpoints are overloaded" )))
141+ return nil
142+ }
143+
144+ func (h * HashBased ) isOverloaded (e * Endpoint ) bool {
145+ avgLoad := h .CalculateAverageLoad ()
146+ balanceFactor := h .pool .HashRoutingProperties .BalanceFactor
147+ if float64 (e .Stats .NumberConnections .Count ())/ avgLoad > balanceFactor {
148+ h .logger .Info ("hash-based-routing-endpoint-overloaded" , slog .String ("host" , h .pool .host ), slog .String ("endpoint-id" , e .PrivateInstanceId ), slog .Int64 ("endpoint-connections" , e .Stats .NumberConnections .Count ()), slog .Float64 ("average-load" , avgLoad ))
149+ return true
150+ }
151+ return false
89152}
90153
91154// findEndpointIfStickySession checks if there is a sticky session endpoint and returns it if available.
@@ -139,3 +202,18 @@ func (h *HashBased) PreRequest(e *Endpoint) {
139202func (h * HashBased ) PostRequest (e * Endpoint ) {
140203 e .Stats .NumberConnections .Decrement ()
141204}
205+
206+ func (h * HashBased ) CalculateAverageLoad () float64 {
207+ if len (h .pool .endpoints ) == 0 {
208+ return 0
209+ }
210+
211+ var currentInFlightRequestCount int64
212+ for _ , endpointElem := range h .pool .endpoints {
213+ endpointElem .RLock ()
214+ currentInFlightRequestCount += endpointElem .endpoint .Stats .NumberConnections .Count ()
215+ endpointElem .RUnlock ()
216+ }
217+
218+ return float64 (currentInFlightRequestCount ) / float64 (len (h .pool .endpoints ))
219+ }
0 commit comments